Quellcode durchsuchen

refactor(hermes): use broadcast channel for api notifications (#1388)

This change removes the manual broadcast implementation to send
out API notifications to WS subscribers.
Ali Behjati vor 1 Jahr
Ursprung
Commit
26c3d08f33
5 geänderte Dateien mit 34 neuen und 100 gelöschten Zeilen
  1. 6 13
      hermes/src/aggregate.rs
  2. 11 39
      hermes/src/api.rs
  3. 11 42
      hermes/src/api/ws.rs
  4. 3 3
      hermes/src/main.rs
  5. 3 3
      hermes/src/state.rs

+ 6 - 13
hermes/src/aggregate.rs

@@ -268,24 +268,17 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
     match aggregate_state.latest_completed_slot {
         None => {
             aggregate_state.latest_completed_slot.replace(slot);
-            state
-                .api_update_tx
-                .send(AggregationEvent::New { slot })
-                .await?;
+            state.api_update_tx.send(AggregationEvent::New { slot })?;
         }
         Some(latest) if slot > latest => {
             state.prune_removed_keys(message_state_keys).await;
             aggregate_state.latest_completed_slot.replace(slot);
-            state
-                .api_update_tx
-                .send(AggregationEvent::New { slot })
-                .await?;
+            state.api_update_tx.send(AggregationEvent::New { slot })?;
         }
         _ => {
             state
                 .api_update_tx
-                .send(AggregationEvent::OutOfOrder { slot })
-                .await?;
+                .send(AggregationEvent::OutOfOrder { slot })?;
         }
     }
 
@@ -583,7 +576,7 @@ mod test {
         // Check that the update_rx channel has received a message
         assert_eq!(
             update_rx.recv().await,
-            Some(AggregationEvent::New { slot: 10 })
+            Ok(AggregationEvent::New { slot: 10 })
         );
 
         // Check the price ids are stored correctly
@@ -708,7 +701,7 @@ mod test {
         // Check that the update_rx channel has received a message
         assert_eq!(
             update_rx.recv().await,
-            Some(AggregationEvent::New { slot: 10 })
+            Ok(AggregationEvent::New { slot: 10 })
         );
 
         // Check the price ids are stored correctly
@@ -745,7 +738,7 @@ mod test {
         // Check that the update_rx channel has received a message
         assert_eq!(
             update_rx.recv().await,
-            Some(AggregationEvent::New { slot: 15 })
+            Ok(AggregationEvent::New { slot: 15 })
         );
 
         // Check that price feed 2 does not exist anymore

+ 11 - 39
hermes/src/api.rs

@@ -1,5 +1,4 @@
 use {
-    self::ws::notify_updates,
     crate::{
         aggregate::AggregationEvent,
         config::RunOptions,
@@ -18,7 +17,7 @@ use {
         atomic::Ordering,
         Arc,
     },
-    tokio::sync::mpsc::Receiver,
+    tokio::sync::broadcast::Sender,
     tower_http::cors::CorsLayer,
     utoipa::OpenApi,
     utoipa_swagger_ui::SwaggerUi,
@@ -32,9 +31,10 @@ mod ws;
 
 #[derive(Clone)]
 pub struct ApiState {
-    pub state:   Arc<State>,
-    pub ws:      Arc<ws::WsState>,
-    pub metrics: Arc<metrics_middleware::Metrics>,
+    pub state:     Arc<State>,
+    pub ws:        Arc<ws::WsState>,
+    pub metrics:   Arc<metrics_middleware::Metrics>,
+    pub update_tx: Sender<AggregationEvent>,
 }
 
 impl ApiState {
@@ -42,6 +42,7 @@ impl ApiState {
         state: Arc<State>,
         ws_whitelist: Vec<IpNet>,
         requester_ip_header_name: String,
+        update_tx: Sender<AggregationEvent>,
     ) -> Self {
         Self {
             metrics: Arc::new(metrics_middleware::Metrics::new(state.clone())),
@@ -51,15 +52,16 @@ impl ApiState {
                 state.clone(),
             )),
             state,
+            update_tx,
         }
     }
 }
 
-#[tracing::instrument(skip(opts, state, update_rx))]
+#[tracing::instrument(skip(opts, state, update_tx))]
 pub async fn spawn(
     opts: RunOptions,
     state: Arc<State>,
-    mut update_rx: Receiver<AggregationEvent>,
+    update_tx: Sender<AggregationEvent>,
 ) -> Result<()> {
     let state = {
         let opts = opts.clone();
@@ -67,41 +69,11 @@ pub async fn spawn(
             state,
             opts.rpc.ws_whitelist,
             opts.rpc.requester_ip_header_name,
+            update_tx,
         )
     };
 
-    let rpc_server = tokio::spawn(run(opts, state.clone()));
-
-    let ws_notifier = tokio::spawn(async move {
-        let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
-
-        while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-            tokio::select! {
-                update = update_rx.recv() => {
-                    match update {
-                        None => {
-                            // When the received message is None it means the channel has been closed. This
-                            // should never happen as the channel is never closed. As we can't recover from
-                            // this we shut down the application.
-                            tracing::error!("Failed to receive update from store.");
-                            crate::SHOULD_EXIT.store(true, Ordering::Release);
-                            break;
-                        }
-                        Some(event) => {
-                            notify_updates(state.ws.clone(), event).await;
-                        },
-                    }
-                },
-                _ = interval.tick() => {}
-            }
-        }
-
-        tracing::info!("Shutting down Websocket notifier...")
-    });
-
-
-    let _ = tokio::join!(ws_notifier, rpc_server);
-    Ok(())
+    run(opts, state.clone()).await
 }
 
 /// This method provides a background service that responds to REST requests

+ 11 - 42
hermes/src/api/ws.rs

@@ -26,9 +26,7 @@ use {
         http::HeaderMap,
         response::IntoResponse,
     },
-    dashmap::DashMap,
     futures::{
-        future::join_all,
         stream::{
             SplitSink,
             SplitStream,
@@ -71,11 +69,10 @@ use {
         },
         time::Duration,
     },
-    tokio::sync::mpsc,
+    tokio::sync::broadcast::Receiver,
 };
 
 const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
-const NOTIFICATIONS_CHAN_LEN: usize = 1000;
 const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
 
 /// The maximum number of bytes that can be sent per second per IP address.
@@ -139,7 +136,6 @@ impl Metrics {
 
 pub struct WsState {
     pub subscriber_counter:       AtomicUsize,
-    pub subscribers:              DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
     pub bytes_limit_whitelist:    Vec<IpNet>,
     pub rate_limiter:             DefaultKeyedRateLimiter<IpAddr>,
     pub requester_ip_header_name: String,
@@ -150,7 +146,6 @@ impl WsState {
     pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<State>) -> Self {
         Self {
             subscriber_counter: AtomicUsize::new(0),
-            subscribers: DashMap::new(),
             rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
                 BYTES_LIMIT_PER_IP_PER_SECOND
             ))),
@@ -220,6 +215,11 @@ async fn websocket_handler(
     subscriber_ip: Option<IpAddr>,
 ) {
     let ws_state = state.ws.clone();
+
+    // Retain the recent rate limit data for the IP addresses to
+    // prevent the rate limiter size from growing indefinitely.
+    ws_state.rate_limiter.retain_recent();
+
     let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
 
     tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
@@ -232,7 +232,7 @@ async fn websocket_handler(
         })
         .inc();
 
-    let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
+    let notify_receiver = state.update_tx.subscribe();
     let (sender, receiver) = stream.split();
     let mut subscriber = Subscriber::new(
         id,
@@ -244,7 +244,6 @@ async fn websocket_handler(
         sender,
     );
 
-    ws_state.subscribers.insert(id, notify_sender);
     subscriber.run().await;
 }
 
@@ -258,7 +257,7 @@ pub struct Subscriber {
     closed:                  bool,
     store:                   Arc<State>,
     ws_state:                Arc<WsState>,
-    notify_receiver:         mpsc::Receiver<AggregationEvent>,
+    notify_receiver:         Receiver<AggregationEvent>,
     receiver:                SplitStream<WebSocket>,
     sender:                  SplitSink<WebSocket, Message>,
     price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
@@ -273,7 +272,7 @@ impl Subscriber {
         ip_addr: Option<IpAddr>,
         store: Arc<State>,
         ws_state: Arc<WsState>,
-        notify_receiver: mpsc::Receiver<AggregationEvent>,
+        notify_receiver: Receiver<AggregationEvent>,
         receiver: SplitStream<WebSocket>,
         sender: SplitSink<WebSocket, Message>,
     ) -> Self {
@@ -307,8 +306,8 @@ impl Subscriber {
         tokio::select! {
             maybe_update_feeds_event = self.notify_receiver.recv() => {
                 match maybe_update_feeds_event {
-                    Some(event) => self.handle_price_feeds_update(event).await,
-                    None => Err(anyhow!("Update channel closed. This should never happen. Closing connection."))
+                    Ok(event) => self.handle_price_feeds_update(event).await,
+                    Err(e) => Err(anyhow!("Failed to receive update from store: {:?}", e)),
                 }
             },
             maybe_message_or_err = self.receiver.next() => {
@@ -610,33 +609,3 @@ impl Subscriber {
         Ok(())
     }
 }
-
-pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
-    let closed_subscribers: Vec<Option<SubscriberId>> =
-        join_all(ws_state.subscribers.iter_mut().map(|subscriber| {
-            let event = event.clone();
-            async move {
-                match subscriber.send(event).await {
-                    Ok(_) => None,
-                    Err(_) => {
-                        // An error here indicates the channel is closed (which may happen either when the
-                        // client has sent Message::Close or some other abrupt disconnection). We remove
-                        // subscribers only when send fails so we can handle closure only once when we are
-                        // able to see send() fail.
-                        Some(*subscriber.key())
-                    }
-                }
-            }
-        }))
-        .await;
-
-    // Remove closed_subscribers from ws_state
-    closed_subscribers.into_iter().for_each(|id| {
-        if let Some(id) = id {
-            ws_state.subscribers.remove(&id);
-        }
-    });
-
-    // Clean the bytes limiting dictionary
-    ws_state.rate_limiter.retain_recent();
-}

+ 3 - 3
hermes/src/main.rs

@@ -44,8 +44,8 @@ async fn init() -> Result<()> {
         config::Options::Run(opts) => {
             tracing::info!("Starting hermes service...");
 
-            // The update channel is used to send store update notifications to the public API.
-            let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
+            // The update broadcast channel is used to send store update notifications to the public API.
+            let (update_tx, _) = tokio::sync::broadcast::channel(1000);
 
             // Initialize a cache store with a 1000 element circular buffer.
             let store = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
@@ -64,7 +64,7 @@ async fn init() -> Result<()> {
                 Box::pin(spawn(network::wormhole::spawn(opts.clone(), store.clone()))),
                 Box::pin(spawn(network::pythnet::spawn(opts.clone(), store.clone()))),
                 Box::pin(spawn(metrics_server::run(opts.clone(), store.clone()))),
-                Box::pin(spawn(api::spawn(opts.clone(), store.clone(), update_rx))),
+                Box::pin(spawn(api::spawn(opts.clone(), store.clone(), update_tx))),
             ])
             .await;
 

+ 3 - 3
hermes/src/state.rs

@@ -20,7 +20,7 @@ use {
         sync::Arc,
     },
     tokio::sync::{
-        mpsc::Sender,
+        broadcast::Sender,
         RwLock,
     },
 };
@@ -81,11 +81,11 @@ pub mod test {
     use {
         super::*,
         crate::network::wormhole::update_guardian_set,
-        tokio::sync::mpsc::Receiver,
+        tokio::sync::broadcast::Receiver,
     };
 
     pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<AggregationEvent>) {
-        let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
+        let (update_tx, update_rx) = tokio::sync::broadcast::channel(1000);
         let state = State::new(update_tx, cache_size, None);
 
         // Add an initial guardian set with public key 0