Sfoglia il codice sorgente

Move Message type extensions to pyth-client

Ali Behjati 2 anni fa
parent
commit
4fa9304b4f

+ 3 - 1
hermes/Cargo.lock

@@ -4120,14 +4120,16 @@ dependencies = [
 [[package]]
 name = "pyth-oracle"
 version = "2.21.0"
-source = "git+https://github.com/pyth-network/pyth-client?rev=7d593d87e07a1e2486e7ca21597d664ee72be1ec#7d593d87e07a1e2486e7ca21597d664ee72be1ec"
+source = "git+https://github.com/pyth-network/pyth-client?rev=319cdc1baade5c4780b830eaf927f9bfef89ee39#319cdc1baade5c4780b830eaf927f9bfef89ee39"
 dependencies = [
  "bindgen",
  "bytemuck",
  "byteorder",
  "num-derive",
  "num-traits",
+ "serde",
  "solana-program",
+ "strum",
  "thiserror",
 ]
 

+ 3 - 1
hermes/Cargo.toml

@@ -64,7 +64,9 @@ serde_qs = { version = "0.12.0", features = ["axum"] }
 
 serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1"}
 wormhole-sdk = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" }
-pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "7d593d87e07a1e2486e7ca21597d664ee72be1ec", features = ["library"] }
+# pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "7d593d87e07a1e2486e7ca21597d664ee72be1ec", features = ["library"] }
+pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "319cdc1baade5c4780b830eaf927f9bfef89ee39" , features = ["library"] }
+
 
 strum = { version = "0.24", features = ["derive"] }
 ethabi = { version = "18.0.0", features = ["serde"] }

+ 38 - 1
hermes/src/network/pythnet.rs

@@ -15,7 +15,9 @@ use {
         Result,
     },
     borsh::BorshDeserialize,
+    byteorder::BE,
     futures::stream::StreamExt,
+    pyth_oracle::Message,
     solana_account_decoder::UiAccountEncoding,
     solana_client::{
         nonblocking::pubsub_client::PubsubClient,
@@ -74,9 +76,44 @@ pub async fn run(store: Arc<Store>, pythnet_ws_endpoint: String) -> Result<!> {
                     }
                 };
 
-                let accumulator_messages = AccumulatorMessages::try_from_slice(&account.data);
+                // The validators writes the accumulator messages using Borsh with
+                // the following struct. We cannot directly have messages as Vec<Messages>
+                // because they are serialized using big-endian byte order and Borsh
+                // uses little-endian byte order.
+                #[derive(BorshDeserialize)]
+                struct RawAccumulatorMessages {
+                    pub magic:        [u8; 4],
+                    pub slot:         u64,
+                    pub ring_size:    u32,
+                    pub raw_messages: Vec<Vec<u8>>,
+                }
+
+                let accumulator_messages = RawAccumulatorMessages::try_from_slice(&account.data);
                 match accumulator_messages {
                     Ok(accumulator_messages) => {
+                        let messages = accumulator_messages
+                            .raw_messages
+                            .iter()
+                            .map(|message| {
+                                pythnet_sdk::wire::from_slice::<BE, Message>(message.as_slice())
+                            })
+                            .collect::<Result<Vec<Message>, _>>();
+
+                        let messages = match messages {
+                            Ok(messages) => messages,
+                            Err(err) => {
+                                log::error!("Failed to parse messages: {:?}", err);
+                                continue;
+                            }
+                        };
+
+                        let accumulator_messages = AccumulatorMessages {
+                            magic: accumulator_messages.magic,
+                            slot: accumulator_messages.slot,
+                            ring_size: accumulator_messages.ring_size,
+                            messages,
+                        };
+
                         let (candidate, _) = Pubkey::find_program_address(
                             &[
                                 b"AccumulatorState",

+ 7 - 8
hermes/src/store.rs

@@ -7,7 +7,6 @@ use {
             StorageInstance,
         },
         types::{
-            MessageType,
             PriceFeedUpdate,
             PriceFeedsWithUpdateData,
             RequestTime,
@@ -33,7 +32,10 @@ use {
         anyhow,
         Result,
     },
-    pyth_oracle::Message,
+    pyth_oracle::{
+        Message,
+        MessageType,
+    },
     pyth_sdk::PriceIdentifier,
     pythnet_sdk::wire::v1::{
         WormholeMessage,
@@ -160,12 +162,9 @@ impl Store {
             .messages
             .iter()
             .enumerate()
-            .map(|(idx, raw_message)| {
-                let message = Message::try_from_bytes(raw_message)?;
-
+            .map(|(idx, message)| {
                 Ok(MessageState::new(
-                    message,
-                    raw_message.clone(),
+                    message.clone(),
                     ProofSet {
                         wormhole_merkle_proof: wormhole_merkle_message_states_proofs
                             .get(idx)
@@ -232,7 +231,7 @@ impl Store {
             .message_state_keys()
             .await
             .iter()
-            .map(|key| key.price_id)
+            .map(|key| PriceIdentifier::new(key.id))
             .collect()
     }
 }

+ 13 - 9
hermes/src/store/proof/wormhole_merkle.rs

@@ -76,20 +76,24 @@ pub fn construct_message_states_proofs(
     let accumulator_messages = &completed_accumulator_state.accumulator_messages;
     let wormhole_merkle_state = &completed_accumulator_state.wormhole_merkle_state;
 
+    let raw_messages = accumulator_messages
+        .messages
+        .iter()
+        .map(|m| m.to_bytes())
+        .collect::<Vec<Vec<u8>>>();
+
     // Check whether the state is valid
-    let merkle_acc = match MerkleAccumulator::<Keccak160>::from_set(
-        accumulator_messages.messages.iter().map(|m| m.as_ref()),
-    ) {
-        Some(merkle_acc) => merkle_acc,
-        None => return Ok(vec![]), // It only happens when the message set is empty
-    };
+    let merkle_acc =
+        match MerkleAccumulator::<Keccak160>::from_set(raw_messages.iter().map(|m| m.as_ref())) {
+            Some(merkle_acc) => merkle_acc,
+            None => return Ok(vec![]), // It only happens when the message set is empty
+        };
 
     if merkle_acc.root != wormhole_merkle_state.root.root {
         return Err(anyhow!("Invalid merkle root"));
     }
 
-    accumulator_messages
-        .messages
+    raw_messages
         .iter()
         .map(|m| {
             Ok(WormholeMerkleMessageProof {
@@ -126,7 +130,7 @@ pub fn construct_update_data(mut message_states: Vec<&MessageState>) -> Result<V
                     updates: messages
                         .iter()
                         .map(|message| MerklePriceUpdate {
-                            message: message.raw_message.clone().into(),
+                            message: message.message.to_bytes().into(),
                             proof:   message.proof_set.wormhole_merkle_proof.proof.clone(),
                         })
                         .collect(),

+ 21 - 20
hermes/src/store/storage.rs

@@ -3,11 +3,7 @@ use {
         proof::wormhole_merkle::WormholeMerkleState,
         types::{
             AccumulatorMessages,
-            MessageExt,
-            MessageIdentifier,
-            MessageType,
             ProofSet,
-            RawMessage,
             RequestTime,
             Slot,
             UnixTimestamp,
@@ -18,7 +14,10 @@ use {
         Result,
     },
     async_trait::async_trait,
-    pyth_oracle::Message,
+    pyth_oracle::{
+        Message,
+        MessageType,
+    },
     pyth_sdk::PriceIdentifier,
 };
 
@@ -56,6 +55,12 @@ impl TryFrom<AccumulatorState> for CompletedAccumulatorState {
     }
 }
 
+#[derive(Clone, PartialEq, Eq, Debug, Hash)]
+pub struct MessageStateKey {
+    pub id:    [u8; 32],
+    pub type_: MessageType,
+}
+
 #[derive(Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
 pub struct MessageStateTime {
     pub publish_time: UnixTimestamp,
@@ -64,40 +69,36 @@ pub struct MessageStateTime {
 
 #[derive(Clone, PartialEq, Debug)]
 pub struct MessageState {
-    pub publish_time: UnixTimestamp,
-    pub slot:         Slot,
-    pub id:           MessageIdentifier,
-    pub message:      Message,
-    pub raw_message:  RawMessage,
-    pub proof_set:    ProofSet,
-    pub received_at:  UnixTimestamp,
+    pub slot:        Slot,
+    pub message:     Message,
+    pub proof_set:   ProofSet,
+    pub received_at: UnixTimestamp,
 }
 
 impl MessageState {
     pub fn time(&self) -> MessageStateTime {
         MessageStateTime {
-            publish_time: self.publish_time,
+            publish_time: self.message.publish_time(),
             slot:         self.slot,
         }
     }
 
-    pub fn key(&self) -> MessageIdentifier {
-        self.id.clone()
+    pub fn key(&self) -> MessageStateKey {
+        MessageStateKey {
+            id:    self.message.id(),
+            type_: self.message.into(),
+        }
     }
 
     pub fn new(
         message: Message,
-        raw_message: RawMessage,
         proof_set: ProofSet,
         slot: Slot,
         received_at: UnixTimestamp,
     ) -> Self {
         Self {
-            publish_time: message.publish_time(),
             slot,
-            id: message.id(),
             message,
-            raw_message,
             proof_set,
             received_at,
         }
@@ -119,7 +120,7 @@ pub enum MessageStateFilter {
 /// key for the update data they wish to access.
 #[async_trait]
 pub trait Storage: Send + Sync {
-    async fn message_state_keys(&self) -> Vec<MessageIdentifier>;
+    async fn message_state_keys(&self) -> Vec<MessageStateKey>;
     async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
     async fn fetch_message_states(
         &self,

+ 9 - 11
hermes/src/store/storage/local_storage.rs

@@ -1,17 +1,14 @@
 use {
     super::{
         AccumulatorState,
-        MessageIdentifier,
         MessageState,
         MessageStateFilter,
+        MessageStateKey,
         RequestTime,
         Storage,
         StorageInstance,
     },
-    crate::store::types::{
-        MessageType,
-        Slot,
-    },
+    crate::store::types::Slot,
     anyhow::{
         anyhow,
         Result,
@@ -19,6 +16,7 @@ use {
     async_trait::async_trait,
     dashmap::DashMap,
     moka::sync::Cache,
+    pyth_oracle::MessageType,
     pyth_sdk::PriceIdentifier,
     std::{
         collections::VecDeque,
@@ -29,7 +27,7 @@ use {
 
 #[derive(Clone)]
 pub struct LocalStorage {
-    message_cache:     Arc<DashMap<MessageIdentifier, VecDeque<MessageState>>>,
+    message_cache:     Arc<DashMap<MessageStateKey, VecDeque<MessageState>>>,
     accumulator_cache: Cache<Slot, AccumulatorState>,
     cache_size:        u64,
 }
@@ -48,7 +46,7 @@ impl LocalStorage {
 
     fn retrieve_message_state(
         &self,
-        key: MessageIdentifier,
+        key: MessageStateKey,
         request_time: RequestTime,
     ) -> Option<MessageState> {
         match self.message_cache.get(&key) {
@@ -135,9 +133,9 @@ impl Storage for LocalStorage {
                 };
 
                 message_types.into_iter().map(move |message_type| {
-                    let key = MessageIdentifier {
-                        price_id: id,
-                        type_:    message_type,
+                    let key = MessageStateKey {
+                        id:    id.to_bytes(),
+                        type_: message_type,
                     };
                     self.retrieve_message_state(key, request_time.clone())
                         .ok_or(anyhow!("Message not found"))
@@ -146,7 +144,7 @@ impl Storage for LocalStorage {
             .collect()
     }
 
-    async fn message_state_keys(&self) -> Vec<MessageIdentifier> {
+    async fn message_state_keys(&self) -> Vec<MessageStateKey> {
         self.message_cache
             .iter()
             .map(|entry| entry.key().clone())

+ 4 - 49
hermes/src/store/types.rs

@@ -1,60 +1,15 @@
 use {
     super::proof::wormhole_merkle::WormholeMerkleMessageProof,
-    borsh::BorshDeserialize,
     pyth_oracle::{
         Message,
+        MessageType,
         PriceFeedMessage,
     },
-    pyth_sdk::PriceIdentifier,
-    strum::EnumIter,
 };
 
-
-// TODO: We can use strum on Message enum to derive this.
-#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, EnumIter)]
-pub enum MessageType {
-    PriceFeedMessage,
-    TwapMessage,
-}
-
-// TODO: Move this methods to Message enum
-pub trait MessageExt {
-    fn type_(&self) -> MessageType;
-    fn id(&self) -> MessageIdentifier;
-    fn publish_time(&self) -> UnixTimestamp;
-}
-
-impl MessageExt for Message {
-    fn type_(&self) -> MessageType {
-        match self {
-            Message::PriceFeedMessage(_) => MessageType::PriceFeedMessage,
-            Message::TwapMessage(_) => MessageType::TwapMessage,
-        }
-    }
-
-    fn id(&self) -> MessageIdentifier {
-        MessageIdentifier {
-            price_id: match self {
-                Message::PriceFeedMessage(message) => PriceIdentifier::new(message.id),
-                Message::TwapMessage(message) => PriceIdentifier::new(message.id),
-            },
-            type_:    self.type_(),
-        }
-    }
-
-    fn publish_time(&self) -> UnixTimestamp {
-        match self {
-            Message::PriceFeedMessage(message) => message.publish_time,
-            Message::TwapMessage(message) => message.publish_time,
-        }
-    }
-}
-
-pub type RawMessage = Vec<u8>;
-
 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
 pub struct MessageIdentifier {
-    pub price_id: PriceIdentifier,
+    pub price_id: [u8; 32],
     pub type_:    MessageType,
 }
 
@@ -74,12 +29,12 @@ pub enum RequestTime {
     FirstAfter(UnixTimestamp),
 }
 
-#[derive(Clone, PartialEq, Debug, BorshDeserialize)]
+#[derive(Clone, PartialEq, Debug)]
 pub struct AccumulatorMessages {
     pub magic:     [u8; 4],
     pub slot:      Slot,
     pub ring_size: u32,
-    pub messages:  Vec<RawMessage>,
+    pub messages:  Vec<Message>,
 }
 
 impl AccumulatorMessages {