constraints.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. use crate::*;
  2. use proc_macro2_diagnostics::SpanDiagnosticExt;
  3. use quote::quote;
  4. use syn::Expr;
  5. pub fn generate(f: &Field) -> proc_macro2::TokenStream {
  6. let constraints = linearize(&f.constraints);
  7. let rent = constraints
  8. .iter()
  9. .any(|c| matches!(c, Constraint::RentExempt(ConstraintRentExempt::Enforce)))
  10. .then(|| quote! { let __anchor_rent = Rent::get()?; })
  11. .unwrap_or_else(|| quote! {});
  12. let checks: Vec<proc_macro2::TokenStream> = constraints
  13. .iter()
  14. .map(|c| generate_constraint(f, c))
  15. .collect();
  16. quote! {
  17. #rent
  18. #(#checks)*
  19. }
  20. }
  21. pub fn generate_composite(f: &CompositeField) -> proc_macro2::TokenStream {
  22. let checks: Vec<proc_macro2::TokenStream> = linearize(&f.constraints)
  23. .iter()
  24. .filter_map(|c| match c {
  25. Constraint::Raw(_) => Some(c),
  26. Constraint::Literal(_) => Some(c),
  27. _ => panic!("Invariant violation: composite constraints can only be raw or literals"),
  28. })
  29. .map(|c| generate_constraint_composite(f, c))
  30. .collect();
  31. quote! {
  32. #(#checks)*
  33. }
  34. }
  35. // Linearizes the constraint group so that constraints with dependencies
  36. // run after those without.
  37. pub fn linearize(c_group: &ConstraintGroup) -> Vec<Constraint> {
  38. let ConstraintGroup {
  39. init,
  40. zeroed,
  41. mutable,
  42. signer,
  43. has_one,
  44. literal,
  45. raw,
  46. owner,
  47. rent_exempt,
  48. seeds,
  49. executable,
  50. state,
  51. close,
  52. address,
  53. } = c_group.clone();
  54. let mut constraints = Vec::new();
  55. if let Some(c) = zeroed {
  56. constraints.push(Constraint::Zeroed(c));
  57. }
  58. if let Some(c) = init {
  59. constraints.push(Constraint::Init(c));
  60. }
  61. if let Some(c) = seeds {
  62. constraints.push(Constraint::Seeds(c));
  63. }
  64. if let Some(c) = mutable {
  65. constraints.push(Constraint::Mut(c));
  66. }
  67. if let Some(c) = signer {
  68. constraints.push(Constraint::Signer(c));
  69. }
  70. constraints.append(&mut has_one.into_iter().map(Constraint::HasOne).collect());
  71. constraints.append(&mut literal.into_iter().map(Constraint::Literal).collect());
  72. constraints.append(&mut raw.into_iter().map(Constraint::Raw).collect());
  73. if let Some(c) = owner {
  74. constraints.push(Constraint::Owner(c));
  75. }
  76. if let Some(c) = rent_exempt {
  77. constraints.push(Constraint::RentExempt(c));
  78. }
  79. if let Some(c) = executable {
  80. constraints.push(Constraint::Executable(c));
  81. }
  82. if let Some(c) = state {
  83. constraints.push(Constraint::State(c));
  84. }
  85. if let Some(c) = close {
  86. constraints.push(Constraint::Close(c));
  87. }
  88. if let Some(c) = address {
  89. constraints.push(Constraint::Address(c));
  90. }
  91. constraints
  92. }
  93. fn generate_constraint(f: &Field, c: &Constraint) -> proc_macro2::TokenStream {
  94. match c {
  95. Constraint::Init(c) => generate_constraint_init(f, c),
  96. Constraint::Zeroed(c) => generate_constraint_zeroed(f, c),
  97. Constraint::Mut(c) => generate_constraint_mut(f, c),
  98. Constraint::HasOne(c) => generate_constraint_has_one(f, c),
  99. Constraint::Signer(c) => generate_constraint_signer(f, c),
  100. Constraint::Literal(c) => generate_constraint_literal(c),
  101. Constraint::Raw(c) => generate_constraint_raw(c),
  102. Constraint::Owner(c) => generate_constraint_owner(f, c),
  103. Constraint::RentExempt(c) => generate_constraint_rent_exempt(f, c),
  104. Constraint::Seeds(c) => generate_constraint_seeds(f, c),
  105. Constraint::Executable(c) => generate_constraint_executable(f, c),
  106. Constraint::State(c) => generate_constraint_state(f, c),
  107. Constraint::Close(c) => generate_constraint_close(f, c),
  108. Constraint::Address(c) => generate_constraint_address(f, c),
  109. }
  110. }
  111. fn generate_constraint_composite(_f: &CompositeField, c: &Constraint) -> proc_macro2::TokenStream {
  112. match c {
  113. Constraint::Raw(c) => generate_constraint_raw(c),
  114. Constraint::Literal(c) => generate_constraint_literal(c),
  115. _ => panic!("Invariant violation"),
  116. }
  117. }
  118. fn generate_constraint_address(f: &Field, c: &ConstraintAddress) -> proc_macro2::TokenStream {
  119. let field = &f.ident;
  120. let addr = &c.address;
  121. quote! {
  122. if #field.to_account_info().key != &#addr {
  123. return Err(anchor_lang::__private::ErrorCode::ConstraintAddress.into());
  124. }
  125. }
  126. }
  127. pub fn generate_constraint_init(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream {
  128. generate_constraint_init_group(f, c)
  129. }
  130. pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macro2::TokenStream {
  131. let field = &f.ident;
  132. let (account_ty, account_wrapper_ty, _) = parse_ty(f);
  133. quote! {
  134. let #field: #account_wrapper_ty<#account_ty> = {
  135. let mut __data: &[u8] = &#field.try_borrow_data()?;
  136. let mut __disc_bytes = [0u8; 8];
  137. __disc_bytes.copy_from_slice(&__data[..8]);
  138. let __discriminator = u64::from_le_bytes(__disc_bytes);
  139. if __discriminator != 0 {
  140. return Err(anchor_lang::__private::ErrorCode::ConstraintZero.into());
  141. }
  142. #account_wrapper_ty::try_from_unchecked(
  143. program_id,
  144. &#field,
  145. )?
  146. };
  147. }
  148. }
  149. pub fn generate_constraint_close(f: &Field, c: &ConstraintClose) -> proc_macro2::TokenStream {
  150. let field = &f.ident;
  151. let target = &c.sol_dest;
  152. quote! {
  153. if #field.to_account_info().key == #target.to_account_info().key {
  154. return Err(anchor_lang::__private::ErrorCode::ConstraintClose.into());
  155. }
  156. }
  157. }
  158. pub fn generate_constraint_mut(f: &Field, _c: &ConstraintMut) -> proc_macro2::TokenStream {
  159. let ident = &f.ident;
  160. quote! {
  161. if !#ident.to_account_info().is_writable {
  162. return Err(anchor_lang::__private::ErrorCode::ConstraintMut.into());
  163. }
  164. }
  165. }
  166. pub fn generate_constraint_has_one(f: &Field, c: &ConstraintHasOne) -> proc_macro2::TokenStream {
  167. let target = c.join_target.clone();
  168. let ident = &f.ident;
  169. let field = match &f.ty {
  170. Ty::Loader(_) => quote! {#ident.load()?},
  171. _ => quote! {#ident},
  172. };
  173. quote! {
  174. if &#field.#target != #target.to_account_info().key {
  175. return Err(anchor_lang::__private::ErrorCode::ConstraintHasOne.into());
  176. }
  177. }
  178. }
  179. pub fn generate_constraint_signer(f: &Field, _c: &ConstraintSigner) -> proc_macro2::TokenStream {
  180. let ident = &f.ident;
  181. let info = match f.ty {
  182. Ty::AccountInfo => quote! { #ident },
  183. Ty::ProgramAccount(_) => quote! { #ident.to_account_info() },
  184. Ty::Loader(_) => quote! { #ident.to_account_info() },
  185. Ty::CpiAccount(_) => quote! { #ident.to_account_info() },
  186. _ => panic!("Invalid syntax: signer cannot be specified."),
  187. };
  188. quote! {
  189. // Don't enforce on CPI, since usually a program is signing and so
  190. // the `try_accounts` deserializatoin will fail *if* the one
  191. // tries to manually invoke it.
  192. //
  193. // This check will be performed on the other end of the invocation.
  194. if cfg!(not(feature = "cpi")) {
  195. if !#info.to_account_info().is_signer {
  196. return Err(anchor_lang::__private::ErrorCode::ConstraintSigner.into());
  197. }
  198. }
  199. }
  200. }
  201. pub fn generate_constraint_literal(c: &ConstraintLiteral) -> proc_macro2::TokenStream {
  202. let lit: proc_macro2::TokenStream = {
  203. let lit = &c.lit;
  204. let constraint = lit.value().replace("\"", "");
  205. let message = format!(
  206. "Deprecated. Should be used with constraint: #[account(constraint = {})]",
  207. constraint,
  208. );
  209. lit.span().warning(message).emit_as_item_tokens();
  210. constraint.parse().unwrap()
  211. };
  212. quote! {
  213. if !(#lit) {
  214. return Err(anchor_lang::__private::ErrorCode::Deprecated.into());
  215. }
  216. }
  217. }
  218. pub fn generate_constraint_raw(c: &ConstraintRaw) -> proc_macro2::TokenStream {
  219. let raw = &c.raw;
  220. quote! {
  221. if !(#raw) {
  222. return Err(anchor_lang::__private::ErrorCode::ConstraintRaw.into());
  223. }
  224. }
  225. }
  226. pub fn generate_constraint_owner(f: &Field, c: &ConstraintOwner) -> proc_macro2::TokenStream {
  227. let ident = &f.ident;
  228. let owner_target = c.owner_target.clone();
  229. quote! {
  230. if #ident.to_account_info().owner != #owner_target.to_account_info().key {
  231. return Err(anchor_lang::__private::ErrorCode::ConstraintOwner.into());
  232. }
  233. }
  234. }
  235. pub fn generate_constraint_rent_exempt(
  236. f: &Field,
  237. c: &ConstraintRentExempt,
  238. ) -> proc_macro2::TokenStream {
  239. let ident = &f.ident;
  240. let info = quote! {
  241. #ident.to_account_info()
  242. };
  243. match c {
  244. ConstraintRentExempt::Skip => quote! {},
  245. ConstraintRentExempt::Enforce => quote! {
  246. if !__anchor_rent.is_exempt(#info.lamports(), #info.try_data_len()?) {
  247. return Err(anchor_lang::__private::ErrorCode::ConstraintRentExempt.into());
  248. }
  249. },
  250. }
  251. }
  252. fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream {
  253. let payer = {
  254. let p = &c.payer;
  255. quote! {
  256. let payer = #p.to_account_info();
  257. }
  258. };
  259. let seeds_with_nonce = match &c.seeds {
  260. None => quote! {},
  261. Some(c) => {
  262. let s = &c.seeds;
  263. let inner = match c.bump.as_ref() {
  264. // Bump target not given. Use the canonical bump.
  265. None => {
  266. quote! {
  267. [
  268. #s,
  269. &[
  270. Pubkey::find_program_address(
  271. &[#s],
  272. program_id,
  273. ).1
  274. ]
  275. ]
  276. }
  277. }
  278. // Bump target given. Use it.
  279. Some(b) => quote! {
  280. [#s, &[#b]]
  281. },
  282. };
  283. quote! {
  284. &#inner[..]
  285. }
  286. }
  287. };
  288. generate_pda(f, seeds_with_nonce, payer, &c.space, &c.kind)
  289. }
  290. fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
  291. let name = &f.ident;
  292. let s = &c.seeds;
  293. // If the bump is provided with init *and target*, then force it to be the
  294. // canonical bump.
  295. if c.is_init && c.bump.is_some() {
  296. let b = c.bump.as_ref().unwrap();
  297. quote! {
  298. let (__program_signer, __bump) = anchor_lang::solana_program::pubkey::Pubkey::find_program_address(
  299. &[#s],
  300. program_id,
  301. );
  302. if #name.to_account_info().key != &__program_signer {
  303. return Err(anchor_lang::__private::ErrorCode::ConstraintSeeds.into());
  304. }
  305. if __bump != #b {
  306. return Err(anchor_lang::__private::ErrorCode::ConstraintSeeds.into());
  307. }
  308. }
  309. } else {
  310. let seeds = match c.bump.as_ref() {
  311. // Bump target not given. Find it.
  312. None => {
  313. quote! {
  314. [
  315. #s,
  316. &[
  317. Pubkey::find_program_address(
  318. &[#s],
  319. program_id,
  320. ).1
  321. ]
  322. ]
  323. }
  324. }
  325. // Bump target given. Use it.
  326. Some(b) => {
  327. quote! {
  328. [#s, &[#b]]
  329. }
  330. }
  331. };
  332. quote! {
  333. let __program_signer = Pubkey::create_program_address(
  334. &#seeds[..],
  335. program_id,
  336. ).map_err(|_| anchor_lang::__private::ErrorCode::ConstraintSeeds)?;
  337. if #name.to_account_info().key != &__program_signer {
  338. return Err(anchor_lang::__private::ErrorCode::ConstraintSeeds.into());
  339. }
  340. }
  341. }
  342. }
  343. fn parse_ty(f: &Field) -> (proc_macro2::TokenStream, proc_macro2::TokenStream, bool) {
  344. match &f.ty {
  345. Ty::ProgramAccount(ty) => {
  346. let ident = &ty.account_type_path;
  347. (
  348. quote! {
  349. #ident
  350. },
  351. quote! {
  352. anchor_lang::ProgramAccount
  353. },
  354. false,
  355. )
  356. }
  357. Ty::Loader(ty) => {
  358. let ident = &ty.account_type_path;
  359. (
  360. quote! {
  361. #ident
  362. },
  363. quote! {
  364. anchor_lang::Loader
  365. },
  366. true,
  367. )
  368. }
  369. Ty::CpiAccount(ty) => {
  370. let ident = &ty.account_type_path;
  371. (
  372. quote! {
  373. #ident
  374. },
  375. quote! {
  376. anchor_lang::CpiAccount
  377. },
  378. false,
  379. )
  380. }
  381. Ty::AccountInfo => (
  382. quote! {
  383. AccountInfo
  384. },
  385. quote! {},
  386. false,
  387. ),
  388. _ => panic!("Invalid type for initializing a program derived address"),
  389. }
  390. }
  391. pub fn generate_pda(
  392. f: &Field,
  393. seeds_with_nonce: proc_macro2::TokenStream,
  394. payer: proc_macro2::TokenStream,
  395. space: &Option<Expr>,
  396. kind: &InitKind,
  397. ) -> proc_macro2::TokenStream {
  398. let field = &f.ident;
  399. let (account_ty, account_wrapper_ty, is_zero_copy) = parse_ty(f);
  400. let (combined_account_ty, try_from) = match f.ty {
  401. Ty::AccountInfo => (
  402. quote! {
  403. AccountInfo
  404. },
  405. quote! {
  406. #field.to_account_info()
  407. },
  408. ),
  409. _ => (
  410. quote! {
  411. #account_wrapper_ty<#account_ty>
  412. },
  413. quote! {
  414. #account_wrapper_ty::try_from_unchecked(
  415. program_id,
  416. &#field.to_account_info(),
  417. )?
  418. },
  419. ),
  420. };
  421. match kind {
  422. InitKind::Token { owner, mint } => {
  423. let create_account = generate_create_account(
  424. field,
  425. quote! {anchor_spl::token::TokenAccount::LEN},
  426. quote! {token_program.to_account_info().key},
  427. seeds_with_nonce,
  428. );
  429. quote! {
  430. let #field: #combined_account_ty = {
  431. // Define payer variable.
  432. #payer
  433. // Create the account with the system program.
  434. #create_account
  435. // Initialize the token account.
  436. let cpi_program = token_program.to_account_info();
  437. let accounts = anchor_spl::token::InitializeAccount {
  438. account: #field.to_account_info(),
  439. mint: #mint.to_account_info(),
  440. authority: #owner.to_account_info(),
  441. rent: rent.to_account_info(),
  442. };
  443. let cpi_ctx = CpiContext::new(cpi_program, accounts);
  444. anchor_spl::token::initialize_account(cpi_ctx)?;
  445. anchor_lang::CpiAccount::try_from_unchecked(
  446. &#field.to_account_info(),
  447. )?
  448. };
  449. }
  450. }
  451. InitKind::Mint { owner, decimals } => {
  452. let create_account = generate_create_account(
  453. field,
  454. quote! {anchor_spl::token::Mint::LEN},
  455. quote! {token_program.to_account_info().key},
  456. seeds_with_nonce,
  457. );
  458. quote! {
  459. let #field: #combined_account_ty = {
  460. // Define payer variable.
  461. #payer
  462. // Create the account with the system program.
  463. #create_account
  464. // Initialize the mint account.
  465. let cpi_program = token_program.to_account_info();
  466. let accounts = anchor_spl::token::InitializeMint {
  467. mint: #field.to_account_info(),
  468. rent: rent.to_account_info(),
  469. };
  470. let cpi_ctx = CpiContext::new(cpi_program, accounts);
  471. anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, None)?;
  472. anchor_lang::CpiAccount::try_from_unchecked(
  473. &#field.to_account_info(),
  474. )?
  475. };
  476. }
  477. }
  478. InitKind::Program { owner } => {
  479. let space = match space {
  480. // If no explicit space param was given, serialize the type to bytes
  481. // and take the length (with +8 for the discriminator.)
  482. None => match is_zero_copy {
  483. false => {
  484. quote! {
  485. let space = 8 + #account_ty::default().try_to_vec().unwrap().len();
  486. }
  487. }
  488. true => {
  489. quote! {
  490. let space = 8 + anchor_lang::__private::bytemuck::bytes_of(&#account_ty::default()).len();
  491. }
  492. }
  493. },
  494. // Explicit account size given. Use it.
  495. Some(s) => quote! {
  496. let space = #s;
  497. },
  498. };
  499. // Owner of the account being created. If not specified,
  500. // default to the currently executing program.
  501. let owner = match owner {
  502. None => quote! {
  503. program_id
  504. },
  505. Some(o) => quote! {
  506. &#o
  507. },
  508. };
  509. let create_account =
  510. generate_create_account(field, quote! {space}, owner, seeds_with_nonce);
  511. quote! {
  512. let #field = {
  513. #space
  514. #payer
  515. #create_account
  516. let mut pa: #combined_account_ty = #try_from;
  517. pa
  518. };
  519. }
  520. }
  521. }
  522. }
  523. // Generated code to create an account with with system program with the
  524. // given `space` amount of data, owned by `owner`.
  525. //
  526. // `seeds_with_nonce` should be given for creating PDAs. Otherwise it's an
  527. // empty stream.
  528. pub fn generate_create_account(
  529. field: &Ident,
  530. space: proc_macro2::TokenStream,
  531. owner: proc_macro2::TokenStream,
  532. seeds_with_nonce: proc_macro2::TokenStream,
  533. ) -> proc_macro2::TokenStream {
  534. quote! {
  535. // If the account being initialized already has lamports, then
  536. // return them all back to the payer so that the account has
  537. // zero lamports when the system program's create instruction
  538. // is eventually called.
  539. let __current_lamports = #field.to_account_info().lamports();
  540. if __current_lamports == 0 {
  541. // Create the token account with right amount of lamports and space, and the correct owner.
  542. let lamports = __anchor_rent.minimum_balance(#space);
  543. anchor_lang::solana_program::program::invoke_signed(
  544. &anchor_lang::solana_program::system_instruction::create_account(
  545. payer.to_account_info().key,
  546. #field.to_account_info().key,
  547. lamports,
  548. #space as u64,
  549. #owner,
  550. ),
  551. &[
  552. payer.to_account_info(),
  553. #field.to_account_info(),
  554. system_program.to_account_info().clone(),
  555. ],
  556. &[#seeds_with_nonce],
  557. )?;
  558. } else {
  559. // Fund the account for rent exemption.
  560. let required_lamports = __anchor_rent
  561. .minimum_balance(#space)
  562. .max(1)
  563. .saturating_sub(__current_lamports);
  564. if required_lamports > 0 {
  565. anchor_lang::solana_program::program::invoke(
  566. &anchor_lang::solana_program::system_instruction::transfer(
  567. payer.to_account_info().key,
  568. #field.to_account_info().key,
  569. required_lamports,
  570. ),
  571. &[
  572. payer.to_account_info(),
  573. #field.to_account_info(),
  574. system_program.to_account_info().clone(),
  575. ],
  576. )?;
  577. }
  578. // Allocate space.
  579. anchor_lang::solana_program::program::invoke_signed(
  580. &anchor_lang::solana_program::system_instruction::allocate(
  581. #field.to_account_info().key,
  582. #space as u64,
  583. ),
  584. &[
  585. #field.to_account_info(),
  586. system_program.clone(),
  587. ],
  588. &[#seeds_with_nonce],
  589. )?;
  590. // Assign to the spl token program.
  591. anchor_lang::solana_program::program::invoke_signed(
  592. &anchor_lang::solana_program::system_instruction::assign(
  593. #field.to_account_info().key,
  594. #owner,
  595. ),
  596. &[
  597. #field.to_account_info(),
  598. system_program.to_account_info(),
  599. ],
  600. &[#seeds_with_nonce],
  601. )?;
  602. }
  603. }
  604. }
  605. pub fn generate_constraint_executable(
  606. f: &Field,
  607. _c: &ConstraintExecutable,
  608. ) -> proc_macro2::TokenStream {
  609. let name = &f.ident;
  610. quote! {
  611. if !#name.to_account_info().executable {
  612. return Err(anchor_lang::__private::ErrorCode::ConstraintExecutable.into());
  613. }
  614. }
  615. }
  616. pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2::TokenStream {
  617. let program_target = c.program_target.clone();
  618. let ident = &f.ident;
  619. let account_ty = match &f.ty {
  620. Ty::CpiState(ty) => &ty.account_type_path,
  621. _ => panic!("Invalid state constraint"),
  622. };
  623. quote! {
  624. // Checks the given state account is the canonical state account for
  625. // the target program.
  626. if #ident.to_account_info().key != &anchor_lang::CpiState::<#account_ty>::address(#program_target.to_account_info().key) {
  627. return Err(anchor_lang::__private::ErrorCode::ConstraintState.into());
  628. }
  629. if #ident.to_account_info().owner != #program_target.to_account_info().key {
  630. return Err(anchor_lang::__private::ErrorCode::ConstraintState.into());
  631. }
  632. }
  633. }