Browse Source

idl: Add accounts resolution for associated token accounts (#2927)

acheron 1 year ago
parent
commit
81c8c556e8

+ 1 - 0
CHANGELOG.md

@@ -13,6 +13,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 ### Features
 ### Features
 
 
 - avm: Support customizing the installation location using `AVM_HOME` environment variable ([#2917](https://github.com/coral-xyz/anchor/pull/2917))
 - avm: Support customizing the installation location using `AVM_HOME` environment variable ([#2917](https://github.com/coral-xyz/anchor/pull/2917))
+- idl, ts: Add accounts resolution for associated token accounts ([#2927](https://github.com/coral-xyz/anchor/pull/2927))
 
 
 ### Fixes
 ### Fixes
 
 

+ 80 - 19
lang/syn/src/idl/accounts.rs

@@ -3,7 +3,7 @@ use proc_macro2::TokenStream;
 use quote::{quote, ToTokens};
 use quote::{quote, ToTokens};
 
 
 use super::common::{get_idl_module_path, get_no_docs};
 use super::common::{get_idl_module_path, get_no_docs};
-use crate::{AccountField, AccountsStruct, Field, Ty};
+use crate::{AccountField, AccountsStruct, Field, InitKind, Ty};
 
 
 /// Generate the IDL build impl for the Accounts struct.
 /// Generate the IDL build impl for the Accounts struct.
 pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
 pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
@@ -168,26 +168,87 @@ fn get_address(acc: &Field) -> TokenStream {
 
 
 fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
 fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
     let idl = get_idl_module_path();
     let idl = get_idl_module_path();
+    let parse_default = |expr: &syn::Expr| parse_seed(expr, accounts);
+
+    // Seeds
     let seed_constraints = acc.constraints.seeds.as_ref();
     let seed_constraints = acc.constraints.seeds.as_ref();
-    let seeds = seed_constraints
-        .map(|seed| seed.seeds.iter().map(|seed| parse_seed(seed, accounts)))
-        .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok());
-    let program = seed_constraints
-        .and_then(|seed| seed.program_seed.as_ref())
-        .and_then(|program| parse_seed(program, accounts).ok())
-        .map(|program| quote! { Some(#program) })
-        .unwrap_or_else(|| quote! { None });
-    match seeds {
-        Some(seeds) => quote! {
-            Some(
-                #idl::IdlPda {
-                    seeds: vec![#(#seeds),*],
-                    program: #program,
-                }
-            )
-        },
-        _ => quote! { None },
+    let pda = seed_constraints
+        .map(|seed| seed.seeds.iter().map(parse_default))
+        .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok())
+        .map(|seeds| {
+            let program = seed_constraints
+                .and_then(|seed| seed.program_seed.as_ref())
+                .and_then(|program| parse_default(program).ok())
+                .map(|program| quote! { Some(#program) })
+                .unwrap_or_else(|| quote! { None });
+
+            quote! {
+                Some(
+                    #idl::IdlPda {
+                        seeds: vec![#(#seeds),*],
+                        program: #program,
+                    }
+                )
+            }
+        });
+    if let Some(pda) = pda {
+        return pda;
     }
     }
+
+    // Associated token
+    let pda = acc
+        .constraints
+        .init
+        .as_ref()
+        .and_then(|init| match &init.kind {
+            InitKind::AssociatedToken {
+                owner,
+                mint,
+                token_program,
+            } => Some((owner, mint, token_program)),
+            _ => None,
+        })
+        .or_else(|| {
+            acc.constraints
+                .associated_token
+                .as_ref()
+                .map(|ata| (&ata.wallet, &ata.mint, &ata.token_program))
+        })
+        .and_then(|(wallet, mint, token_program)| {
+            // ATA constraints have implicit `.key()` call
+            let parse_expr = |ts| parse_default(&syn::parse2(ts).unwrap()).ok();
+            let parse_ata = |expr| parse_expr(quote! { #expr.key().as_ref() });
+
+            let wallet = parse_ata(wallet);
+            let mint = parse_ata(mint);
+            let token_program = token_program
+                .as_ref()
+                .and_then(parse_ata)
+                .or_else(|| parse_expr(quote!(anchor_spl::token::ID)));
+
+            let seeds = match (wallet, mint, token_program) {
+                (Some(w), Some(m), Some(tp)) => quote! { vec![#w, #tp, #m] },
+                _ => return None,
+            };
+
+            let program = parse_expr(quote!(anchor_spl::associated_token::ID))
+                .map(|program| quote! { Some(#program) })
+                .unwrap();
+
+            Some(quote! {
+                Some(
+                    #idl::IdlPda {
+                        seeds: #seeds,
+                        program: #program,
+                    }
+                )
+            })
+        });
+    if let Some(pda) = pda {
+        return pda;
+    }
+
+    quote! { None }
 }
 }
 
 
 /// Parse a seeds constraint, extracting the `IdlSeed` types.
 /// Parse a seeds constraint, extracting the `IdlSeed` types.

+ 2 - 1
tests/pda-derivation/programs/pda-derivation/Cargo.toml

@@ -14,7 +14,8 @@ no-entrypoint = []
 no-idl = []
 no-idl = []
 cpi = ["no-entrypoint"]
 cpi = ["no-entrypoint"]
 default = []
 default = []
-idl-build = ["anchor-lang/idl-build"]
+idl-build = ["anchor-lang/idl-build", "anchor-spl/idl-build"]
 
 
 [dependencies]
 [dependencies]
 anchor-lang = { path = "../../../../lang" }
 anchor-lang = { path = "../../../../lang" }
+anchor-spl = { path = "../../../../spl" }

+ 31 - 0
tests/pda-derivation/programs/pda-derivation/src/lib.rs

@@ -4,6 +4,10 @@
 mod other;
 mod other;
 
 
 use anchor_lang::prelude::*;
 use anchor_lang::prelude::*;
+use anchor_spl::{
+    associated_token::AssociatedToken,
+    token::{Mint, Token, TokenAccount},
+};
 
 
 declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");
 declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");
 
 
@@ -34,6 +38,10 @@ pub mod pda_derivation {
         ctx.accounts.account.data = 1337;
         ctx.accounts.account.data = 1337;
         Ok(())
         Ok(())
     }
     }
+
+    pub fn associated_token_resolution(_ctx: Context<AssociatedTokenResolution>) -> Result<()> {
+        Ok(())
+    }
 }
 }
 
 
 #[derive(Accounts)]
 #[derive(Accounts)]
@@ -115,6 +123,29 @@ pub struct Nested<'info> {
     account_nested: AccountInfo<'info>,
     account_nested: AccountInfo<'info>,
 }
 }
 
 
+#[derive(Accounts)]
+pub struct AssociatedTokenResolution<'info> {
+    #[account(
+        init,
+        payer = payer,
+        mint::authority = payer,
+        mint::decimals = 9,
+    )]
+    pub mint: Account<'info, Mint>,
+    #[account(
+        init,
+        payer = payer,
+        associated_token::authority = payer,
+        associated_token::mint = mint,
+    )]
+    pub ata: Account<'info, TokenAccount>,
+    #[account(mut)]
+    pub payer: Signer<'info>,
+    pub system_program: Program<'info, System>,
+    pub token_program: Program<'info, Token>,
+    pub associated_token_program: Program<'info, AssociatedToken>,
+}
+
 #[account]
 #[account]
 pub struct MyAccount {
 pub struct MyAccount {
     data: u64,
     data: u64,

+ 9 - 0
tests/pda-derivation/tests/typescript.spec.ts

@@ -103,4 +103,13 @@ describe("typescript", () => {
 
 
     expect(called).is.true;
     expect(called).is.true;
   });
   });
+
+  it("Can resolve associated token accounts", async () => {
+    const mintKp = anchor.web3.Keypair.generate();
+    await program.methods
+      .associatedTokenResolution()
+      .accounts({ mint: mintKp.publicKey })
+      .signers([mintKp])
+      .rpc();
+  });
 });
 });