Reisen 2 年之前
父节点
当前提交
454cd03b90

+ 5 - 1
hermes/src/api/rest.rs

@@ -65,7 +65,7 @@ impl IntoResponse for RestError {
 pub async fn price_feed_ids(
     State(state): State<super::State>,
 ) -> Result<Json<HashSet<PriceIdentifier>>, RestError> {
-    let price_feeds = state.store.get_price_feed_ids();
+    let price_feeds = state.store.get_price_feed_ids().await;
     Ok(Json(price_feeds))
 }
 
@@ -83,6 +83,7 @@ pub async fn latest_vaas(
     let price_feeds_with_update_data = state
         .store
         .get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
+        .await
         .map_err(|_| RestError::UpdateDataNotFound)?;
     Ok(Json(
         price_feeds_with_update_data
@@ -111,6 +112,7 @@ pub async fn latest_price_feeds(
     let price_feeds_with_update_data = state
         .store
         .get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
+        .await
         .map_err(|_| RestError::UpdateDataNotFound)?;
     Ok(Json(
         price_feeds_with_update_data
@@ -148,6 +150,7 @@ pub async fn get_vaa(
             vec![price_id],
             RequestTime::FirstAfter(params.publish_time),
         )
+        .await
         .map_err(|_| RestError::UpdateDataNotFound)?;
 
     let vaa = price_feeds_with_update_data
@@ -198,6 +201,7 @@ pub async fn get_vaa_ccip(
     let price_feeds_with_update_data = state
         .store
         .get_price_feeds_with_update_data(vec![price_id], RequestTime::FirstAfter(publish_time))
+        .await
         .map_err(|_| RestError::CcipUpdateDataNotFound)?;
 
     let bytes = price_feeds_with_update_data

+ 27 - 24
hermes/src/api/ws.rs

@@ -161,31 +161,34 @@ impl Subscriber {
     }
 
     async fn handle_price_feeds_update(&mut self) -> Result<()> {
-        let messages = self
-            .price_feeds_with_config
-            .iter()
-            .map(|(price_feed_id, config)| {
-                let price_feeds_with_update_data = self
-                    .store
-                    .get_price_feeds_with_update_data(vec![*price_feed_id], RequestTime::Latest)?;
-                let price_feed = price_feeds_with_update_data
-                    .price_feeds
-                    .into_iter()
-                    .next()
-                    .ok_or_else(|| {
-                        anyhow::anyhow!("Price feed {} not found.", price_feed_id.to_string())
-                    })?;
-                let price_feed =
-                    RpcPriceFeed::from_price_feed_update(price_feed, config.verbose, config.binary);
-
-                Ok(Message::Text(serde_json::to_string(
-                    &ServerMessage::PriceUpdate { price_feed },
+        let price_feed_ids = self.price_feeds_with_config.keys().cloned().collect();
+        for update in self
+            .store
+            .get_price_feeds_with_update_data(price_feed_ids, RequestTime::Latest)
+            .await?
+            .price_feeds
+        {
+            let config = self
+                .price_feeds_with_config
+                .get(&PriceIdentifier::new(update.price_feed.id))
+                .ok_or(anyhow::anyhow!(
+                    "Config missing, price feed list was poisoned during iteration."
+                ))?;
+
+            self.sender
+                .feed(Message::Text(serde_json::to_string(
+                    &ServerMessage::PriceUpdate {
+                        price_feed: RpcPriceFeed::from_price_feed_update(
+                            update,
+                            config.verbose,
+                            config.binary,
+                        ),
+                    },
                 )?))
-            })
-            .collect::<Result<Vec<Message>>>()?;
-        self.sender
-            .send_all(&mut iter(messages.into_iter().map(Ok)))
-            .await?;
+                .await?;
+        }
+
+        self.sender.flush().await?;
         Ok(())
     }
 

+ 54 - 54
hermes/src/store.rs

@@ -1,7 +1,10 @@
 use {
     self::{
         proof::wormhole_merkle::construct_update_data,
-        storage::StorageInstance,
+        storage::{
+            MessageStateFilter,
+            StorageInstance,
+        },
         types::{
             AccumulatorMessages,
             MessageType,
@@ -17,6 +20,7 @@ use {
             construct_message_states_proofs,
             store_wormhole_merkle_verified_message,
         },
+        storage::AccumulatorState,
         types::{
             MessageState,
             ProofSet,
@@ -41,7 +45,6 @@ use {
         collections::HashSet,
         sync::Arc,
         time::{
-            Duration,
             SystemTime,
             UNIX_EPOCH,
         },
@@ -62,28 +65,16 @@ pub mod storage;
 pub mod types;
 pub mod wormhole;
 
-#[derive(Clone, PartialEq, Debug, Builder)]
-#[builder(derive(Debug), pattern = "immutable")]
-pub struct AccumulatorState {
-    pub accumulator_messages:  AccumulatorMessages,
-    pub wormhole_merkle_proof: (WormholeMerkleRoot, Vec<u8>),
-}
-
 pub struct Store {
-    pub storage:               StorageInstance,
-    pub pending_accumulations: Cache<Slot, AccumulatorStateBuilder>,
-    pub guardian_set:          RwLock<Option<Vec<GuardianAddress>>>,
-    pub update_tx:             Sender<()>,
+    pub storage:      StorageInstance,
+    pub guardian_set: RwLock<Option<Vec<GuardianAddress>>>,
+    pub update_tx:    Sender<()>,
 }
 
 impl Store {
-    pub fn new_with_local_cache(update_tx: Sender<()>, max_size_per_key: usize) -> Arc<Self> {
+    pub fn new_with_local_cache(update_tx: Sender<()>, cache_size: u64) -> Arc<Self> {
         Arc::new(Self {
-            storage: storage::local_storage::LocalStorage::new_instance(max_size_per_key),
-            pending_accumulations: Cache::builder()
-                .max_capacity(10_000)
-                .time_to_live(Duration::from_secs(60 * 5))
-                .build(), // FIXME: Make this configurable
+            storage: storage::local_storage::LocalStorage::new_instance(cache_size),
             guardian_set: RwLock::new(None),
             update_tx,
         })
@@ -117,44 +108,47 @@ impl Store {
                     }
                 }
             }
+
             Update::AccumulatorMessages(accumulator_messages) => {
                 let slot = accumulator_messages.slot;
-
                 log::info!("Storing accumulator messages for slot {:?}.", slot,);
-
-                let pending_acc = self
-                    .pending_accumulations
-                    .entry(slot)
-                    .or_default()
-                    .await
-                    .into_value();
-                self.pending_accumulations
-                    .insert(slot, pending_acc.accumulator_messages(accumulator_messages))
-                    .await;
-
+                let mut accumulator_state = self
+                    .storage
+                    .fetch_accumulator_state(slot)
+                    .await?
+                    .unwrap_or(AccumulatorState {
+                        slot,
+                        accumulator_messages: None,
+                        wormhole_merkle_proof: None,
+                    });
+                accumulator_state.accumulator_messages = Some(accumulator_messages);
+                self.storage
+                    .store_accumulator_state(accumulator_state)
+                    .await?;
                 slot
             }
         };
 
-        let pending_state = self.pending_accumulations.get(&slot);
-        let pending_state = match pending_state {
-            Some(pending_state) => pending_state,
-            // Due to some race conditions this might happen when it's processed before
+        let state = match self.storage.fetch_accumulator_state(slot).await? {
+            Some(state) => state,
             None => return Ok(()),
         };
 
-        let state = match pending_state.build() {
-            Ok(state) => state,
-            Err(_) => return Ok(()),
-        };
+        let (accumulator_messages, wormhole_merkle_proof) =
+            match (state.accumulator_messages, state.wormhole_merkle_proof) {
+                (Some(accumulator_messages), Some(wormhole_merkle_proof)) => {
+                    (accumulator_messages, wormhole_merkle_proof)
+                }
+                _ => return Ok(()),
+            };
 
-        let wormhole_merkle_message_states_proofs = construct_message_states_proofs(state.clone())?;
+        let wormhole_merkle_message_states_proofs =
+            construct_message_states_proofs(&accumulator_messages, &wormhole_merkle_proof)?;
 
         let current_time: UnixTimestamp =
             SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as _;
 
-        let message_states = state
-            .accumulator_messages
+        let message_states = accumulator_messages
             .messages
             .iter()
             .enumerate()
@@ -170,7 +164,7 @@ impl Store {
                             .ok_or(anyhow!("Missing proof for message"))?
                             .clone(),
                     },
-                    state.accumulator_messages.slot,
+                    accumulator_messages.slot,
                     current_time,
                 ))
             })
@@ -178,9 +172,7 @@ impl Store {
 
         log::info!("Message states len: {:?}", message_states.len());
 
-        self.storage.store_message_states(message_states)?;
-
-        self.pending_accumulations.invalidate(&slot).await;
+        self.storage.store_message_states(message_states).await?;
 
         self.update_tx.send(()).await?;
 
@@ -191,16 +183,19 @@ impl Store {
         self.guardian_set.write().await.replace(guardian_set);
     }
 
-    pub fn get_price_feeds_with_update_data(
+    pub async fn get_price_feeds_with_update_data(
         &self,
         price_ids: Vec<PriceIdentifier>,
         request_time: RequestTime,
     ) -> Result<PriceFeedsWithUpdateData> {
-        let messages = self.storage.retrieve_message_states(
-            price_ids,
-            request_time,
-            Some(&|message_type| *message_type == MessageType::PriceFeedMessage),
-        )?;
+        let messages = self
+            .storage
+            .fetch_message_states(
+                price_ids,
+                request_time,
+                MessageStateFilter::Only(MessageType::PriceFeedMessage),
+            )
+            .await?;
 
         let price_feeds = messages
             .iter()
@@ -226,7 +221,12 @@ impl Store {
         })
     }
 
-    pub fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier> {
-        self.storage.keys().iter().map(|key| key.price_id).collect()
+    pub async fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier> {
+        self.storage
+            .message_state_keys()
+            .await
+            .iter()
+            .map(|key| key.price_id)
+            .collect()
     }
 }

+ 24 - 23
hermes/src/store/proof/wormhole_merkle.rs

@@ -1,7 +1,10 @@
 use {
     crate::store::{
-        types::MessageState,
-        AccumulatorState,
+        storage::AccumulatorState,
+        types::{
+            AccumulatorMessages,
+            MessageState,
+        },
         Store,
     },
     anyhow::{
@@ -40,45 +43,43 @@ pub async fn store_wormhole_merkle_verified_message(
     proof: WormholeMerkleRoot,
     vaa_bytes: Vec<u8>,
 ) -> Result<()> {
-    let pending_acc = store
-        .pending_accumulations
-        .entry(proof.slot)
-        .or_default()
-        .await
-        .into_value();
+    let mut accumulator_state = store
+        .storage
+        .fetch_accumulator_state(proof.slot)
+        .await?
+        .unwrap_or(AccumulatorState {
+            slot:                  proof.slot,
+            accumulator_messages:  None,
+            wormhole_merkle_proof: None,
+        });
+
+    accumulator_state.wormhole_merkle_proof = Some((proof, vaa_bytes));
     store
-        .pending_accumulations
-        .insert(
-            proof.slot,
-            pending_acc.wormhole_merkle_proof((proof, vaa_bytes)),
-        )
-        .await;
+        .storage
+        .store_accumulator_state(accumulator_state)
+        .await?;
     Ok(())
 }
 
 pub fn construct_message_states_proofs(
-    state: AccumulatorState,
+    accumulator_messages: &AccumulatorMessages,
+    wormhole_merkle_proof: &(WormholeMerkleRoot, Vec<u8>),
 ) -> Result<Vec<WormholeMerkleMessageProof>> {
     // Check whether the state is valid
     let merkle_acc = match MerkleAccumulator::<Keccak160>::from_set(
-        state
-            .accumulator_messages
-            .messages
-            .iter()
-            .map(|m| m.as_ref()),
+        accumulator_messages.messages.iter().map(|m| m.as_ref()),
     ) {
         Some(merkle_acc) => merkle_acc,
         None => return Ok(vec![]), // It only happens when the message set is empty
     };
 
-    let (proof, vaa) = &state.wormhole_merkle_proof;
+    let (proof, vaa) = &wormhole_merkle_proof;
 
     if merkle_acc.root != proof.root {
         return Err(anyhow!("Invalid merkle root"));
     }
 
-    state
-        .accumulator_messages
+    accumulator_messages
         .messages
         .iter()
         .map(|m| {

+ 25 - 4
hermes/src/store/storage.rs

@@ -1,16 +1,33 @@
 use {
     super::types::{
+        AccumulatorMessages,
         MessageIdentifier,
         MessageState,
         MessageType,
         RequestTime,
+        Slot,
     },
     anyhow::Result,
+    async_trait::async_trait,
     pyth_sdk::PriceIdentifier,
+    pythnet_sdk::wire::v1::WormholeMerkleRoot,
 };
 
 pub mod local_storage;
 
+#[derive(Clone, PartialEq, Debug)]
+pub struct AccumulatorState {
+    pub slot:                  Slot,
+    pub accumulator_messages:  Option<AccumulatorMessages>,
+    pub wormhole_merkle_proof: Option<(WormholeMerkleRoot, Vec<u8>)>,
+}
+
+#[derive(Clone, Copy)]
+pub enum MessageStateFilter {
+    All,
+    Only(MessageType),
+}
+
 /// This trait defines the interface for update data storage
 ///
 /// Price update data for Pyth can come in multiple formats, for example VAA's and
@@ -18,15 +35,19 @@ pub mod local_storage;
 /// data to abstract the details of the update data, and so each update data is stored
 /// under a separate key. The caller is responsible for specifying the right
 /// key for the update data they wish to access.
+#[async_trait]
 pub trait Storage: Send + Sync {
-    fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
-    fn retrieve_message_states(
+    async fn message_state_keys(&self) -> Vec<MessageIdentifier>;
+    async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
+    async fn fetch_message_states(
         &self,
         ids: Vec<PriceIdentifier>,
         request_time: RequestTime,
-        filter: Option<&dyn Fn(&MessageType) -> bool>,
+        filter: MessageStateFilter,
     ) -> Result<Vec<MessageState>>;
-    fn keys(&self) -> Vec<MessageIdentifier>;
+
+    async fn store_accumulator_state(&self, state: AccumulatorState) -> Result<()>;
+    async fn fetch_accumulator_state(&self, slot: u64) -> Result<Option<AccumulatorState>>;
 }
 
 pub type StorageInstance = Box<dyn Storage>;

+ 43 - 17
hermes/src/store/storage/local_storage.rs

@@ -1,17 +1,24 @@
 use {
     super::{
+        AccumulatorState,
         MessageIdentifier,
         MessageState,
+        MessageStateFilter,
         RequestTime,
         Storage,
         StorageInstance,
     },
-    crate::store::types::MessageType,
+    crate::store::types::{
+        MessageType,
+        Slot,
+    },
     anyhow::{
         anyhow,
         Result,
     },
+    async_trait::async_trait,
     dashmap::DashMap,
+    moka::sync::Cache,
     pyth_sdk::PriceIdentifier,
     std::{
         collections::VecDeque,
@@ -22,15 +29,20 @@ use {
 
 #[derive(Clone)]
 pub struct LocalStorage {
-    cache:            Arc<DashMap<MessageIdentifier, VecDeque<MessageState>>>,
-    max_size_per_key: usize,
+    message_cache:     Arc<DashMap<MessageIdentifier, VecDeque<MessageState>>>,
+    accumulator_cache: Cache<Slot, AccumulatorState>,
+    cache_size:        u64,
 }
 
 impl LocalStorage {
-    pub fn new_instance(max_size_per_key: usize) -> StorageInstance {
+    pub fn new_instance(cache_size: u64) -> StorageInstance {
         Box::new(Self {
-            cache: Arc::new(DashMap::new()),
-            max_size_per_key,
+            message_cache: Arc::new(DashMap::new()),
+            accumulator_cache: Cache::builder()
+                .max_capacity(cache_size)
+                .time_to_live(std::time::Duration::from_secs(60 * 60 * 24))
+                .build(),
+            cache_size,
         })
     }
 
@@ -39,7 +51,7 @@ impl LocalStorage {
         key: MessageIdentifier,
         request_time: RequestTime,
     ) -> Option<MessageState> {
-        match self.cache.get(&key) {
+        match self.message_cache.get(&key) {
             Some(key_cache) => {
                 match request_time {
                     RequestTime::Latest => key_cache.back().cloned(),
@@ -74,6 +86,7 @@ impl LocalStorage {
     }
 }
 
+#[async_trait]
 impl Storage for LocalStorage {
     /// Add a new db entry to the cache.
     ///
@@ -81,11 +94,11 @@ impl Storage for LocalStorage {
     /// the oldest record in the cache if the max_size is reached. Entries are
     /// usually added in increasing order and likely to be inserted near the
     /// end of the deque. The function is optimized for this specific case.
-    fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()> {
+    async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()> {
         for message_state in message_states {
             let key = message_state.key();
 
-            let mut key_cache = self.cache.entry(key).or_insert_with(VecDeque::new);
+            let mut key_cache = self.message_cache.entry(key).or_insert_with(VecDeque::new);
 
             key_cache.push_back(message_state);
 
@@ -99,7 +112,7 @@ impl Storage for LocalStorage {
             // FIXME remove equal elements by key and time
 
             // Remove the oldest record if the max size is reached.
-            if key_cache.len() > self.max_size_per_key {
+            if key_cache.len() > self.cache_size as usize {
                 key_cache.pop_front();
             }
         }
@@ -107,20 +120,20 @@ impl Storage for LocalStorage {
         Ok(())
     }
 
-    fn retrieve_message_states(
+    async fn fetch_message_states(
         &self,
         ids: Vec<PriceIdentifier>,
         request_time: RequestTime,
-        filter: Option<&dyn Fn(&MessageType) -> bool>,
+        filter: MessageStateFilter,
     ) -> Result<Vec<MessageState>> {
-        // TODO: Should we return an error if any of the ids are not found?
         ids.into_iter()
             .flat_map(|id| {
                 let request_time = request_time.clone();
                 let message_types: Vec<MessageType> = match filter {
-                    Some(filter) => MessageType::iter().filter(filter).collect(),
-                    None => MessageType::iter().collect(),
+                    MessageStateFilter::All => MessageType::iter().collect(),
+                    MessageStateFilter::Only(t) => vec![t],
                 };
+
                 message_types.into_iter().map(move |message_type| {
                     let key = MessageIdentifier {
                         price_id: id,
@@ -133,7 +146,20 @@ impl Storage for LocalStorage {
             .collect()
     }
 
-    fn keys(&self) -> Vec<MessageIdentifier> {
-        self.cache.iter().map(|entry| entry.key().clone()).collect()
+    async fn message_state_keys(&self) -> Vec<MessageIdentifier> {
+        self.message_cache
+            .iter()
+            .map(|entry| entry.key().clone())
+            .collect()
+    }
+
+    async fn store_accumulator_state(&self, state: super::AccumulatorState) -> Result<()> {
+        let key = state.slot;
+        self.accumulator_cache.insert(key, state);
+        Ok(())
+    }
+
+    async fn fetch_accumulator_state(&self, slot: Slot) -> Result<Option<super::AccumulatorState>> {
+        Ok(self.accumulator_cache.get(&slot))
     }
 }

+ 1 - 1
hermes/src/store/types.rs

@@ -11,7 +11,7 @@ use {
 
 
 // TODO: We can use strum on Message enum to derive this.
-#[derive(Clone, Debug, Eq, PartialEq, Hash, EnumIter)]
+#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, EnumIter)]
 pub enum MessageType {
     PriceFeedMessage,
     TwapMessage,