Bladeren bron

various fixes

Mike Rolish 1 maand geleden
bovenliggende
commit
105a403b4c

+ 40 - 32
apps/hip-3-pusher/src/pusher/kms_signer.py

@@ -1,3 +1,4 @@
+
 import boto3
 from cryptography.hazmat.primitives import serialization
 from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
@@ -9,40 +10,55 @@ from hyperliquid.exchange import Exchange
 from hyperliquid.utils.constants import TESTNET_API_URL, MAINNET_API_URL
 from hyperliquid.utils.signing import get_timestamp_ms, action_hash, construct_phantom_agent, l1_payload
 from loguru import logger
+from pathlib import Path
 
 from pusher.config import Config
 
 SECP256K1_N_HALF = SECP256K1_N // 2
 
 
+def _init_client():
+    # AWS_DEFAULT_REGION, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY should be set as environment variables
+    return boto3.client(
+        "kms",
+        # can specify an endpoint for e.g. LocalStack
+        # endpoint_url="http://localhost:4566"
+    )
+
+
 class KMSSigner:
     def __init__(self, config: Config):
-        use_testnet = config.hyperliquid.use_testnet
-        url = TESTNET_API_URL if use_testnet else MAINNET_API_URL
+        self.use_testnet = config.hyperliquid.use_testnet
+        url = TESTNET_API_URL if self.use_testnet else MAINNET_API_URL
         self.oracle_publisher_exchange: Exchange = Exchange(wallet=None, base_url=url)
-        self.client = self._init_client(config)
 
+        # AWS client and public key load
+        self.client = _init_client()
+        self._load_public_key(config.kms.key_path)
+
+    def _load_public_key(self, key_path: str):
         # Fetch public key once so we can derive address and check recovery id
-        key_path = config.kms.key_path
-        self.key_id = open(key_path, "r").read().strip()
-        self.pubkey_der = self.client.get_public_key(KeyId=self.key_id)["PublicKey"]
+        self.key_id = Path(key_path).read_text().strip()
+        pubkey_der = self.client.get_public_key(KeyId=self.key_id)["PublicKey"]
+        self.pubkey = serialization.load_der_public_key(pubkey_der)
+        self._construct_pubkey_address_and_bytes()
+
+    def _construct_pubkey_address_and_bytes(self):
         # Construct eth address to log
-        pub = serialization.load_der_public_key(self.pubkey_der)
-        numbers = pub.public_numbers()
+        numbers = self.pubkey.public_numbers()
         x = numbers.x.to_bytes(32, "big")
         y = numbers.y.to_bytes(32, "big")
         uncompressed = b"\x04" + x + y
-        self.public_key_bytes = uncompressed
         self.address = "0x" + keccak(uncompressed[1:])[-20:].hex()
-        logger.info("KMSSigner address: {}", self.address)
+        logger.info("public key loaded from KMS: {}", self.address)
 
-    def _init_client(self, config):
-        # AWS_REGION, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY should be set as environment variables
-        return boto3.client(
-            "kms",
-            # can specify an endpoint for e.g. LocalStack
-            # endpoint_url="http://localhost:4566"
+        # Parse KMS public key into uncompressed secp256k1 bytes
+        pubkey_bytes = self.pubkey.public_bytes(
+            serialization.Encoding.X962,
+            serialization.PublicFormat.UncompressedPoint,
         )
+        # Strip leading 0x04 (uncompressed point indicator)
+        self.raw_pubkey_bytes = pubkey_bytes[1:]
 
     def set_oracle(self, dex, oracle_pxs, all_mark_pxs, external_perp_pxs):
         timestamp = get_timestamp_ms()
@@ -59,14 +75,14 @@ class KMSSigner:
             },
         }
         signature = self.sign_l1_action(
-            action,
-            timestamp,
-            self.oracle_publisher_exchange.base_url == MAINNET_API_URL,
+            action=action,
+            nonce=timestamp,
+            is_mainnet= self.use_testnet,
         )
         return self.oracle_publisher_exchange._post_action(
-            action,
-            signature,
-            timestamp,
+            action=action,
+            signature=signature,
+            nonce=timestamp,
         )
 
     def sign_l1_action(self, action, nonce, is_mainnet):
@@ -91,20 +107,12 @@ class KMSSigner:
         # Ethereum requires low-s form
         if s > SECP256K1_N_HALF:
             s = SECP256K1_N - s
-        # Parse KMS public key into uncompressed secp256k1 bytes
-        # TODO: Pull this into init
-        pubkey = serialization.load_der_public_key(self.pubkey_der)
-        pubkey_bytes = pubkey.public_bytes(
-            serialization.Encoding.X962,
-            serialization.PublicFormat.UncompressedPoint,
-        )
-        # Strip leading 0x04 (uncompressed point indicator)
-        raw_pubkey_bytes = pubkey_bytes[1:]
+
         # Try both recovery ids
         for v in (0, 1):
             sig_obj = Signature(vrs=(v, r, s))
             recovered_pub = sig_obj.recover_public_key_from_msg_hash(message_hash)
-            if recovered_pub.to_bytes() == raw_pubkey_bytes:
+            if recovered_pub.to_bytes() == self.raw_pubkey_bytes:
                 return {
                     "r": to_hex(r),
                     "s": to_hex(s),

+ 0 - 4
apps/hip-3-pusher/src/pusher/metrics.py

@@ -17,9 +17,7 @@ class Metrics:
         reader = PrometheusMetricReader()
         # Meter is responsible for creating and recording metrics
         set_meter_provider(MeterProvider(metric_readers=[reader]))
-        # TODO: sync version number and add?
         self.meter = get_meter_provider().get_meter(METER_NAME)
-
         self._init_metrics()
 
     def _init_metrics(self):
@@ -35,5 +33,3 @@ class Metrics:
             name="hip_3_pusher_failed_push_count",
             description="Number of failed push attempts",
         )
-
-        # TODO: labels/attributes

+ 9 - 7
apps/hip-3-pusher/src/pusher/publisher.py

@@ -1,5 +1,6 @@
 import asyncio
 from loguru import logger
+from pathlib import Path
 
 from eth_account import Account
 from eth_account.signers.local import LocalAccount
@@ -30,7 +31,7 @@ class Publisher:
             self.kms_signer = KMSSigner(config)
         else:
             oracle_pusher_key_path = config.hyperliquid.oracle_pusher_key_path
-            oracle_pusher_key = open(oracle_pusher_key_path, "r").read().strip()
+            oracle_pusher_key = Path(oracle_pusher_key_path).read_text().strip()
             oracle_account: LocalAccount = Account.from_key(oracle_pusher_key)
             logger.info("oracle pusher local pubkey: {}", oracle_account.address)
 
@@ -42,6 +43,7 @@ class Publisher:
 
         self.price_state = price_state
         self.metrics = metrics
+        self.metrics_labels = {"dex": self.market_name}
 
     async def run(self):
         while True:
@@ -56,17 +58,17 @@ class Publisher:
         oracle_px = self.price_state.get_current_oracle_price()
         if not oracle_px:
             logger.error("No valid oracle price available")
-            self.metrics.no_oracle_price_counter.add(1)
+            self.metrics.no_oracle_price_counter.add(1, self.metrics_labels)
             return
         else:
             logger.debug("Current oracle price: {}", oracle_px)
             oracle_pxs[self.market_symbol] = oracle_px
 
         mark_pxs = []
-        #if self.price_state.hl_mark_price:
-        #    mark_pxs.append({self.market_symbol: self.price_state.hl_mark_price})
-
         external_perp_pxs = {}
+        if self.price_state.hl_mark_price:
+            external_perp_pxs[self.market_symbol] = self.price_state.hl_mark_price.price
+
         # TODO: "Each update can change oraclePx and markPx by at most 1%."
         # TODO: "The markPx cannot be updated such that open interest would be 10x the open interest cap."
 
@@ -90,7 +92,7 @@ class Publisher:
             logger.debug("publish: push response: {} {}", push_response, type(push_response))
             status = push_response.get("status", "")
             if status == "ok":
-                self.metrics.successful_push_counter.add(1)
+                self.metrics.successful_push_counter.add(1, self.metrics_labels)
             elif status == "err":
-                self.metrics.failed_push_counter.add(1)
+                self.metrics.failed_push_counter.add(1, self.metrics_labels)
                 logger.error("publish: publish error: {}", push_response)