Browse Source

feat: Derive has_one's from other programs, recursively search custom resolvers (#2208)

Noah Prince 3 năm trước cách đây
mục cha
commit
fd467df932

+ 1 - 1
tests/pda-derivation/tests/typescript.spec.ts

@@ -77,7 +77,7 @@ describe("typescript", () => {
         if (instruction.name === "initMyAccount") {
           return async ({ accounts }) => {
             called = true;
-            return accounts;
+            return { accounts, resolved: 0 };
           };
         }
       }

+ 77 - 24
ts/packages/anchor/src/program/accounts-resolver.ts

@@ -23,6 +23,7 @@ import Provider from "../provider.js";
 import { AccountNamespace } from "./namespace/account.js";
 import { coder } from "../spl/token";
 import { BorshAccountsCoder } from "src/coder/index.js";
+import { Program } from "./index.js";
 
 type Accounts = { [name: string]: PublicKey | Accounts };
 
@@ -32,7 +33,7 @@ export type CustomAccountResolver<IDL extends Idl> = (params: {
   provider: Provider;
   programId: PublicKey;
   idlIx: AllInstructions<IDL>;
-}) => Promise<Accounts>;
+}) => Promise<{ accounts: Accounts; resolved: number }>;
 
 // Populates a given accounts context with PDAs and common missing accounts.
 export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
@@ -58,7 +59,11 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     private _customResolver?: CustomAccountResolver<IDL>
   ) {
     this._args = _args;
-    this._accountStore = new AccountStore(_provider, _accountNamespace);
+    this._accountStore = new AccountStore(
+      _provider,
+      _accountNamespace,
+      this._programId
+    );
   }
 
   public args(_args: Array<any>): void {
@@ -80,19 +85,25 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     // Auto populate pdas and relations until we stop finding new accounts
     while (
       (await this.resolvePdas(this._idlIx.accounts)) +
-        (await this.resolveRelations(this._idlIx.accounts)) >
+        (await this.resolveRelations(this._idlIx.accounts)) +
+        (await this.resolveCustom()) >
       0
     ) {}
+  }
 
+  private async resolveCustom(): Promise<number> {
     if (this._customResolver) {
-      this._accounts = await this._customResolver({
+      const { accounts, resolved } = await this._customResolver({
         args: this._args,
         accounts: this._accounts,
         provider: this._provider,
         programId: this._programId,
         idlIx: this._idlIx,
       });
+      this._accounts = accounts;
+      return resolved;
     }
+    return 0;
   }
 
   private get(path: string[]): PublicKey | undefined {
@@ -220,7 +231,13 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
 
         found += matching.length;
         if (matching.length > 0) {
-          const account = await this._accountStore.fetchAccount(accountKey);
+          const programId = await this.parseProgramId(
+            accountDesc as IdlAccount
+          );
+          const account = await this._accountStore.fetchAccount({
+            publicKey: accountKey,
+            programId,
+          });
           await Promise.all(
             matching.map(async (rel) => {
               const relName = camelCase(rel);
@@ -248,6 +265,9 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     }
 
     const programId = await this.parseProgramId(accountDesc);
+    if (!programId) {
+      return;
+    }
     const [pubkey] = await PublicKey.findProgramAddress(
       seeds as Buffer[],
       programId
@@ -368,10 +388,10 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     // The key is account data.
     //
     // Fetch and deserialize it.
-    const account = await this._accountStore.fetchAccount(
-      fieldPubkey as PublicKey,
-      seedDesc.account
-    );
+    const account = await this._accountStore.fetchAccount({
+      publicKey: fieldPubkey as PublicKey,
+      name: seedDesc.account,
+    });
 
     // Dereference all fields in the path to get the field value
     // used in the seed.
@@ -424,17 +444,40 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
 // TODO: this should be configureable to avoid unnecessary requests.
 export class AccountStore<IDL extends Idl> {
   private _cache = new Map<string, any>();
+  private _idls: Record<string, AccountNamespace<any>> = {};
 
   // todo: don't use the progrma use the account namespace.
   constructor(
     private _provider: Provider,
-    private _accounts: AccountNamespace<IDL>
-  ) {}
+    _accounts: AccountNamespace<IDL>,
+    private _programId: PublicKey
+  ) {
+    this._idls[_programId.toBase58()] = _accounts;
+  }
+
+  private async ensureIdl(
+    programId: PublicKey
+  ): Promise<AccountNamespace<any> | undefined> {
+    if (!this._idls[programId.toBase58()]) {
+      const idl = await Program.fetchIdl(programId, this._provider);
+      if (idl) {
+        const program = new Program(idl, programId, this._provider);
+        this._idls[programId.toBase58()] = program.account;
+      }
+    }
 
-  public async fetchAccount<T = any>(
-    publicKey: PublicKey,
-    name?: string
-  ): Promise<T> {
+    return this._idls[programId.toBase58()];
+  }
+
+  public async fetchAccount<T = any>({
+    publicKey,
+    name,
+    programId = this._programId,
+  }: {
+    publicKey: PublicKey;
+    name?: string;
+    programId?: PublicKey;
+  }): Promise<T> {
     const address = publicKey.toString();
     if (!this._cache.has(address)) {
       if (name === "TokenAccount") {
@@ -447,8 +490,14 @@ export class AccountStore<IDL extends Idl> {
         const data = coder().accounts.decode("token", accountInfo.data);
         this._cache.set(address, data);
       } else if (name) {
-        const account = this._accounts[camelCase(name)].fetch(publicKey);
-        this._cache.set(address, account);
+        const accounts = await this.ensureIdl(programId);
+        if (accounts) {
+          const accountFetcher = accounts[camelCase(name)];
+          if (accountFetcher) {
+            const account = await accountFetcher.fetch(publicKey);
+            this._cache.set(address, account);
+          }
+        }
       } else {
         const account = await this._provider.connection.getAccountInfo(
           publicKey
@@ -457,14 +506,18 @@ export class AccountStore<IDL extends Idl> {
           throw new Error(`invalid account info for ${address}`);
         }
         const data = account.data;
-        const firstAccountLayout = Object.values(this._accounts)[0] as any;
-        if (!firstAccountLayout) {
-          throw new Error("No accounts for this program");
+        const accounts = await this.ensureIdl(programId);
+        if (accounts) {
+          const firstAccountLayout = Object.values(accounts)[0] as any;
+          if (!firstAccountLayout) {
+            throw new Error("No accounts for this program");
+          }
+
+          const result = (
+            firstAccountLayout.coder.accounts as BorshAccountsCoder
+          ).decodeAny(data);
+          this._cache.set(address, result);
         }
-        const result = (
-          firstAccountLayout.coder.accounts as BorshAccountsCoder
-        ).decodeAny(data);
-        this._cache.set(address, result);
       }
     }
     return this._cache.get(address);