瀏覽代碼

Merge pull request #2612 from pyth-network/hermes-ws

feat(apps/hermes): add connection timeout for SSE & WebSocket connections
Daniel Chew 6 月之前
父節點
當前提交
7cd897736f

+ 1 - 1
apps/hermes/server/Cargo.lock

@@ -1868,7 +1868,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
 
 [[package]]
 name = "hermes"
-version = "0.8.5"
+version = "0.8.6"
 dependencies = [
  "anyhow",
  "async-trait",

+ 1 - 1
apps/hermes/server/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name        = "hermes"
-version     = "0.8.5"
+version     = "0.8.6"
 description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
 edition     = "2021"
 

+ 18 - 3
apps/hermes/server/src/api/rest/v2/sse.rs

@@ -19,12 +19,15 @@ use {
     pyth_sdk::PriceIdentifier,
     serde::Deserialize,
     serde_qs::axum::QsQuery,
-    std::convert::Infallible,
-    tokio::sync::broadcast,
+    std::{convert::Infallible, time::Duration},
+    tokio::{sync::broadcast, time::Instant},
     tokio_stream::{wrappers::BroadcastStream, StreamExt as _},
     utoipa::IntoParams,
 };
 
+// Constants
+const MAX_CONNECTION_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours
+
 #[derive(Debug, Deserialize, IntoParams)]
 #[into_params(parameter_in = Query)]
 pub struct StreamPriceUpdatesQueryParams {
@@ -75,6 +78,9 @@ fn default_true() -> bool {
     params(StreamPriceUpdatesQueryParams)
 )]
 /// SSE route handler for streaming price updates.
+///
+/// The connection will automatically close after 24 hours to prevent resource leaks.
+/// Clients should implement reconnection logic to maintain continuous price updates.
 pub async fn price_stream_sse_handler<S>(
     State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<StreamPriceUpdatesQueryParams>,
@@ -93,7 +99,11 @@ where
     // Convert the broadcast receiver into a Stream
     let stream = BroadcastStream::new(update_rx);
 
+    // Set connection start time
+    let start_time = Instant::now();
+
     let sse_stream = stream
+        .take_while(move |_| start_time.elapsed() < MAX_CONNECTION_DURATION)
         .then(move |message| {
             let state_clone = state.clone(); // Clone again to use inside the async block
             let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
@@ -122,7 +132,12 @@ where
                 }
             }
         })
-        .filter_map(|x| x);
+        .filter_map(|x| x)
+        .chain(futures::stream::once(async {
+            Ok(Event::default()
+                .event("error")
+                .data("Connection timeout reached (24h)"))
+        }));
 
     Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
 }

+ 27 - 1
apps/hermes/server/src/api/ws.rs

@@ -40,11 +40,15 @@ use {
         },
         time::Duration,
     },
-    tokio::sync::{broadcast::Receiver, watch},
+    tokio::{
+        sync::{broadcast::Receiver, watch},
+        time::Instant,
+    },
 };
 
 const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
 const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
+const MAX_CONNECTION_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours
 
 /// The maximum number of bytes that can be sent per second per IP address.
 /// If the limit is exceeded, the connection is closed.
@@ -252,6 +256,7 @@ pub struct Subscriber<S> {
     sender: SplitSink<WebSocket, Message>,
     price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
     ping_interval: tokio::time::Interval,
+    connection_deadline: Instant,
     exit: watch::Receiver<bool>,
     responded_to_ping: bool,
 }
@@ -280,6 +285,7 @@ where
             sender,
             price_feeds_with_config: HashMap::new(),
             ping_interval: tokio::time::interval(PING_INTERVAL_DURATION),
+            connection_deadline: Instant::now() + MAX_CONNECTION_DURATION,
             exit: crate::EXIT.subscribe(),
             responded_to_ping: true, // We start with true so we don't close the connection immediately
         }
@@ -325,6 +331,26 @@ where
                 self.sender.send(Message::Ping(vec![])).await?;
                 Ok(())
             },
+            _ = tokio::time::sleep_until(self.connection_deadline) => {
+                tracing::info!(
+                    id = self.id,
+                    ip = ?self.ip_addr,
+                    "Connection timeout reached (24h). Closing connection.",
+                );
+                self.sender
+                    .send(
+                        serde_json::to_string(&ServerMessage::Response(
+                            ServerResponseMessage::Err {
+                                error: "Connection timeout reached (24h)".to_string(),
+                            },
+                        ))?
+                        .into(),
+                    )
+                    .await?;
+                self.sender.close().await?;
+                self.closed = true;
+                Ok(())
+            },
             _ = self.exit.changed() => {
                 self.sender.close().await?;
                 self.closed = true;