Răsfoiți Sursa

lang, spl, cli: Add associated_token keyword (#790)

Armani Ferrante 4 ani în urmă
părinte
comite
2c827bc839

+ 3 - 2
CHANGELOG.md

@@ -11,11 +11,12 @@ incremented for features.
 
 ## [Unreleased]
 
-## [0.16.1] - 2021-09-17
-
 ### Features
 
 * lang: Add `--detach` flag to `anchor test` ([#770](https://github.com/project-serum/anchor/pull/770)).
+* lang: Add `associated_token` keyword for initializing associated token accounts within `#[derive(Accounts)]` ([#790](https://github.com/project-serum/anchor/pull/790)).
+
+## [0.16.1] - 2021-09-17
 
 ### Fixes
 

+ 1 - 0
Cargo.lock

@@ -216,6 +216,7 @@ dependencies = [
  "lazy_static",
  "serum_dex",
  "solana-program",
+ "spl-associated-token-account",
  "spl-token 3.2.0",
 ]
 

+ 16 - 12
cli/src/lib.rs

@@ -1370,18 +1370,10 @@ fn test(
                 .context(cmd)
         };
 
-        match test_result {
-            Ok(exit) => {
-                if detach {
-                    println!("Local validator still running. Press Ctrl + C quit.");
-                    std::io::stdin().lock().lines().next().unwrap().unwrap();
-                } else if !exit.status.success() && !detach {
-                    std::process::exit(exit.status.code().unwrap());
-                }
-            }
-            Err(err) => {
-                println!("Failed to run test: {:#}", err)
-            }
+        // Keep validator running if needed.
+        if test_result.is_ok() && detach {
+            println!("Local validator still running. Press Ctrl + C quit.");
+            std::io::stdin().lock().lines().next().unwrap().unwrap();
         }
 
         // Check all errors and shut down.
@@ -1396,6 +1388,18 @@ fn test(
             }
         }
 
+        // Must exist *after* shutting down the validator and log streams.
+        match test_result {
+            Ok(exit) => {
+                if !exit.status.success() {
+                    std::process::exit(exit.status.code().unwrap());
+                }
+            }
+            Err(err) => {
+                println!("Failed to run test: {:#}", err)
+            }
+        }
+
         Ok(())
     })
 }

+ 22 - 0
lang/syn/src/codegen/accounts/constraints.rs

@@ -411,6 +411,28 @@ pub fn generate_init(
                 };
             }
         }
+        InitKind::AssociatedToken { owner, mint } => {
+            quote! {
+                let #field: #ty_decl = {
+                    #payer
+
+                    let cpi_program = associated_token_program.to_account_info();
+                    let cpi_accounts = anchor_spl::associated_token::Create {
+                        payer: payer.to_account_info(),
+                        associated_token: #field.to_account_info(),
+                        authority: #owner.to_account_info(),
+                        mint: #mint.to_account_info(),
+                        system_program: system_program.to_account_info(),
+                        token_program: token_program.to_account_info(),
+                        rent: rent.to_account_info(),
+                    };
+                    let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
+                    anchor_spl::associated_token::create(cpi_ctx)?;
+                    let pa: #ty_decl = #from_account_info;
+                    pa
+                };
+            }
+        }
         InitKind::Mint { owner, decimals } => {
             let create_account = generate_create_account(
                 field,

+ 10 - 0
lang/syn/src/lib.rs

@@ -243,6 +243,13 @@ impl Field {
                     }
                 }
             }
+            Ty::CpiAccount(_) => {
+                quote! {
+                    #container_ty::try_from_unchecked(
+                        &#field,
+                    )?
+                }
+            }
             _ => {
                 let owner_addr = match &kind {
                     None => quote! { program_id },
@@ -554,6 +561,8 @@ pub enum ConstraintToken {
     Address(Context<ConstraintAddress>),
     TokenMint(Context<ConstraintTokenMint>),
     TokenAuthority(Context<ConstraintTokenAuthority>),
+    AssociatedTokenMint(Context<ConstraintTokenMint>),
+    AssociatedTokenAuthority(Context<ConstraintTokenAuthority>),
     MintAuthority(Context<ConstraintMintAuthority>),
     MintDecimals(Context<ConstraintMintDecimals>),
     Bump(Context<ConstraintTokenBump>),
@@ -653,6 +662,7 @@ pub enum InitKind {
     // Owner for token and mint represents the authority. Not to be confused
     // with the owner of the AccountInfo.
     Token { owner: Expr, mint: Expr },
+    AssociatedToken { owner: Expr, mint: Expr },
     Mint { owner: Expr, decimals: Expr },
 }
 

+ 100 - 1
lang/syn/src/parser/accounts/constraints.rs

@@ -121,6 +121,33 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
                 _ => return Err(ParseError::new(ident.span(), "Invalid attribute")),
             }
         }
+        "associated_token" => {
+            stream.parse::<Token![:]>()?;
+            stream.parse::<Token![:]>()?;
+            let kw = stream.call(Ident::parse_any)?.to_string();
+            stream.parse::<Token![=]>()?;
+
+            let span = ident
+                .span()
+                .join(stream.span())
+                .unwrap_or_else(|| ident.span());
+
+            match kw.as_str() {
+                "mint" => ConstraintToken::AssociatedTokenMint(Context::new(
+                    span,
+                    ConstraintTokenMint {
+                        mint: stream.parse()?,
+                    },
+                )),
+                "authority" => ConstraintToken::AssociatedTokenAuthority(Context::new(
+                    span,
+                    ConstraintTokenAuthority {
+                        auth: stream.parse()?,
+                    },
+                )),
+                _ => return Err(ParseError::new(ident.span(), "Invalid attribute")),
+            }
+        }
         "bump" => {
             let bump = {
                 if stream.peek(Token![=]) {
@@ -246,6 +273,8 @@ pub struct ConstraintGroupBuilder<'ty> {
     pub address: Option<Context<ConstraintAddress>>,
     pub token_mint: Option<Context<ConstraintTokenMint>>,
     pub token_authority: Option<Context<ConstraintTokenAuthority>>,
+    pub associated_token_mint: Option<Context<ConstraintTokenMint>>,
+    pub associated_token_authority: Option<Context<ConstraintTokenAuthority>>,
     pub mint_authority: Option<Context<ConstraintMintAuthority>>,
     pub mint_decimals: Option<Context<ConstraintMintDecimals>>,
     pub bump: Option<Context<ConstraintTokenBump>>,
@@ -273,6 +302,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             address: None,
             token_mint: None,
             token_authority: None,
+            associated_token_mint: None,
+            associated_token_authority: None,
             mint_authority: None,
             mint_decimals: None,
             bump: None,
@@ -307,7 +338,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             // When initializing a non-PDA account, the account being
             // initialized must sign to invoke the system program's create
             // account instruction.
-            if self.signer.is_none() && self.seeds.is_none() {
+            if self.signer.is_none() && self.seeds.is_none() && self.associated_token_mint.is_none()
+            {
                 self.signer
                     .replace(Context::new(i.span(), ConstraintSigner {}));
             }
@@ -425,6 +457,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             address,
             token_mint,
             token_authority,
+            associated_token_mint,
+            associated_token_authority,
             mint_authority,
             mint_decimals,
             bump,
@@ -469,6 +503,17 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
                             )),
                         },
                     }
+                } else if let Some(tm) = &associated_token_mint {
+                    InitKind::AssociatedToken {
+                        mint: tm.clone().into_inner().mint,
+                        owner: match &associated_token_authority {
+                            Some(a) => a.clone().into_inner().auth,
+                            None => return Err(ParseError::new(
+                                tm.span(),
+                                "authority must be provided to initialize a token program derived address"
+                            )),
+                        },
+                    }
                 } else if let Some(d) = &mint_decimals {
                     InitKind::Mint {
                         decimals: d.clone().into_inner().decimals,
@@ -522,6 +567,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             ConstraintToken::Address(c) => self.add_address(c),
             ConstraintToken::TokenAuthority(c) => self.add_token_authority(c),
             ConstraintToken::TokenMint(c) => self.add_token_mint(c),
+            ConstraintToken::AssociatedTokenAuthority(c) => self.add_associated_token_authority(c),
+            ConstraintToken::AssociatedTokenMint(c) => self.add_associated_token_mint(c),
             ConstraintToken::MintAuthority(c) => self.add_mint_authority(c),
             ConstraintToken::MintDecimals(c) => self.add_mint_decimals(c),
             ConstraintToken::Bump(c) => self.add_bump(c),
@@ -585,6 +632,12 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         if self.token_mint.is_some() {
             return Err(ParseError::new(c.span(), "token mint already provided"));
         }
+        if self.associated_token_mint.is_some() {
+            return Err(ParseError::new(
+                c.span(),
+                "associated token mint already provided",
+            ));
+        }
         if self.init.is_none() {
             return Err(ParseError::new(
                 c.span(),
@@ -595,6 +648,26 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         Ok(())
     }
 
+    fn add_associated_token_mint(&mut self, c: Context<ConstraintTokenMint>) -> ParseResult<()> {
+        if self.associated_token_mint.is_some() {
+            return Err(ParseError::new(
+                c.span(),
+                "associated token mint already provided",
+            ));
+        }
+        if self.token_mint.is_some() {
+            return Err(ParseError::new(c.span(), "token mint already provided"));
+        }
+        if self.init.is_none() {
+            return Err(ParseError::new(
+                c.span(),
+                "init must be provided before token",
+            ));
+        }
+        self.associated_token_mint.replace(c);
+        Ok(())
+    }
+
     fn add_bump(&mut self, c: Context<ConstraintTokenBump>) -> ParseResult<()> {
         if self.bump.is_some() {
             return Err(ParseError::new(c.span(), "bump already provided"));
@@ -626,6 +699,32 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         Ok(())
     }
 
+    fn add_associated_token_authority(
+        &mut self,
+        c: Context<ConstraintTokenAuthority>,
+    ) -> ParseResult<()> {
+        if self.associated_token_authority.is_some() {
+            return Err(ParseError::new(
+                c.span(),
+                "associated token authority already provided",
+            ));
+        }
+        if self.token_authority.is_some() {
+            return Err(ParseError::new(
+                c.span(),
+                "token authority already provided",
+            ));
+        }
+        if self.init.is_none() {
+            return Err(ParseError::new(
+                c.span(),
+                "init must be provided before token authority",
+            ));
+        }
+        self.associated_token_authority.replace(c);
+        Ok(())
+    }
+
     fn add_mint_authority(&mut self, c: Context<ConstraintMintAuthority>) -> ParseResult<()> {
         if self.mint_authority.is_some() {
             return Err(ParseError::new(c.span(), "mint authority already provided"));

+ 1 - 0
spl/Cargo.toml

@@ -15,3 +15,4 @@ lazy_static = "1.4.0"
 serum_dex = { git = "https://github.com/project-serum/serum-dex", rev = "1be91f2", version = "0.4.0", features = ["no-entrypoint"] }
 solana-program = "=1.7.11"
 spl-token = { version = "3.1.1", features = ["no-entrypoint"] }
+spl-associated-token-account = { version = "1.0.3", features = ["no-entrypoint"] }

+ 58 - 0
spl/src/associated_token.rs

@@ -0,0 +1,58 @@
+use anchor_lang::solana_program::account_info::AccountInfo;
+use anchor_lang::solana_program::entrypoint::ProgramResult;
+use anchor_lang::solana_program::program_error::ProgramError;
+use anchor_lang::solana_program::pubkey::Pubkey;
+use anchor_lang::{Accounts, CpiContext};
+
+pub use spl_associated_token_account::ID;
+
+pub fn create<'info>(ctx: CpiContext<'_, '_, '_, 'info, Create<'info>>) -> ProgramResult {
+    let ix = spl_associated_token_account::create_associated_token_account(
+        ctx.accounts.payer.key,
+        ctx.accounts.authority.key,
+        ctx.accounts.mint.key,
+    );
+    solana_program::program::invoke_signed(
+        &ix,
+        &[
+            ctx.accounts.payer,
+            ctx.accounts.associated_token,
+            ctx.accounts.authority,
+            ctx.accounts.mint,
+            ctx.accounts.system_program,
+            ctx.accounts.token_program,
+            ctx.accounts.rent,
+        ],
+        ctx.signer_seeds,
+    )
+}
+
+#[derive(Accounts)]
+pub struct Create<'info> {
+    pub payer: AccountInfo<'info>,
+    pub associated_token: AccountInfo<'info>,
+    pub authority: AccountInfo<'info>,
+    pub mint: AccountInfo<'info>,
+    pub system_program: AccountInfo<'info>,
+    pub token_program: AccountInfo<'info>,
+    pub rent: AccountInfo<'info>,
+}
+
+#[derive(Clone)]
+pub struct AssociatedToken;
+
+impl anchor_lang::AccountDeserialize for AssociatedToken {
+    fn try_deserialize(buf: &mut &[u8]) -> Result<Self, ProgramError> {
+        AssociatedToken::try_deserialize_unchecked(buf)
+    }
+
+    fn try_deserialize_unchecked(_buf: &mut &[u8]) -> Result<Self, ProgramError> {
+        Ok(AssociatedToken)
+    }
+}
+
+impl anchor_lang::Id for AssociatedToken {
+    fn id() -> Pubkey {
+        ID
+    }
+}

+ 1 - 0
spl/src/lib.rs

@@ -1,3 +1,4 @@
+pub mod associated_token;
 pub mod dex;
 pub mod mint;
 pub mod shmem;

+ 19 - 1
tests/misc/programs/misc/src/context.rs

@@ -1,6 +1,7 @@
 use crate::account::*;
 use anchor_lang::prelude::*;
-use anchor_spl::token::{Mint, TokenAccount};
+use anchor_spl::associated_token::AssociatedToken;
+use anchor_spl::token::{Mint, Token, TokenAccount};
 use misc2::misc2::MyState as Misc2State;
 use std::mem::size_of;
 
@@ -31,6 +32,23 @@ pub struct TestTokenSeedsInit<'info> {
     pub token_program: AccountInfo<'info>,
 }
 
+#[derive(Accounts)]
+pub struct TestInitAssociatedToken<'info> {
+    #[account(
+        init,
+        payer = payer,
+        associated_token::mint = mint,
+        associated_token::authority = payer,
+    )]
+    pub token: Account<'info, TokenAccount>,
+    pub mint: Account<'info, Mint>,
+    pub payer: Signer<'info>,
+    pub rent: Sysvar<'info, Rent>,
+    pub system_program: Program<'info, System>,
+    pub token_program: Program<'info, Token>,
+    pub associated_token_program: Program<'info, AssociatedToken>,
+}
+
 #[derive(Accounts)]
 #[instruction(nonce: u8)]
 pub struct TestInstructionConstraint<'info> {

+ 5 - 0
tests/misc/programs/misc/src/lib.rs

@@ -159,4 +159,9 @@ pub mod misc {
         ctx.accounts.data.idata = 3;
         Ok(())
     }
+
+    pub fn test_init_associated_token(ctx: Context<TestInitAssociatedToken>) -> ProgramResult {
+        assert!(ctx.accounts.token.mint == ctx.accounts.mint.key());
+        Ok(())
+    }
 }

+ 40 - 3
tests/misc/tests/misc.js

@@ -1,7 +1,11 @@
 const anchor = require("@project-serum/anchor");
 const PublicKey = anchor.web3.PublicKey;
 const assert = require("assert");
-const { TOKEN_PROGRAM_ID, Token } = require("@solana/spl-token");
+const {
+  ASSOCIATED_TOKEN_PROGRAM_ID,
+  TOKEN_PROGRAM_ID,
+  Token,
+} = require("@solana/spl-token");
 
 describe("misc", () => {
   // Configure the client to use the local cluster.
@@ -155,7 +159,7 @@ describe("misc", () => {
     assert.ok(resp.events[2].data.data === 9);
   });
 
-	let dataI8;
+  let dataI8;
 
   it("Can use i8 in the idl", async () => {
     dataI8 = anchor.web3.Keypair.generate();
@@ -593,7 +597,7 @@ describe("misc", () => {
         data: data2.publicKey,
         systemProgram: anchor.web3.SystemProgram.programId,
       },
-      signers: [data1, data2]
+      signers: [data1, data2],
     });
 
     const account1 = await program.account.dataI8.fetch(data1.publicKey);
@@ -603,4 +607,37 @@ describe("misc", () => {
     assert.equal(account2.udata, 2);
     assert.equal(account2.idata, 3);
   });
+
+  it("Can create an associated token account", async () => {
+    const token = await Token.getAssociatedTokenAddress(
+      ASSOCIATED_TOKEN_PROGRAM_ID,
+      TOKEN_PROGRAM_ID,
+      mint.publicKey,
+      program.provider.wallet.publicKey
+    );
+
+    await program.rpc.testInitAssociatedToken({
+      accounts: {
+        token,
+        mint: mint.publicKey,
+        payer: program.provider.wallet.publicKey,
+        rent: anchor.web3.SYSVAR_RENT_PUBKEY,
+        systemProgram: anchor.web3.SystemProgram.programId,
+        tokenProgram: TOKEN_PROGRAM_ID,
+        associatedTokenProgram: ASSOCIATED_TOKEN_PROGRAM_ID,
+      },
+    });
+    const client = new Token(
+      program.provider.connection,
+      mint.publicKey,
+      TOKEN_PROGRAM_ID,
+      program.provider.wallet.payer
+    );
+    const account = await client.getAccountInfo(token);
+    assert.ok(account.state === 1);
+    assert.ok(account.amount.toNumber() === 0);
+    assert.ok(account.isInitialized);
+    assert.ok(account.owner.equals(program.provider.wallet.publicKey));
+    assert.ok(account.mint.equals(mint.publicKey));
+  });
 });