Browse Source

ts: Get discriminator lengths dynamically (#3120)

acheron 1 year ago
parent
commit
9fce3dfc9c

+ 2 - 0
CHANGELOG.md

@@ -23,6 +23,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)).
 - client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)).
 - lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)).
 - lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)).
 - lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)).
 - lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)).
+- ts: Get discriminator lengths dynamically ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
 
 
 ### Fixes
 ### Fixes
 
 
@@ -44,6 +45,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
 - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
 - lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
 - lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
 - lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)).
 - lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)).
+- ts: Remove `DISCRIMINATOR_SIZE` constant ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
 
 
 ## [0.30.1] - 2024-06-20
 ## [0.30.1] - 2024-06-20
 
 

+ 24 - 17
ts/packages/anchor/src/coder/borsh/accounts.ts

@@ -1,10 +1,9 @@
 import bs58 from "bs58";
 import bs58 from "bs58";
 import { Buffer } from "buffer";
 import { Buffer } from "buffer";
 import { Layout } from "buffer-layout";
 import { Layout } from "buffer-layout";
-import { Idl } from "../../idl.js";
+import { Idl, IdlDiscriminator } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
 import { IdlCoder } from "./idl.js";
 import { AccountsCoder } from "../index.js";
 import { AccountsCoder } from "../index.js";
-import { DISCRIMINATOR_SIZE } from "./discriminator.js";
 
 
 /**
 /**
  * Encodes and decodes account objects.
  * Encodes and decodes account objects.
@@ -15,7 +14,10 @@ export class BorshAccountsCoder<A extends string = string>
   /**
   /**
    * Maps account type identifier to a layout.
    * Maps account type identifier to a layout.
    */
    */
-  private accountLayouts: Map<A, Layout>;
+  private accountLayouts: Map<
+    A,
+    { discriminator: IdlDiscriminator; layout: Layout }
+  >;
 
 
   public constructor(private idl: Idl) {
   public constructor(private idl: Idl) {
     if (!idl.accounts) {
     if (!idl.accounts) {
@@ -28,12 +30,18 @@ export class BorshAccountsCoder<A extends string = string>
       throw new Error("Accounts require `idl.types`");
       throw new Error("Accounts require `idl.types`");
     }
     }
 
 
-    const layouts: [A, Layout][] = idl.accounts.map((acc) => {
+    const layouts = idl.accounts.map((acc) => {
       const typeDef = types.find((ty) => ty.name === acc.name);
       const typeDef = types.find((ty) => ty.name === acc.name);
       if (!typeDef) {
       if (!typeDef) {
         throw new Error(`Account not found: ${acc.name}`);
         throw new Error(`Account not found: ${acc.name}`);
       }
       }
-      return [acc.name as A, IdlCoder.typeDefLayout({ typeDef, types })];
+      return [
+        acc.name as A,
+        {
+          discriminator: acc.discriminator,
+          layout: IdlCoder.typeDefLayout({ typeDef, types }),
+        },
+      ] as const;
     });
     });
 
 
     this.accountLayouts = new Map(layouts);
     this.accountLayouts = new Map(layouts);
@@ -45,7 +53,7 @@ export class BorshAccountsCoder<A extends string = string>
     if (!layout) {
     if (!layout) {
       throw new Error(`Unknown account: ${accountName}`);
       throw new Error(`Unknown account: ${accountName}`);
     }
     }
-    const len = layout.encode(account, buffer);
+    const len = layout.layout.encode(account, buffer);
     const accountData = buffer.slice(0, len);
     const accountData = buffer.slice(0, len);
     const discriminator = this.accountDiscriminator(accountName);
     const discriminator = this.accountDiscriminator(accountName);
     return Buffer.concat([discriminator, accountData]);
     return Buffer.concat([discriminator, accountData]);
@@ -54,32 +62,31 @@ export class BorshAccountsCoder<A extends string = string>
   public decode<T = any>(accountName: A, data: Buffer): T {
   public decode<T = any>(accountName: A, data: Buffer): T {
     // Assert the account discriminator is correct.
     // Assert the account discriminator is correct.
     const discriminator = this.accountDiscriminator(accountName);
     const discriminator = this.accountDiscriminator(accountName);
-    if (discriminator.compare(data.slice(0, DISCRIMINATOR_SIZE))) {
+    if (discriminator.compare(data.slice(0, discriminator.length))) {
       throw new Error("Invalid account discriminator");
       throw new Error("Invalid account discriminator");
     }
     }
     return this.decodeUnchecked(accountName, data);
     return this.decodeUnchecked(accountName, data);
   }
   }
 
 
   public decodeAny<T = any>(data: Buffer): T {
   public decodeAny<T = any>(data: Buffer): T {
-    const discriminator = data.slice(0, DISCRIMINATOR_SIZE);
-    const accountName = Array.from(this.accountLayouts.keys()).find((key) =>
-      this.accountDiscriminator(key).equals(discriminator)
-    );
-    if (!accountName) {
-      throw new Error("Account not found");
+    for (const [name, layout] of this.accountLayouts) {
+      const givenDisc = data.subarray(0, layout.discriminator.length);
+      const matches = givenDisc.equals(Buffer.from(layout.discriminator));
+      if (matches) return this.decodeUnchecked(name, data);
     }
     }
 
 
-    return this.decodeUnchecked<T>(accountName, data);
+    throw new Error("Account not found");
   }
   }
 
 
   public decodeUnchecked<T = any>(accountName: A, acc: Buffer): T {
   public decodeUnchecked<T = any>(accountName: A, acc: Buffer): T {
     // Chop off the discriminator before decoding.
     // Chop off the discriminator before decoding.
-    const data = acc.subarray(DISCRIMINATOR_SIZE);
+    const discriminator = this.accountDiscriminator(accountName);
+    const data = acc.subarray(discriminator.length);
     const layout = this.accountLayouts.get(accountName);
     const layout = this.accountLayouts.get(accountName);
     if (!layout) {
     if (!layout) {
       throw new Error(`Unknown account: ${accountName}`);
       throw new Error(`Unknown account: ${accountName}`);
     }
     }
-    return layout.decode(data);
+    return layout.layout.decode(data);
   }
   }
 
 
   public memcmp(accountName: A, appendData?: Buffer): any {
   public memcmp(accountName: A, appendData?: Buffer): any {
@@ -94,7 +101,7 @@ export class BorshAccountsCoder<A extends string = string>
 
 
   public size(accountName: A): number {
   public size(accountName: A): number {
     return (
     return (
-      DISCRIMINATOR_SIZE +
+      this.accountDiscriminator(accountName).length +
       IdlCoder.typeSize({ defined: { name: accountName } }, this.idl)
       IdlCoder.typeSize({ defined: { name: accountName } }, this.idl)
     );
     );
   }
   }

+ 0 - 4
ts/packages/anchor/src/coder/borsh/discriminator.ts

@@ -1,4 +0,0 @@
-/**
- * Number of bytes in anchor discriminators
- */
-export const DISCRIMINATOR_SIZE = 8;

+ 23 - 27
ts/packages/anchor/src/coder/borsh/event.ts

@@ -1,7 +1,7 @@
 import { Buffer } from "buffer";
 import { Buffer } from "buffer";
 import { Layout } from "buffer-layout";
 import { Layout } from "buffer-layout";
 import * as base64 from "../../utils/bytes/base64.js";
 import * as base64 from "../../utils/bytes/base64.js";
-import { Idl } from "../../idl.js";
+import { Idl, IdlDiscriminator } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
 import { IdlCoder } from "./idl.js";
 import { EventCoder } from "../index.js";
 import { EventCoder } from "../index.js";
 
 
@@ -9,12 +9,10 @@ export class BorshEventCoder implements EventCoder {
   /**
   /**
    * Maps account type identifier to a layout.
    * Maps account type identifier to a layout.
    */
    */
-  private layouts: Map<string, Layout>;
-
-  /**
-   * Maps base64 encoded event discriminator to event name.
-   */
-  private discriminators: Map<string, string>;
+  private layouts: Map<
+    string,
+    { discriminator: IdlDiscriminator; layout: Layout }
+  >;
 
 
   public constructor(idl: Idl) {
   public constructor(idl: Idl) {
     if (!idl.events) {
     if (!idl.events) {
@@ -27,21 +25,20 @@ export class BorshEventCoder implements EventCoder {
       throw new Error("Events require `idl.types`");
       throw new Error("Events require `idl.types`");
     }
     }
 
 
-    const layouts: [string, Layout<any>][] = idl.events.map((ev) => {
+    const layouts = idl.events.map((ev) => {
       const typeDef = types.find((ty) => ty.name === ev.name);
       const typeDef = types.find((ty) => ty.name === ev.name);
       if (!typeDef) {
       if (!typeDef) {
         throw new Error(`Event not found: ${ev.name}`);
         throw new Error(`Event not found: ${ev.name}`);
       }
       }
-      return [ev.name, IdlCoder.typeDefLayout({ typeDef, types })];
+      return [
+        ev.name,
+        {
+          discriminator: ev.discriminator,
+          layout: IdlCoder.typeDefLayout({ typeDef, types }),
+        },
+      ] as const;
     });
     });
     this.layouts = new Map(layouts);
     this.layouts = new Map(layouts);
-
-    this.discriminators = new Map<string, string>(
-      (idl.events ?? []).map((ev) => [
-        base64.encode(Buffer.from(ev.discriminator)),
-        ev.name,
-      ])
-    );
   }
   }
 
 
   public decode(log: string): {
   public decode(log: string): {
@@ -55,19 +52,18 @@ export class BorshEventCoder implements EventCoder {
     } catch (e) {
     } catch (e) {
       return null;
       return null;
     }
     }
-    const disc = base64.encode(logArr.slice(0, 8));
 
 
-    // Only deserialize if the discriminator implies a proper event.
-    const eventName = this.discriminators.get(disc);
-    if (!eventName) {
-      return null;
+    for (const [name, layout] of this.layouts) {
+      const givenDisc = logArr.subarray(0, layout.discriminator.length);
+      const matches = givenDisc.equals(Buffer.from(layout.discriminator));
+      if (matches) {
+        return {
+          name,
+          data: layout.layout.decode(logArr.subarray(givenDisc.length)),
+        };
+      }
     }
     }
 
 
-    const layout = this.layouts.get(eventName);
-    if (!layout) {
-      throw new Error(`Unknown event: ${eventName}`);
-    }
-    const data = layout.decode(logArr.slice(8));
-    return { data, name: eventName };
+    return null;
   }
   }
 }
 }

+ 0 - 1
ts/packages/anchor/src/coder/borsh/index.ts

@@ -7,7 +7,6 @@ import { Coder } from "../index.js";
 
 
 export { BorshInstructionCoder } from "./instruction.js";
 export { BorshInstructionCoder } from "./instruction.js";
 export { BorshAccountsCoder } from "./accounts.js";
 export { BorshAccountsCoder } from "./accounts.js";
-export { DISCRIMINATOR_SIZE } from "./discriminator.js";
 export { BorshEventCoder } from "./event.js";
 export { BorshEventCoder } from "./event.js";
 
 
 /**
 /**

+ 11 - 20
ts/packages/anchor/src/coder/borsh/instruction.ts

@@ -16,7 +16,7 @@ import {
   IdlDiscriminator,
   IdlDiscriminator,
 } from "../../idl.js";
 } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
 import { IdlCoder } from "./idl.js";
-import { DISCRIMINATOR_SIZE, InstructionCoder } from "../index.js";
+import { InstructionCoder } from "../index.js";
 
 
 /**
 /**
  * Encodes and decodes program instructions.
  * Encodes and decodes program instructions.
@@ -28,9 +28,6 @@ export class BorshInstructionCoder implements InstructionCoder {
     { discriminator: IdlDiscriminator; layout: Layout }
     { discriminator: IdlDiscriminator; layout: Layout }
   >;
   >;
 
 
-  // Base58 encoded sighash to instruction layout.
-  private sighashLayouts: Map<string, { name: string; layout: Layout }>;
-
   public constructor(private idl: Idl) {
   public constructor(private idl: Idl) {
     const ixLayouts = idl.instructions.map((ix) => {
     const ixLayouts = idl.instructions.map((ix) => {
       const name = ix.name;
       const name = ix.name;
@@ -41,13 +38,6 @@ export class BorshInstructionCoder implements InstructionCoder {
       return [name, { discriminator: ix.discriminator, layout }] as const;
       return [name, { discriminator: ix.discriminator, layout }] as const;
     });
     });
     this.ixLayouts = new Map(ixLayouts);
     this.ixLayouts = new Map(ixLayouts);
-
-    const sighashLayouts = ixLayouts.map(
-      ([name, { discriminator, layout }]) => {
-        return [bs58.encode(discriminator), { name, layout }] as const;
-      }
-    );
-    this.sighashLayouts = new Map(sighashLayouts);
   }
   }
 
 
   /**
   /**
@@ -77,17 +67,18 @@ export class BorshInstructionCoder implements InstructionCoder {
       ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix);
       ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix);
     }
     }
 
 
-    const disc = ix.slice(0, DISCRIMINATOR_SIZE);
-    const data = ix.slice(DISCRIMINATOR_SIZE);
-    const decoder = this.sighashLayouts.get(bs58.encode(disc));
-    if (!decoder) {
-      return null;
+    for (const [name, layout] of this.ixLayouts) {
+      const givenDisc = ix.subarray(0, layout.discriminator.length);
+      const matches = givenDisc.equals(Buffer.from(layout.discriminator));
+      if (matches) {
+        return {
+          name,
+          data: layout.layout.decode(ix.subarray(givenDisc.length)),
+        };
+      }
     }
     }
 
 
-    return {
-      name: decoder.name,
-      data: decoder.layout.decode(data),
-    };
+    return null;
   }
   }
 
 
   /**
   /**

+ 0 - 2
ts/packages/anchor/tests/coder-accounts.spec.ts

@@ -1,7 +1,5 @@
 import * as assert from "assert";
 import * as assert from "assert";
 import { BorshCoder, Idl } from "../src";
 import { BorshCoder, Idl } from "../src";
-import { DISCRIMINATOR_SIZE } from "../src/coder/borsh/discriminator";
-import { sha256 } from "@noble/hashes/sha256";
 
 
 describe("coder.accounts", () => {
 describe("coder.accounts", () => {
   test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => {
   test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => {