Browse Source

feat: Update seeds inference to allow nested user defined structs as part of the seeds (#2198)

Noah Prince 3 years ago
parent
commit
6f3877f36c

+ 1 - 0
CHANGELOG.md

@@ -30,6 +30,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 ### Fixes
 
 * lang: Fix IDL `seed` generation for byte string literals. ([#2125](https://github.com/coral-xyz/anchor/pull/2125))
+* ts: Update seeds inference to allow nested user defined structs within the seeds ([#2198](https://github.com/coral-xyz/anchor/pull/2198))
 
 ## [0.25.0] - 2022-07-05
 

+ 1 - 1
lang/syn/src/idl/mod.rs

@@ -271,7 +271,7 @@ impl std::str::FromStr for IdlType {
     }
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
 pub struct IdlErrorCode {
     pub code: u32,
     pub name: String,

+ 2 - 5
lang/syn/src/parser/error.rs

@@ -27,12 +27,9 @@ pub fn parse(error_enum: &mut syn::ItemEnum, args: Option<ErrorArgs>) -> Error {
             last_discriminant = id + 1;
 
             // Remove any non-doc attributes on the error variant.
-            variant.attrs = variant
+            variant
                 .attrs
-                .iter()
-                .filter(|attr| attr.path.segments[0].ident == "doc")
-                .cloned()
-                .collect();
+                .retain(|attr| attr.path.segments[0].ident == "doc");
 
             ErrorCode { id, ident, msg }
         })

+ 62 - 11
ts/packages/anchor/src/program/accounts-resolver.ts

@@ -11,6 +11,10 @@ import {
   IdlAccount,
   IdlAccountItem,
   IdlAccounts,
+  IdlTypeDef,
+  IdlTypeDefStruct,
+  IdlTypeDefTyStruct,
+  IdlType,
 } from "../idl.js";
 import * as utf8 from "../utils/bytes/utf8.js";
 import { TOKEN_PROGRAM_ID, ASSOCIATED_PROGRAM_ID } from "../utils/token.js";
@@ -50,6 +54,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     private _programId: PublicKey,
     private _idlIx: AllInstructions<IDL>,
     _accountNamespace: AccountNamespace<IDL>,
+    private _idlTypes: IdlTypeDef[],
     private _customResolver?: CustomAccountResolver<IDL>
   ) {
     this._args = _args;
@@ -180,8 +185,9 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
         accountDescCasted.pda.seeds.length > 0 &&
         !this.get([...path, accountDescName])
       ) {
-        await this.autoPopulatePda(accountDescCasted, path);
-        found += 1;
+        if (Boolean(await this.autoPopulatePda(accountDescCasted, path))) {
+          found += 1;
+        }
       }
     }
     return found;
@@ -233,12 +239,19 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     if (!accountDesc.pda || !accountDesc.pda.seeds)
       throw new Error("Must have seeds");
 
-    const seeds: Buffer[] = await Promise.all(
+    const seeds: (Buffer | undefined)[] = await Promise.all(
       accountDesc.pda.seeds.map((seedDesc: IdlSeed) => this.toBuffer(seedDesc))
     );
 
+    if (seeds.some((seed) => typeof seed == "undefined")) {
+      return;
+    }
+
     const programId = await this.parseProgramId(accountDesc);
-    const [pubkey] = await PublicKey.findProgramAddress(seeds, programId);
+    const [pubkey] = await PublicKey.findProgramAddress(
+      seeds as Buffer[],
+      programId
+    );
 
     this.set([...path, camelCase(accountDesc.name)], pubkey);
   }
@@ -263,7 +276,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     }
   }
 
-  private async toBuffer(seedDesc: IdlSeed): Promise<Buffer> {
+  private async toBuffer(seedDesc: IdlSeed): Promise<Buffer | undefined> {
     switch (seedDesc.kind) {
       case "const":
         return this.toBufferConst(seedDesc);
@@ -276,17 +289,48 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
     }
   }
 
+  /**
+   * Recursively get the type at some path of either a primitive or a user defined struct.
+   */
+  private getType(type: IdlType, path: string[] = []): string {
+    if (path.length > 0 && (type as any).defined) {
+      const subType = this._idlTypes.find(
+        (t) => t.name === (type as any).defined
+      );
+      if (!subType) {
+        throw new Error(`Cannot find type ${(type as any).defined}`);
+      }
+
+      const structType = subType.type as IdlTypeDefTyStruct; // enum not supported yet
+      const field = structType.fields.find((field) => field.name === path[0]);
+
+      return this.getType(field!.type, path.slice(1));
+    }
+
+    return type as string;
+  }
+
   private toBufferConst(seedDesc: IdlSeed): Buffer {
-    return this.toBufferValue(seedDesc.type, seedDesc.value);
+    return this.toBufferValue(
+      this.getType(seedDesc.type, (seedDesc.path || "").split(".").slice(1)),
+      seedDesc.value
+    );
   }
 
-  private async toBufferArg(seedDesc: IdlSeed): Promise<Buffer> {
+  private async toBufferArg(seedDesc: IdlSeed): Promise<Buffer | undefined> {
     const argValue = this.argValue(seedDesc);
-    return this.toBufferValue(seedDesc.type, argValue);
+    if (!argValue) {
+      return;
+    }
+    return this.toBufferValue(
+      this.getType(seedDesc.type, (seedDesc.path || "").split(".").slice(1)),
+      argValue
+    );
   }
 
   private argValue(seedDesc: IdlSeed): any {
-    const seedArgName = camelCase(seedDesc.path.split(".")[0]);
+    const split = seedDesc.path.split(".");
+    const seedArgName = camelCase(split[0]);
 
     const idlArgPosition = this._idlIx.args.findIndex(
       (argDesc: any) => argDesc.name === seedArgName
@@ -295,11 +339,18 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
       throw new Error(`Unable to find argument for seed: ${seedArgName}`);
     }
 
-    return this._args[idlArgPosition];
+    return split
+      .slice(1)
+      .reduce((curr, path) => (curr || {})[path], this._args[idlArgPosition]);
   }
 
-  private async toBufferAccount(seedDesc: IdlSeed): Promise<Buffer> {
+  private async toBufferAccount(
+    seedDesc: IdlSeed
+  ): Promise<Buffer | undefined> {
     const accountValue = await this.accountValue(seedDesc);
+    if (!accountValue) {
+      return;
+    }
     return this.toBufferValue(seedDesc.type, accountValue);
   }
 

+ 1 - 0
ts/packages/anchor/src/program/namespace/index.ts

@@ -90,6 +90,7 @@ export default class NamespaceFactory {
         simulateItem,
         viewItem,
         account,
+        idl.types || [],
         getCustomResolver && getCustomResolver(idlIx)
       );
       const name = camelCase(idlIx.name);

+ 19 - 16
ts/packages/anchor/src/program/namespace/methods.ts

@@ -1,32 +1,31 @@
 import {
-  ConfirmOptions,
   AccountMeta,
+  ConfirmOptions,
+  PublicKey,
   Signer,
   Transaction,
   TransactionInstruction,
   TransactionSignature,
-  PublicKey,
 } from "@solana/web3.js";
-import { SimulateResponse } from "./simulate.js";
-import { TransactionFn } from "./transaction.js";
-import { Idl } from "../../idl.js";
-import {
-  AllInstructions,
-  MethodsFn,
-  MakeMethodsNamespace,
-  InstructionAccountAddresses,
-} from "./types.js";
-import { InstructionFn } from "./instruction.js";
-import { RpcFn } from "./rpc.js";
-import { SimulateFn } from "./simulate.js";
-import { ViewFn } from "./views.js";
+import { Idl, IdlTypeDef } from "../../idl.js";
 import Provider from "../../provider.js";
-import { AccountNamespace } from "./account.js";
 import {
   AccountsResolver,
   CustomAccountResolver,
 } from "../accounts-resolver.js";
 import { Accounts } from "../context.js";
+import { AccountNamespace } from "./account.js";
+import { InstructionFn } from "./instruction.js";
+import { RpcFn } from "./rpc.js";
+import { SimulateFn, SimulateResponse } from "./simulate.js";
+import { TransactionFn } from "./transaction.js";
+import {
+  AllInstructions,
+  InstructionAccountAddresses,
+  MakeMethodsNamespace,
+  MethodsFn,
+} from "./types.js";
+import { ViewFn } from "./views.js";
 
 export type MethodsNamespace<
   IDL extends Idl = Idl,
@@ -44,6 +43,7 @@ export class MethodsBuilderFactory {
     simulateFn: SimulateFn<IDL>,
     viewFn: ViewFn<IDL> | undefined,
     accountNamespace: AccountNamespace<IDL>,
+    idlTypes: IdlTypeDef[],
     customResolver?: CustomAccountResolver<IDL>
   ): MethodsFn<IDL, I, MethodsBuilder<IDL, I>> {
     return (...args) =>
@@ -58,6 +58,7 @@ export class MethodsBuilderFactory {
         programId,
         idlIx,
         accountNamespace,
+        idlTypes,
         customResolver
       );
   }
@@ -84,6 +85,7 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
     _programId: PublicKey,
     _idlIx: AllInstructions<IDL>,
     _accountNamespace: AccountNamespace<IDL>,
+    _idlTypes: IdlTypeDef[],
     _customResolver?: CustomAccountResolver<IDL>
   ) {
     this._args = _args;
@@ -94,6 +96,7 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
       _programId,
       _idlIx,
       _accountNamespace,
+      _idlTypes,
       _customResolver
     );
   }