constraints.rs 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. use quote::quote;
  2. use std::collections::HashSet;
  3. use syn::Expr;
  4. use crate::*;
  5. pub fn generate(f: &Field, accs: &AccountsStruct) -> proc_macro2::TokenStream {
  6. let constraints = linearize(&f.constraints);
  7. let rent = constraints
  8. .iter()
  9. .any(|c| matches!(c, Constraint::RentExempt(ConstraintRentExempt::Enforce)))
  10. .then(|| quote! { let __anchor_rent = Rent::get()?; })
  11. .unwrap_or_else(|| quote! {});
  12. let checks: Vec<proc_macro2::TokenStream> = constraints
  13. .iter()
  14. .map(|c| generate_constraint(f, c, accs))
  15. .collect();
  16. let mut all_checks = quote! {#(#checks)*};
  17. // If the field is optional we do all the inner checks as if the account
  18. // wasn't optional. If the account is init we also need to return an Option
  19. // by wrapping the resulting value with Some or returning None if it doesn't exist.
  20. if f.is_optional && !constraints.is_empty() {
  21. let ident = &f.ident;
  22. let ty_decl = f.ty_decl(false);
  23. all_checks = match &constraints[0] {
  24. Constraint::Init(_) | Constraint::Zeroed(_) => {
  25. quote! {
  26. let #ident: #ty_decl = if let Some(#ident) = #ident {
  27. #all_checks
  28. Some(#ident)
  29. } else {
  30. None
  31. };
  32. }
  33. }
  34. _ => {
  35. quote! {
  36. if let Some(#ident) = &#ident {
  37. #all_checks
  38. }
  39. }
  40. }
  41. };
  42. }
  43. quote! {
  44. #rent
  45. #all_checks
  46. }
  47. }
  48. pub fn generate_composite(f: &CompositeField) -> proc_macro2::TokenStream {
  49. let checks: Vec<proc_macro2::TokenStream> = linearize(&f.constraints)
  50. .iter()
  51. .filter_map(|c| match c {
  52. Constraint::Raw(_) => Some(c),
  53. _ => panic!("Invariant violation: composite constraints can only be raw or literals"),
  54. })
  55. .map(|c| generate_constraint_composite(f, c))
  56. .collect();
  57. quote! {
  58. #(#checks)*
  59. }
  60. }
  61. // Linearizes the constraint group so that constraints with dependencies
  62. // run after those without.
  63. pub fn linearize(c_group: &ConstraintGroup) -> Vec<Constraint> {
  64. let ConstraintGroup {
  65. init,
  66. zeroed,
  67. mutable,
  68. signer,
  69. has_one,
  70. raw,
  71. owner,
  72. rent_exempt,
  73. seeds,
  74. executable,
  75. close,
  76. address,
  77. associated_token,
  78. token_account,
  79. mint,
  80. realloc,
  81. } = c_group.clone();
  82. let mut constraints = Vec::new();
  83. if let Some(c) = zeroed {
  84. constraints.push(Constraint::Zeroed(c));
  85. }
  86. if let Some(c) = init {
  87. constraints.push(Constraint::Init(c));
  88. }
  89. if let Some(c) = realloc {
  90. constraints.push(Constraint::Realloc(c));
  91. }
  92. if let Some(c) = seeds {
  93. constraints.push(Constraint::Seeds(c));
  94. }
  95. if let Some(c) = associated_token {
  96. constraints.push(Constraint::AssociatedToken(c));
  97. }
  98. if let Some(c) = mutable {
  99. constraints.push(Constraint::Mut(c));
  100. }
  101. if let Some(c) = signer {
  102. constraints.push(Constraint::Signer(c));
  103. }
  104. constraints.append(&mut has_one.into_iter().map(Constraint::HasOne).collect());
  105. constraints.append(&mut raw.into_iter().map(Constraint::Raw).collect());
  106. if let Some(c) = owner {
  107. constraints.push(Constraint::Owner(c));
  108. }
  109. if let Some(c) = rent_exempt {
  110. constraints.push(Constraint::RentExempt(c));
  111. }
  112. if let Some(c) = executable {
  113. constraints.push(Constraint::Executable(c));
  114. }
  115. if let Some(c) = close {
  116. constraints.push(Constraint::Close(c));
  117. }
  118. if let Some(c) = address {
  119. constraints.push(Constraint::Address(c));
  120. }
  121. if let Some(c) = token_account {
  122. constraints.push(Constraint::TokenAccount(c));
  123. }
  124. if let Some(c) = mint {
  125. constraints.push(Constraint::Mint(c));
  126. }
  127. constraints
  128. }
  129. fn generate_constraint(
  130. f: &Field,
  131. c: &Constraint,
  132. accs: &AccountsStruct,
  133. ) -> proc_macro2::TokenStream {
  134. match c {
  135. Constraint::Init(c) => generate_constraint_init(f, c, accs),
  136. Constraint::Zeroed(c) => generate_constraint_zeroed(f, c),
  137. Constraint::Mut(c) => generate_constraint_mut(f, c),
  138. Constraint::HasOne(c) => generate_constraint_has_one(f, c, accs),
  139. Constraint::Signer(c) => generate_constraint_signer(f, c),
  140. Constraint::Raw(c) => generate_constraint_raw(&f.ident, c),
  141. Constraint::Owner(c) => generate_constraint_owner(f, c),
  142. Constraint::RentExempt(c) => generate_constraint_rent_exempt(f, c),
  143. Constraint::Seeds(c) => generate_constraint_seeds(f, c),
  144. Constraint::Executable(c) => generate_constraint_executable(f, c),
  145. Constraint::Close(c) => generate_constraint_close(f, c, accs),
  146. Constraint::Address(c) => generate_constraint_address(f, c),
  147. Constraint::AssociatedToken(c) => generate_constraint_associated_token(f, c, accs),
  148. Constraint::TokenAccount(c) => generate_constraint_token_account(f, c, accs),
  149. Constraint::Mint(c) => generate_constraint_mint(f, c, accs),
  150. Constraint::Realloc(c) => generate_constraint_realloc(f, c, accs),
  151. }
  152. }
  153. fn generate_constraint_composite(f: &CompositeField, c: &Constraint) -> proc_macro2::TokenStream {
  154. match c {
  155. Constraint::Raw(c) => generate_constraint_raw(&f.ident, c),
  156. _ => panic!("Invariant violation"),
  157. }
  158. }
  159. fn generate_constraint_address(f: &Field, c: &ConstraintAddress) -> proc_macro2::TokenStream {
  160. let field = &f.ident;
  161. let addr = &c.address;
  162. let error = generate_custom_error(
  163. field,
  164. &c.error,
  165. quote! { ConstraintAddress },
  166. &Some(&(quote! { actual }, quote! { expected })),
  167. );
  168. quote! {
  169. {
  170. let actual = #field.key();
  171. let expected = #addr;
  172. if actual != expected {
  173. return #error;
  174. }
  175. }
  176. }
  177. }
  178. pub fn generate_constraint_init(
  179. f: &Field,
  180. c: &ConstraintInitGroup,
  181. accs: &AccountsStruct,
  182. ) -> proc_macro2::TokenStream {
  183. generate_constraint_init_group(f, c, accs)
  184. }
  185. pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macro2::TokenStream {
  186. let field = &f.ident;
  187. let name_str = field.to_string();
  188. let ty_decl = f.ty_decl(true);
  189. let from_account_info = f.from_account_info(None, false);
  190. quote! {
  191. let #field: #ty_decl = {
  192. let mut __data: &[u8] = &#field.try_borrow_data()?;
  193. let mut __disc_bytes = [0u8; 8];
  194. __disc_bytes.copy_from_slice(&__data[..8]);
  195. let __discriminator = u64::from_le_bytes(__disc_bytes);
  196. if __discriminator != 0 {
  197. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintZero).with_account_name(#name_str));
  198. }
  199. #from_account_info
  200. };
  201. }
  202. }
  203. pub fn generate_constraint_close(
  204. f: &Field,
  205. c: &ConstraintClose,
  206. accs: &AccountsStruct,
  207. ) -> proc_macro2::TokenStream {
  208. let field = &f.ident;
  209. let name_str = field.to_string();
  210. let target = &c.sol_dest;
  211. let target_optional_check =
  212. OptionalCheckScope::new_with_field(accs, field).generate_check(target);
  213. quote! {
  214. {
  215. #target_optional_check
  216. if #field.key() == #target.key() {
  217. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintClose).with_account_name(#name_str));
  218. }
  219. }
  220. }
  221. }
  222. pub fn generate_constraint_mut(f: &Field, c: &ConstraintMut) -> proc_macro2::TokenStream {
  223. let ident = &f.ident;
  224. let error = generate_custom_error(ident, &c.error, quote! { ConstraintMut }, &None);
  225. quote! {
  226. if !#ident.to_account_info().is_writable {
  227. return #error;
  228. }
  229. }
  230. }
  231. pub fn generate_constraint_has_one(
  232. f: &Field,
  233. c: &ConstraintHasOne,
  234. accs: &AccountsStruct,
  235. ) -> proc_macro2::TokenStream {
  236. let target = &c.join_target;
  237. let ident = &f.ident;
  238. let field = match &f.ty {
  239. Ty::AccountLoader(_) => quote! {#ident.load()?},
  240. _ => quote! {#ident},
  241. };
  242. let error = generate_custom_error(
  243. ident,
  244. &c.error,
  245. quote! { ConstraintHasOne },
  246. &Some(&(quote! { my_key }, quote! { target_key })),
  247. );
  248. let target_optional_check =
  249. OptionalCheckScope::new_with_field(accs, &field).generate_check(target);
  250. quote! {
  251. {
  252. #target_optional_check
  253. let my_key = #field.#target;
  254. let target_key = #target.key();
  255. if my_key != target_key {
  256. return #error;
  257. }
  258. }
  259. }
  260. }
  261. pub fn generate_constraint_signer(f: &Field, c: &ConstraintSigner) -> proc_macro2::TokenStream {
  262. let ident = &f.ident;
  263. let info = match f.ty {
  264. Ty::AccountInfo => quote! { #ident },
  265. Ty::Account(_) => quote! { #ident.to_account_info() },
  266. Ty::AccountLoader(_) => quote! { #ident.to_account_info() },
  267. _ => panic!("Invalid syntax: signer cannot be specified."),
  268. };
  269. let error = generate_custom_error(ident, &c.error, quote! { ConstraintSigner }, &None);
  270. quote! {
  271. if !#info.is_signer {
  272. return #error;
  273. }
  274. }
  275. }
  276. pub fn generate_constraint_raw(ident: &Ident, c: &ConstraintRaw) -> proc_macro2::TokenStream {
  277. let raw = &c.raw;
  278. let error = generate_custom_error(ident, &c.error, quote! { ConstraintRaw }, &None);
  279. quote! {
  280. if !(#raw) {
  281. return #error;
  282. }
  283. }
  284. }
  285. pub fn generate_constraint_owner(f: &Field, c: &ConstraintOwner) -> proc_macro2::TokenStream {
  286. let ident = &f.ident;
  287. let owner_address = &c.owner_address;
  288. let error = generate_custom_error(
  289. ident,
  290. &c.error,
  291. quote! { ConstraintOwner },
  292. &Some(&(quote! { *my_owner }, quote! { owner_address })),
  293. );
  294. quote! {
  295. {
  296. let my_owner = AsRef::<AccountInfo>::as_ref(&#ident).owner;
  297. let owner_address = #owner_address;
  298. if my_owner != &owner_address {
  299. return #error;
  300. }
  301. }
  302. }
  303. }
  304. pub fn generate_constraint_rent_exempt(
  305. f: &Field,
  306. c: &ConstraintRentExempt,
  307. ) -> proc_macro2::TokenStream {
  308. let ident = &f.ident;
  309. let name_str = ident.to_string();
  310. let info = quote! {
  311. #ident.to_account_info()
  312. };
  313. match c {
  314. ConstraintRentExempt::Skip => quote! {},
  315. ConstraintRentExempt::Enforce => quote! {
  316. if !__anchor_rent.is_exempt(#info.lamports(), #info.try_data_len()?) {
  317. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintRentExempt).with_account_name(#name_str));
  318. }
  319. },
  320. }
  321. }
  322. fn generate_constraint_realloc(
  323. f: &Field,
  324. c: &ConstraintReallocGroup,
  325. accs: &AccountsStruct,
  326. ) -> proc_macro2::TokenStream {
  327. let field = &f.ident;
  328. let account_name = field.to_string();
  329. let new_space = &c.space;
  330. let payer = &c.payer;
  331. let zero = &c.zero;
  332. let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, field);
  333. let payer_optional_check = optional_check_scope.generate_check(payer);
  334. let system_program_optional_check =
  335. optional_check_scope.generate_check(quote! {system_program});
  336. quote! {
  337. // Blocks duplicate account reallocs in a single instruction to prevent accidental account overwrites
  338. // and to ensure the calculation of the change in bytes is based on account size at program entry
  339. // which inheritantly guarantee idempotency.
  340. if __reallocs.contains(&#field.key()) {
  341. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::AccountDuplicateReallocs).with_account_name(#account_name));
  342. }
  343. let __anchor_rent = anchor_lang::prelude::Rent::get()?;
  344. let __field_info = #field.to_account_info();
  345. let __new_rent_minimum = __anchor_rent.minimum_balance(#new_space);
  346. let __delta_space = (::std::convert::TryInto::<isize>::try_into(#new_space).unwrap())
  347. .checked_sub(::std::convert::TryInto::try_into(__field_info.data_len()).unwrap())
  348. .unwrap();
  349. if __delta_space != 0 {
  350. #payer_optional_check
  351. if __delta_space > 0 {
  352. #system_program_optional_check
  353. if ::std::convert::TryInto::<usize>::try_into(__delta_space).unwrap() > anchor_lang::solana_program::entrypoint::MAX_PERMITTED_DATA_INCREASE {
  354. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::AccountReallocExceedsLimit).with_account_name(#account_name));
  355. }
  356. if __new_rent_minimum > __field_info.lamports() {
  357. anchor_lang::system_program::transfer(
  358. anchor_lang::context::CpiContext::new(
  359. system_program.to_account_info(),
  360. anchor_lang::system_program::Transfer {
  361. from: #payer.to_account_info(),
  362. to: __field_info.clone(),
  363. },
  364. ),
  365. __new_rent_minimum.checked_sub(__field_info.lamports()).unwrap(),
  366. )?;
  367. }
  368. } else {
  369. let __lamport_amt = __field_info.lamports().checked_sub(__new_rent_minimum).unwrap();
  370. **#payer.to_account_info().lamports.borrow_mut() = #payer.to_account_info().lamports().checked_add(__lamport_amt).unwrap();
  371. **__field_info.lamports.borrow_mut() = __field_info.lamports().checked_sub(__lamport_amt).unwrap();
  372. }
  373. #field.to_account_info().realloc(#new_space, #zero)?;
  374. __reallocs.insert(#field.key());
  375. }
  376. }
  377. }
  378. fn generate_constraint_init_group(
  379. f: &Field,
  380. c: &ConstraintInitGroup,
  381. accs: &AccountsStruct,
  382. ) -> proc_macro2::TokenStream {
  383. let field = &f.ident;
  384. let name_str = f.ident.to_string();
  385. let ty_decl = f.ty_decl(true);
  386. let if_needed = if c.if_needed {
  387. quote! {true}
  388. } else {
  389. quote! {false}
  390. };
  391. let space = &c.space;
  392. let payer = &c.payer;
  393. // Convert from account info to account context wrapper type.
  394. let from_account_info = f.from_account_info(Some(&c.kind), true);
  395. let from_account_info_unchecked = f.from_account_info(Some(&c.kind), false);
  396. // PDA bump seeds.
  397. let (find_pda, seeds_with_bump) = match &c.seeds {
  398. None => (quote! {}, quote! {}),
  399. Some(c) => {
  400. let seeds = &mut c.seeds.clone();
  401. // If the seeds came with a trailing comma, we need to chop it off
  402. // before we interpolate them below.
  403. if let Some(pair) = seeds.pop() {
  404. seeds.push_value(pair.into_value());
  405. }
  406. let maybe_seeds_plus_comma = (!seeds.is_empty()).then(|| {
  407. quote! { #seeds, }
  408. });
  409. let validate_pda = {
  410. // If the bump is provided with init *and target*, then force it to be the
  411. // canonical bump.
  412. //
  413. // Note that for `#[account(init, seeds)]`, find_program_address has already
  414. // been run in the init constraint find_pda variable.
  415. if c.bump.is_some() {
  416. let b = c.bump.as_ref().unwrap();
  417. quote! {
  418. if #field.key() != __pda_address {
  419. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#field.key(), __pda_address)));
  420. }
  421. if __bump != #b {
  422. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_values((__bump, #b)));
  423. }
  424. }
  425. } else {
  426. // Init seeds but no bump. We already used the canonical to create bump so
  427. // just check the address.
  428. //
  429. // Note that for `#[account(init, seeds)]`, find_program_address has already
  430. // been run in the init constraint find_pda variable.
  431. quote! {
  432. if #field.key() != __pda_address {
  433. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#field.key(), __pda_address)));
  434. }
  435. }
  436. }
  437. };
  438. (
  439. quote! {
  440. let (__pda_address, __bump) = Pubkey::find_program_address(
  441. &[#maybe_seeds_plus_comma],
  442. program_id,
  443. );
  444. __bumps.insert(#name_str.to_string(), __bump);
  445. #validate_pda
  446. },
  447. quote! {
  448. &[
  449. #maybe_seeds_plus_comma
  450. &[__bump][..]
  451. ][..]
  452. },
  453. )
  454. }
  455. };
  456. // Optional check idents
  457. let system_program = &quote! {system_program};
  458. let token_program = &quote! {token_program};
  459. let associated_token_program = &quote! {associated_token_program};
  460. let rent = &quote! {rent};
  461. let mut check_scope = OptionalCheckScope::new_with_field(accs, field);
  462. match &c.kind {
  463. InitKind::Token { owner, mint } => {
  464. let owner_optional_check = check_scope.generate_check(owner);
  465. let mint_optional_check = check_scope.generate_check(mint);
  466. let system_program_optional_check = check_scope.generate_check(system_program);
  467. let token_program_optional_check = check_scope.generate_check(token_program);
  468. let rent_optional_check = check_scope.generate_check(rent);
  469. let optional_checks = quote! {
  470. #system_program_optional_check
  471. #token_program_optional_check
  472. #rent_optional_check
  473. #owner_optional_check
  474. #mint_optional_check
  475. };
  476. let payer_optional_check = check_scope.generate_check(payer);
  477. let create_account = generate_create_account(
  478. field,
  479. quote! {anchor_spl::token::TokenAccount::LEN},
  480. quote! {&token_program.key()},
  481. quote! {#payer},
  482. seeds_with_bump,
  483. );
  484. quote! {
  485. // Define the bump and pda variable.
  486. #find_pda
  487. let #field: #ty_decl = {
  488. // Checks that all the required accounts for this operation are present.
  489. #optional_checks
  490. if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
  491. #payer_optional_check
  492. // Create the account with the system program.
  493. #create_account
  494. // Initialize the token account.
  495. let cpi_program = token_program.to_account_info();
  496. let accounts = anchor_spl::token::InitializeAccount3 {
  497. account: #field.to_account_info(),
  498. mint: #mint.to_account_info(),
  499. authority: #owner.to_account_info(),
  500. };
  501. let cpi_ctx = anchor_lang::context::CpiContext::new(cpi_program, accounts);
  502. anchor_spl::token::initialize_account3(cpi_ctx)?;
  503. }
  504. let pa: #ty_decl = #from_account_info_unchecked;
  505. if #if_needed {
  506. if pa.mint != #mint.key() {
  507. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenMint).with_account_name(#name_str).with_pubkeys((pa.mint, #mint.key())));
  508. }
  509. if pa.owner != #owner.key() {
  510. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenOwner).with_account_name(#name_str).with_pubkeys((pa.owner, #owner.key())));
  511. }
  512. }
  513. pa
  514. };
  515. }
  516. }
  517. InitKind::AssociatedToken { owner, mint } => {
  518. let owner_optional_check = check_scope.generate_check(owner);
  519. let mint_optional_check = check_scope.generate_check(mint);
  520. let system_program_optional_check = check_scope.generate_check(system_program);
  521. let token_program_optional_check = check_scope.generate_check(token_program);
  522. let associated_token_program_optional_check =
  523. check_scope.generate_check(associated_token_program);
  524. let rent_optional_check = check_scope.generate_check(rent);
  525. let optional_checks = quote! {
  526. #system_program_optional_check
  527. #token_program_optional_check
  528. #associated_token_program_optional_check
  529. #rent_optional_check
  530. #owner_optional_check
  531. #mint_optional_check
  532. };
  533. let payer_optional_check = check_scope.generate_check(payer);
  534. quote! {
  535. // Define the bump and pda variable.
  536. #find_pda
  537. let #field: #ty_decl = {
  538. // Checks that all the required accounts for this operation are present.
  539. #optional_checks
  540. if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
  541. #payer_optional_check
  542. let cpi_program = associated_token_program.to_account_info();
  543. let cpi_accounts = anchor_spl::associated_token::Create {
  544. payer: #payer.to_account_info(),
  545. associated_token: #field.to_account_info(),
  546. authority: #owner.to_account_info(),
  547. mint: #mint.to_account_info(),
  548. system_program: system_program.to_account_info(),
  549. token_program: token_program.to_account_info(),
  550. };
  551. let cpi_ctx = anchor_lang::context::CpiContext::new(cpi_program, cpi_accounts);
  552. anchor_spl::associated_token::create(cpi_ctx)?;
  553. }
  554. let pa: #ty_decl = #from_account_info_unchecked;
  555. if #if_needed {
  556. if pa.mint != #mint.key() {
  557. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenMint).with_account_name(#name_str).with_pubkeys((pa.mint, #mint.key())));
  558. }
  559. if pa.owner != #owner.key() {
  560. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenOwner).with_account_name(#name_str).with_pubkeys((pa.owner, #owner.key())));
  561. }
  562. if pa.key() != anchor_spl::associated_token::get_associated_token_address(&#owner.key(), &#mint.key()) {
  563. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::AccountNotAssociatedTokenAccount).with_account_name(#name_str));
  564. }
  565. }
  566. pa
  567. };
  568. }
  569. }
  570. InitKind::Mint {
  571. owner,
  572. decimals,
  573. freeze_authority,
  574. } => {
  575. let owner_optional_check = check_scope.generate_check(owner);
  576. let freeze_authority_optional_check = match freeze_authority {
  577. Some(fa) => check_scope.generate_check(fa),
  578. None => quote! {},
  579. };
  580. let system_program_optional_check = check_scope.generate_check(system_program);
  581. let token_program_optional_check = check_scope.generate_check(token_program);
  582. let rent_optional_check = check_scope.generate_check(rent);
  583. let optional_checks = quote! {
  584. #system_program_optional_check
  585. #token_program_optional_check
  586. #rent_optional_check
  587. #owner_optional_check
  588. #freeze_authority_optional_check
  589. };
  590. let payer_optional_check = check_scope.generate_check(payer);
  591. let create_account = generate_create_account(
  592. field,
  593. quote! {anchor_spl::token::Mint::LEN},
  594. quote! {&token_program.key()},
  595. quote! {#payer},
  596. seeds_with_bump,
  597. );
  598. let freeze_authority = match freeze_authority {
  599. Some(fa) => quote! { Option::<&anchor_lang::prelude::Pubkey>::Some(&#fa.key()) },
  600. None => quote! { Option::<&anchor_lang::prelude::Pubkey>::None },
  601. };
  602. quote! {
  603. // Define the bump and pda variable.
  604. #find_pda
  605. let #field: #ty_decl = {
  606. // Checks that all the required accounts for this operation are present.
  607. #optional_checks
  608. if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
  609. // Define payer variable.
  610. #payer_optional_check
  611. // Create the account with the system program.
  612. #create_account
  613. // Initialize the mint account.
  614. let cpi_program = token_program.to_account_info();
  615. let accounts = anchor_spl::token::InitializeMint2 {
  616. mint: #field.to_account_info(),
  617. };
  618. let cpi_ctx = anchor_lang::context::CpiContext::new(cpi_program, accounts);
  619. anchor_spl::token::initialize_mint2(cpi_ctx, #decimals, &#owner.key(), #freeze_authority)?;
  620. }
  621. let pa: #ty_decl = #from_account_info_unchecked;
  622. if #if_needed {
  623. if pa.mint_authority != anchor_lang::solana_program::program_option::COption::Some(#owner.key()) {
  624. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintMintMintAuthority).with_account_name(#name_str));
  625. }
  626. if pa.freeze_authority
  627. .as_ref()
  628. .map(|fa| #freeze_authority.as_ref().map(|expected_fa| fa != *expected_fa).unwrap_or(true))
  629. .unwrap_or(#freeze_authority.is_some()) {
  630. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintMintFreezeAuthority).with_account_name(#name_str));
  631. }
  632. if pa.decimals != #decimals {
  633. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintMintDecimals).with_account_name(#name_str).with_values((pa.decimals, #decimals)));
  634. }
  635. }
  636. pa
  637. };
  638. }
  639. }
  640. InitKind::Program { owner } => {
  641. // Define the space variable.
  642. let space = quote! {let space = #space;};
  643. let system_program_optional_check = check_scope.generate_check(system_program);
  644. // Define the owner of the account being created. If not specified,
  645. // default to the currently executing program.
  646. let (owner, owner_optional_check) = match owner {
  647. None => (
  648. quote! {
  649. program_id
  650. },
  651. quote! {},
  652. ),
  653. Some(o) => {
  654. // We clone the `check_scope` here to avoid collisions with the
  655. // `payer_optional_check`, which is in a separate scope
  656. let owner_optional_check = check_scope.clone().generate_check(o);
  657. (
  658. quote! {
  659. &#o
  660. },
  661. owner_optional_check,
  662. )
  663. }
  664. };
  665. let payer_optional_check = check_scope.generate_check(payer);
  666. let optional_checks = quote! {
  667. #system_program_optional_check
  668. };
  669. // CPI to the system program to create the account.
  670. let create_account = generate_create_account(
  671. field,
  672. quote! {space},
  673. owner.clone(),
  674. quote! {#payer},
  675. seeds_with_bump,
  676. );
  677. // Put it all together.
  678. quote! {
  679. // Define the bump variable.
  680. #find_pda
  681. let #field = {
  682. // Checks that all the required accounts for this operation are present.
  683. #optional_checks
  684. let actual_field = #field.to_account_info();
  685. let actual_owner = actual_field.owner;
  686. // Define the account space variable.
  687. #space
  688. // Create the account. Always do this in the event
  689. // if needed is not specified or the system program is the owner.
  690. let pa: #ty_decl = if !#if_needed || actual_owner == &anchor_lang::solana_program::system_program::ID {
  691. #payer_optional_check
  692. // CPI to the system program to create.
  693. #create_account
  694. // Convert from account info to account context wrapper type.
  695. #from_account_info_unchecked
  696. } else {
  697. // Convert from account info to account context wrapper type.
  698. #from_account_info
  699. };
  700. // Assert the account was created correctly.
  701. if #if_needed {
  702. #owner_optional_check
  703. if space != actual_field.data_len() {
  704. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSpace).with_account_name(#name_str).with_values((space, actual_field.data_len())));
  705. }
  706. if actual_owner != #owner {
  707. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintOwner).with_account_name(#name_str).with_pubkeys((*actual_owner, *#owner)));
  708. }
  709. {
  710. let required_lamports = __anchor_rent.minimum_balance(space);
  711. if pa.to_account_info().lamports() < required_lamports {
  712. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintRentExempt).with_account_name(#name_str));
  713. }
  714. }
  715. }
  716. // Done.
  717. pa
  718. };
  719. }
  720. }
  721. }
  722. }
  723. fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
  724. if c.is_init {
  725. // Note that for `#[account(init, seeds)]`, the seed generation and checks is checked in
  726. // the init constraint find_pda/validate_pda block, so we don't do anything here and
  727. // return nothing!
  728. quote! {}
  729. } else {
  730. let name = &f.ident;
  731. let name_str = name.to_string();
  732. let s = &mut c.seeds.clone();
  733. let deriving_program_id = c
  734. .program_seed
  735. .clone()
  736. // If they specified a seeds::program to use when deriving the PDA, use it.
  737. .map(|program_id| quote! { #program_id.key() })
  738. // Otherwise fall back to the current program's program_id.
  739. .unwrap_or(quote! { program_id });
  740. // If the seeds came with a trailing comma, we need to chop it off
  741. // before we interpolate them below.
  742. if let Some(pair) = s.pop() {
  743. s.push_value(pair.into_value());
  744. }
  745. let maybe_seeds_plus_comma = (!s.is_empty()).then(|| {
  746. quote! { #s, }
  747. });
  748. // Not init here, so do all the checks.
  749. let define_pda = match c.bump.as_ref() {
  750. // Bump target not given. Find it.
  751. None => quote! {
  752. let (__pda_address, __bump) = Pubkey::find_program_address(
  753. &[#maybe_seeds_plus_comma],
  754. &#deriving_program_id,
  755. );
  756. __bumps.insert(#name_str.to_string(), __bump);
  757. },
  758. // Bump target given. Use it.
  759. Some(b) => quote! {
  760. let __pda_address = Pubkey::create_program_address(
  761. &[#maybe_seeds_plus_comma &[#b][..]],
  762. &#deriving_program_id,
  763. ).map_err(|_| anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str))?;
  764. },
  765. };
  766. quote! {
  767. // Define the PDA.
  768. #define_pda
  769. // Check it.
  770. if #name.key() != __pda_address {
  771. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#name.key(), __pda_address)));
  772. }
  773. }
  774. }
  775. }
  776. fn generate_constraint_associated_token(
  777. f: &Field,
  778. c: &ConstraintAssociatedToken,
  779. accs: &AccountsStruct,
  780. ) -> proc_macro2::TokenStream {
  781. let name = &f.ident;
  782. let name_str = name.to_string();
  783. let wallet_address = &c.wallet;
  784. let spl_token_mint_address = &c.mint;
  785. let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
  786. let wallet_address_optional_check = optional_check_scope.generate_check(wallet_address);
  787. let spl_token_mint_address_optional_check =
  788. optional_check_scope.generate_check(spl_token_mint_address);
  789. let optional_checks = quote! {
  790. #wallet_address_optional_check
  791. #spl_token_mint_address_optional_check
  792. };
  793. quote! {
  794. {
  795. #optional_checks
  796. let my_owner = #name.owner;
  797. let wallet_address = #wallet_address.key();
  798. if my_owner != wallet_address {
  799. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenOwner).with_account_name(#name_str).with_pubkeys((my_owner, wallet_address)));
  800. }
  801. let __associated_token_address = anchor_spl::associated_token::get_associated_token_address(&wallet_address, &#spl_token_mint_address.key());
  802. let my_key = #name.key();
  803. if my_key != __associated_token_address {
  804. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintAssociated).with_account_name(#name_str).with_pubkeys((my_key, __associated_token_address)));
  805. }
  806. }
  807. }
  808. }
  809. fn generate_constraint_token_account(
  810. f: &Field,
  811. c: &ConstraintTokenAccountGroup,
  812. accs: &AccountsStruct,
  813. ) -> proc_macro2::TokenStream {
  814. let name = &f.ident;
  815. let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
  816. let authority_check = match &c.authority {
  817. Some(authority) => {
  818. let authority_optional_check = optional_check_scope.generate_check(authority);
  819. quote! {
  820. #authority_optional_check
  821. if #name.owner != #authority.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenOwner.into()); }
  822. }
  823. }
  824. None => quote! {},
  825. };
  826. let mint_check = match &c.mint {
  827. Some(mint) => {
  828. let mint_optional_check = optional_check_scope.generate_check(mint);
  829. quote! {
  830. #mint_optional_check
  831. if #name.mint != #mint.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenMint.into()); }
  832. }
  833. }
  834. None => quote! {},
  835. };
  836. quote! {
  837. {
  838. #authority_check
  839. #mint_check
  840. }
  841. }
  842. }
  843. fn generate_constraint_mint(
  844. f: &Field,
  845. c: &ConstraintTokenMintGroup,
  846. accs: &AccountsStruct,
  847. ) -> proc_macro2::TokenStream {
  848. let name = &f.ident;
  849. let decimal_check = match &c.decimals {
  850. Some(decimals) => quote! {
  851. if #name.decimals != #decimals {
  852. return Err(anchor_lang::error::ErrorCode::ConstraintMintDecimals.into());
  853. }
  854. },
  855. None => quote! {},
  856. };
  857. let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
  858. let mint_authority_check = match &c.mint_authority {
  859. Some(mint_authority) => {
  860. let mint_authority_optional_check = optional_check_scope.generate_check(mint_authority);
  861. quote! {
  862. #mint_authority_optional_check
  863. if #name.mint_authority != anchor_lang::solana_program::program_option::COption::Some(#mint_authority.key()) {
  864. return Err(anchor_lang::error::ErrorCode::ConstraintMintMintAuthority.into());
  865. }
  866. }
  867. }
  868. None => quote! {},
  869. };
  870. let freeze_authority_check = match &c.freeze_authority {
  871. Some(freeze_authority) => {
  872. let freeze_authority_optional_check =
  873. optional_check_scope.generate_check(freeze_authority);
  874. quote! {
  875. #freeze_authority_optional_check
  876. if #name.freeze_authority != anchor_lang::solana_program::program_option::COption::Some(#freeze_authority.key()) {
  877. return Err(anchor_lang::error::ErrorCode::ConstraintMintFreezeAuthority.into());
  878. }
  879. }
  880. }
  881. None => quote! {},
  882. };
  883. quote! {
  884. {
  885. #decimal_check
  886. #mint_authority_check
  887. #freeze_authority_check
  888. }
  889. }
  890. }
  891. #[derive(Clone, Debug)]
  892. pub struct OptionalCheckScope<'a> {
  893. seen: HashSet<String>,
  894. accounts: &'a AccountsStruct,
  895. }
  896. impl<'a> OptionalCheckScope<'a> {
  897. pub fn new(accounts: &'a AccountsStruct) -> Self {
  898. Self {
  899. seen: HashSet::new(),
  900. accounts,
  901. }
  902. }
  903. pub fn new_with_field(accounts: &'a AccountsStruct, field: impl ToString) -> Self {
  904. let mut check_scope = Self::new(accounts);
  905. check_scope.seen.insert(field.to_string());
  906. check_scope
  907. }
  908. pub fn generate_check(&mut self, field: impl ToTokens) -> TokenStream {
  909. let field_name = tts_to_string(&field);
  910. if self.seen.contains(&field_name) {
  911. quote! {}
  912. } else {
  913. self.seen.insert(field_name.clone());
  914. if self.accounts.is_field_optional(&field) {
  915. quote! {
  916. let #field = if let Some(ref account) = #field {
  917. account
  918. } else {
  919. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintAccountIsNone).with_account_name(#field_name));
  920. };
  921. }
  922. } else {
  923. quote! {}
  924. }
  925. }
  926. }
  927. }
  928. // Generated code to create an account with with system program with the
  929. // given `space` amount of data, owned by `owner`.
  930. //
  931. // `seeds_with_nonce` should be given for creating PDAs. Otherwise it's an
  932. // empty stream.
  933. //
  934. // This should only be run within scopes where `system_program` is not Optional
  935. fn generate_create_account(
  936. field: &Ident,
  937. space: proc_macro2::TokenStream,
  938. owner: proc_macro2::TokenStream,
  939. payer: proc_macro2::TokenStream,
  940. seeds_with_nonce: proc_macro2::TokenStream,
  941. ) -> proc_macro2::TokenStream {
  942. // Field, payer, and system program are already validated to not be an Option at this point
  943. quote! {
  944. // If the account being initialized already has lamports, then
  945. // return them all back to the payer so that the account has
  946. // zero lamports when the system program's create instruction
  947. // is eventually called.
  948. let __current_lamports = #field.lamports();
  949. if __current_lamports == 0 {
  950. // Create the token account with right amount of lamports and space, and the correct owner.
  951. let lamports = __anchor_rent.minimum_balance(#space);
  952. let cpi_accounts = anchor_lang::system_program::CreateAccount {
  953. from: #payer.to_account_info(),
  954. to: #field.to_account_info()
  955. };
  956. let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
  957. anchor_lang::system_program::create_account(cpi_context.with_signer(&[#seeds_with_nonce]), lamports, #space as u64, #owner)?;
  958. } else {
  959. require_keys_neq!(#payer.key(), #field.key(), anchor_lang::error::ErrorCode::TryingToInitPayerAsProgramAccount);
  960. // Fund the account for rent exemption.
  961. let required_lamports = __anchor_rent
  962. .minimum_balance(#space)
  963. .max(1)
  964. .saturating_sub(__current_lamports);
  965. if required_lamports > 0 {
  966. let cpi_accounts = anchor_lang::system_program::Transfer {
  967. from: #payer.to_account_info(),
  968. to: #field.to_account_info(),
  969. };
  970. let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
  971. anchor_lang::system_program::transfer(cpi_context, required_lamports)?;
  972. }
  973. // Allocate space.
  974. let cpi_accounts = anchor_lang::system_program::Allocate {
  975. account_to_allocate: #field.to_account_info()
  976. };
  977. let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
  978. anchor_lang::system_program::allocate(cpi_context.with_signer(&[#seeds_with_nonce]), #space as u64)?;
  979. // Assign to the spl token program.
  980. let cpi_accounts = anchor_lang::system_program::Assign {
  981. account_to_assign: #field.to_account_info()
  982. };
  983. let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
  984. anchor_lang::system_program::assign(cpi_context.with_signer(&[#seeds_with_nonce]), #owner)?;
  985. }
  986. }
  987. }
  988. pub fn generate_constraint_executable(
  989. f: &Field,
  990. _c: &ConstraintExecutable,
  991. ) -> proc_macro2::TokenStream {
  992. let name = &f.ident;
  993. let name_str = name.to_string();
  994. // because we are only acting on the field, we know it isnt optional at this point
  995. // as it was unwrapped in `generate_constraint`
  996. quote! {
  997. if !#name.to_account_info().executable {
  998. return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintExecutable).with_account_name(#name_str));
  999. }
  1000. }
  1001. }
  1002. fn generate_custom_error(
  1003. account_name: &Ident,
  1004. custom_error: &Option<Expr>,
  1005. error: proc_macro2::TokenStream,
  1006. compared_values: &Option<&(proc_macro2::TokenStream, proc_macro2::TokenStream)>,
  1007. ) -> proc_macro2::TokenStream {
  1008. let account_name = account_name.to_string();
  1009. let mut error = match custom_error {
  1010. Some(error) => {
  1011. quote! { anchor_lang::error::Error::from(#error).with_account_name(#account_name) }
  1012. }
  1013. None => {
  1014. quote! { anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::#error).with_account_name(#account_name) }
  1015. }
  1016. };
  1017. let compared_values = match compared_values {
  1018. Some((left, right)) => quote! { .with_pubkeys((#left, #right)) },
  1019. None => quote! {},
  1020. };
  1021. error.extend(compared_values);
  1022. quote! {
  1023. Err(#error)
  1024. }
  1025. }