accounts.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. use anyhow::{anyhow, Result};
  2. use proc_macro2::TokenStream;
  3. use quote::{quote, ToTokens};
  4. use super::common::{get_idl_module_path, get_no_docs};
  5. use crate::{AccountField, AccountsStruct, ConstraintSeedsGroup, Field, InitKind, Ty};
  6. /// Generate the IDL build impl for the Accounts struct.
  7. pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
  8. let resolution = option_env!("ANCHOR_IDL_BUILD_RESOLUTION")
  9. .map(|val| val == "TRUE")
  10. .unwrap_or_default();
  11. let no_docs = get_no_docs();
  12. let idl = get_idl_module_path();
  13. let ident = &accounts.ident;
  14. let (impl_generics, ty_generics, where_clause) = accounts.generics.split_for_impl();
  15. let (accounts, defined) = accounts
  16. .fields
  17. .iter()
  18. .map(|acc| match acc {
  19. AccountField::Field(acc) => {
  20. let name = acc.ident.to_string();
  21. let writable = acc.constraints.is_mutable();
  22. let signer = match acc.ty {
  23. Ty::Signer => true,
  24. _ => acc.constraints.is_signer(),
  25. };
  26. let optional = acc.is_optional;
  27. let docs = match &acc.docs {
  28. Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
  29. _ => quote! { vec![] },
  30. };
  31. let (address, pda, relations) = if resolution {
  32. (
  33. get_address(acc),
  34. get_pda(acc, accounts),
  35. get_relations(acc, accounts),
  36. )
  37. } else {
  38. (quote! { None }, quote! { None }, quote! { vec![] })
  39. };
  40. let acc_type_path = match &acc.ty {
  41. Ty::Account(ty)
  42. // Skip `UpgradeableLoaderState` type for now until `bincode` serialization
  43. // is supported.
  44. //
  45. // TODO: Remove this once either `bincode` serialization is supported or
  46. // we wrap the type in order to implement `IdlBuild` in `anchor-lang`.
  47. if !ty
  48. .account_type_path
  49. .path
  50. .to_token_stream()
  51. .to_string()
  52. .contains("UpgradeableLoaderState") =>
  53. {
  54. Some(&ty.account_type_path)
  55. }
  56. Ty::LazyAccount(ty) => Some(&ty.account_type_path),
  57. Ty::AccountLoader(ty) => Some(&ty.account_type_path),
  58. Ty::InterfaceAccount(ty) => Some(&ty.account_type_path),
  59. _ => None,
  60. };
  61. (
  62. quote! {
  63. #idl::IdlInstructionAccountItem::Single(#idl::IdlInstructionAccount {
  64. name: #name.into(),
  65. docs: #docs,
  66. writable: #writable,
  67. signer: #signer,
  68. optional: #optional,
  69. address: #address,
  70. pda: #pda,
  71. relations: #relations,
  72. })
  73. },
  74. acc_type_path,
  75. )
  76. }
  77. AccountField::CompositeField(comp_f) => {
  78. let ty = if let syn::Type::Path(path) = &comp_f.raw_field.ty {
  79. // some::path::Foo<'info> -> some::path::Foo
  80. let mut res = syn::Path {
  81. leading_colon: path.path.leading_colon,
  82. segments: syn::punctuated::Punctuated::new(),
  83. };
  84. for segment in &path.path.segments {
  85. let s = syn::PathSegment {
  86. ident: segment.ident.clone(),
  87. arguments: syn::PathArguments::None,
  88. };
  89. res.segments.push(s);
  90. }
  91. res
  92. } else {
  93. panic!(
  94. "Compose field type must be a path but received: {:?}",
  95. comp_f.raw_field.ty
  96. )
  97. };
  98. let name = comp_f.ident.to_string();
  99. (
  100. quote! {
  101. #idl::IdlInstructionAccountItem::Composite(#idl::IdlInstructionAccounts {
  102. name: #name.into(),
  103. accounts: <#ty>::__anchor_private_gen_idl_accounts(accounts, types),
  104. })
  105. },
  106. None,
  107. )
  108. }
  109. })
  110. .unzip::<_, _, Vec<_>, Vec<_>>();
  111. let defined = defined.into_iter().flatten().collect::<Vec<_>>();
  112. quote! {
  113. impl #impl_generics #ident #ty_generics #where_clause {
  114. pub fn __anchor_private_gen_idl_accounts(
  115. accounts: &mut std::collections::BTreeMap<String, #idl::IdlAccount>,
  116. types: &mut std::collections::BTreeMap<String, #idl::IdlTypeDef>,
  117. ) -> Vec<#idl::IdlInstructionAccountItem> {
  118. #(
  119. if let Some(ty) = <#defined>::create_type() {
  120. let account = #idl::IdlAccount {
  121. name: ty.name.clone(),
  122. discriminator: #defined::DISCRIMINATOR.into(),
  123. };
  124. accounts.insert(account.name.clone(), account);
  125. types.insert(ty.name.clone(), ty);
  126. <#defined>::insert_types(types);
  127. }
  128. );*
  129. vec![#(#accounts),*]
  130. }
  131. }
  132. }
  133. }
  134. fn get_address(acc: &Field) -> TokenStream {
  135. match &acc.ty {
  136. Ty::Program(_) | Ty::Sysvar(_) => {
  137. let ty = acc.account_ty();
  138. let id_trait = matches!(acc.ty, Ty::Program(_))
  139. .then(|| quote!(anchor_lang::Id))
  140. .unwrap_or_else(|| quote!(anchor_lang::solana_program::sysvar::SysvarId));
  141. quote! { Some(<#ty as #id_trait>::id().to_string()) }
  142. }
  143. _ => acc
  144. .constraints
  145. .address
  146. .as_ref()
  147. .map(|constraint| &constraint.address)
  148. .filter(|address| {
  149. match address {
  150. // Allow constants (assume the identifier follows the Rust naming convention)
  151. // e.g. `crate::ID`
  152. syn::Expr::Path(expr) => expr
  153. .path
  154. .segments
  155. .last()
  156. .unwrap()
  157. .ident
  158. .to_string()
  159. .chars()
  160. .all(|c| c.is_uppercase() || c == '_'),
  161. // Allow `const fn`s (assume any stand-alone function call without an argument)
  162. // e.g. `crate::id()`
  163. syn::Expr::Call(expr) => expr.args.is_empty(),
  164. _ => false,
  165. }
  166. })
  167. .map(|address| quote! { Some(#address.to_string()) })
  168. .unwrap_or_else(|| quote! { None }),
  169. }
  170. }
  171. fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
  172. let idl = get_idl_module_path();
  173. let parse_default = |expr: &syn::Expr| parse_seed(expr, accounts);
  174. // Seeds
  175. let seed_constraints = acc.constraints.seeds.as_ref();
  176. let pda = seed_constraints
  177. .map(|seed| seed.seeds.iter().map(parse_default))
  178. .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok())
  179. .and_then(|seeds| {
  180. let program = match seed_constraints {
  181. Some(ConstraintSeedsGroup {
  182. program_seed: Some(program),
  183. ..
  184. }) => parse_default(program)
  185. .map(|program| quote! { Some(#program) })
  186. .ok()?,
  187. _ => quote! { None },
  188. };
  189. Some(quote! {
  190. Some(
  191. #idl::IdlPda {
  192. seeds: vec![#(#seeds),*],
  193. program: #program,
  194. }
  195. )
  196. })
  197. });
  198. if let Some(pda) = pda {
  199. return pda;
  200. }
  201. // Associated token
  202. let pda = acc
  203. .constraints
  204. .init
  205. .as_ref()
  206. .and_then(|init| match &init.kind {
  207. InitKind::AssociatedToken {
  208. owner,
  209. mint,
  210. token_program,
  211. } => Some((owner, mint, token_program)),
  212. _ => None,
  213. })
  214. .or_else(|| {
  215. acc.constraints
  216. .associated_token
  217. .as_ref()
  218. .map(|ata| (&ata.wallet, &ata.mint, &ata.token_program))
  219. })
  220. .and_then(|(wallet, mint, token_program)| {
  221. // ATA constraints have implicit `.key()` call
  222. let parse_expr = |ts| parse_default(&syn::parse2(ts).unwrap()).ok();
  223. let parse_ata = |expr| parse_expr(quote! { #expr.key().as_ref() });
  224. let wallet = parse_ata(wallet);
  225. let mint = parse_ata(mint);
  226. let token_program = token_program
  227. .as_ref()
  228. .and_then(parse_ata)
  229. .or_else(|| parse_expr(quote!(anchor_spl::token::ID)));
  230. let seeds = match (wallet, mint, token_program) {
  231. (Some(w), Some(m), Some(tp)) => quote! { vec![#w, #tp, #m] },
  232. _ => return None,
  233. };
  234. let program = parse_expr(quote!(anchor_spl::associated_token::ID))
  235. .map(|program| quote! { Some(#program) })
  236. .unwrap();
  237. Some(quote! {
  238. Some(
  239. #idl::IdlPda {
  240. seeds: #seeds,
  241. program: #program,
  242. }
  243. )
  244. })
  245. });
  246. if let Some(pda) = pda {
  247. return pda;
  248. }
  249. quote! { None }
  250. }
  251. /// Parse a seeds constraint, extracting the `IdlSeed` types.
  252. ///
  253. /// Note: This implementation makes assumptions about the types that can be used (e.g., no
  254. /// program-defined function calls in seeds).
  255. ///
  256. /// This probably doesn't cover all cases. If you see a warning log, you can add a new case here.
  257. /// In the worst case, we miss a seed and the parser will treat the given seeds as empty and so
  258. /// clients will simply fail to automatically populate the PDA accounts.
  259. ///
  260. /// # Seed assumptions
  261. ///
  262. /// Seeds must be of one of the following forms:
  263. ///
  264. /// - Constant
  265. /// - Instruction argument
  266. /// - Account key or field
  267. fn parse_seed(seed: &syn::Expr, accounts: &AccountsStruct) -> Result<TokenStream> {
  268. let idl = get_idl_module_path();
  269. let args = accounts.instruction_args().unwrap_or_default();
  270. match seed {
  271. syn::Expr::MethodCall(_) => {
  272. let seed_path = SeedPath::new(seed)?;
  273. if args.contains_key(&seed_path.name) {
  274. let path = seed_path.path();
  275. Ok(quote! {
  276. #idl::IdlSeed::Arg(
  277. #idl::IdlSeedArg {
  278. path: #path.into(),
  279. }
  280. )
  281. })
  282. } else if let Some(account_field) = accounts
  283. .fields
  284. .iter()
  285. .find(|field| *field.ident() == seed_path.name)
  286. {
  287. let path = seed_path.path();
  288. let account = match account_field.ty_name() {
  289. Some(name) if !seed_path.subfields.is_empty() => {
  290. quote! { Some(#name.into()) }
  291. }
  292. _ => quote! { None },
  293. };
  294. Ok(quote! {
  295. #idl::IdlSeed::Account(
  296. #idl::IdlSeedAccount {
  297. path: #path.into(),
  298. account: #account,
  299. }
  300. )
  301. })
  302. } else if seed_path.name.contains('"') {
  303. let seed = seed_path.name.trim_start_matches("b\"").trim_matches('"');
  304. Ok(quote! {
  305. #idl::IdlSeed::Const(
  306. #idl::IdlSeedConst {
  307. value: #seed.into(),
  308. }
  309. )
  310. })
  311. } else {
  312. Ok(quote! {
  313. #idl::IdlSeed::Const(
  314. #idl::IdlSeedConst {
  315. value: #seed.into(),
  316. }
  317. )
  318. })
  319. }
  320. }
  321. // Support call expressions that don't have any arguments e.g. `System::id()`
  322. syn::Expr::Call(call) if call.args.is_empty() => Ok(quote! {
  323. #idl::IdlSeed::Const(
  324. #idl::IdlSeedConst {
  325. value: AsRef::<[u8]>::as_ref(&#seed).into(),
  326. }
  327. )
  328. }),
  329. syn::Expr::Path(path) => {
  330. let seed = path
  331. .path
  332. .get_ident()
  333. .map(|ident| ident.to_string())
  334. .filter(|ident| args.contains_key(ident))
  335. .map(|path| {
  336. quote! {
  337. #idl::IdlSeed::Arg(
  338. #idl::IdlSeedArg {
  339. path: #path.into(),
  340. }
  341. )
  342. }
  343. })
  344. .unwrap_or_else(|| {
  345. quote! {
  346. #idl::IdlSeed::Const(
  347. #idl::IdlSeedConst {
  348. value: AsRef::<[u8]>::as_ref(&#seed).into(),
  349. }
  350. )
  351. }
  352. });
  353. Ok(seed)
  354. }
  355. syn::Expr::Lit(_) => Ok(quote! {
  356. #idl::IdlSeed::Const(
  357. #idl::IdlSeedConst {
  358. value: #seed.into(),
  359. }
  360. )
  361. }),
  362. syn::Expr::Reference(rf) => parse_seed(&rf.expr, accounts),
  363. _ => Err(anyhow!("Unexpected seed: {seed:?}")),
  364. }
  365. }
  366. /// SeedPath represents the deconstructed syntax of a single pda seed,
  367. /// consisting of a variable name and a vec of all the sub fields accessed
  368. /// on that variable name. For example, if a seed is `my_field.my_data.as_ref()`,
  369. /// then the field name is `my_field` and the vec of sub fields is `[my_data]`.
  370. struct SeedPath {
  371. /// Seed name
  372. name: String,
  373. /// All path components for the subfields accessed on this seed
  374. subfields: Vec<String>,
  375. }
  376. impl SeedPath {
  377. /// Extract the seed path from a single seed expression.
  378. fn new(seed: &syn::Expr) -> Result<Self> {
  379. // Convert the seed into the raw string representation.
  380. let seed_str = seed.to_token_stream().to_string();
  381. // Check unsupported cases e.g. `&(account.field + 1).to_le_bytes()`
  382. if !seed_str.contains('"')
  383. && seed_str.contains(|c: char| matches!(c, '+' | '-' | '*' | '/' | '%' | '^'))
  384. {
  385. return Err(anyhow!("Seed expression not supported: {seed:#?}"));
  386. }
  387. // Break up the seed into each subfield component.
  388. let mut components = seed_str.split('.').collect::<Vec<_>>();
  389. if components.len() <= 1 {
  390. return Err(anyhow!("Seed is in unexpected format: {seed:#?}"));
  391. }
  392. // The name of the variable (or field).
  393. let name = components.remove(0).to_owned();
  394. // The path to the seed (only if the `name` type is a struct).
  395. let mut path = Vec::new();
  396. while !components.is_empty() {
  397. let subfield = components.remove(0);
  398. if subfield.contains("()") {
  399. break;
  400. }
  401. path.push(subfield.into());
  402. }
  403. if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
  404. path = Vec::new();
  405. }
  406. Ok(SeedPath {
  407. name,
  408. subfields: path,
  409. })
  410. }
  411. /// Get the full path to the data this seed represents.
  412. fn path(&self) -> String {
  413. match self.subfields.len() {
  414. 0 => self.name.to_owned(),
  415. _ => format!("{}.{}", self.name, self.subfields.join(".")),
  416. }
  417. }
  418. }
  419. fn get_relations(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
  420. let relations = accounts
  421. .fields
  422. .iter()
  423. .filter_map(|af| match af {
  424. AccountField::Field(f) => f
  425. .constraints
  426. .has_one
  427. .iter()
  428. .filter_map(|c| match &c.join_target {
  429. syn::Expr::Path(path) => path
  430. .path
  431. .segments
  432. .first()
  433. .filter(|seg| seg.ident == acc.ident)
  434. .map(|_| Some(f.ident.to_string())),
  435. _ => None,
  436. })
  437. .collect::<Option<Vec<_>>>(),
  438. _ => None,
  439. })
  440. .flatten()
  441. .collect::<Vec<_>>();
  442. quote! { vec![#(#relations.into()),*] }
  443. }