|
@@ -1,9 +1,11 @@
|
|
|
-use crate::*;
|
|
|
use proc_macro2_diagnostics::SpanDiagnosticExt;
|
|
|
use quote::quote;
|
|
|
+use std::collections::HashSet;
|
|
|
use syn::Expr;
|
|
|
|
|
|
-pub fn generate(f: &Field) -> proc_macro2::TokenStream {
|
|
|
+use crate::*;
|
|
|
+
|
|
|
+pub fn generate(f: &Field, accs: &AccountsStruct) -> proc_macro2::TokenStream {
|
|
|
let constraints = linearize(&f.constraints);
|
|
|
|
|
|
let rent = constraints
|
|
@@ -14,12 +16,41 @@ pub fn generate(f: &Field) -> proc_macro2::TokenStream {
|
|
|
|
|
|
let checks: Vec<proc_macro2::TokenStream> = constraints
|
|
|
.iter()
|
|
|
- .map(|c| generate_constraint(f, c))
|
|
|
+ .map(|c| generate_constraint(f, c, accs))
|
|
|
.collect();
|
|
|
|
|
|
+ let mut all_checks = quote! {#(#checks)*};
|
|
|
+
|
|
|
+ // If the field is optional we do all the inner checks as if the account
|
|
|
+ // wasn't optional. If the account is init we also need to return an Option
|
|
|
+ // by wrapping the resulting value with Some or returning None if it doesn't exist.
|
|
|
+ if f.is_optional && !constraints.is_empty() {
|
|
|
+ let ident = &f.ident;
|
|
|
+ let ty_decl = f.ty_decl(false);
|
|
|
+ all_checks = match &constraints[0] {
|
|
|
+ Constraint::Init(_) | Constraint::Zeroed(_) => {
|
|
|
+ quote! {
|
|
|
+ let #ident: #ty_decl = if let Some(#ident) = #ident {
|
|
|
+ #all_checks
|
|
|
+ Some(#ident)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ };
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ => {
|
|
|
+ quote! {
|
|
|
+ if let Some(#ident) = &#ident {
|
|
|
+ #all_checks
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
quote! {
|
|
|
#rent
|
|
|
- #(#checks)*
|
|
|
+ #all_checks
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -115,12 +146,16 @@ pub fn linearize(c_group: &ConstraintGroup) -> Vec<Constraint> {
|
|
|
constraints
|
|
|
}
|
|
|
|
|
|
-fn generate_constraint(f: &Field, c: &Constraint) -> proc_macro2::TokenStream {
|
|
|
+fn generate_constraint(
|
|
|
+ f: &Field,
|
|
|
+ c: &Constraint,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
match c {
|
|
|
- Constraint::Init(c) => generate_constraint_init(f, c),
|
|
|
+ Constraint::Init(c) => generate_constraint_init(f, c, accs),
|
|
|
Constraint::Zeroed(c) => generate_constraint_zeroed(f, c),
|
|
|
Constraint::Mut(c) => generate_constraint_mut(f, c),
|
|
|
- Constraint::HasOne(c) => generate_constraint_has_one(f, c),
|
|
|
+ Constraint::HasOne(c) => generate_constraint_has_one(f, c, accs),
|
|
|
Constraint::Signer(c) => generate_constraint_signer(f, c),
|
|
|
Constraint::Literal(c) => generate_constraint_literal(&f.ident, c),
|
|
|
Constraint::Raw(c) => generate_constraint_raw(&f.ident, c),
|
|
@@ -128,13 +163,13 @@ fn generate_constraint(f: &Field, c: &Constraint) -> proc_macro2::TokenStream {
|
|
|
Constraint::RentExempt(c) => generate_constraint_rent_exempt(f, c),
|
|
|
Constraint::Seeds(c) => generate_constraint_seeds(f, c),
|
|
|
Constraint::Executable(c) => generate_constraint_executable(f, c),
|
|
|
- Constraint::State(c) => generate_constraint_state(f, c),
|
|
|
- Constraint::Close(c) => generate_constraint_close(f, c),
|
|
|
+ Constraint::State(c) => generate_constraint_state(f, c, accs),
|
|
|
+ Constraint::Close(c) => generate_constraint_close(f, c, accs),
|
|
|
Constraint::Address(c) => generate_constraint_address(f, c),
|
|
|
- Constraint::AssociatedToken(c) => generate_constraint_associated_token(f, c),
|
|
|
- Constraint::TokenAccount(c) => generate_constraint_token_account(f, c),
|
|
|
- Constraint::Mint(c) => generate_constraint_mint(f, c),
|
|
|
- Constraint::Realloc(c) => generate_constraint_realloc(f, c),
|
|
|
+ Constraint::AssociatedToken(c) => generate_constraint_associated_token(f, c, accs),
|
|
|
+ Constraint::TokenAccount(c) => generate_constraint_token_account(f, c, accs),
|
|
|
+ Constraint::Mint(c) => generate_constraint_mint(f, c, accs),
|
|
|
+ Constraint::Realloc(c) => generate_constraint_realloc(f, c, accs),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -166,14 +201,18 @@ fn generate_constraint_address(f: &Field, c: &ConstraintAddress) -> proc_macro2:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn generate_constraint_init(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream {
|
|
|
- generate_constraint_init_group(f, c)
|
|
|
+pub fn generate_constraint_init(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintInitGroup,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
+ generate_constraint_init_group(f, c, accs)
|
|
|
}
|
|
|
|
|
|
pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macro2::TokenStream {
|
|
|
let field = &f.ident;
|
|
|
let name_str = field.to_string();
|
|
|
- let ty_decl = f.ty_decl();
|
|
|
+ let ty_decl = f.ty_decl(true);
|
|
|
let from_account_info = f.from_account_info(None, false);
|
|
|
quote! {
|
|
|
let #field: #ty_decl = {
|
|
@@ -189,13 +228,22 @@ pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macr
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn generate_constraint_close(f: &Field, c: &ConstraintClose) -> proc_macro2::TokenStream {
|
|
|
+pub fn generate_constraint_close(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintClose,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
let field = &f.ident;
|
|
|
let name_str = field.to_string();
|
|
|
let target = &c.sol_dest;
|
|
|
+ let target_optional_check =
|
|
|
+ OptionalCheckScope::new_with_field(accs, field).generate_check(target);
|
|
|
quote! {
|
|
|
- if #field.key() == #target.key() {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintClose).with_account_name(#name_str));
|
|
|
+ {
|
|
|
+ #target_optional_check
|
|
|
+ if #field.key() == #target.key() {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintClose).with_account_name(#name_str));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -210,8 +258,12 @@ pub fn generate_constraint_mut(f: &Field, c: &ConstraintMut) -> proc_macro2::Tok
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn generate_constraint_has_one(f: &Field, c: &ConstraintHasOne) -> proc_macro2::TokenStream {
|
|
|
- let target = c.join_target.clone();
|
|
|
+pub fn generate_constraint_has_one(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintHasOne,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
+ let target = &c.join_target;
|
|
|
let ident = &f.ident;
|
|
|
let field = match &f.ty {
|
|
|
Ty::Loader(_) => quote! {#ident.load()?},
|
|
@@ -224,8 +276,12 @@ pub fn generate_constraint_has_one(f: &Field, c: &ConstraintHasOne) -> proc_macr
|
|
|
quote! { ConstraintHasOne },
|
|
|
&Some(&(quote! { my_key }, quote! { target_key })),
|
|
|
);
|
|
|
+ let target_optional_check =
|
|
|
+ OptionalCheckScope::new_with_field(accs, &field).generate_check(target);
|
|
|
+
|
|
|
quote! {
|
|
|
{
|
|
|
+ #target_optional_check
|
|
|
let my_key = #field.#target;
|
|
|
let target_key = #target.key();
|
|
|
if my_key != target_key {
|
|
@@ -325,13 +381,22 @@ pub fn generate_constraint_rent_exempt(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-fn generate_constraint_realloc(f: &Field, c: &ConstraintReallocGroup) -> proc_macro2::TokenStream {
|
|
|
+fn generate_constraint_realloc(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintReallocGroup,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
let field = &f.ident;
|
|
|
let account_name = field.to_string();
|
|
|
let new_space = &c.space;
|
|
|
let payer = &c.payer;
|
|
|
let zero = &c.zero;
|
|
|
|
|
|
+ let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, field);
|
|
|
+ let payer_optional_check = optional_check_scope.generate_check(payer);
|
|
|
+ let system_program_optional_check =
|
|
|
+ optional_check_scope.generate_check(quote! {system_program});
|
|
|
+
|
|
|
quote! {
|
|
|
// Blocks duplicate account reallocs in a single instruction to prevent accidental account overwrites
|
|
|
// and to ensure the calculation of the change in bytes is based on account size at program entry
|
|
@@ -349,7 +414,9 @@ fn generate_constraint_realloc(f: &Field, c: &ConstraintReallocGroup) -> proc_ma
|
|
|
.unwrap();
|
|
|
|
|
|
if __delta_space != 0 {
|
|
|
+ #payer_optional_check
|
|
|
if __delta_space > 0 {
|
|
|
+ #system_program_optional_check
|
|
|
if ::std::convert::TryInto::<usize>::try_into(__delta_space).unwrap() > anchor_lang::solana_program::entrypoint::MAX_PERMITTED_DATA_INCREASE {
|
|
|
return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::AccountReallocExceedsLimit).with_account_name(#account_name));
|
|
|
}
|
|
@@ -378,10 +445,14 @@ fn generate_constraint_realloc(f: &Field, c: &ConstraintReallocGroup) -> proc_ma
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream {
|
|
|
+fn generate_constraint_init_group(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintInitGroup,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
let field = &f.ident;
|
|
|
let name_str = f.ident.to_string();
|
|
|
- let ty_decl = f.ty_decl();
|
|
|
+ let ty_decl = f.ty_decl(true);
|
|
|
let if_needed = if c.if_needed {
|
|
|
quote! {true}
|
|
|
} else {
|
|
@@ -389,13 +460,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
};
|
|
|
let space = &c.space;
|
|
|
|
|
|
- // Payer for rent exemption.
|
|
|
- let payer = {
|
|
|
- let p = &c.payer;
|
|
|
- quote! {
|
|
|
- let payer = #p.to_account_info();
|
|
|
- }
|
|
|
- };
|
|
|
+ let payer = &c.payer;
|
|
|
|
|
|
// Convert from account info to account context wrapper type.
|
|
|
let from_account_info = f.from_account_info(Some(&c.kind), true);
|
|
@@ -417,6 +482,36 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
quote! { #seeds, }
|
|
|
});
|
|
|
|
|
|
+ let validate_pda = {
|
|
|
+ // If the bump is provided with init *and target*, then force it to be the
|
|
|
+ // canonical bump.
|
|
|
+ //
|
|
|
+ // Note that for `#[account(init, seeds)]`, find_program_address has already
|
|
|
+ // been run in the init constraint find_pda variable.
|
|
|
+ if c.bump.is_some() {
|
|
|
+ let b = c.bump.as_ref().unwrap();
|
|
|
+ quote! {
|
|
|
+ if #field.key() != __pda_address {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#field.key(), __pda_address)));
|
|
|
+ }
|
|
|
+ if __bump != #b {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_values((__bump, #b)));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // Init seeds but no bump. We already used the canonical to create bump so
|
|
|
+ // just check the address.
|
|
|
+ //
|
|
|
+ // Note that for `#[account(init, seeds)]`, find_program_address has already
|
|
|
+ // been run in the init constraint find_pda variable.
|
|
|
+ quote! {
|
|
|
+ if #field.key() != __pda_address {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#field.key(), __pda_address)));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
(
|
|
|
quote! {
|
|
|
let (__pda_address, __bump) = Pubkey::find_program_address(
|
|
@@ -424,6 +519,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
program_id,
|
|
|
);
|
|
|
__bumps.insert(#name_str.to_string(), __bump);
|
|
|
+ #validate_pda
|
|
|
},
|
|
|
quote! {
|
|
|
&[
|
|
@@ -435,22 +531,50 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+ // Optional check idents
|
|
|
+ let system_program = "e! {system_program};
|
|
|
+ let token_program = "e! {token_program};
|
|
|
+ let associated_token_program = "e! {associated_token_program};
|
|
|
+ let rent = "e! {rent};
|
|
|
+
|
|
|
+ let mut check_scope = OptionalCheckScope::new_with_field(accs, field);
|
|
|
match &c.kind {
|
|
|
InitKind::Token { owner, mint } => {
|
|
|
+ let owner_optional_check = check_scope.generate_check(owner);
|
|
|
+ let mint_optional_check = check_scope.generate_check(mint);
|
|
|
+
|
|
|
+ let system_program_optional_check = check_scope.generate_check(system_program);
|
|
|
+ let token_program_optional_check = check_scope.generate_check(token_program);
|
|
|
+ let rent_optional_check = check_scope.generate_check(rent);
|
|
|
+
|
|
|
+ let optional_checks = quote! {
|
|
|
+ #system_program_optional_check
|
|
|
+ #token_program_optional_check
|
|
|
+ #rent_optional_check
|
|
|
+ #owner_optional_check
|
|
|
+ #mint_optional_check
|
|
|
+ };
|
|
|
+
|
|
|
+ let payer_optional_check = check_scope.generate_check(payer);
|
|
|
+
|
|
|
let create_account = generate_create_account(
|
|
|
field,
|
|
|
quote! {anchor_spl::token::TokenAccount::LEN},
|
|
|
quote! {&token_program.key()},
|
|
|
+ quote! {#payer},
|
|
|
seeds_with_bump,
|
|
|
);
|
|
|
+
|
|
|
quote! {
|
|
|
// Define the bump and pda variable.
|
|
|
#find_pda
|
|
|
|
|
|
let #field: #ty_decl = {
|
|
|
+ // Checks that all the required accounts for this operation are present.
|
|
|
+ #optional_checks
|
|
|
+
|
|
|
if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
|
|
|
- // Define payer variable.
|
|
|
- #payer
|
|
|
+ #payer_optional_check
|
|
|
|
|
|
// Create the account with the system program.
|
|
|
#create_account
|
|
@@ -480,17 +604,40 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
}
|
|
|
}
|
|
|
InitKind::AssociatedToken { owner, mint } => {
|
|
|
+ let owner_optional_check = check_scope.generate_check(owner);
|
|
|
+ let mint_optional_check = check_scope.generate_check(mint);
|
|
|
+
|
|
|
+ let system_program_optional_check = check_scope.generate_check(system_program);
|
|
|
+ let token_program_optional_check = check_scope.generate_check(token_program);
|
|
|
+ let associated_token_program_optional_check =
|
|
|
+ check_scope.generate_check(associated_token_program);
|
|
|
+ let rent_optional_check = check_scope.generate_check(rent);
|
|
|
+
|
|
|
+ let optional_checks = quote! {
|
|
|
+ #system_program_optional_check
|
|
|
+ #token_program_optional_check
|
|
|
+ #associated_token_program_optional_check
|
|
|
+ #rent_optional_check
|
|
|
+ #owner_optional_check
|
|
|
+ #mint_optional_check
|
|
|
+ };
|
|
|
+
|
|
|
+ let payer_optional_check = check_scope.generate_check(payer);
|
|
|
+
|
|
|
quote! {
|
|
|
// Define the bump and pda variable.
|
|
|
#find_pda
|
|
|
|
|
|
let #field: #ty_decl = {
|
|
|
+ // Checks that all the required accounts for this operation are present.
|
|
|
+ #optional_checks
|
|
|
+
|
|
|
if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
|
|
|
- #payer
|
|
|
+ #payer_optional_check
|
|
|
|
|
|
let cpi_program = associated_token_program.to_account_info();
|
|
|
let cpi_accounts = anchor_spl::associated_token::Create {
|
|
|
- payer: payer.to_account_info(),
|
|
|
+ payer: #payer.to_account_info(),
|
|
|
associated_token: #field.to_account_info(),
|
|
|
authority: #owner.to_account_info(),
|
|
|
mint: #mint.to_account_info(),
|
|
@@ -522,24 +669,50 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
decimals,
|
|
|
freeze_authority,
|
|
|
} => {
|
|
|
+ let owner_optional_check = check_scope.generate_check(owner);
|
|
|
+ let freeze_authority_optional_check = match freeze_authority {
|
|
|
+ Some(fa) => check_scope.generate_check(fa),
|
|
|
+ None => quote! {},
|
|
|
+ };
|
|
|
+
|
|
|
+ let system_program_optional_check = check_scope.generate_check(system_program);
|
|
|
+ let token_program_optional_check = check_scope.generate_check(token_program);
|
|
|
+ let rent_optional_check = check_scope.generate_check(rent);
|
|
|
+
|
|
|
+ let optional_checks = quote! {
|
|
|
+ #system_program_optional_check
|
|
|
+ #token_program_optional_check
|
|
|
+ #rent_optional_check
|
|
|
+ #owner_optional_check
|
|
|
+ #freeze_authority_optional_check
|
|
|
+ };
|
|
|
+
|
|
|
+ let payer_optional_check = check_scope.generate_check(payer);
|
|
|
+
|
|
|
let create_account = generate_create_account(
|
|
|
field,
|
|
|
quote! {anchor_spl::token::Mint::LEN},
|
|
|
quote! {&token_program.key()},
|
|
|
+ quote! {#payer},
|
|
|
seeds_with_bump,
|
|
|
);
|
|
|
+
|
|
|
let freeze_authority = match freeze_authority {
|
|
|
Some(fa) => quote! { Option::<&anchor_lang::prelude::Pubkey>::Some(&#fa.key()) },
|
|
|
None => quote! { Option::<&anchor_lang::prelude::Pubkey>::None },
|
|
|
};
|
|
|
+
|
|
|
quote! {
|
|
|
// Define the bump and pda variable.
|
|
|
#find_pda
|
|
|
|
|
|
let #field: #ty_decl = {
|
|
|
+ // Checks that all the required accounts for this operation are present.
|
|
|
+ #optional_checks
|
|
|
+
|
|
|
if !#if_needed || AsRef::<AccountInfo>::as_ref(&#field).owner == &anchor_lang::solana_program::system_program::ID {
|
|
|
// Define payer variable.
|
|
|
- #payer
|
|
|
+ #payer_optional_check
|
|
|
|
|
|
// Create the account with the system program.
|
|
|
#create_account
|
|
@@ -575,20 +748,45 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
// Define the space variable.
|
|
|
let space = quote! {let space = #space;};
|
|
|
|
|
|
+ let system_program_optional_check = check_scope.generate_check(system_program);
|
|
|
+
|
|
|
// Define the owner of the account being created. If not specified,
|
|
|
// default to the currently executing program.
|
|
|
- let owner = match owner {
|
|
|
- None => quote! {
|
|
|
- program_id
|
|
|
- },
|
|
|
- Some(o) => quote! {
|
|
|
- &#o
|
|
|
- },
|
|
|
+ let (owner, owner_optional_check) = match owner {
|
|
|
+ None => (
|
|
|
+ quote! {
|
|
|
+ program_id
|
|
|
+ },
|
|
|
+ quote! {},
|
|
|
+ ),
|
|
|
+
|
|
|
+ Some(o) => {
|
|
|
+ // We clone the `check_scope` here to avoid collisions with the
|
|
|
+ // `payer_optional_check`, which is in a separate scope
|
|
|
+ let owner_optional_check = check_scope.clone().generate_check(o);
|
|
|
+ (
|
|
|
+ quote! {
|
|
|
+ &#o
|
|
|
+ },
|
|
|
+ owner_optional_check,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ let payer_optional_check = check_scope.generate_check(payer);
|
|
|
+
|
|
|
+ let optional_checks = quote! {
|
|
|
+ #system_program_optional_check
|
|
|
};
|
|
|
|
|
|
// CPI to the system program to create the account.
|
|
|
- let create_account =
|
|
|
- generate_create_account(field, quote! {space}, owner.clone(), seeds_with_bump);
|
|
|
+ let create_account = generate_create_account(
|
|
|
+ field,
|
|
|
+ quote! {space},
|
|
|
+ owner.clone(),
|
|
|
+ quote! {#payer},
|
|
|
+ seeds_with_bump,
|
|
|
+ );
|
|
|
|
|
|
// Put it all together.
|
|
|
quote! {
|
|
@@ -596,6 +794,9 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
#find_pda
|
|
|
|
|
|
let #field = {
|
|
|
+ // Checks that all the required accounts for this operation are present.
|
|
|
+ #optional_checks
|
|
|
+
|
|
|
let actual_field = #field.to_account_info();
|
|
|
let actual_owner = actual_field.owner;
|
|
|
|
|
@@ -605,8 +806,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
// Create the account. Always do this in the event
|
|
|
// if needed is not specified or the system program is the owner.
|
|
|
let pa: #ty_decl = if !#if_needed || actual_owner == &anchor_lang::solana_program::system_program::ID {
|
|
|
- // Define the payer variable.
|
|
|
- #payer
|
|
|
+ #payer_optional_check
|
|
|
|
|
|
// CPI to the system program to create.
|
|
|
#create_account
|
|
@@ -620,6 +820,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
|
|
|
// Assert the account was created correctly.
|
|
|
if #if_needed {
|
|
|
+ #owner_optional_check
|
|
|
if space != actual_field.data_len() {
|
|
|
return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSpace).with_account_name(#name_str).with_values((space, actual_field.data_len())));
|
|
|
}
|
|
@@ -645,59 +846,36 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
|
|
|
}
|
|
|
|
|
|
fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
|
|
|
- let name = &f.ident;
|
|
|
- let name_str = name.to_string();
|
|
|
-
|
|
|
- let s = &mut c.seeds.clone();
|
|
|
-
|
|
|
- let deriving_program_id = c
|
|
|
- .program_seed
|
|
|
- .clone()
|
|
|
- // If they specified a seeds::program to use when deriving the PDA, use it.
|
|
|
- .map(|program_id| quote! { #program_id.key() })
|
|
|
- // Otherwise fall back to the current program's program_id.
|
|
|
- .unwrap_or(quote! { program_id });
|
|
|
-
|
|
|
- // If the seeds came with a trailing comma, we need to chop it off
|
|
|
- // before we interpolate them below.
|
|
|
- if let Some(pair) = s.pop() {
|
|
|
- s.push_value(pair.into_value());
|
|
|
- }
|
|
|
-
|
|
|
- // If the bump is provided with init *and target*, then force it to be the
|
|
|
- // canonical bump.
|
|
|
- //
|
|
|
- // Note that for `#[account(init, seeds)]`, find_program_address has already
|
|
|
- // been run in the init constraint.
|
|
|
- if c.is_init && c.bump.is_some() {
|
|
|
- let b = c.bump.as_ref().unwrap();
|
|
|
- quote! {
|
|
|
- if #name.key() != __pda_address {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#name.key(), __pda_address)));
|
|
|
- }
|
|
|
- if __bump != #b {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_values((__bump, #b)));
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- // Init seeds but no bump. We already used the canonical to create bump so
|
|
|
- // just check the address.
|
|
|
- //
|
|
|
- // Note that for `#[account(init, seeds)]`, find_program_address has already
|
|
|
- // been run in the init constraint.
|
|
|
- else if c.is_init {
|
|
|
- quote! {
|
|
|
- if #name.key() != __pda_address {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintSeeds).with_account_name(#name_str).with_pubkeys((#name.key(), __pda_address)));
|
|
|
- }
|
|
|
+ if c.is_init {
|
|
|
+ // Note that for `#[account(init, seeds)]`, the seed generation and checks is checked in
|
|
|
+ // the init constraint find_pda/validate_pda block, so we don't do anything here and
|
|
|
+ // return nothing!
|
|
|
+ quote! {}
|
|
|
+ } else {
|
|
|
+ let name = &f.ident;
|
|
|
+ let name_str = name.to_string();
|
|
|
+
|
|
|
+ let s = &mut c.seeds.clone();
|
|
|
+
|
|
|
+ let deriving_program_id = c
|
|
|
+ .program_seed
|
|
|
+ .clone()
|
|
|
+ // If they specified a seeds::program to use when deriving the PDA, use it.
|
|
|
+ .map(|program_id| quote! { #program_id.key() })
|
|
|
+ // Otherwise fall back to the current program's program_id.
|
|
|
+ .unwrap_or(quote! { program_id });
|
|
|
+
|
|
|
+ // If the seeds came with a trailing comma, we need to chop it off
|
|
|
+ // before we interpolate them below.
|
|
|
+ if let Some(pair) = s.pop() {
|
|
|
+ s.push_value(pair.into_value());
|
|
|
}
|
|
|
- }
|
|
|
- // No init. So we just check the address.
|
|
|
- else {
|
|
|
+
|
|
|
let maybe_seeds_plus_comma = (!s.is_empty()).then(|| {
|
|
|
quote! { #s, }
|
|
|
});
|
|
|
|
|
|
+ // Not init here, so do all the checks.
|
|
|
let define_pda = match c.bump.as_ref() {
|
|
|
// Bump target not given. Find it.
|
|
|
None => quote! {
|
|
@@ -730,13 +908,25 @@ fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2
|
|
|
fn generate_constraint_associated_token(
|
|
|
f: &Field,
|
|
|
c: &ConstraintAssociatedToken,
|
|
|
+ accs: &AccountsStruct,
|
|
|
) -> proc_macro2::TokenStream {
|
|
|
let name = &f.ident;
|
|
|
let name_str = name.to_string();
|
|
|
let wallet_address = &c.wallet;
|
|
|
let spl_token_mint_address = &c.mint;
|
|
|
+ let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
|
|
|
+ let wallet_address_optional_check = optional_check_scope.generate_check(wallet_address);
|
|
|
+ let spl_token_mint_address_optional_check =
|
|
|
+ optional_check_scope.generate_check(spl_token_mint_address);
|
|
|
+ let optional_checks = quote! {
|
|
|
+ #wallet_address_optional_check
|
|
|
+ #spl_token_mint_address_optional_check
|
|
|
+ };
|
|
|
+
|
|
|
quote! {
|
|
|
{
|
|
|
+ #optional_checks
|
|
|
+
|
|
|
let my_owner = #name.owner;
|
|
|
let wallet_address = #wallet_address.key();
|
|
|
if my_owner != wallet_address {
|
|
@@ -754,27 +944,43 @@ fn generate_constraint_associated_token(
|
|
|
fn generate_constraint_token_account(
|
|
|
f: &Field,
|
|
|
c: &ConstraintTokenAccountGroup,
|
|
|
+ accs: &AccountsStruct,
|
|
|
) -> proc_macro2::TokenStream {
|
|
|
let name = &f.ident;
|
|
|
+ let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
|
|
|
let authority_check = match &c.authority {
|
|
|
Some(authority) => {
|
|
|
- quote! { if #name.owner != #authority.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenOwner.into()); } }
|
|
|
+ let authority_optional_check = optional_check_scope.generate_check(authority);
|
|
|
+ quote! {
|
|
|
+ #authority_optional_check
|
|
|
+ if #name.owner != #authority.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenOwner.into()); }
|
|
|
+ }
|
|
|
}
|
|
|
None => quote! {},
|
|
|
};
|
|
|
let mint_check = match &c.mint {
|
|
|
Some(mint) => {
|
|
|
- quote! { if #name.mint != #mint.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenMint.into()); } }
|
|
|
+ let mint_optional_check = optional_check_scope.generate_check(mint);
|
|
|
+ quote! {
|
|
|
+ #mint_optional_check
|
|
|
+ if #name.mint != #mint.key() { return Err(anchor_lang::error::ErrorCode::ConstraintTokenMint.into()); }
|
|
|
+ }
|
|
|
}
|
|
|
None => quote! {},
|
|
|
};
|
|
|
quote! {
|
|
|
- #authority_check
|
|
|
- #mint_check
|
|
|
+ {
|
|
|
+ #authority_check
|
|
|
+ #mint_check
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-fn generate_constraint_mint(f: &Field, c: &ConstraintTokenMintGroup) -> proc_macro2::TokenStream {
|
|
|
+fn generate_constraint_mint(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintTokenMintGroup,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
let name = &f.ident;
|
|
|
|
|
|
let decimal_check = match &c.decimals {
|
|
@@ -785,26 +991,77 @@ fn generate_constraint_mint(f: &Field, c: &ConstraintTokenMintGroup) -> proc_mac
|
|
|
},
|
|
|
None => quote! {},
|
|
|
};
|
|
|
+ let mut optional_check_scope = OptionalCheckScope::new_with_field(accs, name);
|
|
|
let mint_authority_check = match &c.mint_authority {
|
|
|
- Some(mint_authority) => quote! {
|
|
|
- if #name.mint_authority != anchor_lang::solana_program::program_option::COption::Some(anchor_lang::Key::key(&#mint_authority)) {
|
|
|
- return Err(anchor_lang::error::ErrorCode::ConstraintMintMintAuthority.into());
|
|
|
+ Some(mint_authority) => {
|
|
|
+ let mint_authority_optional_check = optional_check_scope.generate_check(mint_authority);
|
|
|
+ quote! {
|
|
|
+ #mint_authority_optional_check
|
|
|
+ if #name.mint_authority != anchor_lang::solana_program::program_option::COption::Some(#mint_authority.key()) {
|
|
|
+ return Err(anchor_lang::error::ErrorCode::ConstraintMintMintAuthority.into());
|
|
|
+ }
|
|
|
}
|
|
|
- },
|
|
|
+ }
|
|
|
None => quote! {},
|
|
|
};
|
|
|
let freeze_authority_check = match &c.freeze_authority {
|
|
|
- Some(freeze_authority) => quote! {
|
|
|
- if #name.freeze_authority != anchor_lang::solana_program::program_option::COption::Some(anchor_lang::Key::key(&#freeze_authority)) {
|
|
|
- return Err(anchor_lang::error::ErrorCode::ConstraintMintFreezeAuthority.into());
|
|
|
+ Some(freeze_authority) => {
|
|
|
+ let freeze_authority_optional_check =
|
|
|
+ optional_check_scope.generate_check(freeze_authority);
|
|
|
+ quote! {
|
|
|
+ #freeze_authority_optional_check
|
|
|
+ if #name.freeze_authority != anchor_lang::solana_program::program_option::COption::Some(#freeze_authority.key()) {
|
|
|
+ return Err(anchor_lang::error::ErrorCode::ConstraintMintFreezeAuthority.into());
|
|
|
+ }
|
|
|
}
|
|
|
- },
|
|
|
+ }
|
|
|
None => quote! {},
|
|
|
};
|
|
|
quote! {
|
|
|
- #decimal_check
|
|
|
- #mint_authority_check
|
|
|
- #freeze_authority_check
|
|
|
+ {
|
|
|
+ #decimal_check
|
|
|
+ #mint_authority_check
|
|
|
+ #freeze_authority_check
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Clone, Debug)]
|
|
|
+pub struct OptionalCheckScope<'a> {
|
|
|
+ seen: HashSet<String>,
|
|
|
+ accounts: &'a AccountsStruct,
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> OptionalCheckScope<'a> {
|
|
|
+ pub fn new(accounts: &'a AccountsStruct) -> Self {
|
|
|
+ Self {
|
|
|
+ seen: HashSet::new(),
|
|
|
+ accounts,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ pub fn new_with_field(accounts: &'a AccountsStruct, field: impl ToString) -> Self {
|
|
|
+ let mut check_scope = Self::new(accounts);
|
|
|
+ check_scope.seen.insert(field.to_string());
|
|
|
+ check_scope
|
|
|
+ }
|
|
|
+ pub fn generate_check(&mut self, field: impl ToTokens) -> TokenStream {
|
|
|
+ let field_name = tts_to_string(&field);
|
|
|
+ if self.seen.contains(&field_name) {
|
|
|
+ quote! {}
|
|
|
+ } else {
|
|
|
+ self.seen.insert(field_name.clone());
|
|
|
+ if self.accounts.is_field_optional(&field) {
|
|
|
+ quote! {
|
|
|
+ let #field = if let Some(ref account) = #field {
|
|
|
+ account
|
|
|
+ } else {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintAccountIsNone).with_account_name(#field_name));
|
|
|
+ };
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ quote! {}
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -813,12 +1070,16 @@ fn generate_constraint_mint(f: &Field, c: &ConstraintTokenMintGroup) -> proc_mac
|
|
|
//
|
|
|
// `seeds_with_nonce` should be given for creating PDAs. Otherwise it's an
|
|
|
// empty stream.
|
|
|
-pub fn generate_create_account(
|
|
|
+//
|
|
|
+// This should only be run within scopes where `system_program` is not Optional
|
|
|
+fn generate_create_account(
|
|
|
field: &Ident,
|
|
|
space: proc_macro2::TokenStream,
|
|
|
owner: proc_macro2::TokenStream,
|
|
|
+ payer: proc_macro2::TokenStream,
|
|
|
seeds_with_nonce: proc_macro2::TokenStream,
|
|
|
) -> proc_macro2::TokenStream {
|
|
|
+ // Field, payer, and system program are already validated to not be an Option at this point
|
|
|
quote! {
|
|
|
// If the account being initialized already has lamports, then
|
|
|
// return them all back to the payer so that the account has
|
|
@@ -829,13 +1090,13 @@ pub fn generate_create_account(
|
|
|
// Create the token account with right amount of lamports and space, and the correct owner.
|
|
|
let lamports = __anchor_rent.minimum_balance(#space);
|
|
|
let cpi_accounts = anchor_lang::system_program::CreateAccount {
|
|
|
- from: payer.to_account_info(),
|
|
|
+ from: #payer.to_account_info(),
|
|
|
to: #field.to_account_info()
|
|
|
};
|
|
|
let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
|
|
|
anchor_lang::system_program::create_account(cpi_context.with_signer(&[#seeds_with_nonce]), lamports, #space as u64, #owner)?;
|
|
|
} else {
|
|
|
- require_keys_neq!(payer.key(), #field.key(), anchor_lang::error::ErrorCode::TryingToInitPayerAsProgramAccount);
|
|
|
+ require_keys_neq!(#payer.key(), #field.key(), anchor_lang::error::ErrorCode::TryingToInitPayerAsProgramAccount);
|
|
|
// Fund the account for rent exemption.
|
|
|
let required_lamports = __anchor_rent
|
|
|
.minimum_balance(#space)
|
|
@@ -843,7 +1104,7 @@ pub fn generate_create_account(
|
|
|
.saturating_sub(__current_lamports);
|
|
|
if required_lamports > 0 {
|
|
|
let cpi_accounts = anchor_lang::system_program::Transfer {
|
|
|
- from: payer.to_account_info(),
|
|
|
+ from: #payer.to_account_info(),
|
|
|
to: #field.to_account_info(),
|
|
|
};
|
|
|
let cpi_context = anchor_lang::context::CpiContext::new(system_program.to_account_info(), cpi_accounts);
|
|
@@ -871,6 +1132,9 @@ pub fn generate_constraint_executable(
|
|
|
) -> proc_macro2::TokenStream {
|
|
|
let name = &f.ident;
|
|
|
let name_str = name.to_string();
|
|
|
+
|
|
|
+ // because we are only acting on the field, we know it isnt optional at this point
|
|
|
+ // as it was unwrapped in `generate_constraint`
|
|
|
quote! {
|
|
|
if !#name.to_account_info().executable {
|
|
|
return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintExecutable).with_account_name(#name_str));
|
|
@@ -878,7 +1142,11 @@ pub fn generate_constraint_executable(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2::TokenStream {
|
|
|
+pub fn generate_constraint_state(
|
|
|
+ f: &Field,
|
|
|
+ c: &ConstraintState,
|
|
|
+ accs: &AccountsStruct,
|
|
|
+) -> proc_macro2::TokenStream {
|
|
|
let program_target = c.program_target.clone();
|
|
|
let ident = &f.ident;
|
|
|
let name_str = ident.to_string();
|
|
@@ -886,14 +1154,19 @@ pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2:
|
|
|
Ty::CpiState(ty) => &ty.account_type_path,
|
|
|
_ => panic!("Invalid state constraint"),
|
|
|
};
|
|
|
+ let program_target_optional_check =
|
|
|
+ OptionalCheckScope::new_with_field(accs, ident).generate_check(quote! {#program_target});
|
|
|
quote! {
|
|
|
- // Checks the given state account is the canonical state account for
|
|
|
- // the target program.
|
|
|
- if #ident.key() != anchor_lang::accounts::cpi_state::CpiState::<#account_ty>::address(&#program_target.key()) {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintState).with_account_name(#name_str));
|
|
|
- }
|
|
|
- if AsRef::<AccountInfo>::as_ref(&#ident).owner != &#program_target.key() {
|
|
|
- return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintState).with_account_name(#name_str));
|
|
|
+ {
|
|
|
+ #program_target_optional_check
|
|
|
+ // Checks the given state account is the canonical state account for
|
|
|
+ // the target program.
|
|
|
+ if #ident.key() != anchor_lang::accounts::cpi_state::CpiState::<#account_ty>::address(&#program_target.key()) {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintState).with_account_name(#name_str));
|
|
|
+ }
|
|
|
+ if AsRef::<AccountInfo>::as_ref(&#ident).owner != &#program_target.key() {
|
|
|
+ return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintState).with_account_name(#name_str));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|