try_accounts.rs 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. use crate::codegen::accounts::{bumps, constraints, generics, ParsedGenerics};
  2. use crate::{AccountField, AccountsStruct};
  3. use quote::quote;
  4. use syn::Expr;
  5. // Generates the `Accounts` trait implementation.
  6. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  7. let name = &accs.ident;
  8. let ParsedGenerics {
  9. combined_generics,
  10. trait_generics,
  11. struct_generics,
  12. where_clause,
  13. } = generics(accs);
  14. // Deserialization for each field
  15. let deser_fields: Vec<proc_macro2::TokenStream> = accs
  16. .fields
  17. .iter()
  18. .map(|af: &AccountField| {
  19. match af {
  20. AccountField::CompositeField(s) => {
  21. let name = &s.ident;
  22. let ty = &s.raw_field.ty;
  23. quote! {
  24. #[cfg(feature = "anchor-debug")]
  25. ::solana_program::log::sol_log(stringify!(#name));
  26. let #name: #ty = anchor_lang::Accounts::try_accounts(__program_id, __accounts, __ix_data, &mut __bumps.#name, __reallocs)?;
  27. }
  28. }
  29. AccountField::Field(f) => {
  30. // `init` and `zero` acccounts are special cased as they are
  31. // deserialized by constraints. Here, we just take out the
  32. // AccountInfo for later use at constraint validation time.
  33. if is_init(af) || f.constraints.zeroed.is_some() {
  34. let name = &f.ident;
  35. // Optional accounts have slightly different behavior here and
  36. // we can't leverage the try_accounts implementation for zero and init.
  37. if f.is_optional {
  38. // Thus, this block essentially reimplements the try_accounts
  39. // behavior with optional accounts minus the deserialziation.
  40. let empty_behavior = if cfg!(feature = "allow-missing-optionals") {
  41. quote!{ None }
  42. } else {
  43. quote!{ return Err(anchor_lang::error::ErrorCode::AccountNotEnoughKeys.into()); }
  44. };
  45. quote! {
  46. let #name = if __accounts.is_empty() {
  47. #empty_behavior
  48. } else if __accounts[0].key == __program_id {
  49. *__accounts = &__accounts[1..];
  50. None
  51. } else {
  52. let account = &__accounts[0];
  53. *__accounts = &__accounts[1..];
  54. Some(account)
  55. };
  56. }
  57. } else {
  58. quote!{
  59. if __accounts.is_empty() {
  60. return Err(anchor_lang::error::ErrorCode::AccountNotEnoughKeys.into());
  61. }
  62. let #name = &__accounts[0];
  63. *__accounts = &__accounts[1..];
  64. }
  65. }
  66. } else {
  67. let name = f.ident.to_string();
  68. let typed_name = f.typed_ident();
  69. quote! {
  70. #[cfg(feature = "anchor-debug")]
  71. ::solana_program::log::sol_log(stringify!(#typed_name));
  72. let #typed_name = anchor_lang::Accounts::try_accounts(__program_id, __accounts, __ix_data, __bumps, __reallocs)
  73. .map_err(|e| e.with_account_name(#name))?;
  74. }
  75. }
  76. }
  77. }
  78. })
  79. .collect();
  80. let constraints = generate_constraints(accs);
  81. let accounts_instance = generate_accounts_instance(accs);
  82. let bumps_struct_name = bumps::generate_bumps_name(&accs.ident);
  83. let ix_de = match &accs.instruction_api {
  84. None => quote! {},
  85. Some(ix_api) => {
  86. let strct_inner = &ix_api;
  87. let field_names: Vec<proc_macro2::TokenStream> = ix_api
  88. .iter()
  89. .map(|expr: &Expr| match expr {
  90. Expr::Type(expr_type) => {
  91. let field = &expr_type.expr;
  92. quote! {
  93. #field
  94. }
  95. }
  96. _ => panic!("Invalid instruction declaration"),
  97. })
  98. .collect();
  99. quote! {
  100. let mut __ix_data = __ix_data;
  101. #[derive(anchor_lang::AnchorSerialize, anchor_lang::AnchorDeserialize)]
  102. struct __Args {
  103. #strct_inner
  104. }
  105. let __Args {
  106. #(#field_names),*
  107. } = __Args::deserialize(&mut __ix_data)
  108. .map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotDeserialize)?;
  109. }
  110. }
  111. };
  112. quote! {
  113. #[automatically_derived]
  114. impl<#combined_generics> anchor_lang::Accounts<#trait_generics, #bumps_struct_name> for #name<#struct_generics> #where_clause {
  115. #[inline(never)]
  116. fn try_accounts(
  117. __program_id: &anchor_lang::solana_program::pubkey::Pubkey,
  118. __accounts: &mut &#trait_generics [anchor_lang::solana_program::account_info::AccountInfo<#trait_generics>],
  119. __ix_data: &[u8],
  120. __bumps: &mut #bumps_struct_name,
  121. __reallocs: &mut std::collections::BTreeSet<anchor_lang::solana_program::pubkey::Pubkey>,
  122. ) -> anchor_lang::Result<Self> {
  123. // Deserialize instruction, if declared.
  124. #ix_de
  125. // Deserialize each account.
  126. #(#deser_fields)*
  127. // Execute accounts constraints.
  128. #constraints
  129. // Success. Return the validated accounts.
  130. Ok(#accounts_instance)
  131. }
  132. }
  133. }
  134. }
  135. pub fn generate_constraints(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  136. let non_init_fields: Vec<&AccountField> =
  137. accs.fields.iter().filter(|af| !is_init(af)).collect();
  138. // Deserialization for each pda init field. This must be after
  139. // the inital extraction from the accounts slice and before access_checks.
  140. let init_fields: Vec<proc_macro2::TokenStream> = accs
  141. .fields
  142. .iter()
  143. .filter_map(|af| match af {
  144. AccountField::CompositeField(_s) => None,
  145. AccountField::Field(f) => match is_init(af) {
  146. false => None,
  147. true => Some(f),
  148. },
  149. })
  150. .map(|f| constraints::generate(f, accs))
  151. .collect();
  152. // Constraint checks for each account fields.
  153. let access_checks: Vec<proc_macro2::TokenStream> = non_init_fields
  154. .iter()
  155. .map(|af: &&AccountField| match af {
  156. AccountField::Field(f) => constraints::generate(f, accs),
  157. AccountField::CompositeField(s) => constraints::generate_composite(s),
  158. })
  159. .collect();
  160. quote! {
  161. #(#init_fields)*
  162. #(#access_checks)*
  163. }
  164. }
  165. pub fn generate_accounts_instance(accs: &AccountsStruct) -> proc_macro2::TokenStream {
  166. let name = &accs.ident;
  167. // Each field in the final deserialized accounts struct.
  168. let return_tys: Vec<proc_macro2::TokenStream> = accs
  169. .fields
  170. .iter()
  171. .map(|f: &AccountField| {
  172. let name = match f {
  173. AccountField::CompositeField(s) => &s.ident,
  174. AccountField::Field(f) => &f.ident,
  175. };
  176. quote! {
  177. #name
  178. }
  179. })
  180. .collect();
  181. quote! {
  182. #name {
  183. #(#return_tys),*
  184. }
  185. }
  186. }
  187. fn is_init(af: &AccountField) -> bool {
  188. match af {
  189. AccountField::CompositeField(_s) => false,
  190. AccountField::Field(f) => f.constraints.init.is_some(),
  191. }
  192. }