123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- use std::{
- collections::BTreeMap,
- env, mem,
- path::Path,
- process::{Command, Stdio},
- };
- use anyhow::{anyhow, Result};
- use regex::Regex;
- use serde::Deserialize;
- use crate::types::{Idl, IdlEvent, IdlTypeDef};
- /// A trait that types must implement in order to include the type in the IDL definition.
- ///
- /// This trait is automatically implemented for Anchor all types that use the `AnchorSerialize`
- /// proc macro. Note that manually implementing the `AnchorSerialize` trait does **NOT** have the
- /// same effect.
- ///
- /// Types that don't implement this trait will cause a compile error during the IDL generation.
- ///
- /// The default implementation of the trait allows the program to compile but the type does **NOT**
- /// get included in the IDL.
- pub trait IdlBuild {
- /// Create an IDL type definition for the type.
- ///
- /// The type is only included in the IDL if this method returns `Some`.
- fn create_type() -> Option<IdlTypeDef> {
- None
- }
- /// Insert all types that are included in the current type definition to the given map.
- fn insert_types(_types: &mut BTreeMap<String, IdlTypeDef>) {}
- /// Get the full module path of the type.
- ///
- /// The full path will be used in the case of a conflicting type definition, e.g. when there
- /// are multiple structs with the same name.
- ///
- /// The default implementation covers most cases.
- fn get_full_path() -> String {
- std::any::type_name::<Self>().into()
- }
- }
- /// Generate IDL via compilation.
- pub fn build_idl(
- program_path: impl AsRef<Path>,
- resolution: bool,
- skip_lint: bool,
- no_docs: bool,
- ) -> Result<Idl> {
- build_idl_with_cargo_args(program_path, resolution, skip_lint, no_docs, &[])
- }
- /// Generate IDL via compilation with passing cargo arguments.
- pub fn build_idl_with_cargo_args(
- program_path: impl AsRef<Path>,
- resolution: bool,
- skip_lint: bool,
- no_docs: bool,
- cargo_args: &[String],
- ) -> Result<Idl> {
- let idl = build(
- program_path.as_ref(),
- resolution,
- skip_lint,
- no_docs,
- cargo_args,
- )?;
- let idl = convert_module_paths(idl);
- let idl = sort(idl);
- verify(&idl)?;
- Ok(idl)
- }
- /// Build IDL.
- fn build(
- program_path: &Path,
- resolution: bool,
- skip_lint: bool,
- no_docs: bool,
- cargo_args: &[String],
- ) -> Result<Idl> {
- // `nightly` toolchain is currently required for building the IDL.
- let toolchain = std::env::var("RUSTUP_TOOLCHAIN")
- .map(|toolchain| format!("+{}", toolchain))
- .unwrap_or_else(|_| "+nightly".to_string());
- install_toolchain_if_needed(&toolchain)?;
- let output = Command::new("cargo")
- .args([
- &toolchain,
- "test",
- "__anchor_private_print_idl",
- "--features",
- "idl-build",
- ])
- .args(cargo_args)
- .args(["--", "--show-output", "--quiet"])
- .env(
- "ANCHOR_IDL_BUILD_NO_DOCS",
- if no_docs { "TRUE" } else { "FALSE" },
- )
- .env(
- "ANCHOR_IDL_BUILD_RESOLUTION",
- if resolution { "TRUE" } else { "FALSE" },
- )
- .env(
- "ANCHOR_IDL_BUILD_SKIP_LINT",
- if skip_lint { "TRUE" } else { "FALSE" },
- )
- .env("ANCHOR_IDL_BUILD_PROGRAM_PATH", program_path)
- .env("RUSTFLAGS", "--cfg procmacro2_semver_exempt")
- .current_dir(program_path)
- .stderr(Stdio::inherit())
- .output()?;
- if !output.status.success() {
- return Err(anyhow!("Building IDL failed"));
- }
- enum State {
- Pass,
- Address,
- Constants(Vec<String>),
- Events(Vec<String>),
- Errors(Vec<String>),
- Program(Vec<String>),
- }
- let mut address = String::new();
- let mut events = vec![];
- let mut error_codes = vec![];
- let mut constants = vec![];
- let mut types = BTreeMap::new();
- let mut idl: Option<Idl> = None;
- let output = String::from_utf8_lossy(&output.stdout);
- if env::var("ANCHOR_LOG").is_ok() {
- println!("{}", output);
- }
- let mut state = State::Pass;
- for line in output.lines() {
- match &mut state {
- State::Pass => match line {
- "--- IDL begin address ---" => state = State::Address,
- "--- IDL begin const ---" => state = State::Constants(vec![]),
- "--- IDL begin event ---" => state = State::Events(vec![]),
- "--- IDL begin errors ---" => state = State::Errors(vec![]),
- "--- IDL begin program ---" => state = State::Program(vec![]),
- _ => {
- if line.starts_with("test result: ok")
- && !line.starts_with("test result: ok. 0 passed; 0 failed; 0")
- {
- if let Some(idl) = idl.as_mut() {
- idl.address = mem::take(&mut address);
- idl.constants = mem::take(&mut constants);
- idl.events = mem::take(&mut events);
- idl.errors = mem::take(&mut error_codes);
- idl.types = {
- let prog_ty = mem::take(&mut idl.types);
- let mut types = mem::take(&mut types);
- types.extend(prog_ty.into_iter().map(|ty| (ty.name.clone(), ty)));
- types.into_values().collect()
- };
- }
- }
- }
- },
- State::Address => {
- address = line.replace(|c: char| !c.is_alphanumeric(), "");
- state = State::Pass;
- continue;
- }
- State::Constants(lines) => {
- if line == "--- IDL end const ---" {
- let constant = serde_json::from_str(&lines.join("\n"))?;
- constants.push(constant);
- state = State::Pass;
- continue;
- }
- lines.push(line.to_owned());
- }
- State::Events(lines) => {
- if line == "--- IDL end event ---" {
- #[derive(Deserialize)]
- struct IdlBuildEventPrint {
- event: IdlEvent,
- types: Vec<IdlTypeDef>,
- }
- let event = serde_json::from_str::<IdlBuildEventPrint>(&lines.join("\n"))?;
- events.push(event.event);
- types.extend(event.types.into_iter().map(|ty| (ty.name.clone(), ty)));
- state = State::Pass;
- continue;
- }
- lines.push(line.to_owned());
- }
- State::Errors(lines) => {
- if line == "--- IDL end errors ---" {
- error_codes = serde_json::from_str(&lines.join("\n"))?;
- state = State::Pass;
- continue;
- }
- lines.push(line.to_owned());
- }
- State::Program(lines) => {
- if line == "--- IDL end program ---" {
- idl = Some(serde_json::from_str(&lines.join("\n"))?);
- state = State::Pass;
- continue;
- }
- lines.push(line.to_owned());
- }
- }
- }
- idl.ok_or_else(|| anyhow!("IDL doesn't exist"))
- }
- /// Install the given toolchain if it's not already installed.
- fn install_toolchain_if_needed(toolchain: &str) -> Result<()> {
- let is_installed = Command::new("cargo")
- .arg(toolchain)
- .output()?
- .status
- .success();
- if !is_installed {
- Command::new("rustup")
- .args(["toolchain", "install", toolchain.trim_start_matches('+')])
- .spawn()?
- .wait()?;
- }
- Ok(())
- }
- /// Convert paths to name if there are no conflicts.
- fn convert_module_paths(idl: Idl) -> Idl {
- let idl = serde_json::to_string(&idl).unwrap();
- let idl = Regex::new(r#""((\w+::)+)(\w+)""#)
- .unwrap()
- .captures_iter(&idl.clone())
- .fold(idl, |acc, cur| {
- let path = cur.get(0).unwrap().as_str();
- let name = cur.get(3).unwrap().as_str();
- // Replace path with name
- let replaced_idl = acc.replace(path, &format!(r#""{name}""#));
- // Check whether there is a conflict
- let has_conflict = replaced_idl.contains(&format!(r#"::{name}""#));
- if has_conflict {
- acc
- } else {
- replaced_idl
- }
- });
- serde_json::from_str(&idl).expect("Invalid IDL")
- }
- /// Alphabetically sort fields for consistency.
- fn sort(mut idl: Idl) -> Idl {
- idl.accounts.sort_by(|a, b| a.name.cmp(&b.name));
- idl.constants.sort_by(|a, b| a.name.cmp(&b.name));
- idl.events.sort_by(|a, b| a.name.cmp(&b.name));
- idl.instructions.sort_by(|a, b| a.name.cmp(&b.name));
- idl.types.sort_by(|a, b| a.name.cmp(&b.name));
- idl
- }
- /// Verify IDL is valid.
- fn verify(idl: &Idl) -> Result<()> {
- // Check full path accounts
- if let Some(account) = idl
- .accounts
- .iter()
- .find(|account| account.name.contains("::"))
- {
- return Err(anyhow!(
- "Conflicting accounts names are not allowed.\nProgram: `{}`\nAccount: `{}`",
- idl.metadata.name,
- account.name
- ));
- }
- // Check potential discriminator collisions
- macro_rules! check_discriminator_collision {
- ($field:ident) => {
- if let Some((outer, inner)) = idl.$field.iter().find_map(|outer| {
- idl.$field
- .iter()
- .filter(|inner| inner.name != outer.name)
- .find(|inner| outer.discriminator.starts_with(&inner.discriminator))
- .map(|inner| (outer, inner))
- }) {
- return Err(anyhow!(
- "Ambiguous discriminators for {} `{}` and `{}`",
- stringify!($field),
- outer.name,
- inner.name
- ));
- }
- };
- }
- check_discriminator_collision!(accounts);
- check_discriminator_collision!(events);
- check_discriminator_collision!(instructions);
- Ok(())
- }
|