Browse Source

Add Websocket to price service (#221)

* Add initial working ws

* Add tests

* Add prom metrics + improve logging

* Handle ids with leading 0x

* Add a multi client test

* Minor code format

* Fix Typo

* rename PriceFeedPriceInfo to PriceStore

It is because in the future we might have multiple spies and this
will be a middleware store

* format the code
Ali Behjati 3 năm trước cách đây
mục cha
commit
df1854752c

+ 1 - 0
Tiltfile

@@ -249,6 +249,7 @@ if pyth:
         resource_deps = ["pyth", "p2w-attest", "spy", "eth-devnet"],
         port_forwards = [
             port_forward(4202, container_port = 4200, name = "Rest API (Status + Query) [:4202]", host = webHost),
+            port_forward(6202, container_port = 6200, name = "WSS API [:6202]", host = webHost),
             port_forward(8083, container_port = 8081, name = "Prometheus [:8083]", host = webHost)],
         labels = ["pyth"]
     )

+ 5 - 0
devnet/pyth-price-service.yaml

@@ -13,6 +13,9 @@ spec:
     - port: 4200
       name: rest-api
       protocol: TCP
+    - port: 6200
+      name: wss-api
+      protocol: TCP
   clusterIP: None
   selector:
     app: pyth-price-service
@@ -68,6 +71,8 @@ spec:
               value: '[{"chain_id":1,"emitter_address":"71f8dcb863d176e2c420ad6610cf687359612b6fb392e0642b0ca6b1f186aa3b"}]'
             - name: REST_PORT
               value: '4200'
+            - name: WS_PORT
+              value: '6200'
             - name: PROM_PORT
               value: '8081'
             - name: READINESS_SPY_SYNC_TIME_SECONDS

+ 1 - 0
third_party/pyth/price-service/.env.sample

@@ -16,6 +16,7 @@ SPY_SERVICE_FILTERS=[{"chain_id":1,"emitter_address":"71f8dcb863d176e2c420ad6610
 READINESS_SPY_SYNC_TIME_SECONDS=60
 READINESS_NUM_LOADED_SYMBOLS=5
 
+WS_PORT=6200
 REST_PORT=4200
 PROM_PORT=8081
 

+ 145 - 14
third_party/pyth/price-service/package-lock.json

@@ -17,16 +17,19 @@
         "@types/express": "^4.17.13",
         "@types/morgan": "^1.9.3",
         "@types/response-time": "^2.3.5",
+        "@types/ws": "^8.5.3",
         "cors": "^2.8.5",
         "dotenv": "^10.0.0",
         "ethers": "^5.4.4",
         "express": "^4.17.2",
         "express-validation": "^4.0.1",
         "http-status-codes": "^2.2.0",
+        "joi": "^17.6.0",
         "morgan": "^1.10.0",
         "prom-client": "^14.0.1",
         "response-time": "^2.3.2",
-        "winston": "^3.3.3"
+        "winston": "^3.3.3",
+        "ws": "^8.6.0"
       },
       "devDependencies": {
         "@types/jest": "^27.5.0",
@@ -40,6 +43,10 @@
         "tslint": "^6.1.3",
         "tslint-config-prettier": "^1.18.0",
         "typescript": "^4.3.5"
+      },
+      "optionalDependencies": {
+        "bufferutil": "^4.0.6",
+        "utf-8-validate": "^5.0.9"
       }
     },
     "../p2w-sdk/js": {
@@ -2443,6 +2450,26 @@
         "react-dom": "^17.0.0"
       }
     },
+    "node_modules/@terra-dev/walletconnect/node_modules/ws": {
+      "version": "7.5.7",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "engines": {
+        "node": ">=8.3.0"
+      },
+      "peerDependencies": {
+        "bufferutil": "^4.0.1",
+        "utf-8-validate": "^5.0.2"
+      },
+      "peerDependenciesMeta": {
+        "bufferutil": {
+          "optional": true
+        },
+        "utf-8-validate": {
+          "optional": true
+        }
+      }
+    },
     "node_modules/@terra-dev/web-connector-controller": {
       "version": "0.8.1",
       "resolved": "https://registry.npmjs.org/@terra-dev/web-connector-controller/-/web-connector-controller-0.8.1.tgz",
@@ -2504,6 +2531,26 @@
         "follow-redirects": "^1.14.0"
       }
     },
+    "node_modules/@terra-money/terra.js/node_modules/ws": {
+      "version": "7.5.7",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "engines": {
+        "node": ">=8.3.0"
+      },
+      "peerDependencies": {
+        "bufferutil": "^4.0.1",
+        "utf-8-validate": "^5.0.2"
+      },
+      "peerDependenciesMeta": {
+        "bufferutil": {
+          "optional": true
+        },
+        "utf-8-validate": {
+          "optional": true
+        }
+      }
+    },
     "node_modules/@terra-money/terra.proto": {
       "version": "0.1.7",
       "resolved": "https://registry.npmjs.org/@terra-money/terra.proto/-/terra.proto-0.1.7.tgz",
@@ -2920,9 +2967,9 @@
       }
     },
     "node_modules/@types/ws": {
-      "version": "7.4.7",
-      "resolved": "https://registry.npmjs.org/@types/ws/-/ws-7.4.7.tgz",
-      "integrity": "sha512-JQbbmxZTZehdc2iszGKs5oC3NFnjeay7mtAWrdt7qNtAVK0g19muApzAy4bm9byz79xa2ZnO/BOBC2R8RC5Lww==",
+      "version": "8.5.3",
+      "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.3.tgz",
+      "integrity": "sha512-6YOoWjruKj1uLf3INHH7D3qTXwFfEsg1kf3c0uDdSBJwfa/llkwIjrAGV7j7mVgGNbzTQ3HiHKKDXl6bJPD97w==",
       "dependencies": {
         "@types/node": "*"
       }
@@ -5135,6 +5182,34 @@
       "resolved": "https://registry.npmjs.org/@types/node/-/node-12.20.48.tgz",
       "integrity": "sha512-4kxzqkrpwYtn6okJUcb2lfUu9ilnb3yhUOH6qX3nug8D2DupZ2drIkff2yJzYcNJVl3begnlcaBJ7tqiTTzjnQ=="
     },
+    "node_modules/jayson/node_modules/@types/ws": {
+      "version": "7.4.7",
+      "resolved": "https://registry.npmjs.org/@types/ws/-/ws-7.4.7.tgz",
+      "integrity": "sha512-JQbbmxZTZehdc2iszGKs5oC3NFnjeay7mtAWrdt7qNtAVK0g19muApzAy4bm9byz79xa2ZnO/BOBC2R8RC5Lww==",
+      "dependencies": {
+        "@types/node": "*"
+      }
+    },
+    "node_modules/jayson/node_modules/ws": {
+      "version": "7.5.7",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "engines": {
+        "node": ">=8.3.0"
+      },
+      "peerDependencies": {
+        "bufferutil": "^4.0.1",
+        "utf-8-validate": "^5.0.2"
+      },
+      "peerDependenciesMeta": {
+        "bufferutil": {
+          "optional": true
+        },
+        "utf-8-validate": {
+          "optional": true
+        }
+      }
+    },
     "node_modules/jest": {
       "version": "28.0.3",
       "resolved": "https://registry.npmjs.org/jest/-/jest-28.0.3.tgz",
@@ -7794,6 +7869,26 @@
         "utf-8-validate": "^5.0.2"
       }
     },
+    "node_modules/rpc-websockets/node_modules/ws": {
+      "version": "7.5.7",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "engines": {
+        "node": ">=8.3.0"
+      },
+      "peerDependencies": {
+        "bufferutil": "^4.0.1",
+        "utf-8-validate": "^5.0.2"
+      },
+      "peerDependenciesMeta": {
+        "bufferutil": {
+          "optional": true
+        },
+        "utf-8-validate": {
+          "optional": true
+        }
+      }
+    },
     "node_modules/rxjs": {
       "version": "7.5.5",
       "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.5.5.tgz",
@@ -8881,11 +8976,11 @@
       }
     },
     "node_modules/ws": {
-      "version": "7.5.7",
-      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
-      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "version": "8.6.0",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-8.6.0.tgz",
+      "integrity": "sha512-AzmM3aH3gk0aX7/rZLYvjdvZooofDu3fFOzGqcSnQ1tOcTWwhM/o+q++E8mAyVVIyUdajrkzWUGftaVSDLn1bw==",
       "engines": {
-        "node": ">=8.3.0"
+        "node": ">=10.0.0"
       },
       "peerDependencies": {
         "bufferutil": "^4.0.1",
@@ -10615,6 +10710,14 @@
         "@walletconnect/utils": "^1.6.6",
         "rxjs": "^7.4.0",
         "ws": "^7.5.5"
+      },
+      "dependencies": {
+        "ws": {
+          "version": "7.5.7",
+          "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+          "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+          "requires": {}
+        }
       }
     },
     "@terra-dev/walletconnect-qrcode-modal": {
@@ -10673,6 +10776,12 @@
           "requires": {
             "follow-redirects": "^1.14.0"
           }
+        },
+        "ws": {
+          "version": "7.5.7",
+          "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+          "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+          "requires": {}
         }
       }
     },
@@ -11042,9 +11151,9 @@
       }
     },
     "@types/ws": {
-      "version": "7.4.7",
-      "resolved": "https://registry.npmjs.org/@types/ws/-/ws-7.4.7.tgz",
-      "integrity": "sha512-JQbbmxZTZehdc2iszGKs5oC3NFnjeay7mtAWrdt7qNtAVK0g19muApzAy4bm9byz79xa2ZnO/BOBC2R8RC5Lww==",
+      "version": "8.5.3",
+      "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.3.tgz",
+      "integrity": "sha512-6YOoWjruKj1uLf3INHH7D3qTXwFfEsg1kf3c0uDdSBJwfa/llkwIjrAGV7j7mVgGNbzTQ3HiHKKDXl6bJPD97w==",
       "requires": {
         "@types/node": "*"
       }
@@ -12809,6 +12918,20 @@
           "version": "12.20.48",
           "resolved": "https://registry.npmjs.org/@types/node/-/node-12.20.48.tgz",
           "integrity": "sha512-4kxzqkrpwYtn6okJUcb2lfUu9ilnb3yhUOH6qX3nug8D2DupZ2drIkff2yJzYcNJVl3begnlcaBJ7tqiTTzjnQ=="
+        },
+        "@types/ws": {
+          "version": "7.4.7",
+          "resolved": "https://registry.npmjs.org/@types/ws/-/ws-7.4.7.tgz",
+          "integrity": "sha512-JQbbmxZTZehdc2iszGKs5oC3NFnjeay7mtAWrdt7qNtAVK0g19muApzAy4bm9byz79xa2ZnO/BOBC2R8RC5Lww==",
+          "requires": {
+            "@types/node": "*"
+          }
+        },
+        "ws": {
+          "version": "7.5.7",
+          "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+          "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+          "requires": {}
         }
       }
     },
@@ -14797,6 +14920,14 @@
         "utf-8-validate": "^5.0.2",
         "uuid": "^8.3.0",
         "ws": "^7.4.5"
+      },
+      "dependencies": {
+        "ws": {
+          "version": "7.5.7",
+          "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
+          "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+          "requires": {}
+        }
       }
     },
     "rxjs": {
@@ -15609,9 +15740,9 @@
       }
     },
     "ws": {
-      "version": "7.5.7",
-      "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz",
-      "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==",
+      "version": "8.6.0",
+      "resolved": "https://registry.npmjs.org/ws/-/ws-8.6.0.tgz",
+      "integrity": "sha512-AzmM3aH3gk0aX7/rZLYvjdvZooofDu3fFOzGqcSnQ1tOcTWwhM/o+q++E8mAyVVIyUdajrkzWUGftaVSDLn1bw==",
       "requires": {}
     },
     "y18n": {

+ 10 - 3
third_party/pyth/price-service/package.json

@@ -7,7 +7,7 @@
     "format": "prettier --write \"src/**/*.ts\"",
     "build": "tsc",
     "start": "node lib/index.js",
-    "test": "jest"
+    "test": "jest src/"
   },
   "author": "",
   "license": "Apache-2.0",
@@ -33,19 +33,26 @@
     "@types/express": "^4.17.13",
     "@types/morgan": "^1.9.3",
     "@types/response-time": "^2.3.5",
+    "@types/ws": "^8.5.3",
     "cors": "^2.8.5",
     "dotenv": "^10.0.0",
     "ethers": "^5.4.4",
     "express": "^4.17.2",
     "express-validation": "^4.0.1",
     "http-status-codes": "^2.2.0",
+    "joi": "^17.6.0",
     "morgan": "^1.10.0",
     "prom-client": "^14.0.1",
     "response-time": "^2.3.2",
-    "winston": "^3.3.3"
+    "winston": "^3.3.3",
+    "ws": "^8.6.0"
   },
   "directories": {
     "lib": "lib"
   },
-  "keywords": []
+  "keywords": [],
+  "optionalDependencies": {
+    "bufferutil": "^4.0.6",
+    "utf-8-validate": "^5.0.9"
+  }
 }

+ 72 - 51
third_party/pyth/price-service/src/__tests__/rest.test.ts

@@ -1,6 +1,6 @@
 import { HexString, PriceFeed, PriceStatus } from "@pythnetwork/pyth-sdk-js";
-import { PriceFeedPriceInfo, PriceInfo } from "../listen";
-import {RestAPI} from "../rest"
+import { PriceStore, PriceInfo } from "../listen";
+import { RestAPI } from "../rest";
 import { Express } from "express";
 import request from "supertest";
 import { StatusCodes } from "http-status-codes";
@@ -27,78 +27,99 @@ function dummyPriceFeed(id: string): PriceFeed {
     price: "11",
     productId: "def456",
     publishTime: 13,
-    status: PriceStatus.Trading
+    status: PriceStatus.Trading,
   });
 }
 
-function dummyPriceInfoPair(id: HexString, seqNum: number, vaa: HexString): [HexString, PriceInfo] {
-  return [id, {
-    priceFeed: dummyPriceFeed(id),
-    receiveTime: 0,
-    seqNum,
-    vaaBytes: Buffer.from(vaa, 'hex').toString('binary')
-  }]
+function dummyPriceInfoPair(
+  id: HexString,
+  seqNum: number,
+  vaa: HexString
+): [HexString, PriceInfo] {
+  return [
+    id,
+    {
+      priceFeed: dummyPriceFeed(id),
+      receiveTime: 0,
+      seqNum,
+      vaaBytes: Buffer.from(vaa, "hex").toString("binary"),
+    },
+  ];
 }
 
 beforeAll(async () => {
-    priceInfoMap = new Map<string, PriceInfo>([
-        dummyPriceInfoPair(expandTo64Len('abcd'), 1, 'a1b2c3d4'),
-        dummyPriceInfoPair(expandTo64Len('ef01'), 1, 'a1b2c3d4'),
-        dummyPriceInfoPair(expandTo64Len('3456'), 2, 'bad01bad'),
-        dummyPriceInfoPair(expandTo64Len('10101'), 3, 'bidbidbid'),
-    ]);
+  priceInfoMap = new Map<string, PriceInfo>([
+    dummyPriceInfoPair(expandTo64Len("abcd"), 1, "a1b2c3d4"),
+    dummyPriceInfoPair(expandTo64Len("ef01"), 1, "a1b2c3d4"),
+    dummyPriceInfoPair(expandTo64Len("3456"), 2, "bad01bad"),
+    dummyPriceInfoPair(expandTo64Len("10101"), 3, "bidbidbid"),
+  ]);
 
-    let priceInfo: PriceFeedPriceInfo = {
-        getLatestPriceInfo: (priceFeedId: string) => {
-            return priceInfoMap.get(priceFeedId);
-        }
-    };
+  let priceInfo: PriceStore = {
+    getLatestPriceInfo: (priceFeedId: string) => {
+      return priceInfoMap.get(priceFeedId);
+    },
+    addUpdateListener: (_callback: (priceFeed: PriceFeed) => any) => {},
+    getPriceIds: () => new Set(),
+  };
 
-    const api = new RestAPI(
-        {port: 8889},
-        priceInfo,
-        () => true
-    );
+  const api = new RestAPI({ port: 8889 }, priceInfo, () => true);
 
-    app = await api.createApp();
-})
+  app = await api.createApp();
+});
 
 describe("Latest Price Feed Endpoint", () => {
-    test("When called with valid ids, returns correct price feed", async () => {
-      const ids = [expandTo64Len('abcd'), expandTo64Len('3456')];
-      const resp = await request(app).get('/latest_price_feeds').query({ids});
-      expect(resp.status).toBe(StatusCodes.OK);
-      expect(resp.body.length).toBe(2);
-      expect(resp.body).toContainEqual(dummyPriceFeed(ids[0]).toJson());
-      expect(resp.body).toContainEqual(dummyPriceFeed(ids[1]).toJson());
-    });
+  test("When called with valid ids, returns correct price feed", async () => {
+    const ids = [expandTo64Len("abcd"), expandTo64Len("3456")];
+    const resp = await request(app).get("/latest_price_feeds").query({ ids });
+    expect(resp.status).toBe(StatusCodes.OK);
+    expect(resp.body.length).toBe(2);
+    expect(resp.body).toContainEqual(dummyPriceFeed(ids[0]).toJson());
+    expect(resp.body).toContainEqual(dummyPriceFeed(ids[1]).toJson());
+  });
 
-    test("When called with some non-existant ids within ids, returns error mentioning non-existant ids", async () => {
-      const ids = [expandTo64Len('ab01'), expandTo64Len('3456'), expandTo64Len('effe')];
-      const resp = await request(app).get('/latest_price_feeds').query({ids});
-      expect(resp.status).toBe(StatusCodes.BAD_REQUEST);
-      expect(resp.body.message).toContain(ids[0]);
-      expect(resp.body.message).not.toContain(ids[1]);
-      expect(resp.body.message).toContain(ids[2]);
-    });
+  test("When called with some non-existant ids within ids, returns error mentioning non-existant ids", async () => {
+    const ids = [
+      expandTo64Len("ab01"),
+      expandTo64Len("3456"),
+      expandTo64Len("effe"),
+    ];
+    const resp = await request(app).get("/latest_price_feeds").query({ ids });
+    expect(resp.status).toBe(StatusCodes.BAD_REQUEST);
+    expect(resp.body.message).toContain(ids[0]);
+    expect(resp.body.message).not.toContain(ids[1]);
+    expect(resp.body.message).toContain(ids[2]);
+  });
 });
 
 describe("Latest Vaa Bytes Endpoint", () => {
   test("When called with valid ids, returns vaa bytes as array, merged if necessary", async () => {
-    const ids = [expandTo64Len('abcd'), expandTo64Len('ef01'), expandTo64Len('3456')];
-    const resp = await request(app).get('/latest_vaas').query({ids});
+    const ids = [
+      expandTo64Len("abcd"),
+      expandTo64Len("ef01"),
+      expandTo64Len("3456"),
+    ];
+    const resp = await request(app).get("/latest_vaas").query({ ids });
     expect(resp.status).toBe(StatusCodes.OK);
     expect(resp.body.length).toBe(2);
-    expect(resp.body).toContain(Buffer.from('a1b2c3d4', 'hex').toString('base64'));
-    expect(resp.body).toContain(Buffer.from('bad01bad', 'hex').toString('base64'));
+    expect(resp.body).toContain(
+      Buffer.from("a1b2c3d4", "hex").toString("base64")
+    );
+    expect(resp.body).toContain(
+      Buffer.from("bad01bad", "hex").toString("base64")
+    );
   });
 
   test("When called with some non-existant ids within ids, returns error mentioning non-existant ids", async () => {
-    const ids = [expandTo64Len('ab01'), expandTo64Len('3456'), expandTo64Len('effe')];
-    const resp = await request(app).get('/latest_vaas').query({ids});
+    const ids = [
+      expandTo64Len("ab01"),
+      expandTo64Len("3456"),
+      expandTo64Len("effe"),
+    ];
+    const resp = await request(app).get("/latest_vaas").query({ ids });
     expect(resp.status).toBe(StatusCodes.BAD_REQUEST);
     expect(resp.body.message).toContain(ids[0]);
     expect(resp.body.message).not.toContain(ids[1]);
     expect(resp.body.message).toContain(ids[2]);
   });
-})
+});

+ 313 - 0
third_party/pyth/price-service/src/__tests__/ws.test.ts

@@ -0,0 +1,313 @@
+import { HexString, PriceFeed, PriceStatus } from "@pythnetwork/pyth-sdk-js";
+import { PriceStore, PriceInfo } from "../listen";
+import { WebSocketAPI, ClientMessage } from "../ws";
+import { Server } from "http";
+import { WebSocket, WebSocketServer } from "ws";
+import { sleep } from "../helpers";
+
+const port = 2524;
+
+let api: WebSocketAPI;
+let server: Server;
+let wss: WebSocketServer;
+
+let priceFeeds: PriceFeed[];
+
+function expandTo64Len(id: string): string {
+  return id.repeat(64).substring(0, 64);
+}
+
+function dummyPriceFeed(id: string): PriceFeed {
+  return new PriceFeed({
+    conf: "0",
+    emaConf: "1",
+    emaPrice: "2",
+    expo: 4,
+    id,
+    maxNumPublishers: 7,
+    numPublishers: 6,
+    prevConf: "8",
+    prevPrice: "9",
+    prevPublishTime: 10,
+    price: "11",
+    productId: "def456",
+    publishTime: 13,
+    status: PriceStatus.Trading,
+  });
+}
+
+async function waitForSocketState(
+  client: WebSocket,
+  state: number
+): Promise<void> {
+  while (client.readyState !== state) {
+    await sleep(10);
+  }
+}
+
+async function waitForMessages(messages: any[], cnt: number): Promise<void> {
+  while (messages.length < cnt) {
+    await sleep(10);
+  }
+}
+
+async function createSocketClient(): Promise<[WebSocket, any[]]> {
+  const client = new WebSocket(`ws://localhost:${port}`);
+
+  await waitForSocketState(client, client.OPEN);
+
+  const messages: any[] = [];
+
+  client.on("message", (data) => {
+    messages.push(JSON.parse(data.toString()));
+  });
+
+  return [client, messages];
+}
+
+beforeAll(async () => {
+  priceFeeds = [
+    dummyPriceFeed(expandTo64Len("abcd")),
+    dummyPriceFeed(expandTo64Len("ef01")),
+    dummyPriceFeed(expandTo64Len("2345")),
+    dummyPriceFeed(expandTo64Len("6789")),
+  ];
+
+  let priceInfo: PriceStore = {
+    getLatestPriceInfo: (_priceFeedId: string) => undefined,
+    addUpdateListener: (_callback: (priceFeed: PriceFeed) => any) => undefined,
+    getPriceIds: () => new Set(priceFeeds.map((priceFeed) => priceFeed.id)),
+  };
+
+  api = new WebSocketAPI({ port }, priceInfo);
+
+  [wss, server] = api.run();
+});
+
+afterAll(async () => {
+  wss.close();
+  server.close();
+});
+
+describe("Client receives data", () => {
+  test("When subscribes with valid ids, returns correct price feed", async () => {
+    let [client, serverMessages] = await createSocketClient();
+
+    let message: ClientMessage = {
+      ids: [priceFeeds[0].id, priceFeeds[1].id],
+      type: "subscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 1);
+
+    expect(serverMessages[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[0]);
+
+    await waitForMessages(serverMessages, 2);
+
+    expect(serverMessages[1]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[0].toJson(),
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[1]);
+
+    await waitForMessages(serverMessages, 3);
+
+    expect(serverMessages[2]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[1].toJson(),
+    });
+
+    client.close();
+    await waitForSocketState(client, client.CLOSED);
+  });
+
+  test("When subscribes with invalid ids, returns error", async () => {
+    let [client, serverMessages] = await createSocketClient();
+
+    let message: ClientMessage = {
+      ids: [expandTo64Len("aaaa")],
+      type: "subscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 1);
+
+    expect(serverMessages.length).toBe(1);
+    expect(serverMessages[0].type).toBe("response");
+    expect(serverMessages[0].status).toBe("error");
+
+    client.close();
+    await waitForSocketState(client, client.CLOSED);
+  });
+
+  test("When subscribes for Price Feed A, doesn't receive updates for Price Feed B", async () => {
+    let [client, serverMessages] = await createSocketClient();
+
+    let message: ClientMessage = {
+      ids: [priceFeeds[0].id],
+      type: "subscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 1);
+
+    expect(serverMessages[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[1]);
+
+    await sleep(100);
+
+    api.dispatchPriceFeedUpdate(priceFeeds[0]);
+
+    await waitForMessages(serverMessages, 2);
+
+    expect(serverMessages[1]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[0].toJson(),
+    });
+
+    await sleep(100);
+    expect(serverMessages.length).toBe(2);
+
+    client.close();
+    await waitForSocketState(client, client.CLOSED);
+  });
+
+  test("When subscribes for Price Feed A, receives updated and when unsubscribes stops receiving", async () => {
+    let [client, serverMessages] = await createSocketClient();
+
+    let message: ClientMessage = {
+      ids: [priceFeeds[0].id],
+      type: "subscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 1);
+
+    expect(serverMessages[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[0]);
+
+    await waitForMessages(serverMessages, 2);
+
+    expect(serverMessages[1]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[0].toJson(),
+    });
+
+    message = {
+      ids: [priceFeeds[0].id],
+      type: "unsubscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 3);
+
+    expect(serverMessages[2]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[0]);
+
+    await sleep(100);
+
+    expect(serverMessages.length).toBe(3);
+
+    client.close();
+    await waitForSocketState(client, client.CLOSED);
+  });
+
+  test("Unsubscribe on not subscribed price feed is ok", async () => {
+    let [client, serverMessages] = await createSocketClient();
+
+    let message: ClientMessage = {
+      ids: [priceFeeds[0].id],
+      type: "unsubscribe",
+    };
+
+    client.send(JSON.stringify(message));
+
+    await waitForMessages(serverMessages, 1);
+
+    expect(serverMessages[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    client.close();
+    await waitForSocketState(client, client.CLOSED);
+  });
+
+  test("Multiple clients with different price feed works", async () => {
+    let [client1, serverMessages1] = await createSocketClient();
+    let [client2, serverMessages2] = await createSocketClient();
+
+    let message1: ClientMessage = {
+      ids: [priceFeeds[0].id],
+      type: "subscribe",
+    };
+
+    client1.send(JSON.stringify(message1));
+
+    let message2: ClientMessage = {
+      ids: [priceFeeds[1].id],
+      type: "subscribe",
+    };
+
+    client2.send(JSON.stringify(message2));
+
+    await waitForMessages(serverMessages1, 1);
+    await waitForMessages(serverMessages2, 1);
+
+    expect(serverMessages1[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    expect(serverMessages2[0]).toStrictEqual({
+      type: "response",
+      status: "success",
+    });
+
+    api.dispatchPriceFeedUpdate(priceFeeds[0]);
+    api.dispatchPriceFeedUpdate(priceFeeds[1]);
+
+    await waitForMessages(serverMessages1, 2);
+    await waitForMessages(serverMessages2, 2);
+
+    expect(serverMessages1[1]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[0].toJson(),
+    });
+
+    expect(serverMessages2[1]).toStrictEqual({
+      type: "price_update",
+      price_feed: priceFeeds[1].toJson(),
+    });
+
+    client1.close();
+    client2.close();
+
+    await waitForSocketState(client1, client1.CLOSED);
+    await waitForSocketState(client2, client2.CLOSED);
+  });
+});

+ 31 - 14
third_party/pyth/price-service/src/index.ts

@@ -5,7 +5,7 @@ import { Listener } from "./listen";
 import { initLogger } from "./logging";
 import { PromClient } from "./promClient";
 import { RestAPI } from "./rest";
-
+import { WebSocketAPI } from "./ws";
 
 let configFile: string = ".env";
 if (process.env.PYTH_RELAY_CONFIG) {
@@ -18,28 +18,45 @@ require("dotenv").config({ path: configFile });
 setDefaultWasm("node");
 
 // Set up the logger.
-initLogger({logLevel: process.env.LOG_LEVEL});
+initLogger({ logLevel: process.env.LOG_LEVEL });
 
 const promClient = new PromClient({
   name: "price_service",
-  port: parseInt(envOrErr("PROM_PORT"))
+  port: parseInt(envOrErr("PROM_PORT")),
 });
 
-const listener = new Listener({
-  spyServiceHost: envOrErr("SPY_SERVICE_HOST"),
-  filtersRaw: process.env.SPY_SERVICE_FILTERS,
-  readiness: {
-    spySyncTimeSeconds: parseInt(envOrErr("READINESS_SPY_SYNC_TIME_SECONDS")),
-    numLoadedSymbols: parseInt(envOrErr("READINESS_NUM_LOADED_SYMBOLS"))
-  }
-}, promClient);
+const listener = new Listener(
+  {
+    spyServiceHost: envOrErr("SPY_SERVICE_HOST"),
+    filtersRaw: process.env.SPY_SERVICE_FILTERS,
+    readiness: {
+      spySyncTimeSeconds: parseInt(envOrErr("READINESS_SPY_SYNC_TIME_SECONDS")),
+      numLoadedSymbols: parseInt(envOrErr("READINESS_NUM_LOADED_SYMBOLS")),
+    },
+  },
+  promClient
+);
 
 // In future if we have more components we will modify it to include them all
 const isReady = () => listener.isReady();
 
-const restAPI = new RestAPI({
-  port: parseInt(envOrErr("REST_PORT"))
-}, listener, isReady, promClient);
+const restAPI = new RestAPI(
+  {
+    port: parseInt(envOrErr("REST_PORT")),
+  },
+  listener,
+  isReady,
+  promClient
+);
+
+const wsAPI = new WebSocketAPI(
+  {
+    port: parseInt(envOrErr("WS_PORT")),
+  },
+  listener,
+  promClient
+);
 
 listener.run();
 restAPI.run();
+wsAPI.run();

+ 61 - 34
third_party/pyth/price-service/src/listen.ts

@@ -13,35 +13,44 @@ import { importCoreWasm } from "@certusone/wormhole-sdk/lib/cjs/solana/wasm";
 
 import { envOrErr, sleep, TimestampInSec } from "./helpers";
 import { PromClient } from "./promClient";
-import { getBatchSummary, parseBatchPriceAttestation, priceAttestationToPriceFeed } from "@certusone/p2w-sdk";
+import {
+  getBatchSummary,
+  parseBatchPriceAttestation,
+  priceAttestationToPriceFeed,
+} from "@certusone/p2w-sdk";
 import { ClientReadableStream } from "@grpc/grpc-js";
-import { FilterEntry, SubscribeSignedVAAResponse } from "@certusone/wormhole-spydk/lib/cjs/proto/spy/v1/spy";
+import {
+  FilterEntry,
+  SubscribeSignedVAAResponse,
+} from "@certusone/wormhole-spydk/lib/cjs/proto/spy/v1/spy";
 import { logger } from "./logging";
-import { PriceFeed } from "@pythnetwork/pyth-sdk-js";
+import { HexString, PriceFeed } from "@pythnetwork/pyth-sdk-js";
 
 export type PriceInfo = {
-  vaaBytes: string,
-  seqNum: number,
-  receiveTime: TimestampInSec,
-  priceFeed: PriceFeed
+  vaaBytes: string;
+  seqNum: number;
+  receiveTime: TimestampInSec;
+  priceFeed: PriceFeed;
 };
 
-export interface PriceFeedPriceInfo {
-  getLatestPriceInfo(priceFeedId: string): PriceInfo | undefined;
+export interface PriceStore {
+  getPriceIds(): Set<HexString>;
+  getLatestPriceInfo(priceFeedId: HexString): PriceInfo | undefined;
+  addUpdateListener(callback: (priceFeed: PriceFeed) => any): void;
 }
 
 type ListenerReadinessConfig = {
-  spySyncTimeSeconds: number,
-  numLoadedSymbols: number,
+  spySyncTimeSeconds: number;
+  numLoadedSymbols: number;
 };
 
 type ListenerConfig = {
-  spyServiceHost: string,
-  filtersRaw?: string,
-  readiness: ListenerReadinessConfig,
+  spyServiceHost: string;
+  filtersRaw?: string;
+  readiness: ListenerReadinessConfig;
 };
 
-export class Listener implements PriceFeedPriceInfo {
+export class Listener implements PriceStore {
   // Mapping of Price Feed Id to Vaa
   private priceFeedVaaMap = new Map<string, PriceInfo>();
   private promClient: PromClient | undefined;
@@ -49,12 +58,14 @@ export class Listener implements PriceFeedPriceInfo {
   private filters: FilterEntry[] = [];
   private spyConnectionTime: TimestampInSec | undefined;
   private readinessConfig: ListenerReadinessConfig;
+  private updateCallbacks: ((priceFeed: PriceFeed) => any)[];
 
   constructor(config: ListenerConfig, promClient?: PromClient) {
     this.promClient = promClient;
     this.spyServiceHost = config.spyServiceHost;
     this.loadFilters(config.filtersRaw);
     this.readinessConfig = config.readiness;
+    this.updateCallbacks = [];
   }
 
   private loadFilters(filtersRaw?: string) {
@@ -76,10 +87,10 @@ export class Listener implements PriceFeedPriceInfo {
       };
       logger.info(
         "adding filter: chainId: [" +
-        myEmitterFilter.emitterFilter!.chainId +
-        "], emitterAddress: [" +
-        myEmitterFilter.emitterFilter!.emitterAddress +
-        "]"
+          myEmitterFilter.emitterFilter!.chainId +
+          "], emitterAddress: [" +
+          myEmitterFilter.emitterFilter!.emitterAddress +
+          "]"
       );
       this.filters.push(myEmitterFilter);
     }
@@ -90,7 +101,7 @@ export class Listener implements PriceFeedPriceInfo {
   async run() {
     logger.info(
       "pyth_relay starting up, will listen for signed VAAs from " +
-      this.spyServiceHost
+        this.spyServiceHost
     );
 
     while (true) {
@@ -101,11 +112,11 @@ export class Listener implements PriceFeedPriceInfo {
         );
         stream = await subscribeSignedVAA(client, { filters: this.filters });
 
-        stream!.on("data", ({ vaaBytes }: { vaaBytes: string; }) => {
+        stream!.on("data", ({ vaaBytes }: { vaaBytes: string }) => {
           this.processVaa(vaaBytes);
         });
 
-        this.spyConnectionTime = (new Date()).getTime() / 1000;
+        this.spyConnectionTime = new Date().getTime() / 1000;
 
         let connected = true;
         stream!.on("error", (err: any) => {
@@ -171,24 +182,29 @@ export class Listener implements PriceFeedPriceInfo {
 
       let lastSeqNum = this.priceFeedVaaMap.get(key)?.seqNum;
       if (lastSeqNum === undefined || lastSeqNum < parsedVAA.sequence) {
+        const priceFeed = priceAttestationToPriceFeed(priceAttestation);
         this.priceFeedVaaMap.set(key, {
           seqNum: parsedVAA.sequence,
           vaaBytes: vaaBytes,
-          receiveTime: (new Date()).getTime() / 1000,
-          priceFeed: priceAttestationToPriceFeed(priceAttestation)
+          receiveTime: new Date().getTime() / 1000,
+          priceFeed,
         });
+
+        for (let callback of this.updateCallbacks) {
+          callback(priceFeed);
+        }
       }
     }
 
     logger.info(
       "Parsed a new Batch Price Attestation: [" +
-      parsedVAA.emitter_chain +
-      ":" +
-      uint8ArrayToHex(parsedVAA.emitter_address) +
-      "], seqNum: " +
-      parsedVAA.sequence +
-      ", Batch Summary: " +
-      getBatchSummary(batchAttestation)
+        parsedVAA.emitter_chain +
+        ":" +
+        uint8ArrayToHex(parsedVAA.emitter_address) +
+        "], seqNum: " +
+        parsedVAA.sequence +
+        ", Batch Summary: " +
+        getBatchSummary(batchAttestation)
     );
 
     this.promClient?.incReceivedVaa();
@@ -198,10 +214,21 @@ export class Listener implements PriceFeedPriceInfo {
     return this.priceFeedVaaMap.get(priceFeedId);
   }
 
+  addUpdateListener(callback: (priceFeed: PriceFeed) => any) {
+    this.updateCallbacks.push(callback);
+  }
+
+  getPriceIds(): Set<HexString> {
+    return new Set(this.priceFeedVaaMap.keys());
+  }
+
   isReady(): boolean {
-    let currentTime: TimestampInSec = (new Date()).getTime() / 1000;
-    if (this.spyConnectionTime === undefined ||
-      currentTime < this.spyConnectionTime + this.readinessConfig.spySyncTimeSeconds) {
+    let currentTime: TimestampInSec = new Date().getTime() / 1000;
+    if (
+      this.spyConnectionTime === undefined ||
+      currentTime <
+        this.spyConnectionTime + this.readinessConfig.spySyncTimeSeconds
+    ) {
       return false;
     }
     if (this.priceFeedVaaMap.size < this.readinessConfig.numLoadedSymbols) {

+ 4 - 3
third_party/pyth/price-service/src/logging.ts

@@ -1,9 +1,11 @@
 import * as winston from "winston";
 
-export let logger = winston.createLogger({transports: [new winston.transports.Console()]});
+export let logger = winston.createLogger({
+  transports: [new winston.transports.Console()],
+});
 
 // Logger should be initialized before using logger
-export function initLogger(config?: {logLevel?: string}) {
+export function initLogger(config?: { logLevel?: string }) {
   let logLevel = "info";
   if (config?.logLevel) {
     logLevel = config.logLevel;
@@ -16,7 +18,6 @@ export function initLogger(config?: {logLevel?: string}) {
     level: logLevel,
   });
 
-
   const logConfiguration = {
     transports: [transport],
     format: winston.format.combine(

+ 40 - 14
third_party/pyth/price-service/src/promClient.ts

@@ -23,13 +23,18 @@ export class PromClient {
   private apiResponseTimeSummary = new client.Summary({
     name: `${SERVICE_PREFIX}api_response_time_ms`,
     help: "Response time of a VAA",
-    labelNames: ["path", "status"]
+    labelNames: ["path", "status"],
   });
   private apiRequestsPriceFreshnessHistogram = new client.Histogram({
     name: `${SERVICE_PREFIX}api_requests_price_freshness_seconds`,
     help: "Freshness time of Vaa (time difference of Vaa and request time)",
     buckets: [1, 5, 10, 15, 30, 60, 120, 180],
-    labelNames: ["path", "price_id"]
+    labelNames: ["path", "price_id"],
+  });
+  private webSocketInteractionCounter = new client.Counter({
+    name: `${SERVICE_PREFIX}websocket_interaction`,
+    help: "number of Web Socket interactions",
+    labelNames: ["type", "status"],
   });
   // End metrics
 
@@ -42,15 +47,19 @@ export class PromClient {
     }
   });
 
-  constructor(config: {name: string, port: number; }) {
+  constructor(config: { name: string; port: number }) {
     this.register.setDefaultLabels({
       app: config.name,
     });
-    this.collectDefaultMetrics({ register: this.register, prefix: SERVICE_PREFIX });
+    this.collectDefaultMetrics({
+      register: this.register,
+      prefix: SERVICE_PREFIX,
+    });
     // Register each metric
     this.register.registerMetric(this.receivedVaaCounter);
-    this.register.registerMetric(this.apiResponseTimeSummary)
+    this.register.registerMetric(this.apiResponseTimeSummary);
     this.register.registerMetric(this.apiRequestsPriceFreshnessHistogram);
+    this.register.registerMetric(this.webSocketInteractionCounter);
     // End registering metric
 
     logger.info("prometheus client listening on port " + config.port);
@@ -62,16 +71,33 @@ export class PromClient {
   }
 
   addResponseTime(path: string, status: number, duration: DurationInMs) {
-    this.apiResponseTimeSummary.observe({
-      path: path,
-      status: status
-    }, duration);
+    this.apiResponseTimeSummary.observe(
+      {
+        path: path,
+        status: status,
+      },
+      duration
+    );
   }
 
-  addApiRequestsPriceFreshness(path: string, priceId: string, duration: DurationInSec) {
-    this.apiRequestsPriceFreshnessHistogram.observe({
-      path: path,
-      price_id: priceId,
-    }, duration);
+  addApiRequestsPriceFreshness(
+    path: string,
+    priceId: string,
+    duration: DurationInSec
+  ) {
+    this.apiRequestsPriceFreshnessHistogram.observe(
+      {
+        path: path,
+        price_id: priceId,
+      },
+      duration
+    );
+  }
+
+  addWebSocketInteraction(type: string, status: "ok" | "err") {
+    this.webSocketInteractionCounter.inc({
+      type: type,
+      status: status,
+    });
   }
 }

+ 115 - 85
third_party/pyth/price-service/src/rest.ts

@@ -1,16 +1,17 @@
-import express, {Express} from "express";
+import express, { Express } from "express";
 import cors from "cors";
 import morgan from "morgan";
 import responseTime from "response-time";
 import { Request, Response, NextFunction } from "express";
-import { PriceFeedPriceInfo } from "./listen";
+import { PriceStore } from "./listen";
 import { logger } from "./logging";
 import { PromClient } from "./promClient";
 import { DurationInMs, DurationInSec } from "./helpers";
 import { StatusCodes } from "http-status-codes";
 import { validate, ValidationError, Joi, schema } from "express-validation";
 
-const MORGAN_LOG_FORMAT = ':remote-addr - :remote-user ":method :url HTTP/:http-version"' +
+const MORGAN_LOG_FORMAT =
+  ':remote-addr - :remote-user ":method :url HTTP/:http-version"' +
   ' :status :res[content-length] :response-time ms ":referrer" ":user-agent"';
 
 export class RestException extends Error {
@@ -23,20 +24,25 @@ export class RestException extends Error {
   }
 
   static PriceFeedIdNotFound(notFoundIds: string[]): RestException {
-    return new RestException(StatusCodes.BAD_REQUEST, `Price Feeds with ids ${notFoundIds.join(', ')} not found`);
+    return new RestException(
+      StatusCodes.BAD_REQUEST,
+      `Price Feeds with ids ${notFoundIds.join(", ")} not found`
+    );
   }
 }
 
 export class RestAPI {
   private port: number;
-  private priceFeedVaaInfo: PriceFeedPriceInfo;
+  private priceFeedVaaInfo: PriceStore;
   private isReady: (() => boolean) | undefined;
   private promClient: PromClient | undefined;
 
-  constructor(config: { port: number; },
-    priceFeedVaaInfo: PriceFeedPriceInfo,
+  constructor(
+    config: { port: number },
+    priceFeedVaaInfo: PriceStore,
     isReady?: () => boolean,
-    promClient?: PromClient) {
+    promClient?: PromClient
+  ) {
     this.port = config.port;
     this.priceFeedVaaInfo = priceFeedVaaInfo;
     this.isReady = isReady;
@@ -51,101 +57,128 @@ export class RestAPI {
     const winstonStream = {
       write: (text: string) => {
         logger.info(text);
-      }
+      },
     };
 
     app.use(morgan(MORGAN_LOG_FORMAT, { stream: winstonStream }));
 
-    app.use(responseTime((req: Request, res: Response, time: DurationInMs) => {
-      if (res.statusCode !== StatusCodes.NOT_FOUND) {
-        this.promClient?.addResponseTime(req.path, res.statusCode, time);
-      }
-    }))
+    app.use(
+      responseTime((req: Request, res: Response, time: DurationInMs) => {
+        if (res.statusCode !== StatusCodes.NOT_FOUND) {
+          this.promClient?.addResponseTime(req.path, res.statusCode, time);
+        }
+      })
+    );
 
     let endpoints: string[] = [];
-    
+
     const latestVaasInputSchema: schema = {
       query: Joi.object({
-        ids: Joi.array().items(Joi.string().regex(/^(0x)?[a-f0-9]{64}$/)).required()
-      }).required()
-    }
-    app.get("/latest_vaas", validate(latestVaasInputSchema), (req: Request, res: Response) => {
-      let priceIds = req.query.ids as string[];
-
-      // Multiple price ids might share same vaa, we use sequence number as
-      // key of a vaa and deduplicate using a map of seqnum to vaa bytes.
-      let vaaMap = new Map<number, string>();
-
-      let notFoundIds: string[] = [];
-
-      for (let id of priceIds) {
-        if (id.startsWith("0x")) {
-          id = id.substring(2);
+        ids: Joi.array()
+          .items(Joi.string().regex(/^(0x)?[a-f0-9]{64}$/))
+          .required(),
+      }).required(),
+    };
+    app.get(
+      "/latest_vaas",
+      validate(latestVaasInputSchema),
+      (req: Request, res: Response) => {
+        let priceIds = req.query.ids as string[];
+
+        // Multiple price ids might share same vaa, we use sequence number as
+        // key of a vaa and deduplicate using a map of seqnum to vaa bytes.
+        let vaaMap = new Map<number, string>();
+
+        let notFoundIds: string[] = [];
+
+        for (let id of priceIds) {
+          if (id.startsWith("0x")) {
+            id = id.substring(2);
+          }
+
+          let latestPriceInfo = this.priceFeedVaaInfo.getLatestPriceInfo(id);
+
+          if (latestPriceInfo === undefined) {
+            notFoundIds.push(id);
+            continue;
+          }
+
+          const freshness: DurationInSec =
+            new Date().getTime() / 1000 - latestPriceInfo.receiveTime;
+          this.promClient?.addApiRequestsPriceFreshness(
+            req.path,
+            id,
+            freshness
+          );
+
+          vaaMap.set(latestPriceInfo.seqNum, latestPriceInfo.vaaBytes);
         }
 
-        let latestPriceInfo = this.priceFeedVaaInfo.getLatestPriceInfo(id);
-
-        if (latestPriceInfo === undefined) {
-          notFoundIds.push(id);
-          continue;
+        if (notFoundIds.length > 0) {
+          throw RestException.PriceFeedIdNotFound(notFoundIds);
         }
 
-        const freshness: DurationInSec = (new Date).getTime() / 1000 - latestPriceInfo.receiveTime;
-        this.promClient?.addApiRequestsPriceFreshness(req.path, id, freshness);
-
-        vaaMap.set(latestPriceInfo.seqNum, latestPriceInfo.vaaBytes);
-      }
+        const jsonResponse = Array.from(vaaMap.values(), (vaaBytes) =>
+          Buffer.from(vaaBytes, "binary").toString("base64")
+        );
 
-      if (notFoundIds.length > 0) {
-        throw RestException.PriceFeedIdNotFound(notFoundIds);
+        res.json(jsonResponse);
       }
-
-      const jsonResponse = Array.from(vaaMap.values(),
-        vaaBytes => Buffer.from(vaaBytes, 'binary').toString('base64')
-      );
-
-      res.json(jsonResponse);
-    });
-    endpoints.push("latest_vaas?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&..");
+    );
+    endpoints.push(
+      "latest_vaas?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&.."
+    );
 
     const latestPriceFeedsInputSchema: schema = {
       query: Joi.object({
-        ids: Joi.array().items(Joi.string().regex(/^(0x)?[a-f0-9]{64}$/)).required()
-      }).required()
-    }
-    app.get("/latest_price_feeds", validate(latestPriceFeedsInputSchema), (req: Request, res: Response) => {
-      let priceIds = req.query.ids as string[];
+        ids: Joi.array()
+          .items(Joi.string().regex(/^(0x)?[a-f0-9]{64}$/))
+          .required(),
+      }).required(),
+    };
+    app.get(
+      "/latest_price_feeds",
+      validate(latestPriceFeedsInputSchema),
+      (req: Request, res: Response) => {
+        let priceIds = req.query.ids as string[];
 
-      let responseJson = [];
+        let responseJson = [];
 
-      let notFoundIds: string[] = [];
+        let notFoundIds: string[] = [];
 
-      for (let id of priceIds) {
-        if (id.startsWith("0x")) {
-          id = id.substring(2);
-        }
+        for (let id of priceIds) {
+          if (id.startsWith("0x")) {
+            id = id.substring(2);
+          }
 
-        let latestPriceInfo = this.priceFeedVaaInfo.getLatestPriceInfo(id);
+          let latestPriceInfo = this.priceFeedVaaInfo.getLatestPriceInfo(id);
 
-        if (latestPriceInfo === undefined) {
-          notFoundIds.push(id);
-          continue;
-        }
+          if (latestPriceInfo === undefined) {
+            notFoundIds.push(id);
+            continue;
+          }
 
-        const freshness: DurationInSec = (new Date).getTime() / 1000 - latestPriceInfo.receiveTime;
-        this.promClient?.addApiRequestsPriceFreshness(req.path, id, freshness);
+          const freshness: DurationInSec =
+            new Date().getTime() / 1000 - latestPriceInfo.receiveTime;
+          this.promClient?.addApiRequestsPriceFreshness(
+            req.path,
+            id,
+            freshness
+          );
 
-        responseJson.push(latestPriceInfo.priceFeed.toJson());
-      }
-
-      if (notFoundIds.length > 0) {
-        throw RestException.PriceFeedIdNotFound(notFoundIds);
-      }
+          responseJson.push(latestPriceInfo.priceFeed.toJson());
+        }
 
-      res.json(responseJson);
-    });
-    endpoints.push("latest_price_feeds?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&..");
+        if (notFoundIds.length > 0) {
+          throw RestException.PriceFeedIdNotFound(notFoundIds);
+        }
 
+        res.json(responseJson);
+      }
+    );
+    endpoints.push(
+      "latest_price_feeds?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&.."
+    );
 
     app.get("/ready", (_, res: Response) => {
       if (this.isReady!()) {
@@ -154,19 +187,16 @@ export class RestAPI {
         res.sendStatus(StatusCodes.SERVICE_UNAVAILABLE);
       }
     });
-    endpoints.push('ready');
+    endpoints.push("ready");
 
     app.get("/live", (_, res: Response) => {
       res.sendStatus(StatusCodes.OK);
     });
     endpoints.push("live");
 
+    app.get("/", (_, res: Response) => res.json(endpoints));
 
-    app.get("/", (_, res: Response) =>
-      res.json(endpoints)
-    );
-
-    app.use(function(err: any, _: Request, res: Response, next: NextFunction) {
+    app.use(function (err: any, _: Request, res: Response, next: NextFunction) {
       if (err instanceof ValidationError) {
         return res.status(err.statusCode).json(err);
       }
@@ -174,9 +204,9 @@ export class RestAPI {
       if (err instanceof RestException) {
         return res.status(err.statusCode).json(err);
       }
-    
+
       return next(err);
-    })
+    });
 
     return app;
   }

+ 231 - 0
third_party/pyth/price-service/src/ws.ts

@@ -0,0 +1,231 @@
+import { HexString, PriceFeed } from "@pythnetwork/pyth-sdk-js";
+import express from "express";
+import * as http from "http";
+import Joi from "joi";
+import WebSocket, { RawData, WebSocketServer } from "ws";
+import { PriceStore } from "./listen";
+import { logger } from "./logging";
+import { PromClient } from "./promClient";
+
+const ClientMessageSchema: Joi.Schema = Joi.object({
+  type: Joi.string().valid("subscribe", "unsubscribe").required(),
+  ids: Joi.array()
+    .items(Joi.string().regex(/^(0x)?[a-f0-9]{64}$/))
+    .required(),
+}).required();
+
+export type ClientMessage = {
+  type: "subscribe" | "unsubscribe";
+  ids: HexString[];
+};
+
+export type ServerResponse = {
+  type: "response";
+  status: "success" | "error";
+  error?: string;
+};
+
+export type ServerPriceUpdate = {
+  type: "price_update";
+  price_feed: any;
+};
+
+export type ServerMessage = ServerResponse | ServerPriceUpdate;
+
+export class WebSocketAPI {
+  private wsCounter: number;
+  private port: number;
+  private priceFeedClients: Map<HexString, Set<WebSocket>>;
+  private aliveClients: Set<WebSocket>;
+  private wsId: Map<WebSocket, number>;
+  private priceFeedVaaInfo: PriceStore;
+  private promClient: PromClient | undefined;
+
+  constructor(
+    config: { port: number },
+    priceFeedVaaInfo: PriceStore,
+    promClient?: PromClient
+  ) {
+    this.port = config.port;
+    this.priceFeedVaaInfo = priceFeedVaaInfo;
+    this.priceFeedClients = new Map();
+    this.aliveClients = new Set();
+    this.wsCounter = 0;
+    this.wsId = new Map();
+    this.promClient = promClient;
+  }
+
+  private addPriceFeedClient(ws: WebSocket, id: HexString) {
+    if (!this.priceFeedClients.has(id)) {
+      this.priceFeedClients.set(id, new Set());
+    }
+
+    this.priceFeedClients.get(id)!.add(ws);
+  }
+
+  private delPriceFeedClient(ws: WebSocket, id: HexString) {
+    this.priceFeedClients.get(id)?.delete(ws);
+  }
+
+  dispatchPriceFeedUpdate(priceFeed: PriceFeed) {
+    if (this.priceFeedClients.get(priceFeed.id) === undefined) {
+      logger.info(`Sending ${priceFeed.id} price update to no clients.`);
+      return;
+    }
+
+    logger.info(
+      `Sending ${priceFeed.id} price update to ${
+        this.priceFeedClients.get(priceFeed.id)!.size
+      } clients`
+    );
+
+    for (let client of this.priceFeedClients.get(priceFeed.id)!.values()) {
+      logger.info(
+        `Sending ${priceFeed.id} price update to client ${this.wsId.get(
+          client
+        )}`
+      );
+      this.promClient?.addWebSocketInteraction("server_update", "ok");
+
+      let priceUpdate: ServerPriceUpdate = {
+        type: "price_update",
+        price_feed: priceFeed.toJson(),
+      };
+
+      client.send(JSON.stringify(priceUpdate));
+    }
+  }
+
+  clientClose(ws: WebSocket) {
+    for (let clients of this.priceFeedClients.values()) {
+      if (clients.has(ws)) {
+        clients.delete(ws);
+      }
+    }
+
+    this.aliveClients.delete(ws);
+    this.wsId.delete(ws);
+  }
+
+  handleMessage(ws: WebSocket, data: RawData) {
+    try {
+      let jsonData = JSON.parse(data.toString());
+      let validationResult = ClientMessageSchema.validate(jsonData);
+      if (validationResult.error !== undefined) {
+        throw validationResult.error;
+      }
+
+      let message = jsonData as ClientMessage;
+
+      message.ids = message.ids.map((id) => {
+        if (id.startsWith("0x")) {
+          return id.substring(2);
+        }
+        return id;
+      });
+
+      const availableIds = this.priceFeedVaaInfo.getPriceIds();
+      let notFoundIds = message.ids.filter((id) => !availableIds.has(id));
+
+      if (notFoundIds.length > 0) {
+        throw new Error(
+          `Price Feeds with ids ${notFoundIds.join(", ")} not found`
+        );
+      }
+
+      if (message.type == "subscribe") {
+        message.ids.forEach((id) => this.addPriceFeedClient(ws, id));
+      } else {
+        message.ids.forEach((id) => this.delPriceFeedClient(ws, id));
+      }
+    } catch (e: any) {
+      let response: ServerResponse = {
+        type: "response",
+        status: "error",
+        error: e.message,
+      };
+
+      logger.info(
+        `Invalid request ${data.toString()} from client ${this.wsId.get(ws)}`
+      );
+      this.promClient?.addWebSocketInteraction("client_message", "err");
+
+      ws.send(JSON.stringify(response));
+      return;
+    }
+
+    logger.info(
+      `Successful request ${data.toString()} from client ${this.wsId.get(ws)}`
+    );
+    this.promClient?.addWebSocketInteraction("client_message", "ok");
+
+    let response: ServerResponse = {
+      type: "response",
+      status: "success",
+    };
+
+    ws.send(JSON.stringify(response));
+  }
+
+  run(): [WebSocketServer, http.Server] {
+    const app = express();
+    const server = http.createServer(app);
+
+    const wss = new WebSocketServer({ server });
+
+    wss.on("connection", (ws: WebSocket, request: http.IncomingMessage) => {
+      logger.info(
+        `Incoming ws connection from ${request.socket.remoteAddress}, assigned id: ${this.wsCounter}`
+      );
+
+      this.wsId.set(ws, this.wsCounter);
+      this.wsCounter += 1;
+
+      ws.on("message", (data: RawData) => this.handleMessage(ws, data));
+
+      this.aliveClients.add(ws);
+
+      ws.on("pong", (_data) => {
+        this.aliveClients.add(ws);
+      });
+
+      ws.on("close", (_code: number, _reason: Buffer) => {
+        logger.info(`client ${this.wsId.get(ws)} closed the connection.`);
+        this.promClient?.addWebSocketInteraction("close", "ok");
+
+        this.clientClose(ws);
+      });
+
+      this.promClient?.addWebSocketInteraction("connection", "ok");
+    });
+
+    const pingInterval = setInterval(() => {
+      wss.clients.forEach((ws) => {
+        if (this.aliveClients.has(ws) === false) {
+          logger.info(
+            `client ${this.wsId.get(ws)} timed out. terminating connection`
+          );
+          this.promClient?.addWebSocketInteraction("timeout", "ok");
+          this.clientClose(ws);
+          ws.terminate();
+          return;
+        }
+
+        this.aliveClients.delete(ws);
+        ws.ping();
+      });
+    }, 30000);
+
+    wss.on("close", () => {
+      clearInterval(pingInterval);
+    });
+
+    server.listen(this.port, () =>
+      logger.debug("listening on WS port " + this.port)
+    );
+    this.priceFeedVaaInfo.addUpdateListener(
+      this.dispatchPriceFeedUpdate.bind(this)
+    );
+    return [wss, server];
+  }
+}

+ 2 - 2
third_party/pyth/price-service/tsconfig.json

@@ -14,8 +14,8 @@
     "resolveJsonModule": true,
     "isolatedModules": true,
     "downlevelIteration": true,
-    "esModuleInterop": true
+    "esModuleInterop": true,
   },
   "include": ["src"],
-  "exclude": ["node_modules", "**/__tests__/*"]
+  "exclude": ["node_modules"]
 }