Forráskód Böngészése

fix: unify deduplication across all connections

Co-Authored-By: Tejas Badadare <tejas@dourolabs.xyz>
Devin AI 9 hónapja
szülő
commit
b0280521cf
2 módosított fájl, 224 hozzáadás és 70 törlés
  1. 223 69
      lazer/sdk/rust/consumer/src/client.rs
  2. 1 1
      lazer/sdk/rust/consumer/src/lib.rs

+ 223 - 69
lazer/sdk/rust/consumer/src/client.rs

@@ -1,24 +1,26 @@
 use {
     anyhow::{Context, Result},
-    tokio_stream,
     futures_util::{SinkExt, StreamExt},
     pyth_lazer_protocol::{
-        router::{Channel, PriceFeedId, PriceFeedProperty, SubscriptionParams, SubscriptionParamsRepr, JsonUpdate},
-        subscription::{Request, Response, SubscriptionId, SubscribeRequest, UnsubscribeRequest, StreamUpdatedResponse},
+        router::{
+            Channel, JsonUpdate, PriceFeedId, PriceFeedProperty, SubscriptionParams,
+            SubscriptionParamsRepr,
+        },
+        subscription::{
+            Request, Response, StreamUpdatedResponse, SubscribeRequest, SubscriptionId,
+            UnsubscribeRequest,
+        },
     },
     std::{
         sync::Arc,
         time::{Duration, Instant},
     },
+    tokio::sync::Mutex,
     tokio::{
         net::TcpStream,
         sync::{mpsc, RwLock},
     },
-    tokio_tungstenite::{
-        connect_async,
-        tungstenite::Message,
-        MaybeTlsStream, WebSocketStream,
-    },
+    tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream},
     tracing::{error, info, warn},
     ttl_cache::TtlCache,
     url::Url,
@@ -36,7 +38,6 @@ type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
 struct WebSocketState {
     stream: WsStream,
     created_at: Instant,
-    seen_updates: TtlCache<u64, bool>,
 }
 
 pub struct RedundantLazerClient {
@@ -47,14 +48,32 @@ pub struct RedundantLazerClient {
     _timeout: Duration,
     endpoint: Url,
     subscription_id: Option<SubscriptionId>,
+    seen_updates: Arc<RwLock<TtlCache<String, bool>>>,
+    was_all_down: Arc<RwLock<bool>>,
+    connection_state_tx: mpsc::UnboundedSender<bool>,
+    connection_state_rx: Arc<Mutex<mpsc::UnboundedReceiver<bool>>>,
+}
+
+impl Clone for RedundantLazerClient {
+    fn clone(&self) -> Self {
+        Self {
+            connections: self.connections.clone(),
+            feed_ids: self.feed_ids.clone(),
+            properties: self.properties.clone(),
+            channel: self.channel,
+            _timeout: self._timeout,
+            endpoint: self.endpoint.clone(),
+            subscription_id: self.subscription_id,
+            seen_updates: self.seen_updates.clone(),
+            was_all_down: self.was_all_down.clone(),
+            connection_state_tx: self.connection_state_tx.clone(),
+            connection_state_rx: Arc::clone(&self.connection_state_rx),
+        }
+    }
 }
 
 impl RedundantLazerClient {
-    pub async fn new(
-        endpoint: Url,
-        num_connections: usize,
-        timeout: Duration,
-    ) -> Result<Self> {
+    pub async fn new(endpoint: Url, num_connections: usize, timeout: Duration) -> Result<Self> {
         if num_connections > MAX_NUM_CONNECTIONS {
             anyhow::bail!("too many connections requested");
         }
@@ -65,22 +84,62 @@ impl RedundantLazerClient {
             connections.push(Arc::new(RwLock::new(WebSocketState {
                 stream,
                 created_at: Instant::now(),
-                seen_updates: TtlCache::new(DEDUP_CACHE_SIZE),
             })));
         }
 
+        let (connection_state_tx, rx) = mpsc::unbounded_channel();
+        let connection_state_rx = Arc::new(Mutex::new(rx));
         Ok(Self {
             connections,
-
             feed_ids: Vec::new(),
             properties: Vec::new(),
             channel: Channel::FixedRate(pyth_lazer_protocol::router::FixedRate::MIN),
             _timeout: timeout,
             endpoint,
             subscription_id: None,
+            seen_updates: Arc::new(RwLock::new(TtlCache::new(DEDUP_CACHE_SIZE))),
+            was_all_down: Arc::new(RwLock::new(false)),
+            connection_state_tx,
+            connection_state_rx,
         })
     }
 
+    pub fn get_connection_state_receiver(&self) -> mpsc::UnboundedReceiver<bool> {
+        let (tx, rx) = mpsc::unbounded_channel();
+        let state_rx = Arc::clone(&self.connection_state_rx);
+        tokio::spawn(async move {
+            let mut rx = state_rx.lock().await;
+            while let Some(state) = rx.recv().await {
+                if tx.send(state).is_err() {
+                    break;
+                }
+            }
+        });
+        rx
+    }
+
+    #[allow(dead_code)]
+    async fn check_connection_states(&self) {
+        let mut all_down = true;
+        for connection in &self.connections {
+            let state = connection.read().await;
+            if Instant::now().duration_since(state.created_at) <= CONNECTION_TTL {
+                all_down = false;
+                break;
+            }
+        }
+
+        let mut was_down = self.was_all_down.write().await;
+        if all_down && !*was_down {
+            *was_down = true;
+            error!("All WebSocket connections are down or reconnecting");
+            let _ = self.connection_state_tx.send(true);
+        } else if !all_down && *was_down {
+            *was_down = false;
+            let _ = self.connection_state_tx.send(false);
+        }
+    }
+
     async fn reconnect_connection(
         endpoint: &Url,
         connection: Arc<RwLock<WebSocketState>>,
@@ -113,7 +172,9 @@ impl RedundantLazerClient {
         channel: Channel,
     ) -> Result<()> {
         if self.subscription_id.is_some() {
-            return self.update_subscription(feed_ids, properties, channel).await;
+            return self
+                .update_subscription(feed_ids, properties, channel)
+                .await;
         }
 
         let subscription_id = SubscriptionId(1);
@@ -125,7 +186,8 @@ impl RedundantLazerClient {
             json_binary_encoding: Default::default(),
             parsed: true,
             channel,
-        }).map_err(|e| anyhow::anyhow!(e))?;
+        })
+        .map_err(|e| anyhow::anyhow!(e))?;
 
         for connection in &self.connections {
             let mut state = connection.write().await;
@@ -147,7 +209,9 @@ impl RedundantLazerClient {
         properties: &[PriceFeedProperty],
         channel: Channel,
     ) -> Result<()> {
-        let subscription_id = self.subscription_id.ok_or_else(|| anyhow::anyhow!("no active subscription"))?;
+        let subscription_id = self
+            .subscription_id
+            .ok_or_else(|| anyhow::anyhow!("no active subscription"))?;
         let params = SubscriptionParams::new(SubscriptionParamsRepr {
             price_feed_ids: feed_ids.to_vec(),
             properties: properties.to_vec(),
@@ -156,7 +220,8 @@ impl RedundantLazerClient {
             json_binary_encoding: Default::default(),
             parsed: true,
             channel,
-        }).map_err(|e| anyhow::anyhow!(e))?;
+        })
+        .map_err(|e| anyhow::anyhow!(e))?;
 
         for connection in &self.connections {
             let mut state = connection.write().await;
@@ -180,7 +245,9 @@ impl RedundantLazerClient {
                 let mut state = connection.write().await;
                 let request = Request::Unsubscribe(UnsubscribeRequest { subscription_id });
                 let message = serde_json::to_string(&request)?;
-                state.stream.send(Message::Text(message))
+                state
+                    .stream
+                    .send(Message::Text(message))
                     .await
                     .context("failed to send unsubscribe request")?;
             }
@@ -190,18 +257,34 @@ impl RedundantLazerClient {
         Ok(())
     }
 
-    pub async fn into_stream(
-        self,
-    ) -> Result<impl futures_util::Stream<Item = JsonUpdate>> {
+    pub async fn into_stream(self) -> Result<impl futures_util::Stream<Item = Result<JsonUpdate>>> {
         let (tx, rx) = mpsc::channel(STREAM_POOL_CHANNEL_SIZE);
         let mut response_rx = self.start().await?;
-        
         tokio::spawn(async move {
             while let Some(response) = response_rx.recv().await {
-                if let Response::StreamUpdated(StreamUpdatedResponse { payload, .. }) = response {
-                    if tx.send(payload).await.is_err() {
-                        break;
+                match response {
+                    Response::StreamUpdated(StreamUpdatedResponse { payload, .. }) => {
+                        if tx.send(Ok(payload)).await.is_err() {
+                            break;
+                        }
+                    }
+                    Response::SubscriptionError(error) => {
+                        let err = anyhow::anyhow!(
+                            "Error occurred for subscription ID {}: {}",
+                            error.subscription_id.0,
+                            error.error
+                        );
+                        if tx.send(Err(err)).await.is_err() {
+                            break;
+                        }
                     }
+                    Response::Error(error) => {
+                        let err = anyhow::anyhow!("Error: {}", error.error);
+                        if tx.send(Err(err)).await.is_err() {
+                            break;
+                        }
+                    }
+                    _ => {}
                 }
             }
         });
@@ -209,11 +292,12 @@ impl RedundantLazerClient {
         Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
     }
 
-    pub async fn start(
-        self,
-    ) -> Result<mpsc::Receiver<Response>> {
+    pub async fn start(self) -> Result<mpsc::Receiver<Response>> {
         let (tx, rx) = mpsc::channel(STREAM_POOL_CHANNEL_SIZE);
         let subscription_id = self.subscription_id.unwrap_or(SubscriptionId(1));
+        let seen_updates = Arc::clone(&self.seen_updates);
+        let connections = self.connections.clone();
+        let endpoint = self.endpoint.clone();
         let params = SubscriptionParams::new(SubscriptionParamsRepr {
             price_feed_ids: self.feed_ids.clone(),
             properties: self.properties.clone(),
@@ -222,18 +306,23 @@ impl RedundantLazerClient {
             json_binary_encoding: Default::default(),
             parsed: true,
             channel: self.channel,
-        }).map_err(|e| anyhow::anyhow!(e))?;
+        })
+        .map_err(|e| anyhow::anyhow!(e))?;
 
-        for connection in self.connections {
+        for connection in &connections {
             let tx = tx.clone();
-            let endpoint = self.endpoint.clone();
+            let endpoint = endpoint.clone();
             let params = params.clone();
+            let seen_updates = Arc::clone(&seen_updates);
+            let connection = Arc::clone(connection);
 
             tokio::spawn(async move {
                 loop {
                     let mut state = connection.write().await;
                     if Instant::now().duration_since(state.created_at) > CONNECTION_TTL {
-                        if let Err(e) = Self::reconnect_connection(&endpoint, connection.clone()).await {
+                        if let Err(e) =
+                            Self::reconnect_connection(&endpoint, connection.clone()).await
+                        {
                             error!("Failed to reconnect: {}", e);
                             tokio::time::sleep(RECONNECT_WAIT).await;
                             continue;
@@ -247,7 +336,7 @@ impl RedundantLazerClient {
                         params.clone(),
                     )
                     .await
-                        .context("failed to send subscription request")
+                    .context("failed to send subscription request")
                     {
                         error!("Failed to send subscribe request: {}", e);
                         tokio::time::sleep(RECONNECT_WAIT).await;
@@ -262,43 +351,79 @@ impl RedundantLazerClient {
                             Some(Ok(msg)) => match msg {
                                 Message::Text(text) => {
                                     match serde_json::from_str::<Response>(&text) {
-                                        Ok(response) => {
-                                            match &response {
-                                                Response::Subscribed(_) => {
-                                                    info!("Subscription confirmed");
+                                        Ok(response) => match &response {
+                                            Response::Subscribed(_) => {
+                                                info!("Subscription confirmed");
+                                            }
+                                            Response::SubscriptionError(error) => {
+                                                let err_msg = format!(
+                                                    "Error occurred for subscription ID {}: {}",
+                                                    error.subscription_id.0, error.error
+                                                );
+                                                error!("{}", err_msg);
+                                                if let Err(e) = tx
+                                                    .send(Response::SubscriptionError(
+                                                        error.clone(),
+                                                    ))
+                                                    .await
+                                                {
+                                                    error!(
+                                                        "Failed to forward subscription error: {}",
+                                                        e
+                                                    );
                                                 }
-                                                Response::SubscriptionError(error) => {
-                                                    error!("Subscription error: {}", error.error);
-                                                    break;
+                                                break;
+                                            }
+                                            Response::StreamUpdated(StreamUpdatedResponse {
+                                                subscription_id: _,
+                                                payload,
+                                            }) => {
+                                                let message = serde_json::to_string(
+                                                    &Response::StreamUpdated(
+                                                        StreamUpdatedResponse {
+                                                            subscription_id,
+                                                            payload: payload.clone(),
+                                                        },
+                                                    ),
+                                                )
+                                                .map_err(|e| {
+                                                    error!("Failed to serialize message: {}", e);
+                                                })
+                                                .unwrap_or_default();
+
+                                                let mut cache = seen_updates.write().await;
+                                                if cache.get(&message).is_some() {
+                                                    continue;
                                                 }
-                                                Response::StreamUpdated(StreamUpdatedResponse { subscription_id: _, payload }) => {
-                                                    // Generate a unique ID for deduplication based on update content
-                                                    let update_hash = {
-                                                        use std::hash::{Hash, Hasher};
-                                                        let mut hasher = std::collections::hash_map::DefaultHasher::new();
-                                                        payload.hash(&mut hasher);
-                                                        hasher.finish()
-                                                    };
-
-                                                    // Skip if we've seen this update recently
-                                                    if state.seen_updates.get(&update_hash).is_some() {
-                                                        continue;
-                                                    }
-
-                                                    // Insert into TTL cache and forward the update
-                                                    state.seen_updates.insert(update_hash, true, DEDUP_TTL);
-                                                    if tx.send(Response::StreamUpdated(StreamUpdatedResponse {
-                                                        subscription_id,
-                                                        payload: payload.clone(),
-                                                    })).await.is_err() {
-                                                        return;
-                                                    }
+
+                                                cache.insert(message, true, DEDUP_TTL);
+                                                if tx
+                                                    .send(Response::StreamUpdated(
+                                                        StreamUpdatedResponse {
+                                                            subscription_id,
+                                                            payload: payload.clone(),
+                                                        },
+                                                    ))
+                                                    .await
+                                                    .is_err()
+                                                {
+                                                    return;
                                                 }
-                                                _ => {
-                                                    warn!("Unexpected response type: {:?}", response);
+                                            }
+                                            Response::Error(error) => {
+                                                let err_msg = format!("Error: {}", error.error);
+                                                error!("{}", err_msg);
+                                                if let Err(e) =
+                                                    tx.send(Response::Error(error.clone())).await
+                                                {
+                                                    error!("Failed to forward error: {}", e);
                                                 }
+                                                break;
                                             }
-                                        }
+                                            _ => {
+                                                warn!("Unexpected response type: {:?}", response);
+                                            }
+                                        },
                                         Err(e) => {
                                             warn!("Failed to parse response: {}", e);
                                         }
@@ -315,7 +440,8 @@ impl RedundantLazerClient {
                         }
                     }
 
-                    if let Err(e) = Self::reconnect_connection(&endpoint, connection.clone()).await {
+                    if let Err(e) = Self::reconnect_connection(&endpoint, connection.clone()).await
+                    {
                         error!("Failed to reconnect: {}", e);
                         tokio::time::sleep(RECONNECT_WAIT).await;
                     }
@@ -323,6 +449,34 @@ impl RedundantLazerClient {
             });
         }
 
+        // Spawn connection state monitoring task
+        let connections_monitor = connections.clone();
+        let was_all_down = Arc::clone(&self.was_all_down);
+        let connection_state_tx = self.connection_state_tx.clone();
+        tokio::spawn(async move {
+            loop {
+                let mut all_down = true;
+                for connection in &connections_monitor {
+                    let state = connection.read().await;
+                    if Instant::now().duration_since(state.created_at) <= CONNECTION_TTL {
+                        all_down = false;
+                        break;
+                    }
+                }
+
+                let mut was_down = was_all_down.write().await;
+                if all_down && !*was_down {
+                    *was_down = true;
+                    error!("All WebSocket connections are down or reconnecting");
+                    let _ = connection_state_tx.send(true);
+                } else if !all_down && *was_down {
+                    *was_down = false;
+                    let _ = connection_state_tx.send(false);
+                }
+                tokio::time::sleep(Duration::from_millis(100)).await;
+            }
+        });
+
         Ok(rx)
     }
 }

+ 1 - 1
lazer/sdk/rust/consumer/src/lib.rs

@@ -1,5 +1,5 @@
 //! Rust consumer SDK for Pyth Lazer.
-//! 
+//!
 //! This SDK allows subscribing to Pyth Lazer WebSocket feeds and receiving price updates.
 
 mod client;