Browse Source

feat: Recursively derive seeds and add custom account resolver (#2194)

Noah Prince 3 years ago
parent
commit
436791b039

+ 5 - 3
CHANGELOG.md

@@ -20,10 +20,12 @@ The minor version will be incremented upon a breaking change and the patch versi
 * lang: Add parsing for consts from impl blocks for IDL PDA seeds generation ([#2128](https://github.com/coral-xyz/anchor/pull/2014))
 * lang: Account closing reassigns to system program and reallocates ([#2169](https://github.com/coral-xyz/anchor/pull/2169)).
 * ts: Add coders for SPL programs ([#2143](https://github.com/coral-xyz/anchor/pull/2143)).
-* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided
-* ts: Add ability to set args after setting accounts and retriving pubkyes
-* ts: Add `.prepare()` to builder pattern
+* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
+* ts: Add ability to set args after setting accounts and retrieving pubkyes ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
+* ts: Add `.prepare()` to builder pattern ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
 * spl: Add `freeze_delegated_account` and `thaw_delegated_account` wrappers ([#2164](https://github.com/coral-xyz/anchor/pull/2164))
+* ts: Add nested PDA inference ([#2194](https://github.com/coral-xyz/anchor/pull/2194))
+* ts: Add ability to resolve missing accounts with a custom resolver ([#2194](https://github.com/coral-xyz/anchor/pull/2194))
 
 ### Fixes
 

+ 19 - 0
tests/pda-derivation/programs/pda-derivation/src/lib.rs

@@ -67,11 +67,30 @@ pub struct InitMyAccount<'info> {
         bump,
     )]
     account: Account<'info, MyAccount>,
+    nested: Nested<'info>,
     #[account(mut)]
     payer: Signer<'info>,
     system_program: Program<'info, System>,
 }
 
+#[derive(Accounts)]
+pub struct Nested<'info> {
+    #[account(
+        seeds = [
+            "nested-seed".as_bytes(),
+            b"test".as_ref(),
+            MY_SEED.as_ref(),
+            MY_SEED_STR.as_bytes(),
+            MY_SEED_U8.to_le_bytes().as_ref(),
+            &MY_SEED_U32.to_le_bytes(),
+            &MY_SEED_U64.to_le_bytes(),
+        ],
+        bump,
+    )]
+    /// CHECK: Not needed
+    account_nested: AccountInfo<'info>,
+}
+
 #[account]
 pub struct MyAccount {
     data: u64,

+ 27 - 0
tests/pda-derivation/tests/typescript.spec.ts

@@ -65,4 +65,31 @@ describe("typescript", () => {
       .data;
     expect(actualData.toNumber()).is.equal(1337);
   });
+
+  it("should allow custom resolvers", async () => {
+    let called = false;
+    const customProgram = new Program<PdaDerivation>(
+      program.idl,
+      program.programId,
+      program.provider,
+      program.coder,
+      (instruction) => {
+        if (instruction.name === "initMyAccount") {
+          return async ({ accounts }) => {
+            called = true;
+            return accounts;
+          };
+        }
+      }
+    );
+    await customProgram.methods
+      .initMyAccount(seedA)
+      .accounts({
+        base: base.publicKey,
+        base2: base.publicKey,
+      })
+      .pubkeys();
+
+    expect(called).is.true;
+  });
 });

+ 57 - 21
ts/packages/anchor/src/program/accounts-resolver.ts

@@ -17,6 +17,14 @@ import { BorshAccountsCoder } from "src/coder/index.js";
 
 type Accounts = { [name: string]: PublicKey | Accounts };
 
+export type CustomAccountResolver<IDL extends Idl> = (params: {
+  args: Array<any>;
+  accounts: Accounts;
+  provider: Provider;
+  programId: PublicKey;
+  idlIx: AllInstructions<IDL>;
+}) => Promise<Accounts>;
+
 // Populates a given accounts context with PDAs and common missing accounts.
 export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
   _args: Array<any>;
@@ -35,7 +43,8 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     private _provider: Provider,
     private _programId: PublicKey,
     private _idlIx: AllInstructions<IDL>,
-    _accountNamespace: AccountNamespace<IDL>
+    _accountNamespace: AccountNamespace<IDL>,
+    private _customResolver?: CustomAccountResolver<IDL>
   ) {
     this._args = _args;
     this._accountStore = new AccountStore(_provider, _accountNamespace);
@@ -84,25 +93,22 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
       }
     }
 
-    for (let k = 0; k < this._idlIx.accounts.length; k += 1) {
-      // Cast is ok because only a non-nested IdlAccount can have a seeds
-      // cosntraint.
-      const accountDesc = this._idlIx.accounts[k] as IdlAccount;
-      const accountDescName = camelCase(accountDesc.name);
-
-      // PDA derived from IDL seeds.
-      if (
-        accountDesc.pda &&
-        accountDesc.pda.seeds.length > 0 &&
-        !this._accounts[accountDescName]
-      ) {
-        await this.autoPopulatePda(accountDesc);
-        continue;
-      }
+    // Auto populate pdas and relations until we stop finding new accounts
+    while (
+      (await this.resolvePdas(this._idlIx.accounts)) +
+        (await this.resolveRelations(this._idlIx.accounts)) >
+      0
+    ) {}
+
+    if (this._customResolver) {
+      this._accounts = await this._customResolver({
+        args: this._args,
+        accounts: this._accounts,
+        provider: this._provider,
+        programId: this._programId,
+        idlIx: this._idlIx,
+      });
     }
-
-    // Auto populate has_one relationships until we stop finding new accounts
-    while ((await this.resolveRelations(this._idlIx.accounts)) > 0) {}
   }
 
   private get(path: string[]): PublicKey | undefined {
@@ -130,6 +136,36 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     });
   }
 
+  private async resolvePdas(
+    accounts: IdlAccountItem[],
+    path: string[] = []
+  ): Promise<number> {
+    let found = 0;
+    for (let k = 0; k < accounts.length; k += 1) {
+      const accountDesc = accounts[k];
+      const subAccounts = (accountDesc as IdlAccounts).accounts;
+      if (subAccounts) {
+        found += await this.resolvePdas(subAccounts, [
+          ...path,
+          accountDesc.name,
+        ]);
+      }
+
+      const accountDescCasted: IdlAccount = accountDesc as IdlAccount;
+      const accountDescName = camelCase(accountDesc.name);
+      // PDA derived from IDL seeds.
+      if (
+        accountDescCasted.pda &&
+        accountDescCasted.pda.seeds.length > 0 &&
+        !this.get([...path, accountDescName])
+      ) {
+        await this.autoPopulatePda(accountDescCasted, path);
+        found += 1;
+      }
+    }
+    return found;
+  }
+
   private async resolveRelations(
     accounts: IdlAccountItem[],
     path: string[] = []
@@ -172,7 +208,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     return found;
   }
 
-  private async autoPopulatePda(accountDesc: IdlAccount) {
+  private async autoPopulatePda(accountDesc: IdlAccount, path: string[] = []) {
     if (!accountDesc.pda || !accountDesc.pda.seeds)
       throw new Error("Must have seeds");
 
@@ -183,7 +219,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     const programId = await this.parseProgramId(accountDesc);
     const [pubkey] = await PublicKey.findProgramAddress(seeds, programId);
 
-    this._accounts[camelCase(accountDesc.name)] = pubkey;
+    this.set([...path, camelCase(accountDesc.name)], pubkey);
   }
 
   private async parseProgramId(accountDesc: IdlAccount): Promise<PublicKey> {

+ 16 - 3
ts/packages/anchor/src/program/index.ts

@@ -1,7 +1,7 @@
 import { inflate } from "pako";
 import { PublicKey } from "@solana/web3.js";
 import Provider, { getProvider } from "../provider.js";
-import { Idl, idlAddress, decodeIdlAccount } from "../idl.js";
+import { Idl, idlAddress, decodeIdlAccount, IdlInstruction } from "../idl.js";
 import { Coder, BorshCoder } from "../coder/index.js";
 import NamespaceFactory, {
   RpcNamespace,
@@ -16,6 +16,7 @@ import NamespaceFactory, {
 import { utf8 } from "../utils/bytes/index.js";
 import { EventManager } from "./event.js";
 import { Address, translateAddress } from "./common.js";
+import { CustomAccountResolver } from "./accounts-resolver.js";
 
 export * from "./common.js";
 export * from "./context.js";
@@ -263,12 +264,18 @@ export class Program<IDL extends Idl = Idl> {
    * @param programId The on-chain address of the program.
    * @param provider  The network and wallet context to use. If not provided
    *                  then uses [[getProvider]].
+   * @param getCustomResolver A function that returns a custom account resolver
+   *                          for the given instruction. This is useful for resolving
+   *                          public keys of missing accounts when building instructions
    */
   public constructor(
     idl: IDL,
     programId: Address,
     provider?: Provider,
-    coder?: Coder
+    coder?: Coder,
+    getCustomResolver?: (
+      instruction: IdlInstruction
+    ) => CustomAccountResolver<IDL> | undefined
   ) {
     programId = translateAddress(programId);
 
@@ -293,7 +300,13 @@ export class Program<IDL extends Idl = Idl> {
       methods,
       state,
       views,
-    ] = NamespaceFactory.build(idl, this._coder, programId, provider);
+    ] = NamespaceFactory.build(
+      idl,
+      this._coder,
+      programId,
+      provider,
+      getCustomResolver ?? (() => undefined)
+    );
     this.rpc = rpc;
     this.instruction = instruction;
     this.transaction = transaction;

+ 8 - 3
ts/packages/anchor/src/program/namespace/index.ts

@@ -2,7 +2,7 @@ import camelCase from "camelcase";
 import { PublicKey } from "@solana/web3.js";
 import { Coder } from "../../coder/index.js";
 import Provider from "../../provider.js";
-import { Idl } from "../../idl.js";
+import { Idl, IdlInstruction } from "../../idl.js";
 import StateFactory, { StateClient } from "./state.js";
 import InstructionFactory, { InstructionNamespace } from "./instruction.js";
 import TransactionFactory, { TransactionNamespace } from "./transaction.js";
@@ -12,6 +12,7 @@ import SimulateFactory, { SimulateNamespace } from "./simulate.js";
 import { parseIdlErrors } from "../common.js";
 import { MethodsBuilderFactory, MethodsNamespace } from "./methods";
 import ViewFactory, { ViewNamespace } from "./views";
+import { CustomAccountResolver } from "../accounts-resolver.js";
 
 // Re-exports.
 export { StateClient } from "./state.js";
@@ -32,7 +33,10 @@ export default class NamespaceFactory {
     idl: IDL,
     coder: Coder,
     programId: PublicKey,
-    provider: Provider
+    provider: Provider,
+    getCustomResolver?: (
+      instruction: IdlInstruction
+    ) => CustomAccountResolver<IDL> | undefined
   ): [
     RpcNamespace<IDL>,
     InstructionNamespace<IDL>,
@@ -85,7 +89,8 @@ export default class NamespaceFactory {
         rpcItem,
         simulateItem,
         viewItem,
-        account
+        account,
+        getCustomResolver && getCustomResolver(idlIx)
       );
       const name = camelCase(idlIx.name);
 

+ 12 - 5
ts/packages/anchor/src/program/namespace/methods.ts

@@ -22,7 +22,10 @@ import { SimulateFn } from "./simulate.js";
 import { ViewFn } from "./views.js";
 import Provider from "../../provider.js";
 import { AccountNamespace } from "./account.js";
-import { AccountsResolver } from "../accounts-resolver.js";
+import {
+  AccountsResolver,
+  CustomAccountResolver,
+} from "../accounts-resolver.js";
 import { Accounts } from "../context.js";
 
 export type MethodsNamespace<
@@ -40,7 +43,8 @@ export class MethodsBuilderFactory {
     rpcFn: RpcFn<IDL>,
     simulateFn: SimulateFn<IDL>,
     viewFn: ViewFn<IDL> | undefined,
-    accountNamespace: AccountNamespace<IDL>
+    accountNamespace: AccountNamespace<IDL>,
+    customResolver?: CustomAccountResolver<IDL>
   ): MethodsFn<IDL, I, MethodsBuilder<IDL, I>> {
     return (...args) =>
       new MethodsBuilder(
@@ -53,7 +57,8 @@ export class MethodsBuilderFactory {
         provider,
         programId,
         idlIx,
-        accountNamespace
+        accountNamespace,
+        customResolver
       );
   }
 }
@@ -78,7 +83,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
     _provider: Provider,
     _programId: PublicKey,
     _idlIx: AllInstructions<IDL>,
-    _accountNamespace: AccountNamespace<IDL>
+    _accountNamespace: AccountNamespace<IDL>,
+    _customResolver?: CustomAccountResolver<IDL>
   ) {
     this._args = _args;
     this._accountsResolver = new AccountsResolver(
@@ -87,7 +93,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
       _provider,
       _programId,
       _idlIx,
-      _accountNamespace
+      _accountNamespace,
+      _customResolver
     );
   }