Pārlūkot izejas kodu

feat(sse): implement connection timeout for SSE streams with a maximum duration

Daniel Chew 7 mēneši atpakaļ
vecāks
revīzija
0967ecfda0
1 mainītis faili ar 20 papildinājumiem un 5 dzēšanām
  1. 20 5
      apps/hermes/server/src/api/rest/v2/sse.rs

+ 20 - 5
apps/hermes/server/src/api/rest/v2/sse.rs

@@ -19,12 +19,15 @@ use {
     pyth_sdk::PriceIdentifier,
     serde::Deserialize,
     serde_qs::axum::QsQuery,
-    std::convert::Infallible,
-    tokio::sync::broadcast,
+    std::{convert::Infallible, time::Duration},
+    tokio::{sync::broadcast, time::Instant},
     tokio_stream::{wrappers::BroadcastStream, StreamExt as _},
     utoipa::IntoParams,
 };
 
+// Constants
+const MAX_CONNECTION_DURATION: Duration = Duration::from_secs(10); // 24 hours
+
 #[derive(Debug, Deserialize, IntoParams)]
 #[into_params(parameter_in = Query)]
 pub struct StreamPriceUpdatesQueryParams {
@@ -93,10 +96,17 @@ where
     // Convert the broadcast receiver into a Stream
     let stream = BroadcastStream::new(update_rx);
 
+    // Set connection deadline
+    let connection_deadline = Instant::now() + MAX_CONNECTION_DURATION;
+
     let sse_stream = stream
+        .take_while(move |_| {
+            let now = Instant::now();
+            now < connection_deadline
+        })
         .then(move |message| {
-            let state_clone = state.clone(); // Clone again to use inside the async block
-            let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
+            let state_clone = state.clone();
+            let price_ids_clone = price_ids.clone();
             async move {
                 match message {
                     Ok(event) => {
@@ -122,7 +132,12 @@ where
                 }
             }
         })
-        .filter_map(|x| x);
+        .filter_map(|x| x)
+        .chain(futures::stream::once(async {
+            Ok(Event::default()
+                .event("error")
+                .data("Connection timeout reached (24h)"))
+        }));
 
     Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
 }