try_accounts.rs 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. use crate::codegen::accounts::{constraints, generics, ParsedGenerics};
  2. use crate::{AccountField, AccountsStruct, Field, SysvarTy, Ty};
  3. use proc_macro2::TokenStream;
  4. use quote::quote;
  5. use syn::Expr;
  6. // Generates the `Accounts` trait implementation.
  7. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  8. let name = &accs.ident;
  9. let ParsedGenerics {
  10. combined_generics,
  11. trait_generics,
  12. struct_generics,
  13. where_clause,
  14. } = generics(accs);
  15. // Deserialization for each field
  16. let deser_fields: Vec<proc_macro2::TokenStream> = accs
  17. .fields
  18. .iter()
  19. .map(|af: &AccountField| {
  20. match af {
  21. AccountField::CompositeField(s) => {
  22. let name = &s.ident;
  23. let ty = &s.raw_field.ty;
  24. quote! {
  25. #[cfg(feature = "anchor-debug")]
  26. ::solana_program::log::sol_log(stringify!(#name));
  27. let #name: #ty = anchor_lang::Accounts::try_accounts(program_id, accounts, ix_data)?;
  28. }
  29. }
  30. AccountField::Field(f) => {
  31. // Associated fields are *first* deserialized into
  32. // AccountInfos, and then later deserialized into
  33. // ProgramAccounts in the "constraint check" phase.
  34. if is_pda_init(af) {
  35. let name = &f.ident;
  36. quote!{
  37. let #name = &accounts[0];
  38. *accounts = &accounts[1..];
  39. }
  40. } else {
  41. let name = typed_ident(f);
  42. match f.constraints.is_init() {
  43. false => quote! {
  44. #[cfg(feature = "anchor-debug")]
  45. ::solana_program::log::sol_log(stringify!(#name));
  46. let #name = anchor_lang::Accounts::try_accounts(program_id, accounts, ix_data)?;
  47. },
  48. true => quote! {
  49. #[cfg(feature = "anchor-debug")]
  50. ::solana_program::log::sol_log(stringify!(#name));
  51. let #name = anchor_lang::AccountsInit::try_accounts_init(program_id, accounts)?;
  52. },
  53. }
  54. }
  55. }
  56. }
  57. })
  58. .collect();
  59. let constraints = generate_constraints(accs);
  60. let accounts_instance = generate_accounts_instance(accs);
  61. let ix_de = match &accs.instruction_api {
  62. None => quote! {},
  63. Some(ix_api) => {
  64. let strct_inner = &ix_api;
  65. let field_names: Vec<proc_macro2::TokenStream> = ix_api
  66. .iter()
  67. .map(|expr: &Expr| match expr {
  68. Expr::Type(expr_type) => {
  69. let field = &expr_type.expr;
  70. quote! {
  71. #field
  72. }
  73. }
  74. _ => panic!("Invalid instruction declaration"),
  75. })
  76. .collect();
  77. quote! {
  78. let mut ix_data = ix_data;
  79. #[derive(anchor_lang::AnchorSerialize, anchor_lang::AnchorDeserialize)]
  80. struct __Args {
  81. #strct_inner
  82. }
  83. let __Args {
  84. #(#field_names),*
  85. } = __Args::deserialize(&mut ix_data)
  86. .map_err(|_| anchor_lang::__private::ErrorCode::InstructionDidNotDeserialize)?;
  87. }
  88. }
  89. };
  90. quote! {
  91. #[automatically_derived]
  92. impl<#combined_generics> anchor_lang::Accounts<#trait_generics> for #name<#struct_generics> #where_clause {
  93. #[inline(never)]
  94. fn try_accounts(
  95. program_id: &anchor_lang::solana_program::pubkey::Pubkey,
  96. accounts: &mut &[anchor_lang::solana_program::account_info::AccountInfo<'info>],
  97. ix_data: &[u8],
  98. ) -> std::result::Result<Self, anchor_lang::solana_program::program_error::ProgramError> {
  99. // Deserialize instruction, if declared.
  100. #ix_de
  101. // Deserialize each account.
  102. #(#deser_fields)*
  103. // Execute accounts constraints.
  104. #constraints
  105. // Success. Return the validated accounts.
  106. Ok(#accounts_instance)
  107. }
  108. }
  109. }
  110. }
  111. // Returns true if the given AccountField has an associated init constraint.
  112. fn is_pda_init(af: &AccountField) -> bool {
  113. match af {
  114. AccountField::CompositeField(_s) => false,
  115. AccountField::Field(f) => {
  116. f.constraints
  117. .associated
  118. .as_ref()
  119. .map(|f| f.is_init)
  120. .unwrap_or(false)
  121. || f.constraints
  122. .seeds
  123. .as_ref()
  124. .map(|f| f.is_init)
  125. .unwrap_or(false)
  126. }
  127. }
  128. }
  129. fn typed_ident(field: &Field) -> TokenStream {
  130. let name = &field.ident;
  131. let ty = match &field.ty {
  132. Ty::AccountInfo => quote! { AccountInfo },
  133. Ty::ProgramState(ty) => {
  134. let account = &ty.account_type_path;
  135. quote! {
  136. ProgramState<#account>
  137. }
  138. }
  139. Ty::CpiState(ty) => {
  140. let account = &ty.account_type_path;
  141. quote! {
  142. CpiState<#account>
  143. }
  144. }
  145. Ty::ProgramAccount(ty) => {
  146. let account = &ty.account_type_path;
  147. quote! {
  148. ProgramAccount<#account>
  149. }
  150. }
  151. Ty::Loader(ty) => {
  152. let account = &ty.account_type_path;
  153. quote! {
  154. Loader<#account>
  155. }
  156. }
  157. Ty::CpiAccount(ty) => {
  158. let account = &ty.account_type_path;
  159. quote! {
  160. CpiAccount<#account>
  161. }
  162. }
  163. Ty::Sysvar(ty) => {
  164. let account = match ty {
  165. SysvarTy::Clock => quote! {Clock},
  166. SysvarTy::Rent => quote! {Rent},
  167. SysvarTy::EpochSchedule => quote! {EpochSchedule},
  168. SysvarTy::Fees => quote! {Fees},
  169. SysvarTy::RecentBlockhashes => quote! {RecentBlockhashes},
  170. SysvarTy::SlotHashes => quote! {SlotHashes},
  171. SysvarTy::SlotHistory => quote! {SlotHistory},
  172. SysvarTy::StakeHistory => quote! {StakeHistory},
  173. SysvarTy::Instructions => quote! {Instructions},
  174. SysvarTy::Rewards => quote! {Rewards},
  175. };
  176. quote! {
  177. Sysvar<#account>
  178. }
  179. }
  180. };
  181. quote! {
  182. #name: #ty
  183. }
  184. }
  185. pub fn generate_constraints(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  186. // All fields without an `#[account(associated)]` attribute.
  187. let non_associated_fields: Vec<&AccountField> =
  188. accs.fields.iter().filter(|af| !is_pda_init(af)).collect();
  189. // Deserialization for each *associated* field. This must be after
  190. // the inital extraction from the accounts slice and before access_checks.
  191. let init_associated_fields: Vec<proc_macro2::TokenStream> = accs
  192. .fields
  193. .iter()
  194. .filter_map(|af| match af {
  195. AccountField::CompositeField(_s) => None,
  196. AccountField::Field(f) => match is_pda_init(af) {
  197. false => None,
  198. true => Some(f),
  199. },
  200. })
  201. .map(constraints::generate)
  202. .collect();
  203. // Constraint checks for each account fields.
  204. let access_checks: Vec<proc_macro2::TokenStream> = non_associated_fields
  205. .iter()
  206. .map(|af: &&AccountField| match af {
  207. AccountField::Field(f) => constraints::generate(f),
  208. AccountField::CompositeField(s) => constraints::generate_composite(s),
  209. })
  210. .collect();
  211. quote! {
  212. #(#init_associated_fields)*
  213. #(#access_checks)*
  214. }
  215. }
  216. pub fn generate_accounts_instance(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  217. let name = &accs.ident;
  218. // Each field in the final deserialized accounts struct.
  219. let return_tys: Vec<proc_macro2::TokenStream> = accs
  220. .fields
  221. .iter()
  222. .map(|f: &AccountField| {
  223. let name = match f {
  224. AccountField::CompositeField(s) => &s.ident,
  225. AccountField::Field(f) => &f.ident,
  226. };
  227. quote! {
  228. #name
  229. }
  230. })
  231. .collect();
  232. quote! {
  233. #name {
  234. #(#return_tys),*
  235. }
  236. }
  237. }