Jelajahi Sumber

feat(hermes): add out of order subscription

Also improve the readiness probe

Co-authored-by: Reisen <Reisen@users.noreply.github.com>
Ali Behjati 2 tahun lalu
induk
melakukan
5e45146acb
8 mengubah file dengan 299 tambahan dan 118 penghapusan
  1. 1 1
      hermes/Cargo.lock
  2. 1 1
      hermes/Cargo.toml
  3. 107 12
      hermes/src/aggregate.rs
  4. 18 11
      hermes/src/api.rs
  5. 96 78
      hermes/src/api/ws.rs
  6. 1 3
      hermes/src/main.rs
  7. 14 12
      hermes/src/state.rs
  8. 61 0
      hermes/src/state/cache.rs

+ 1 - 1
hermes/Cargo.lock

@@ -1764,7 +1764,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
 
 [[package]]
 name = "hermes"
-version = "0.1.16"
+version = "0.1.17"
 dependencies = [
  "anyhow",
  "async-trait",

+ 1 - 1
hermes/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name    = "hermes"
-version = "0.1.16"
+version = "0.1.17"
 edition = "2021"
 
 [dependencies]

+ 107 - 12
hermes/src/aggregate.rs

@@ -79,10 +79,51 @@ pub type UnixTimestamp = i64;
 pub enum RequestTime {
     Latest,
     FirstAfter(UnixTimestamp),
+    AtSlot(Slot),
 }
 
 pub type RawMessage = Vec<u8>;
 
+/// An event that is emitted when an aggregation is completed.
+#[derive(Clone, PartialEq, Debug)]
+pub enum AggregationEvent {
+    New { slot: Slot },
+    OutOfOrder { slot: Slot },
+}
+
+impl AggregationEvent {
+    pub fn slot(&self) -> Slot {
+        match self {
+            AggregationEvent::New { slot } => *slot,
+            AggregationEvent::OutOfOrder { slot } => *slot,
+        }
+    }
+}
+
+#[derive(Clone, PartialEq, Debug)]
+pub struct AggregateState {
+    /// The latest completed slot. This is used to check whether a completed state is new or out of
+    /// order.
+    pub latest_completed_slot: Option<Slot>,
+
+    /// Time of the latest completed update. This is used for the health probes.
+    pub latest_completed_update_at: Option<Instant>,
+
+    /// The latest observed slot among different Aggregate updates. This is used for the health
+    /// probes.
+    pub latest_observed_slot: Option<Slot>,
+}
+
+impl AggregateState {
+    pub fn new() -> Self {
+        Self {
+            latest_completed_slot:      None,
+            latest_completed_update_at: None,
+            latest_observed_slot:       None,
+        }
+    }
+}
+
 /// Accumulator messages coming from Pythnet validators.
 ///
 /// The validators writes the accumulator messages using Borsh with
@@ -125,6 +166,10 @@ pub struct PriceFeedsWithUpdateData {
 
 const READINESS_STALENESS_THRESHOLD: Duration = Duration::from_secs(30);
 
+/// The maximum allowed slot lag between the latest observed slot and the latest completed slot.
+/// 10 slots is almost 5 seconds.
+const READINESS_MAX_ALLOWED_SLOT_LAG: Slot = 10;
+
 /// Stores the update data in the store
 #[tracing::instrument(skip(state, update))]
 pub async fn store_update(state: &State, update: Update) -> Result<()> {
@@ -150,7 +195,6 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
                 }
             }
         }
-
         Update::AccumulatorMessages(accumulator_messages) => {
             let slot = accumulator_messages.slot;
             tracing::info!(slot = slot, "Storing Accumulator Messages.");
@@ -162,6 +206,14 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
         }
     };
 
+    // Update the aggregate state with the latest observed slot
+    {
+        let mut aggregate_state = state.aggregate_state.write().await;
+        aggregate_state.latest_observed_slot = aggregate_state
+            .latest_observed_slot
+            .map(|latest| latest.max(slot));
+    }
+
     let accumulator_messages = state.fetch_accumulator_messages(slot).await?;
     let wormhole_merkle_state = state.fetch_wormhole_merkle_state(slot).await?;
 
@@ -179,12 +231,39 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
     // we can build the message states
     build_message_states(state, accumulator_messages, wormhole_merkle_state).await?;
 
-    state.update_tx.send(()).await?;
+    // Update the aggregate state
+    let mut aggregate_state = state.aggregate_state.write().await;
 
-    state
-        .last_completed_update_at
-        .write()
-        .await
+    // Check if the update is new or out of order
+    match aggregate_state.latest_completed_slot {
+        None => {
+            aggregate_state.latest_completed_slot.replace(slot);
+            state
+                .api_update_tx
+                .send(AggregationEvent::New { slot })
+                .await?;
+        }
+        Some(latest) if slot > latest => {
+            aggregate_state.latest_completed_slot.replace(slot);
+            state
+                .api_update_tx
+                .send(AggregationEvent::New { slot })
+                .await?;
+        }
+        _ => {
+            state
+                .api_update_tx
+                .send(AggregationEvent::OutOfOrder { slot })
+                .await?;
+        }
+    }
+
+    aggregate_state.latest_completed_slot = aggregate_state
+        .latest_completed_slot
+        .map(|latest| latest.max(slot));
+
+    aggregate_state
+        .latest_completed_update_at
         .replace(Instant::now());
 
     Ok(())
@@ -321,13 +400,26 @@ where
 }
 
 pub async fn is_ready(state: &State) -> bool {
-    let last_completed_update_at = state.last_completed_update_at.read().await;
-    match last_completed_update_at.as_ref() {
-        Some(last_completed_update_at) => {
-            last_completed_update_at.elapsed() < READINESS_STALENESS_THRESHOLD
+    let metadata = state.aggregate_state.read().await;
+
+    let has_completed_recently = match metadata.latest_completed_update_at.as_ref() {
+        Some(latest_completed_update_time) => {
+            latest_completed_update_time.elapsed() < READINESS_STALENESS_THRESHOLD
         }
         None => false,
-    }
+    };
+
+    let is_not_behind = match (
+        metadata.latest_completed_slot,
+        metadata.latest_observed_slot,
+    ) {
+        (Some(latest_completed_slot), Some(latest_observed_slot)) => {
+            latest_observed_slot - latest_completed_slot <= READINESS_MAX_ALLOWED_SLOT_LAG
+        }
+        _ => false,
+    };
+
+    has_completed_recently && is_not_behind
 }
 
 #[cfg(test)]
@@ -456,7 +548,10 @@ mod test {
         .await;
 
         // Check that the update_rx channel has received a message
-        assert_eq!(update_rx.recv().await, Some(()));
+        assert_eq!(
+            update_rx.recv().await,
+            Some(AggregationEvent::New { slot: 10 })
+        );
 
         // Check the price ids are stored correctly
         assert_eq!(

+ 18 - 11
hermes/src/api.rs

@@ -1,6 +1,7 @@
 use {
     self::ws::notify_updates,
     crate::{
+        aggregate::AggregationEvent,
         config::RunOptions,
         state::State,
     },
@@ -48,7 +49,11 @@ impl ApiState {
 /// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
 /// packages they are based on (tokio & hyper).
 #[tracing::instrument(skip(opts, state, update_rx))]
-pub async fn run(opts: RunOptions, state: Arc<State>, mut update_rx: Receiver<()>) -> Result<()> {
+pub async fn run(
+    opts: RunOptions,
+    state: Arc<State>,
+    mut update_rx: Receiver<AggregationEvent>,
+) -> Result<()> {
     tracing::info!(endpoint = %opts.api_addr, "Starting RPC Server.");
 
     #[derive(OpenApi)]
@@ -103,19 +108,21 @@ pub async fn run(opts: RunOptions, state: Arc<State>, mut update_rx: Receiver<()
         // default value for this parameter).
         .layer(Extension(QsQueryConfig::new(5, false)));
 
-    // Call dispatch updates to websocket every 1 seconds
-    // FIXME use a channel to get updates from the store
     tokio::spawn(async move {
         while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-            // Causes a full application shutdown if an error occurs, we can't recover from this so
-            // we just quit.
-            if update_rx.recv().await.is_none() {
-                tracing::error!("Failed to receive update from store.");
-                crate::SHOULD_EXIT.store(true, Ordering::Release);
-                break;
+            match update_rx.recv().await {
+                None => {
+                    // When the received message is None it means the channel has been closed. This
+                    // should never happen as the channel is never closed. As we can't recover from
+                    // this we shut down the application.
+                    tracing::error!("Failed to receive update from store.");
+                    crate::SHOULD_EXIT.store(true, Ordering::Release);
+                    break;
+                }
+                Some(event) => {
+                    notify_updates(state.ws.clone(), event).await;
+                }
             }
-
-            notify_updates(state.ws.clone()).await;
         }
 
         tracing::info!("Shutting down websocket updates...")

+ 96 - 78
hermes/src/api/ws.rs

@@ -4,7 +4,10 @@ use {
         RpcPriceFeed,
     },
     crate::{
-        aggregate::RequestTime,
+        aggregate::{
+            AggregationEvent,
+            RequestTime,
+        },
         state::State,
     },
     anyhow::{
@@ -54,6 +57,64 @@ use {
 pub const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
 pub const NOTIFICATIONS_CHAN_LEN: usize = 1000;
 
+#[derive(Clone)]
+pub struct PriceFeedClientConfig {
+    verbose:            bool,
+    binary:             bool,
+    allow_out_of_order: bool,
+}
+
+pub struct WsState {
+    pub subscriber_counter: AtomicUsize,
+    pub subscribers:        DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
+}
+
+impl WsState {
+    pub fn new() -> Self {
+        Self {
+            subscriber_counter: AtomicUsize::new(0),
+            subscribers:        DashMap::new(),
+        }
+    }
+}
+
+
+#[derive(Deserialize, Debug, Clone)]
+#[serde(tag = "type")]
+enum ClientMessage {
+    #[serde(rename = "subscribe")]
+    Subscribe {
+        ids:                Vec<PriceIdInput>,
+        #[serde(default)]
+        verbose:            bool,
+        #[serde(default)]
+        binary:             bool,
+        #[serde(default)]
+        allow_out_of_order: bool,
+    },
+    #[serde(rename = "unsubscribe")]
+    Unsubscribe { ids: Vec<PriceIdInput> },
+}
+
+
+#[derive(Serialize, Debug, Clone)]
+#[serde(tag = "type")]
+enum ServerMessage {
+    #[serde(rename = "response")]
+    Response(ServerResponseMessage),
+    #[serde(rename = "price_update")]
+    PriceUpdate { price_feed: RpcPriceFeed },
+}
+
+#[derive(Serialize, Debug, Clone)]
+#[serde(tag = "status")]
+enum ServerResponseMessage {
+    #[serde(rename = "success")]
+    Success,
+    #[serde(rename = "error")]
+    Err { error: String },
+}
+
 pub async fn ws_route_handler(
     ws: WebSocketUpgrade,
     AxumState(state): AxumState<super::ApiState>,
@@ -67,7 +128,7 @@ async fn websocket_handler(stream: WebSocket, state: super::ApiState) {
     let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
     tracing::debug!(id, "New Websocket Connection");
 
-    let (notify_sender, notify_receiver) = mpsc::channel::<()>(NOTIFICATIONS_CHAN_LEN);
+    let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
     let (sender, receiver) = stream.split();
     let mut subscriber =
         Subscriber::new(id, state.state.clone(), notify_receiver, receiver, sender);
@@ -84,7 +145,7 @@ pub struct Subscriber {
     id:                      SubscriberId,
     closed:                  bool,
     store:                   Arc<State>,
-    notify_receiver:         mpsc::Receiver<()>,
+    notify_receiver:         mpsc::Receiver<AggregationEvent>,
     receiver:                SplitStream<WebSocket>,
     sender:                  SplitSink<WebSocket, Message>,
     price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
@@ -96,7 +157,7 @@ impl Subscriber {
     pub fn new(
         id: SubscriberId,
         store: Arc<State>,
-        notify_receiver: mpsc::Receiver<()>,
+        notify_receiver: mpsc::Receiver<AggregationEvent>,
         receiver: SplitStream<WebSocket>,
         sender: SplitSink<WebSocket, Message>,
     ) -> Self {
@@ -125,11 +186,11 @@ impl Subscriber {
 
     async fn handle_next(&mut self) -> Result<()> {
         tokio::select! {
-            maybe_update_feeds = self.notify_receiver.recv() => {
-                if maybe_update_feeds.is_none() {
-                    return Err(anyhow!("Update channel closed. This should never happen. Closing connection."));
-                };
-                self.handle_price_feeds_update().await
+            maybe_update_feeds_event = self.notify_receiver.recv() => {
+                match maybe_update_feeds_event {
+                    Some(event) => self.handle_price_feeds_update(event).await,
+                    None => Err(anyhow!("Update channel closed. This should never happen. Closing connection."))
+                }
             },
             maybe_message_or_err = self.receiver.next() => {
                 self.handle_client_message(
@@ -147,12 +208,12 @@ impl Subscriber {
         }
     }
 
-    async fn handle_price_feeds_update(&mut self) -> Result<()> {
+    async fn handle_price_feeds_update(&mut self, event: AggregationEvent) -> Result<()> {
         let price_feed_ids = self.price_feeds_with_config.keys().cloned().collect();
         for update in crate::aggregate::get_price_feeds_with_update_data(
             &*self.store,
             price_feed_ids,
-            RequestTime::Latest,
+            RequestTime::AtSlot(event.slot()),
         )
         .await?
         .price_feeds
@@ -164,6 +225,12 @@ impl Subscriber {
                     "Config missing, price feed list was poisoned during iteration."
                 ))?;
 
+            if let AggregationEvent::OutOfOrder { slot: _ } = event {
+                if !config.allow_out_of_order {
+                    continue;
+                }
+            }
+
             // `sender.feed` buffers a message to the client but does not flush it, so we can send
             // multiple messages and flush them all at once.
             self.sender
@@ -231,6 +298,7 @@ impl Subscriber {
                 ids,
                 verbose,
                 binary,
+                allow_out_of_order,
             }) => {
                 let price_ids: Vec<PriceIdentifier> = ids.into_iter().map(|id| id.into()).collect();
                 let available_price_ids = crate::aggregate::get_price_feed_ids(&*self.store).await;
@@ -259,8 +327,14 @@ impl Subscriber {
                     return Ok(());
                 } else {
                     for price_id in price_ids {
-                        self.price_feeds_with_config
-                            .insert(price_id, PriceFeedClientConfig { verbose, binary });
+                        self.price_feeds_with_config.insert(
+                            price_id,
+                            PriceFeedClientConfig {
+                                verbose,
+                                binary,
+                                allow_out_of_order,
+                            },
+                        );
                     }
                 }
             }
@@ -283,13 +357,12 @@ impl Subscriber {
     }
 }
 
-pub async fn notify_updates(ws_state: Arc<WsState>) {
-    let closed_subscribers: Vec<Option<SubscriberId>> = join_all(
-        ws_state
-            .subscribers
-            .iter_mut()
-            .map(|subscriber| async move {
-                match subscriber.send(()).await {
+pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
+    let closed_subscribers: Vec<Option<SubscriberId>> =
+        join_all(ws_state.subscribers.iter_mut().map(|subscriber| {
+            let event = event.clone();
+            async move {
+                match subscriber.send(event).await {
                     Ok(_) => None,
                     Err(_) => {
                         // An error here indicates the channel is closed (which may happen either when the
@@ -299,9 +372,9 @@ pub async fn notify_updates(ws_state: Arc<WsState>) {
                         Some(*subscriber.key())
                     }
                 }
-            }),
-    )
-    .await;
+            }
+        }))
+        .await;
 
     // Remove closed_subscribers from ws_state
     closed_subscribers.into_iter().for_each(|id| {
@@ -310,58 +383,3 @@ pub async fn notify_updates(ws_state: Arc<WsState>) {
         }
     });
 }
-
-#[derive(Clone)]
-pub struct PriceFeedClientConfig {
-    verbose: bool,
-    binary:  bool,
-}
-
-pub struct WsState {
-    pub subscriber_counter: AtomicUsize,
-    pub subscribers:        DashMap<SubscriberId, mpsc::Sender<()>>,
-}
-
-impl WsState {
-    pub fn new() -> Self {
-        Self {
-            subscriber_counter: AtomicUsize::new(0),
-            subscribers:        DashMap::new(),
-        }
-    }
-}
-
-
-#[derive(Deserialize, Debug, Clone)]
-#[serde(tag = "type")]
-enum ClientMessage {
-    #[serde(rename = "subscribe")]
-    Subscribe {
-        ids:     Vec<PriceIdInput>,
-        #[serde(default)]
-        verbose: bool,
-        #[serde(default)]
-        binary:  bool,
-    },
-    #[serde(rename = "unsubscribe")]
-    Unsubscribe { ids: Vec<PriceIdInput> },
-}
-
-
-#[derive(Serialize, Debug, Clone)]
-#[serde(tag = "type")]
-enum ServerMessage {
-    #[serde(rename = "response")]
-    Response(ServerResponseMessage),
-    #[serde(rename = "price_update")]
-    PriceUpdate { price_feed: RpcPriceFeed },
-}
-
-#[derive(Serialize, Debug, Clone)]
-#[serde(tag = "status")]
-enum ServerResponseMessage {
-    #[serde(rename = "success")]
-    Success,
-    #[serde(rename = "error")]
-    Err { error: String },
-}

+ 1 - 3
hermes/src/main.rs

@@ -47,14 +47,12 @@ async fn init() -> Result<()> {
             // Initialize a cache store with a 1000 element circular buffer.
             let store = State::new(update_tx.clone(), 1000, opts.benchmarks_endpoint.clone());
 
-            // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown. We
-            // also send off any notifications needed to close off any waiting tasks.
+            // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
             spawn(async move {
                 tracing::info!("Registered shutdown signal handler...");
                 tokio::signal::ctrl_c().await.unwrap();
                 tracing::info!("Shut down signal received, waiting for tasks...");
                 SHOULD_EXIT.store(true, std::sync::atomic::Ordering::Release);
-                let _ = update_tx.send(()).await;
             });
 
             // Spawn all worker tasks, and wait for all to complete (which will happen if a shutdown

+ 14 - 12
hermes/src/state.rs

@@ -1,12 +1,14 @@
 //! This module contains the global state of the application.
 
-#[cfg(test)]
-use mock_instant::Instant;
-#[cfg(not(test))]
-use std::time::Instant;
 use {
     self::cache::Cache,
-    crate::wormhole::GuardianSet,
+    crate::{
+        aggregate::{
+            AggregateState,
+            AggregationEvent,
+        },
+        wormhole::GuardianSet,
+    },
     reqwest::Url,
     std::{
         collections::{
@@ -37,10 +39,10 @@ pub struct State {
     pub guardian_set: RwLock<BTreeMap<u32, GuardianSet>>,
 
     /// The sender to the channel between Store and Api to notify completed updates.
-    pub update_tx: Sender<()>,
+    pub api_update_tx: Sender<AggregationEvent>,
 
-    /// Time of the last completed update. This is used for the health probes.
-    pub last_completed_update_at: RwLock<Option<Instant>>,
+    /// The aggregate module state.
+    pub aggregate_state: RwLock<AggregateState>,
 
     /// Benchmarks endpoint
     pub benchmarks_endpoint: Option<Url>,
@@ -48,7 +50,7 @@ pub struct State {
 
 impl State {
     pub fn new(
-        update_tx: Sender<()>,
+        update_tx: Sender<AggregationEvent>,
         cache_size: u64,
         benchmarks_endpoint: Option<Url>,
     ) -> Arc<Self> {
@@ -56,8 +58,8 @@ impl State {
             cache: Cache::new(cache_size),
             observed_vaa_seqs: RwLock::new(Default::default()),
             guardian_set: RwLock::new(Default::default()),
-            update_tx,
-            last_completed_update_at: RwLock::new(None),
+            api_update_tx: update_tx,
+            aggregate_state: RwLock::new(AggregateState::new()),
             benchmarks_endpoint,
         })
     }
@@ -71,7 +73,7 @@ pub mod test {
         tokio::sync::mpsc::Receiver,
     };
 
-    pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<()>) {
+    pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<AggregationEvent>) {
         let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
         let state = State::new(update_tx, cache_size, None);
 

+ 61 - 0
hermes/src/state/cache.rs

@@ -136,6 +136,15 @@ fn retrieve_message_state(
                         .value()
                         .cloned()
                 }
+                RequestTime::AtSlot(slot) => {
+                    // Get the state with slot equal to the lookup slot.
+                    key_cache
+                        .iter()
+                        .rev() // Usually the slot lies at the end of the map
+                        .find(|(k, _)| k.slot == slot)
+                        .map(|(_, v)| v)
+                        .cloned()
+                }
             }
         }
         None => None,
@@ -458,6 +467,58 @@ mod test {
         }
     }
 
+    #[tokio::test]
+    pub async fn test_store_and_retrieve_at_slot_message_state_works() {
+        // Initialize state with a cache size of 2 per key.
+        let (state, _) = setup_state(2).await;
+
+        // Create and store a message state with feed id [1....] and publish time 10 at slot 5.
+        let old_message_state =
+            create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 10, 5).await;
+
+        // Create and store a message state with feed id [1....] and publish time 13 at slot 10.
+        let new_message_state =
+            create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 13, 10).await;
+
+        // The first message state at slot 5 should be the old message state.
+        assert_eq!(
+            state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::AtSlot(5),
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage)
+                )
+                .await
+                .unwrap(),
+            vec![old_message_state]
+        );
+
+        // Querying the slot at for slots 6..9 should all return None.
+        for request_slot in 6..10 {
+            assert!(state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::AtSlot(request_slot),
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage)
+                )
+                .await
+                .is_err());
+        }
+
+        // The first message state at slot 10 should be the new message state.
+        assert_eq!(
+            state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::AtSlot(10),
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage)
+                )
+                .await
+                .unwrap(),
+            vec![new_message_state]
+        );
+    }
+
     #[tokio::test]
     pub async fn test_store_and_retrieve_latest_message_state_with_same_pubtime_works() {
         // Initialize state with a cache size of 2 per key.