Explorar el Código

ts: More efficient event subscriptions (#502)

aac hace 4 años
padre
commit
a4a8b6f769
Se han modificado 4 ficheros con 142 adiciones y 18 borrados
  1. 126 1
      ts/src/program/event.ts
  2. 14 15
      ts/src/program/index.ts
  3. 1 1
      ts/src/program/namespace/simulate.ts
  4. 1 1
      ts/tests/events.spec.ts

+ 126 - 1
ts/src/program/event.ts

@@ -1,6 +1,7 @@
 import { PublicKey } from "@solana/web3.js";
 import * as assert from "assert";
 import Coder from "../coder";
+import Provider from "../provider";
 
 const LOG_START_INDEX = "Program log: ".length;
 
@@ -10,11 +11,135 @@ export type Event = {
   data: Object;
 };
 
+type EventCallback = (event: any, slot: number) => void;
+
+export class EventManager {
+  /**
+   * Program ID for event subscriptions.
+   */
+  private _programId: PublicKey;
+
+  /**
+   * Network and wallet provider.
+   */
+  private _provider: Provider;
+
+  /**
+   * Event parser to handle onLogs callbacks.
+   */
+  private _eventParser: EventParser;
+
+  /**
+   * Maps event listener id to [event-name, callback].
+   */
+  private _eventCallbacks: Map<number, [string, EventCallback]>;
+
+  /**
+   * Maps event name to all listeners for the event.
+   */
+  private _eventListeners: Map<string, Array<number>>;
+
+  /**
+   * The next listener id to allocate.
+   */
+  private _listenerIdCount: number;
+
+  /**
+   * The subscription id from the connection onLogs subscription.
+   */
+  private _onLogsSubscriptionId: number | undefined;
+
+  constructor(programId: PublicKey, provider: Provider, coder: Coder) {
+    this._programId = programId;
+    this._provider = provider;
+    this._eventParser = new EventParser(programId, coder);
+    this._eventCallbacks = new Map();
+    this._eventListeners = new Map();
+    this._listenerIdCount = 0;
+  }
+
+  public addEventListener(
+    eventName: string,
+    callback: (event: any, slot: number) => void
+  ): number {
+    let listener = this._listenerIdCount;
+    this._listenerIdCount += 1;
+
+    // Store the listener into the event map.
+    if (!(eventName in this._eventCallbacks)) {
+      this._eventListeners.set(eventName, []);
+    }
+    this._eventListeners.set(
+      eventName,
+      this._eventListeners.get(eventName).concat(listener)
+    );
+
+    // Store the callback into the listener map.
+    this._eventCallbacks.set(listener, [eventName, callback]);
+
+    // Create the subscription singleton, if needed.
+    if (this._onLogsSubscriptionId !== undefined) {
+      return;
+    }
+    this._onLogsSubscriptionId = this._provider.connection.onLogs(
+      this._programId,
+      (logs, ctx) => {
+        if (logs.err) {
+          console.error(logs);
+          return;
+        }
+        this._eventParser.parseLogs(logs.logs, (event) => {
+          const allListeners = this._eventListeners.get(eventName);
+          if (allListeners) {
+            allListeners.forEach((listener) => {
+              const [, callback] = this._eventCallbacks.get(listener);
+              callback(event.data, ctx.slot);
+            });
+          }
+        });
+      }
+    );
+
+    return listener;
+  }
+
+  public async removeEventListener(listener: number): Promise<void> {
+    // Get the callback.
+    const callback = this._eventCallbacks.get(listener);
+    if (!callback) {
+      throw new Error(`Event listener ${listener} doesn't exist!`);
+    }
+    const [eventName] = callback;
+
+    // Get the listeners.
+    let listeners = this._eventListeners.get(eventName);
+    if (!listeners) {
+      throw new Error(`Event listeners dont' exist for ${eventName}!`);
+    }
+
+    // Update both maps.
+    this._eventCallbacks.delete(listener);
+    listeners = listeners.filter((l) => l !== listener);
+    if (listeners.length === 0) {
+      this._eventListeners.delete(eventName);
+    }
+
+    // Kill the websocket connection if all listeners have been removed.
+    if (this._eventCallbacks.size == 0) {
+      assert.ok(this._eventListeners.size === 0);
+      await this._provider.connection.removeOnLogsListener(
+        this._onLogsSubscriptionId
+      );
+      this._onLogsSubscriptionId = undefined;
+    }
+  }
+}
+
 export class EventParser {
   private coder: Coder;
   private programId: PublicKey;
 
-  constructor(coder: Coder, programId: PublicKey) {
+  constructor(programId: PublicKey, coder: Coder) {
     this.coder = coder;
     this.programId = programId;
   }

+ 14 - 15
ts/src/program/index.ts

@@ -13,7 +13,7 @@ import NamespaceFactory, {
 } from "./namespace";
 import { getProvider } from "../";
 import { utf8 } from "../utils/bytes";
-import { EventParser } from "./event";
+import { EventManager } from "./event";
 import { Address, translateAddress } from "./common";
 
 /**
@@ -234,6 +234,11 @@ export class Program {
   }
   private _provider: Provider;
 
+  /**
+   * Handles event subscriptions.
+   */
+  private _events: EventManager;
+
   /**
    * @param idl       The interface definition.
    * @param programId The on-chain address of the program.
@@ -248,6 +253,11 @@ export class Program {
     this._programId = programId;
     this._provider = provider ?? getProvider();
     this._coder = new Coder(idl);
+    this._events = new EventManager(
+      this._programId,
+      this._provider,
+      this._coder
+    );
 
     // Dynamic namespaces.
     const [
@@ -314,24 +324,13 @@ export class Program {
     eventName: string,
     callback: (event: any, slot: number) => void
   ): number {
-    const eventParser = new EventParser(this._coder, this._programId);
-    return this._provider.connection.onLogs(this._programId, (logs, ctx) => {
-      if (logs.err) {
-        console.error(logs);
-        return;
-      }
-      eventParser.parseLogs(logs.logs, (event) => {
-        if (event.name === eventName) {
-          callback(event.data, ctx.slot);
-        }
-      });
-    });
+    return this._events.addEventListener(eventName, callback);
   }
 
   /**
-   * Unsubscribes from the given event listener.
+   * Unsubscribes from the given eventName.
    */
   public async removeEventListener(listener: number): Promise<void> {
-    return this._provider.connection.removeOnLogsListener(listener);
+    return await this._events.removeEventListener(listener);
   }
 }

+ 1 - 1
ts/src/program/namespace/simulate.ts

@@ -45,7 +45,7 @@ export default class SimulateFactory {
 
       const events = [];
       if (idl.events) {
-        let parser = new EventParser(coder, programId);
+        let parser = new EventParser(programId, coder);
         parser.parseLogs(logs, (event) => {
           events.push(event);
         });

+ 1 - 1
ts/tests/events.spec.ts

@@ -24,7 +24,7 @@ describe("Events", () => {
     };
     const coder = new Coder(idl);
     const programId = PublicKey.default;
-    const eventParser = new EventParser(coder, programId);
+    const eventParser = new EventParser(programId, coder);
 
     eventParser.parseLogs(logs, () => {
       throw new Error("Should never find logs");