Armani Ferrante 3 年之前
父节点
当前提交
3cbf74cace
共有 3 个文件被更改,包括 54 次插入37 次删除
  1. 29 21
      ts/src/coder/borsh/accounts.ts
  2. 19 12
      ts/src/coder/borsh/event.ts
  3. 6 4
      ts/src/coder/borsh/state.ts

+ 29 - 21
ts/src/coder/borsh/accounts.ts

@@ -8,8 +8,7 @@ import { Idl, IdlTypeDef } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
 import { AccountsCoder } from "../index.js";
 import { accountSize } from "../common.js";
-import * as features from "../../utils/features";
-import { Features } from "../../utils/features";
+import { FeatureSet } from "../../utils/features";
 
 /**
  * Number of bytes of the account header.
@@ -35,9 +34,14 @@ export class BorshAccountsCoder<A extends string = string>
   /**
    * IDL whose acconts will be coded.
    */
-  private idl: Idl;
+	private idl: Idl;
 
-  public constructor(idl: Idl) {
+	/**
+	 * Header configuration.
+	 */
+	readonly header: BorshAccountHeader;
+
+	public constructor(idl: Idl) {
     if (idl.accounts === undefined) {
       this.accountLayouts = new Map();
       return;
@@ -48,6 +52,7 @@ export class BorshAccountsCoder<A extends string = string>
 
     this.accountLayouts = new Map(layouts);
     this.idl = idl;
+		this.header = new BorshAccountHeader(idl);
   }
 
   public async encode<T = any>(accountName: A, account: T): Promise<Buffer> {
@@ -58,13 +63,13 @@ export class BorshAccountsCoder<A extends string = string>
     }
     const len = layout.encode(account, buffer);
     let accountData = buffer.slice(0, len);
-    let header = BorshAccountHeader.encode(accountName);
+    let header = this.header.encode(accountName);
     return Buffer.concat([header, accountData]);
   }
 
   public decode<T = any>(accountName: A, data: Buffer): T {
-    const expectedDiscriminator = BorshAccountHeader.discriminator(accountName);
-    const givenDisc = BorshAccountHeader.parseDiscriminator(data);
+    const expectedDiscriminator = this.header.discriminator(accountName);
+    const givenDisc = this.header.parseDiscriminator(data);
     if (expectedDiscriminator.compare(givenDisc)) {
       throw new Error("Invalid account discriminator");
     }
@@ -81,10 +86,10 @@ export class BorshAccountsCoder<A extends string = string>
   }
 
   public memcmp(accountName: A): GetProgramAccountsFilter {
-    const discriminator = BorshAccountHeader.discriminator(accountName);
+    const discriminator = this.header.discriminator(accountName);
     return {
       memcmp: {
-        offset: BorshAccountHeader.discriminatorOffset(),
+        offset: this.header.discriminatorOffset(),
         bytes: bs58.encode(discriminator),
       },
     };
@@ -100,17 +105,20 @@ export class BorshAccountsCoder<A extends string = string>
 }
 
 export class BorshAccountHeader {
+
+	constructor(_idl: Idl) {}
+
   /**
    * Returns the default account header for an account with the given name.
    */
-  public static encode(accountName: string, nameSpace?: string): Buffer {
-    if (features.isSet(Features.DeprecatedLayout)) {
-      return BorshAccountHeader.discriminator(accountName, nameSpace);
+  public encode(accountName: string, nameSpace?: string): Buffer {
+    if (this._features.deprecatedLayout) {
+      return this.discriminator(accountName, nameSpace);
     } else {
       return Buffer.concat([
         Buffer.from([0]), // Version.
         Buffer.from([0]), // Bump.
-        BorshAccountHeader.discriminator(accountName), // Disc.
+        this.discriminator(accountName, nameSpace), // Disc.
         Buffer.from([0, 0]), // Unused.
       ]);
     }
@@ -121,16 +129,16 @@ export class BorshAccountHeader {
    *
    * @param name The name of the account to calculate the discriminator.
    */
-  public static discriminator(name: string, nameSpace?: string): Buffer {
+  public discriminator(name: string, nameSpace?: string): Buffer {
     return Buffer.from(
       sha256.digest(
         `${nameSpace ?? "account"}:${camelcase(name, { pascalCase: true })}`
       )
-    ).slice(0, BorshAccountHeader.discriminatorSize());
+    ).slice(0, this.discriminatorSize());
   }
 
-  public static discriminatorSize(): number {
-    return features.isSet(Features.DeprecatedLayout)
+  public discriminatorSize(): number {
+    return this._features.deprecatedLayout
       ? DEPRECATED_ACCOUNT_DISCRIMINATOR_SIZE
       : ACCOUNT_DISCRIMINATOR_SIZE;
   }
@@ -138,8 +146,8 @@ export class BorshAccountHeader {
   /**
    * Returns the account data index at which the discriminator starts.
    */
-  public static discriminatorOffset(): number {
-    if (features.isSet(Features.DeprecatedLayout)) {
+  public discriminatorOffset(): number {
+    if (this._features.deprecatedLayout) {
       return 0;
     } else {
       return 2;
@@ -156,8 +164,8 @@ export class BorshAccountHeader {
   /**
    * Returns the discriminator from the given account data.
    */
-  public static parseDiscriminator(data: Buffer): Buffer {
-    if (features.isSet(Features.DeprecatedLayout)) {
+  public parseDiscriminator(data: Buffer): Buffer {
+    if (this._features.deprecatedLayout) {
       return data.slice(0, 8);
     } else {
       return data.slice(2, 6);

+ 19 - 12
ts/src/coder/borsh/event.ts

@@ -6,8 +6,7 @@ import { Idl, IdlEvent, IdlTypeDef } from "../../idl.js";
 import { Event, EventData } from "../../program/event.js";
 import { IdlCoder } from "./idl.js";
 import { EventCoder } from "../index.js";
-import * as features from "../../utils/features";
-import { Features } from "../../utils/features";
+import { FeatureSet } from "../../utils/features";
 
 export class BorshEventCoder implements EventCoder {
   /**
@@ -20,11 +19,17 @@ export class BorshEventCoder implements EventCoder {
    */
   private discriminators: Map<string, string>;
 
+	/**
+	 * Header configuration.
+	 */
+	private header: EventHeader;
+
   public constructor(idl: Idl) {
     if (idl.events === undefined) {
       this.layouts = new Map();
       return;
     }
+		this.header = new EventHeader(features);
     const layouts: [string, Layout<any>][] = idl.events.map((event) => {
       let eventTypeDef: IdlTypeDef = {
         name: event.name,
@@ -43,7 +48,7 @@ export class BorshEventCoder implements EventCoder {
       idl.events === undefined
         ? []
         : idl.events.map((e) => [
-            base64.fromByteArray(EventHeader.discriminator(e.name)),
+            base64.fromByteArray(this.header.discriminator(e.name)),
             e.name,
           ])
     );
@@ -59,7 +64,7 @@ export class BorshEventCoder implements EventCoder {
     } catch (e) {
       return null;
     }
-    const disc = base64.fromByteArray(EventHeader.parseDiscriminator(logArr));
+    const disc = base64.fromByteArray(this.header.parseDiscriminator(logArr));
 
     // Only deserialize if the discriminator implies a proper event.
     const eventName = this.discriminators.get(disc);
@@ -71,7 +76,7 @@ export class BorshEventCoder implements EventCoder {
     if (!layout) {
       throw new Error(`Unknown event: ${eventName}`);
     }
-    const data = layout.decode(logArr.slice(EventHeader.size())) as EventData<
+    const data = layout.decode(logArr.slice(this.header.size())) as EventData<
       E["fields"][number],
       T
     >;
@@ -80,28 +85,30 @@ export class BorshEventCoder implements EventCoder {
 }
 
 export function eventDiscriminator(name: string): Buffer {
-  return EventHeader.discriminator(name);
+  return this.header.discriminator(name);
 }
 
 class EventHeader {
-  public static parseDiscriminator(data: Buffer): Buffer {
-    if (features.isSet(Features.DeprecatedLayout)) {
+	constructor(private _features: FeatureSet) {}
+
+  public parseDiscriminator(data: Buffer): Buffer {
+    if (this._features.deprecatedLayout) {
       return data.slice(0, 8);
     } else {
       return data.slice(0, 4);
     }
   }
 
-  public static size(): number {
-    if (features.isSet(Features.DeprecatedLayout)) {
+  public size(): number {
+    if (this._features.deprecatedLayout) {
       return 8;
     } else {
       return 4;
     }
   }
 
-  public static discriminator(name: string): Buffer {
-    if (features.isSet(Features.DeprecatedLayout)) {
+  public discriminator(name: string): Buffer {
+    if (this._features.deprecatedLayout) {
       return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 8);
     } else {
       return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 4);

+ 6 - 4
ts/src/coder/borsh/state.ts

@@ -3,17 +3,19 @@ import { Layout } from "buffer-layout";
 import { sha256 } from "js-sha256";
 import { Idl } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
-import * as features from "../../utils/features.js";
+import * as features from '../../utils/features';
 import { BorshAccountHeader } from "./accounts";
 
 export class BorshStateCoder {
   private layout: Layout;
+	private header: BorshAccountHeader;
 
-  public constructor(idl: Idl) {
+  public constructor(idl: Idl, header: BorshAccountHeader) {
     if (idl.state === undefined) {
       throw new Error("Idl state not defined.");
     }
     this.layout = IdlCoder.typeDefLayout(idl.state.struct, idl.types);
+		this.header = header;
   }
 
   public async encode<T = any>(name: string, account: T): Promise<Buffer> {
@@ -21,7 +23,7 @@ export class BorshStateCoder {
     const len = this.layout.encode(account, buffer);
 
     let ns = features.isSet("anchor-deprecated-state") ? "account" : "state";
-    const header = BorshAccountHeader.encode(name, ns);
+    const header = this.header.encode(name, ns);
     const accData = buffer.slice(0, len);
 
     return Buffer.concat([header, accData]);
@@ -39,6 +41,6 @@ export async function stateDiscriminator(name: string): Promise<Buffer> {
   let ns = features.isSet("anchor-deprecated-state") ? "account" : "state";
   return Buffer.from(sha256.digest(`${ns}:${name}`)).slice(
     0,
-    BorshAccountHeader.discriminatorSize()
+    this.header.discriminatorSize()
   );
 }