accounts.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. use anyhow::{anyhow, Result};
  2. use proc_macro2::TokenStream;
  3. use quote::{quote, ToTokens};
  4. use super::common::{get_idl_module_path, get_no_docs};
  5. use crate::{AccountField, AccountsStruct, Field, InitKind, Ty};
  6. /// Generate the IDL build impl for the Accounts struct.
  7. pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
  8. let resolution = option_env!("ANCHOR_IDL_BUILD_RESOLUTION")
  9. .map(|val| val == "TRUE")
  10. .unwrap_or_default();
  11. let no_docs = get_no_docs();
  12. let idl = get_idl_module_path();
  13. let ident = &accounts.ident;
  14. let (impl_generics, ty_generics, where_clause) = accounts.generics.split_for_impl();
  15. let (accounts, defined) = accounts
  16. .fields
  17. .iter()
  18. .map(|acc| match acc {
  19. AccountField::Field(acc) => {
  20. let name = acc.ident.to_string();
  21. let writable = acc.constraints.is_mutable();
  22. let signer = match acc.ty {
  23. Ty::Signer => true,
  24. _ => acc.constraints.is_signer(),
  25. };
  26. let optional = acc.is_optional;
  27. let docs = match &acc.docs {
  28. Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
  29. _ => quote! { vec![] },
  30. };
  31. let (address, pda, relations) = if resolution {
  32. (
  33. get_address(acc),
  34. get_pda(acc, accounts),
  35. get_relations(acc, accounts),
  36. )
  37. } else {
  38. (quote! { None }, quote! { None }, quote! { vec![] })
  39. };
  40. let acc_type_path = match &acc.ty {
  41. Ty::Account(ty)
  42. // Skip `UpgradeableLoaderState` type for now until `bincode` serialization
  43. // is supported.
  44. //
  45. // TODO: Remove this once either `bincode` serialization is supported or
  46. // we wrap the type in order to implement `IdlBuild` in `anchor-lang`.
  47. if !ty
  48. .account_type_path
  49. .path
  50. .to_token_stream()
  51. .to_string()
  52. .contains("UpgradeableLoaderState") =>
  53. {
  54. Some(&ty.account_type_path)
  55. }
  56. Ty::AccountLoader(ty) => Some(&ty.account_type_path),
  57. Ty::InterfaceAccount(ty) => Some(&ty.account_type_path),
  58. _ => None,
  59. };
  60. (
  61. quote! {
  62. #idl::IdlInstructionAccountItem::Single(#idl::IdlInstructionAccount {
  63. name: #name.into(),
  64. docs: #docs,
  65. writable: #writable,
  66. signer: #signer,
  67. optional: #optional,
  68. address: #address,
  69. pda: #pda,
  70. relations: #relations,
  71. })
  72. },
  73. acc_type_path,
  74. )
  75. }
  76. AccountField::CompositeField(comp_f) => {
  77. let ty = if let syn::Type::Path(path) = &comp_f.raw_field.ty {
  78. // some::path::Foo<'info> -> some::path::Foo
  79. let mut res = syn::Path {
  80. leading_colon: path.path.leading_colon,
  81. segments: syn::punctuated::Punctuated::new(),
  82. };
  83. for segment in &path.path.segments {
  84. let s = syn::PathSegment {
  85. ident: segment.ident.clone(),
  86. arguments: syn::PathArguments::None,
  87. };
  88. res.segments.push(s);
  89. }
  90. res
  91. } else {
  92. panic!(
  93. "Compose field type must be a path but received: {:?}",
  94. comp_f.raw_field.ty
  95. )
  96. };
  97. let name = comp_f.ident.to_string();
  98. (
  99. quote! {
  100. #idl::IdlInstructionAccountItem::Composite(#idl::IdlInstructionAccounts {
  101. name: #name.into(),
  102. accounts: <#ty>::__anchor_private_gen_idl_accounts(accounts, types),
  103. })
  104. },
  105. None,
  106. )
  107. }
  108. })
  109. .unzip::<_, _, Vec<_>, Vec<_>>();
  110. let defined = defined.into_iter().flatten().collect::<Vec<_>>();
  111. quote! {
  112. impl #impl_generics #ident #ty_generics #where_clause {
  113. pub fn __anchor_private_gen_idl_accounts(
  114. accounts: &mut std::collections::BTreeMap<String, #idl::IdlAccount>,
  115. types: &mut std::collections::BTreeMap<String, #idl::IdlTypeDef>,
  116. ) -> Vec<#idl::IdlInstructionAccountItem> {
  117. #(
  118. if let Some(ty) = <#defined>::create_type() {
  119. let account = #idl::IdlAccount {
  120. name: ty.name.clone(),
  121. discriminator: #defined::DISCRIMINATOR.into(),
  122. };
  123. accounts.insert(account.name.clone(), account);
  124. types.insert(ty.name.clone(), ty);
  125. <#defined>::insert_types(types);
  126. }
  127. );*
  128. vec![#(#accounts),*]
  129. }
  130. }
  131. }
  132. }
  133. fn get_address(acc: &Field) -> TokenStream {
  134. match &acc.ty {
  135. Ty::Program(ty) => ty
  136. .account_type_path
  137. .path
  138. .segments
  139. .last()
  140. .map(|seg| &seg.ident)
  141. .map(|ident| quote! { Some(#ident::id().to_string()) })
  142. .unwrap_or_else(|| quote! { None }),
  143. Ty::Sysvar(_) => {
  144. let ty = acc.account_ty();
  145. let sysvar_id_trait = quote!(anchor_lang::solana_program::sysvar::SysvarId);
  146. quote! { Some(<#ty as #sysvar_id_trait>::id().to_string()) }
  147. }
  148. _ => acc
  149. .constraints
  150. .address
  151. .as_ref()
  152. .map(|constraint| &constraint.address)
  153. .filter(|address| {
  154. match address {
  155. // Allow constants (assume the identifier follows the Rust naming convention)
  156. // e.g. `crate::ID`
  157. syn::Expr::Path(expr) => expr
  158. .path
  159. .segments
  160. .last()
  161. .unwrap()
  162. .ident
  163. .to_string()
  164. .chars()
  165. .all(|c| c.is_uppercase() || c == '_'),
  166. // Allow `const fn`s (assume any stand-alone function call without an argument)
  167. // e.g. `crate::id()`
  168. syn::Expr::Call(expr) => expr.args.is_empty(),
  169. _ => false,
  170. }
  171. })
  172. .map(|address| quote! { Some(#address.to_string()) })
  173. .unwrap_or_else(|| quote! { None }),
  174. }
  175. }
  176. fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
  177. let idl = get_idl_module_path();
  178. let parse_default = |expr: &syn::Expr| parse_seed(expr, accounts);
  179. // Seeds
  180. let seed_constraints = acc.constraints.seeds.as_ref();
  181. let pda = seed_constraints
  182. .map(|seed| seed.seeds.iter().map(parse_default))
  183. .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok())
  184. .map(|seeds| {
  185. let program = seed_constraints
  186. .and_then(|seed| seed.program_seed.as_ref())
  187. .and_then(|program| parse_default(program).ok())
  188. .map(|program| quote! { Some(#program) })
  189. .unwrap_or_else(|| quote! { None });
  190. quote! {
  191. Some(
  192. #idl::IdlPda {
  193. seeds: vec![#(#seeds),*],
  194. program: #program,
  195. }
  196. )
  197. }
  198. });
  199. if let Some(pda) = pda {
  200. return pda;
  201. }
  202. // Associated token
  203. let pda = acc
  204. .constraints
  205. .init
  206. .as_ref()
  207. .and_then(|init| match &init.kind {
  208. InitKind::AssociatedToken {
  209. owner,
  210. mint,
  211. token_program,
  212. } => Some((owner, mint, token_program)),
  213. _ => None,
  214. })
  215. .or_else(|| {
  216. acc.constraints
  217. .associated_token
  218. .as_ref()
  219. .map(|ata| (&ata.wallet, &ata.mint, &ata.token_program))
  220. })
  221. .and_then(|(wallet, mint, token_program)| {
  222. // ATA constraints have implicit `.key()` call
  223. let parse_expr = |ts| parse_default(&syn::parse2(ts).unwrap()).ok();
  224. let parse_ata = |expr| parse_expr(quote! { #expr.key().as_ref() });
  225. let wallet = parse_ata(wallet);
  226. let mint = parse_ata(mint);
  227. let token_program = token_program
  228. .as_ref()
  229. .and_then(parse_ata)
  230. .or_else(|| parse_expr(quote!(anchor_spl::token::ID)));
  231. let seeds = match (wallet, mint, token_program) {
  232. (Some(w), Some(m), Some(tp)) => quote! { vec![#w, #tp, #m] },
  233. _ => return None,
  234. };
  235. let program = parse_expr(quote!(anchor_spl::associated_token::ID))
  236. .map(|program| quote! { Some(#program) })
  237. .unwrap();
  238. Some(quote! {
  239. Some(
  240. #idl::IdlPda {
  241. seeds: #seeds,
  242. program: #program,
  243. }
  244. )
  245. })
  246. });
  247. if let Some(pda) = pda {
  248. return pda;
  249. }
  250. quote! { None }
  251. }
  252. /// Parse a seeds constraint, extracting the `IdlSeed` types.
  253. ///
  254. /// Note: This implementation makes assumptions about the types that can be used (e.g., no
  255. /// program-defined function calls in seeds).
  256. ///
  257. /// This probably doesn't cover all cases. If you see a warning log, you can add a new case here.
  258. /// In the worst case, we miss a seed and the parser will treat the given seeds as empty and so
  259. /// clients will simply fail to automatically populate the PDA accounts.
  260. ///
  261. /// # Seed assumptions
  262. ///
  263. /// Seeds must be of one of the following forms:
  264. ///
  265. /// - Constant
  266. /// - Instruction argument
  267. /// - Account key or field
  268. fn parse_seed(seed: &syn::Expr, accounts: &AccountsStruct) -> Result<TokenStream> {
  269. let idl = get_idl_module_path();
  270. let args = accounts.instruction_args().unwrap_or_default();
  271. match seed {
  272. syn::Expr::MethodCall(_) => {
  273. let seed_path = SeedPath::new(seed)?;
  274. if args.contains_key(&seed_path.name) {
  275. let path = seed_path.path();
  276. Ok(quote! {
  277. #idl::IdlSeed::Arg(
  278. #idl::IdlSeedArg {
  279. path: #path.into(),
  280. }
  281. )
  282. })
  283. } else if let Some(account_field) = accounts
  284. .fields
  285. .iter()
  286. .find(|field| *field.ident() == seed_path.name)
  287. {
  288. let path = seed_path.path();
  289. let account = match account_field.ty_name() {
  290. Some(name) if !seed_path.subfields.is_empty() => {
  291. quote! { Some(#name.into()) }
  292. }
  293. _ => quote! { None },
  294. };
  295. Ok(quote! {
  296. #idl::IdlSeed::Account(
  297. #idl::IdlSeedAccount {
  298. path: #path.into(),
  299. account: #account,
  300. }
  301. )
  302. })
  303. } else if seed_path.name.contains('"') {
  304. let seed = seed_path.name.trim_start_matches("b\"").trim_matches('"');
  305. Ok(quote! {
  306. #idl::IdlSeed::Const(
  307. #idl::IdlSeedConst {
  308. value: #seed.into(),
  309. }
  310. )
  311. })
  312. } else {
  313. Ok(quote! {
  314. #idl::IdlSeed::Const(
  315. #idl::IdlSeedConst {
  316. value: #seed.into(),
  317. }
  318. )
  319. })
  320. }
  321. }
  322. syn::Expr::Path(path) => {
  323. let seed = path
  324. .path
  325. .get_ident()
  326. .map(|ident| ident.to_string())
  327. .filter(|ident| args.contains_key(ident))
  328. .map(|path| {
  329. quote! {
  330. #idl::IdlSeed::Arg(
  331. #idl::IdlSeedArg {
  332. path: #path.into(),
  333. }
  334. )
  335. }
  336. })
  337. .unwrap_or_else(|| {
  338. // Not all types can be converted to `Vec<u8>` with `.into` call e.g. `Pubkey`.
  339. // This is problematic for `seeds::program` but a hacky way to handle this
  340. // scenerio is to check whether the last segment of the path ends with `ID`.
  341. let seed = path
  342. .path
  343. .segments
  344. .last()
  345. .filter(|seg| seg.ident.to_string().ends_with("ID"))
  346. .map(|_| quote! { #seed.as_ref() })
  347. .unwrap_or_else(|| quote! { #seed });
  348. quote! {
  349. #idl::IdlSeed::Const(
  350. #idl::IdlSeedConst {
  351. value: #seed.into(),
  352. }
  353. )
  354. }
  355. });
  356. Ok(seed)
  357. }
  358. syn::Expr::Lit(_) => Ok(quote! {
  359. #idl::IdlSeed::Const(
  360. #idl::IdlSeedConst {
  361. value: #seed.into(),
  362. }
  363. )
  364. }),
  365. syn::Expr::Reference(rf) => parse_seed(&rf.expr, accounts),
  366. _ => Err(anyhow!("Unexpected seed: {seed:?}")),
  367. }
  368. }
  369. /// SeedPath represents the deconstructed syntax of a single pda seed,
  370. /// consisting of a variable name and a vec of all the sub fields accessed
  371. /// on that variable name. For example, if a seed is `my_field.my_data.as_ref()`,
  372. /// then the field name is `my_field` and the vec of sub fields is `[my_data]`.
  373. struct SeedPath {
  374. /// Seed name
  375. name: String,
  376. /// All path components for the subfields accessed on this seed
  377. subfields: Vec<String>,
  378. }
  379. impl SeedPath {
  380. /// Extract the seed path from a single seed expression.
  381. fn new(seed: &syn::Expr) -> Result<Self> {
  382. // Convert the seed into the raw string representation.
  383. let seed_str = seed.to_token_stream().to_string();
  384. // Check unsupported cases e.g. `&(account.field + 1).to_le_bytes()`
  385. if !seed_str.contains('"')
  386. && seed_str.contains(|c: char| matches!(c, '+' | '-' | '*' | '/' | '%' | '^'))
  387. {
  388. return Err(anyhow!("Seed expression not supported: {seed:#?}"));
  389. }
  390. // Break up the seed into each subfield component.
  391. let mut components = seed_str.split('.').collect::<Vec<_>>();
  392. if components.len() <= 1 {
  393. return Err(anyhow!("Seed is in unexpected format: {seed:#?}"));
  394. }
  395. // The name of the variable (or field).
  396. let name = components.remove(0).to_owned();
  397. // The path to the seed (only if the `name` type is a struct).
  398. let mut path = Vec::new();
  399. while !components.is_empty() {
  400. let subfield = components.remove(0);
  401. if subfield.contains("()") {
  402. break;
  403. }
  404. path.push(subfield.into());
  405. }
  406. if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
  407. path = Vec::new();
  408. }
  409. Ok(SeedPath {
  410. name,
  411. subfields: path,
  412. })
  413. }
  414. /// Get the full path to the data this seed represents.
  415. fn path(&self) -> String {
  416. match self.subfields.len() {
  417. 0 => self.name.to_owned(),
  418. _ => format!("{}.{}", self.name, self.subfields.join(".")),
  419. }
  420. }
  421. }
  422. fn get_relations(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
  423. let relations = accounts
  424. .fields
  425. .iter()
  426. .filter_map(|af| match af {
  427. AccountField::Field(f) => f
  428. .constraints
  429. .has_one
  430. .iter()
  431. .filter_map(|c| match &c.join_target {
  432. syn::Expr::Path(path) => path
  433. .path
  434. .segments
  435. .first()
  436. .filter(|seg| seg.ident == acc.ident)
  437. .map(|_| Some(f.ident.to_string())),
  438. _ => None,
  439. })
  440. .collect::<Option<Vec<_>>>(),
  441. _ => None,
  442. })
  443. .flatten()
  444. .collect::<Vec<_>>();
  445. quote! { vec![#(#relations.into()),*] }
  446. }