Browse Source

lang: Remove mandatory init when using associated_token constraints (#843)

Alan O'Donnell 4 years ago
parent
commit
9d33e13465

+ 20 - 0
lang/syn/src/codegen/accounts/constraints.rs

@@ -56,6 +56,7 @@ pub fn linearize(c_group: &ConstraintGroup) -> Vec<Constraint> {
         state,
         close,
         address,
+        associated_token,
     } = c_group.clone();
 
     let mut constraints = Vec::new();
@@ -69,6 +70,9 @@ pub fn linearize(c_group: &ConstraintGroup) -> Vec<Constraint> {
     if let Some(c) = seeds {
         constraints.push(Constraint::Seeds(c));
     }
+    if let Some(c) = associated_token {
+        constraints.push(Constraint::AssociatedToken(c));
+    }
     if let Some(c) = mutable {
         constraints.push(Constraint::Mut(c));
     }
@@ -115,6 +119,7 @@ fn generate_constraint(f: &Field, c: &Constraint) -> proc_macro2::TokenStream {
         Constraint::State(c) => generate_constraint_state(f, c),
         Constraint::Close(c) => generate_constraint_close(f, c),
         Constraint::Address(c) => generate_constraint_address(f, c),
+        Constraint::AssociatedToken(c) => generate_constraint_associated_token(f, c),
     }
 }
 
@@ -363,6 +368,21 @@ fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2
     }
 }
 
+fn generate_constraint_associated_token(
+    f: &Field,
+    c: &ConstraintAssociatedToken,
+) -> proc_macro2::TokenStream {
+    let name = &f.ident;
+    let wallet_address = &c.wallet;
+    let spl_token_mint_address = &c.mint;
+    quote! {
+        let __associated_token_address = anchor_spl::associated_token::get_associated_token_address(&#wallet_address.key(), &#spl_token_mint_address.key());
+        if #name.to_account_info().key != &__associated_token_address {
+            return Err(anchor_lang::__private::ErrorCode::ConstraintAssociated.into());
+        }
+    }
+}
+
 pub fn generate_init(
     f: &Field,
     seeds_with_nonce: proc_macro2::TokenStream,

+ 8 - 0
lang/syn/src/lib.rs

@@ -497,6 +497,7 @@ pub struct ConstraintGroup {
     raw: Vec<ConstraintRaw>,
     close: Option<ConstraintClose>,
     address: Option<ConstraintAddress>,
+    associated_token: Option<ConstraintAssociatedToken>,
 }
 
 impl ConstraintGroup {
@@ -533,6 +534,7 @@ pub enum Constraint {
     Owner(ConstraintOwner),
     RentExempt(ConstraintRentExempt),
     Seeds(ConstraintSeedsGroup),
+    AssociatedToken(ConstraintAssociatedToken),
     Executable(ConstraintExecutable),
     State(ConstraintState),
     Close(ConstraintClose),
@@ -714,6 +716,12 @@ pub struct ConstraintTokenBump {
     bump: Option<Expr>,
 }
 
+#[derive(Debug, Clone)]
+pub struct ConstraintAssociatedToken {
+    pub wallet: Expr,
+    pub mint: Expr,
+}
+
 // Syntaxt context object for preserving metadata about the inner item.
 #[derive(Debug, Clone)]
 pub struct Context<T> {

+ 23 - 22
lang/syn/src/parser/accounts/constraints.rs

@@ -489,13 +489,31 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             };
         }
 
+        let is_init = init.is_some();
         let seeds = seeds.map(|c| ConstraintSeedsGroup {
-            is_init: init.is_some(),
+            is_init,
             seeds: c.seeds.clone(),
             bump: into_inner!(bump)
                 .map(|b| b.bump)
                 .expect("bump must be provided with seeds"),
         });
+        let associated_token = match (associated_token_mint, associated_token_authority) {
+            (Some(mint), Some(auth)) => Some(ConstraintAssociatedToken {
+                wallet: auth.into_inner().auth,
+                mint: mint.into_inner().mint,
+            }),
+            (Some(mint), None) => return Err(ParseError::new(
+                mint.span(),
+                "authority must be provided to specify an associated token program derived address",
+            )),
+            (None, Some(auth)) => {
+                return Err(ParseError::new(
+                    auth.span(),
+                    "mint must be provided to specify an associated token program derived address",
+                ))
+            }
+            _ => None,
+        };
         Ok(ConstraintGroup {
             init: init.as_ref().map(|_| Ok(ConstraintInitGroup {
                 seeds: seeds.clone(),
@@ -512,16 +530,10 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
                             )),
                         },
                     }
-                } else if let Some(tm) = &associated_token_mint {
+                } else if let Some(at) = &associated_token {
                     InitKind::AssociatedToken {
-                        mint: tm.clone().into_inner().mint,
-                        owner: match &associated_token_authority {
-                            Some(a) => a.clone().into_inner().auth,
-                            None => return Err(ParseError::new(
-                                tm.span(),
-                                "authority must be provided to initialize a token program derived address"
-                            )),
-                        },
+                        mint: at.mint.clone(),
+                        owner: at.wallet.clone()
                     }
                 } else if let Some(d) = &mint_decimals {
                     InitKind::Mint {
@@ -553,6 +565,7 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             state: into_inner!(state),
             close: into_inner!(close),
             address: into_inner!(address),
+            associated_token: if !is_init { associated_token } else { None },
             seeds,
         })
     }
@@ -669,12 +682,6 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         if self.token_mint.is_some() {
             return Err(ParseError::new(c.span(), "token mint already provided"));
         }
-        if self.init.is_none() {
-            return Err(ParseError::new(
-                c.span(),
-                "init must be provided before token",
-            ));
-        }
         self.associated_token_mint.replace(c);
         Ok(())
     }
@@ -726,12 +733,6 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
                 "token authority already provided",
             ));
         }
-        if self.init.is_none() {
-            return Err(ParseError::new(
-                c.span(),
-                "init must be provided before token authority",
-            ));
-        }
         self.associated_token_authority.replace(c);
         Ok(())
     }

+ 1 - 1
spl/src/associated_token.rs

@@ -4,7 +4,7 @@ use anchor_lang::solana_program::program_error::ProgramError;
 use anchor_lang::solana_program::pubkey::Pubkey;
 use anchor_lang::{Accounts, CpiContext};
 
-pub use spl_associated_token_account::ID;
+pub use spl_associated_token_account::{get_associated_token_address, ID};
 
 pub fn create<'info>(ctx: CpiContext<'_, '_, '_, 'info, Create<'info>>) -> ProgramResult {
     let ix = spl_associated_token_account::create_associated_token_account(

+ 11 - 0
tests/misc/programs/misc/src/context.rs

@@ -49,6 +49,17 @@ pub struct TestInitAssociatedToken<'info> {
     pub associated_token_program: Program<'info, AssociatedToken>,
 }
 
+#[derive(Accounts)]
+pub struct TestValidateAssociatedToken<'info> {
+    #[account(
+        associated_token::mint = mint,
+        associated_token::authority = wallet,
+    )]
+    pub token: Account<'info, TokenAccount>,
+    pub mint: Account<'info, Mint>,
+    pub wallet: AccountInfo<'info>,
+}
+
 #[derive(Accounts)]
 #[instruction(nonce: u8)]
 pub struct TestInstructionConstraint<'info> {

+ 6 - 0
tests/misc/programs/misc/src/lib.rs

@@ -165,6 +165,12 @@ pub mod misc {
         Ok(())
     }
 
+    pub fn test_validate_associated_token(
+        _ctx: Context<TestValidateAssociatedToken>,
+    ) -> ProgramResult {
+        Ok(())
+    }
+
     pub fn test_fetch_all(ctx: Context<TestFetchAll>, filterable: Pubkey) -> ProgramResult {
         ctx.accounts.data.authority = ctx.accounts.authority.key();
         ctx.accounts.data.filterable = filterable;

+ 31 - 3
tests/misc/tests/misc.js

@@ -612,8 +612,10 @@ describe("misc", () => {
     assert.equal(account2.idata, 3);
   });
 
+  let associatedToken = null;
+
   it("Can create an associated token account", async () => {
-    const token = await Token.getAssociatedTokenAddress(
+    associatedToken = await Token.getAssociatedTokenAddress(
       ASSOCIATED_TOKEN_PROGRAM_ID,
       TOKEN_PROGRAM_ID,
       mint.publicKey,
@@ -622,7 +624,7 @@ describe("misc", () => {
 
     await program.rpc.testInitAssociatedToken({
       accounts: {
-        token,
+        token: associatedToken,
         mint: mint.publicKey,
         payer: program.provider.wallet.publicKey,
         rent: anchor.web3.SYSVAR_RENT_PUBKEY,
@@ -637,7 +639,7 @@ describe("misc", () => {
       TOKEN_PROGRAM_ID,
       program.provider.wallet.payer
     );
-    const account = await client.getAccountInfo(token);
+    const account = await client.getAccountInfo(associatedToken);
     assert.ok(account.state === 1);
     assert.ok(account.amount.toNumber() === 0);
     assert.ok(account.isInitialized);
@@ -645,6 +647,32 @@ describe("misc", () => {
     assert.ok(account.mint.equals(mint.publicKey));
   });
 
+  it("Can validate associated_token constraints", async () => {
+    await program.rpc.testValidateAssociatedToken({
+      accounts: {
+        token: associatedToken,
+        mint: mint.publicKey,
+        wallet: program.provider.wallet.publicKey
+      }
+    });
+
+    await assert.rejects(
+      async () => {
+        await program.rpc.testValidateAssociatedToken({
+          accounts: {
+            token: associatedToken,
+            mint: mint.publicKey,
+            wallet: anchor.web3.Keypair.generate().publicKey
+          }
+        });
+      },
+      (err) => {
+        assert.equal(err.code, 149);
+        return true;
+      }
+    );
+  });
+
   it("Can fetch all accounts of a given type", async () => {
     // Initialize the accounts.
     const data1 = anchor.web3.Keypair.generate();