Explorar o código

feat(lazer): add resilient client in rust (#2859)

* feat(lazer): add resilient client in rust

* configurable backoff

* add backoff reset

* impl dedup

* fix

* use Url

* asset non empty endponts in constructor

* configurable channel capacity

* use single channel for connections

* add expo backoff builder wrapper

* add timeout

* add pyth lazer client builder

* fix backoff reset logic
Keyvan Khademi hai 4 meses
pai
achega
928f003d4d

+ 19 - 1
Cargo.lock

@@ -4423,6 +4423,12 @@ dependencies = [
  "vcpkg",
 ]
 
+[[package]]
+name = "linked-hash-map"
+version = "0.5.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
+
 [[package]]
 name = "linux-raw-sys"
 version = "0.4.15"
@@ -5647,10 +5653,11 @@ dependencies = [
 
 [[package]]
 name = "pyth-lazer-client"
-version = "0.1.3"
+version = "1.0.0"
 dependencies = [
  "alloy-primitives 0.8.25",
  "anyhow",
+ "backoff",
  "base64 0.22.1",
  "bincode 1.3.3",
  "bs58",
@@ -5665,6 +5672,8 @@ dependencies = [
  "tokio",
  "tokio-tungstenite 0.20.1",
  "tracing",
+ "tracing-subscriber",
+ "ttl_cache",
  "url",
 ]
 
@@ -10299,6 +10308,15 @@ version = "0.2.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
 
+[[package]]
+name = "ttl_cache"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a"
+dependencies = [
+ "linked-hash-map",
+]
+
 [[package]]
 name = "tungstenite"
 version = "0.20.1"

+ 5 - 1
lazer/sdk/rust/client/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name = "pyth-lazer-client"
-version = "0.1.3"
+version = "1.0.0"
 edition = "2021"
 description = "A Rust client for Pyth Lazer"
 license = "Apache-2.0"
@@ -17,6 +17,9 @@ anyhow = "1.0"
 tracing = "0.1"
 url = "2.4"
 derive_more = { version = "1.0.0", features = ["from"] }
+backoff = { version = "0.4.0", features = ["futures", "tokio"] }
+ttl_cache = "0.5.1"
+
 
 [dev-dependencies]
 bincode = "1.3.3"
@@ -25,3 +28,4 @@ hex = "0.4.3"
 libsecp256k1 = "0.7.1"
 bs58 = "0.5.1"
 alloy-primitives = "0.8.19"
+tracing-subscriber = { version = "0.3.19", features = ["env-filter", "json"] }

+ 36 - 12
lazer/sdk/rust/client/examples/subscribe_price_feeds.rs

@@ -1,6 +1,9 @@
+use std::time::Duration;
+
 use base64::Engine;
-use futures_util::StreamExt;
-use pyth_lazer_client::{AnyResponse, LazerClient};
+use pyth_lazer_client::backoff::PythLazerExponentialBackoffBuilder;
+use pyth_lazer_client::client::PythLazerClientBuilder;
+use pyth_lazer_client::ws_connection::AnyResponse;
 use pyth_lazer_protocol::message::{
     EvmMessage, LeEcdsaMessage, LeUnsignedMessage, Message, SolanaMessage,
 };
@@ -9,8 +12,10 @@ use pyth_lazer_protocol::router::{
     Channel, DeliveryFormat, FixedRate, Format, JsonBinaryEncoding, PriceFeedId, PriceFeedProperty,
     SubscriptionParams, SubscriptionParamsRepr,
 };
-use pyth_lazer_protocol::subscription::{Request, Response, SubscribeRequest, SubscriptionId};
+use pyth_lazer_protocol::subscription::{Response, SubscribeRequest, SubscriptionId};
 use tokio::pin;
+use tracing::level_filters::LevelFilter;
+use tracing_subscriber::EnvFilter;
 
 fn get_lazer_access_token() -> String {
     // Place your access token in your env at LAZER_ACCESS_TOKEN or set it here
@@ -20,11 +25,32 @@ fn get_lazer_access_token() -> String {
 
 #[tokio::main]
 async fn main() -> anyhow::Result<()> {
+    tracing_subscriber::fmt()
+        .with_env_filter(
+            EnvFilter::builder()
+                .with_default_directive(LevelFilter::INFO.into())
+                .from_env()?,
+        )
+        .json()
+        .init();
+
     // Create and start the client
-    let mut client = LazerClient::new(
-        "wss://pyth-lazer.dourolabs.app/v1/stream",
-        &get_lazer_access_token(),
-    )?;
+    let mut client = PythLazerClientBuilder::new(get_lazer_access_token())
+        // Optionally override the default endpoints
+        .with_endpoints(vec![
+            "wss://pyth-lazer-0.dourolabs.app/v1/stream".parse()?,
+            "wss://pyth-lazer-1.dourolabs.app/v1/stream".parse()?,
+        ])
+        // Optionally set the number of connections
+        .with_num_connections(4)
+        // Optionally set the backoff strategy
+        .with_backoff(PythLazerExponentialBackoffBuilder::default().build())
+        // Optionally set the timeout for each connection
+        .with_timeout(Duration::from_secs(5))
+        // Optionally set the channel capacity for responses
+        .with_channel_capacity(1000)
+        .build()?;
+
     let stream = client.start().await?;
     pin!(stream);
 
@@ -72,16 +98,16 @@ async fn main() -> anyhow::Result<()> {
     ];
 
     for req in subscription_requests {
-        client.subscribe(Request::Subscribe(req)).await?;
+        client.subscribe(req).await?;
     }
 
     println!("Subscribed to price feeds. Waiting for updates...");
 
     // Process the first few updates
     let mut count = 0;
-    while let Some(msg) = stream.next().await {
+    while let Some(msg) = stream.recv().await {
         // The stream gives us base64-encoded binary messages. We need to decode, parse, and verify them.
-        match msg? {
+        match msg {
             AnyResponse::Json(msg) => match msg {
                 Response::StreamUpdated(update) => {
                     println!("Received a JSON update for {:?}", update.subscription_id);
@@ -189,8 +215,6 @@ async fn main() -> anyhow::Result<()> {
         println!("Unsubscribed from {sub_id:?}");
     }
 
-    tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
-    client.close().await?;
     Ok(())
 }
 

+ 69 - 0
lazer/sdk/rust/client/src/backoff.rs

@@ -0,0 +1,69 @@
+use std::time::Duration;
+
+use backoff::{
+    default::{INITIAL_INTERVAL_MILLIS, MAX_INTERVAL_MILLIS, MULTIPLIER, RANDOMIZATION_FACTOR},
+    ExponentialBackoff, ExponentialBackoffBuilder,
+};
+
+#[derive(Debug)]
+pub struct PythLazerExponentialBackoffBuilder {
+    initial_interval: Duration,
+    randomization_factor: f64,
+    multiplier: f64,
+    max_interval: Duration,
+}
+
+impl Default for PythLazerExponentialBackoffBuilder {
+    fn default() -> Self {
+        Self {
+            initial_interval: Duration::from_millis(INITIAL_INTERVAL_MILLIS),
+            randomization_factor: RANDOMIZATION_FACTOR,
+            multiplier: MULTIPLIER,
+            max_interval: Duration::from_millis(MAX_INTERVAL_MILLIS),
+        }
+    }
+}
+
+impl PythLazerExponentialBackoffBuilder {
+    pub fn new() -> Self {
+        Default::default()
+    }
+
+    /// The initial retry interval.
+    pub fn with_initial_interval(&mut self, initial_interval: Duration) -> &mut Self {
+        self.initial_interval = initial_interval;
+        self
+    }
+
+    /// The randomization factor to use for creating a range around the retry interval.
+    ///
+    /// A randomization factor of 0.5 results in a random period ranging between 50% below and 50%
+    /// above the retry interval.
+    pub fn with_randomization_factor(&mut self, randomization_factor: f64) -> &mut Self {
+        self.randomization_factor = randomization_factor;
+        self
+    }
+
+    /// The value to multiply the current interval with for each retry attempt.
+    pub fn with_multiplier(&mut self, multiplier: f64) -> &mut Self {
+        self.multiplier = multiplier;
+        self
+    }
+
+    /// The maximum value of the back off period. Once the retry interval reaches this
+    /// value it stops increasing.
+    pub fn with_max_interval(&mut self, max_interval: Duration) -> &mut Self {
+        self.max_interval = max_interval;
+        self
+    }
+
+    pub fn build(&self) -> ExponentialBackoff {
+        ExponentialBackoffBuilder::default()
+            .with_initial_interval(self.initial_interval)
+            .with_randomization_factor(self.randomization_factor)
+            .with_multiplier(self.multiplier)
+            .with_max_interval(self.max_interval)
+            .with_max_elapsed_time(None)
+            .build()
+    }
+}

+ 186 - 0
lazer/sdk/rust/client/src/client.rs

@@ -0,0 +1,186 @@
+use std::time::Duration;
+
+use crate::{
+    resilient_ws_connection::PythLazerResilientWSConnection, ws_connection::AnyResponse,
+    CHANNEL_CAPACITY,
+};
+use anyhow::{bail, Result};
+use backoff::ExponentialBackoff;
+use pyth_lazer_protocol::subscription::{SubscribeRequest, SubscriptionId};
+use tokio::sync::mpsc::{self, error::TrySendError};
+use tracing::{error, warn};
+use ttl_cache::TtlCache;
+use url::Url;
+
+const DEDUP_CACHE_SIZE: usize = 100_000;
+const DEDUP_TTL: Duration = Duration::from_secs(10);
+
+const DEFAULT_ENDPOINTS: [&str; 2] = [
+    "wss://pyth-lazer-0.dourolabs.app/v1/stream",
+    "wss://pyth-lazer-1.dourolabs.app/v1/stream",
+];
+const DEFAULT_NUM_CONNECTIONS: usize = 4;
+const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
+
+pub struct PythLazerClient {
+    endpoints: Vec<Url>,
+    access_token: String,
+    num_connections: usize,
+    ws_connections: Vec<PythLazerResilientWSConnection>,
+    backoff: ExponentialBackoff,
+    timeout: Duration,
+    channel_capacity: usize,
+}
+
+impl PythLazerClient {
+    /// Creates a new client instance
+    ///
+    /// # Arguments
+    /// * `endpoints` - A vector of endpoint URLs
+    /// * `access_token` - The access token for authentication
+    /// * `num_connections` - The number of WebSocket connections to maintain
+    pub fn new(
+        endpoints: Vec<Url>,
+        access_token: String,
+        num_connections: usize,
+        backoff: ExponentialBackoff,
+        timeout: Duration,
+        channel_capacity: usize,
+    ) -> Result<Self> {
+        if backoff.max_elapsed_time.is_some() {
+            bail!("max_elapsed_time is not supported in Pyth Lazer client");
+        }
+        if endpoints.is_empty() {
+            bail!("At least one endpoint must be provided");
+        }
+        Ok(Self {
+            endpoints,
+            access_token,
+            num_connections,
+            ws_connections: Vec::with_capacity(num_connections),
+            backoff,
+            timeout,
+            channel_capacity,
+        })
+    }
+
+    pub async fn start(&mut self) -> Result<mpsc::Receiver<AnyResponse>> {
+        let (sender, receiver) = mpsc::channel::<AnyResponse>(self.channel_capacity);
+        let (ws_connection_sender, mut ws_connection_receiver) =
+            mpsc::channel::<AnyResponse>(CHANNEL_CAPACITY);
+
+        for i in 0..self.num_connections {
+            let endpoint = self.endpoints[i % self.endpoints.len()].clone();
+            let connection = PythLazerResilientWSConnection::new(
+                endpoint,
+                self.access_token.clone(),
+                self.backoff.clone(),
+                self.timeout,
+                ws_connection_sender.clone(),
+            );
+            self.ws_connections.push(connection);
+        }
+
+        let mut seen_updates = TtlCache::new(DEDUP_CACHE_SIZE);
+
+        tokio::spawn(async move {
+            while let Some(response) = ws_connection_receiver.recv().await {
+                let cache_key = response.cache_key();
+                if seen_updates.contains_key(&cache_key) {
+                    continue;
+                }
+                seen_updates.insert(cache_key, response.clone(), DEDUP_TTL);
+
+                match sender.try_send(response) {
+                    Ok(_) => (),
+                    Err(TrySendError::Full(r)) => {
+                        warn!("Sender channel is full, responses will be delayed");
+                        if sender.send(r).await.is_err() {
+                            error!("Sender channel is closed, stopping client");
+                        }
+                    }
+                    Err(TrySendError::Closed(_)) => {
+                        error!("Sender channel is closed, stopping client");
+                    }
+                }
+            }
+        });
+
+        Ok(receiver)
+    }
+
+    pub async fn subscribe(&mut self, subscribe_request: SubscribeRequest) -> Result<()> {
+        for connection in &mut self.ws_connections {
+            connection.subscribe(subscribe_request.clone()).await?;
+        }
+        Ok(())
+    }
+
+    pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> {
+        for connection in &mut self.ws_connections {
+            connection.unsubscribe(subscription_id).await?;
+        }
+        Ok(())
+    }
+}
+
+pub struct PythLazerClientBuilder {
+    endpoints: Vec<Url>,
+    access_token: String,
+    num_connections: usize,
+    backoff: ExponentialBackoff,
+    timeout: Duration,
+    channel_capacity: usize,
+}
+
+impl PythLazerClientBuilder {
+    pub fn new(access_token: String) -> Self {
+        Self {
+            endpoints: DEFAULT_ENDPOINTS
+                .iter()
+                .map(|&s| s.parse().unwrap())
+                .collect(),
+            access_token,
+            num_connections: DEFAULT_NUM_CONNECTIONS,
+            backoff: ExponentialBackoff::default(),
+            timeout: DEFAULT_TIMEOUT,
+            channel_capacity: CHANNEL_CAPACITY,
+        }
+    }
+
+    pub fn with_endpoints(mut self, endpoints: Vec<Url>) -> Self {
+        self.endpoints = endpoints;
+        self
+    }
+
+    pub fn with_num_connections(mut self, num_connections: usize) -> Self {
+        self.num_connections = num_connections;
+        self
+    }
+
+    pub fn with_backoff(mut self, backoff: ExponentialBackoff) -> Self {
+        self.backoff = backoff;
+        self
+    }
+
+    pub fn with_timeout(mut self, timeout: Duration) -> Self {
+        self.timeout = timeout;
+        self
+    }
+
+    pub fn with_channel_capacity(mut self, channel_capacity: usize) -> Self {
+        self.channel_capacity = channel_capacity;
+        self
+    }
+
+    pub fn build(self) -> Result<PythLazerClient> {
+        PythLazerClient::new(
+            self.endpoints,
+            self.access_token,
+            self.num_connections,
+            self.backoff,
+            self.timeout,
+            self.channel_capacity,
+        )
+    }
+}

+ 5 - 137
lazer/sdk/rust/client/src/lib.rs

@@ -1,138 +1,6 @@
-use anyhow::Result;
-use derive_more::From;
-use futures_util::{SinkExt, StreamExt, TryStreamExt};
-use pyth_lazer_protocol::{
-    binary_update::BinaryWsUpdate,
-    subscription::{ErrorResponse, Request, Response, SubscriptionId, UnsubscribeRequest},
-};
-use tokio_tungstenite::{connect_async, tungstenite::Message};
-use url::Url;
+const CHANNEL_CAPACITY: usize = 1000;
 
-/// A WebSocket client for consuming Pyth Lazer price feed updates
-///
-/// This client provides a simple interface to:
-/// - Connect to a Lazer WebSocket endpoint
-/// - Subscribe to price feed updates
-/// - Receive updates as a stream of messages
-///
-pub struct LazerClient {
-    endpoint: Url,
-    access_token: String,
-    ws_sender: Option<
-        futures_util::stream::SplitSink<
-            tokio_tungstenite::WebSocketStream<
-                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
-            >,
-            Message,
-        >,
-    >,
-}
-
-#[derive(Debug, Clone, PartialEq, Eq, Hash, From)]
-pub enum AnyResponse {
-    Json(Response),
-    Binary(BinaryWsUpdate),
-}
-
-impl LazerClient {
-    /// Creates a new Lazer client instance
-    ///
-    /// # Arguments
-    /// * `endpoint` - The WebSocket URL of the Lazer service
-    /// * `access_token` - Access token for authentication
-    ///
-    /// # Returns
-    /// Returns a new client instance (not yet connected)
-    pub fn new(endpoint: &str, access_token: &str) -> Result<Self> {
-        let endpoint = Url::parse(endpoint)?;
-        let access_token = access_token.to_string();
-        Ok(Self {
-            endpoint,
-            access_token,
-            ws_sender: None,
-        })
-    }
-
-    /// Starts the WebSocket connection
-    ///
-    /// # Returns
-    /// Returns a stream of responses from the server
-    pub async fn start(&mut self) -> Result<impl futures_util::Stream<Item = Result<AnyResponse>>> {
-        let url = self.endpoint.clone();
-        let mut request =
-            tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(url)?;
-
-        request.headers_mut().insert(
-            "Authorization",
-            format!("Bearer {}", self.access_token).parse().unwrap(),
-        );
-
-        let (ws_stream, _) = connect_async(request).await?;
-        let (ws_sender, ws_receiver) = ws_stream.split();
-
-        self.ws_sender = Some(ws_sender);
-        let response_stream =
-            ws_receiver
-                .map_err(anyhow::Error::from)
-                .try_filter_map(|msg| async {
-                    let r: Result<Option<AnyResponse>> = match msg {
-                        Message::Text(text) => {
-                            Ok(Some(serde_json::from_str::<Response>(&text)?.into()))
-                        }
-                        Message::Binary(data) => {
-                            Ok(Some(BinaryWsUpdate::deserialize_slice(&data)?.into()))
-                        }
-                        Message::Close(_) => Ok(Some(
-                            Response::Error(ErrorResponse {
-                                error: "WebSocket connection closed".to_string(),
-                            })
-                            .into(),
-                        )),
-                        _ => Ok(None),
-                    };
-                    r
-                });
-
-        Ok(response_stream)
-    }
-
-    /// Subscribes to price feed updates
-    ///
-    /// # Arguments
-    /// * `request` - A subscription request containing feed IDs and parameters
-    pub async fn subscribe(&mut self, request: Request) -> Result<()> {
-        if let Some(sender) = &mut self.ws_sender {
-            let msg = serde_json::to_string(&request)?;
-            sender.send(Message::Text(msg)).await?;
-            Ok(())
-        } else {
-            anyhow::bail!("WebSocket connection not started")
-        }
-    }
-
-    /// Unsubscribes from a previously subscribed feed
-    ///
-    /// # Arguments
-    /// * `subscription_id` - The ID of the subscription to cancel
-    pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> {
-        if let Some(sender) = &mut self.ws_sender {
-            let request = Request::Unsubscribe(UnsubscribeRequest { subscription_id });
-            let msg = serde_json::to_string(&request)?;
-            sender.send(Message::Text(msg)).await?;
-            Ok(())
-        } else {
-            anyhow::bail!("WebSocket connection not started")
-        }
-    }
-
-    /// Closes the WebSocket connection
-    pub async fn close(&mut self) -> Result<()> {
-        if let Some(sender) = &mut self.ws_sender {
-            sender.send(Message::Close(None)).await?;
-            self.ws_sender = None;
-            Ok(())
-        } else {
-            anyhow::bail!("WebSocket connection not started")
-        }
-    }
-}
+pub mod backoff;
+pub mod client;
+pub mod resilient_ws_connection;
+pub mod ws_connection;

+ 210 - 0
lazer/sdk/rust/client/src/resilient_ws_connection.rs

@@ -0,0 +1,210 @@
+use std::time::Duration;
+
+use backoff::{backoff::Backoff, ExponentialBackoff};
+use futures_util::StreamExt;
+use pyth_lazer_protocol::subscription::{
+    Request, SubscribeRequest, SubscriptionId, UnsubscribeRequest,
+};
+use tokio::{pin, select, sync::mpsc, time::Instant};
+use tracing::{error, info, warn};
+use url::Url;
+
+use crate::{
+    ws_connection::{AnyResponse, PythLazerWSConnection},
+    CHANNEL_CAPACITY,
+};
+use anyhow::{bail, Context, Result};
+
+const BACKOFF_RESET_DURATION: Duration = Duration::from_secs(10);
+
+pub struct PythLazerResilientWSConnection {
+    request_sender: mpsc::Sender<Request>,
+}
+
+impl PythLazerResilientWSConnection {
+    /// Creates a new resilient WebSocket client instance
+    ///
+    /// # Arguments
+    /// * `endpoint` - The WebSocket URL of the Lazer service
+    /// * `access_token` - Access token for authentication
+    /// * `sender` - A sender to send responses back to the client
+    ///
+    /// # Returns
+    /// Returns a new client instance (not yet connected)
+    pub fn new(
+        endpoint: Url,
+        access_token: String,
+        backoff: ExponentialBackoff,
+        timeout: Duration,
+        sender: mpsc::Sender<AnyResponse>,
+    ) -> Self {
+        let (request_sender, mut request_receiver) = mpsc::channel(CHANNEL_CAPACITY);
+        let mut task =
+            PythLazerResilientWSConnectionTask::new(endpoint, access_token, backoff, timeout);
+
+        tokio::spawn(async move {
+            if let Err(e) = task.run(sender, &mut request_receiver).await {
+                error!("Resilient WebSocket connection task failed: {}", e);
+            }
+        });
+
+        Self { request_sender }
+    }
+
+    pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> {
+        self.request_sender
+            .send(Request::Subscribe(request))
+            .await
+            .context("Failed to send subscribe request")?;
+        Ok(())
+    }
+
+    pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> {
+        self.request_sender
+            .send(Request::Unsubscribe(UnsubscribeRequest { subscription_id }))
+            .await
+            .context("Failed to send unsubscribe request")?;
+        Ok(())
+    }
+}
+
+struct PythLazerResilientWSConnectionTask {
+    endpoint: Url,
+    access_token: String,
+    subscriptions: Vec<SubscribeRequest>,
+    backoff: ExponentialBackoff,
+    timeout: Duration,
+}
+
+impl PythLazerResilientWSConnectionTask {
+    pub fn new(
+        endpoint: Url,
+        access_token: String,
+        backoff: ExponentialBackoff,
+        timeout: Duration,
+    ) -> Self {
+        Self {
+            endpoint,
+            access_token,
+            subscriptions: Vec::new(),
+            backoff,
+            timeout,
+        }
+    }
+
+    pub async fn run(
+        &mut self,
+        response_sender: mpsc::Sender<AnyResponse>,
+        request_receiver: &mut mpsc::Receiver<Request>,
+    ) -> Result<()> {
+        loop {
+            let start_time = Instant::now();
+            if let Err(e) = self.start(response_sender.clone(), request_receiver).await {
+                // If a connection was working for BACKOFF_RESET_DURATION
+                // and timeout + 1sec, it was considered successful therefore reset the backoff
+                if start_time.elapsed() > BACKOFF_RESET_DURATION
+                    && start_time.elapsed() > self.timeout + Duration::from_secs(1)
+                {
+                    self.backoff.reset();
+                }
+
+                let delay = self.backoff.next_backoff();
+                match delay {
+                    Some(d) => {
+                        info!("WebSocket connection failed: {}. Retrying in {:?}", e, d);
+                        tokio::time::sleep(d).await;
+                    }
+                    None => {
+                        bail!(
+                            "Max retries reached for WebSocket connection to {}, this should never happen, please contact developers",
+                            self.endpoint
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    pub async fn start(
+        &mut self,
+        sender: mpsc::Sender<AnyResponse>,
+        request_receiver: &mut mpsc::Receiver<Request>,
+    ) -> Result<()> {
+        let mut ws_connection =
+            PythLazerWSConnection::new(self.endpoint.clone(), self.access_token.clone())?;
+        let stream = ws_connection.start().await?;
+        pin!(stream);
+
+        for subscription in self.subscriptions.clone() {
+            ws_connection
+                .send_request(Request::Subscribe(subscription))
+                .await?;
+        }
+        loop {
+            let timeout_response = tokio::time::timeout(self.timeout, stream.next());
+
+            select! {
+                response = timeout_response => {
+                    match response {
+                        Ok(Some(response)) => match response {
+                            Ok(response) => {
+                                sender
+                                    .send(response)
+                                    .await
+                                    .context("Failed to send response")?;
+                            }
+                            Err(e) => {
+                                bail!("WebSocket stream error: {}", e);
+                            }
+                        },
+                        Ok(None) => {
+                            bail!("WebSocket stream ended unexpectedly");
+                        }
+                        Err(_elapsed) => {
+                            bail!("WebSocket stream timed out");
+                        }
+                    }
+                }
+                Some(request) = request_receiver.recv() => {
+                   match request {
+                        Request::Subscribe(request) => {
+                            self.subscribe(&mut ws_connection, request).await?;
+                        }
+                        Request::Unsubscribe(request) => {
+                            self.unsubscribe(&mut ws_connection, request).await?;
+                        }
+                   }
+                }
+            }
+        }
+    }
+
+    pub async fn subscribe(
+        &mut self,
+        ws_connection: &mut PythLazerWSConnection,
+        request: SubscribeRequest,
+    ) -> Result<()> {
+        self.subscriptions.push(request.clone());
+        ws_connection.subscribe(request).await
+    }
+
+    pub async fn unsubscribe(
+        &mut self,
+        ws_connection: &mut PythLazerWSConnection,
+        request: UnsubscribeRequest,
+    ) -> Result<()> {
+        if let Some(index) = self
+            .subscriptions
+            .iter()
+            .position(|r| r.subscription_id == request.subscription_id)
+        {
+            self.subscriptions.remove(index);
+        } else {
+            warn!(
+                "Unsubscribe called for non-existent subscription: {:?}",
+                request.subscription_id
+            );
+        }
+        ws_connection.unsubscribe(request).await
+    }
+}

+ 144 - 0
lazer/sdk/rust/client/src/ws_connection.rs

@@ -0,0 +1,144 @@
+use std::hash::{DefaultHasher, Hash, Hasher};
+
+use anyhow::Result;
+use derive_more::From;
+use futures_util::{SinkExt, StreamExt, TryStreamExt};
+use pyth_lazer_protocol::{
+    binary_update::BinaryWsUpdate,
+    subscription::{ErrorResponse, Request, Response, SubscribeRequest, UnsubscribeRequest},
+};
+use tokio_tungstenite::{connect_async, tungstenite::Message};
+use url::Url;
+
+/// A WebSocket client for consuming Pyth Lazer price feed updates
+///
+/// This client provides a simple interface to:
+/// - Connect to a Lazer WebSocket endpoint
+/// - Subscribe to price feed updates
+/// - Receive updates as a stream of messages
+///
+pub struct PythLazerWSConnection {
+    endpoint: Url,
+    access_token: String,
+    ws_sender: Option<
+        futures_util::stream::SplitSink<
+            tokio_tungstenite::WebSocketStream<
+                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+            >,
+            Message,
+        >,
+    >,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash, From)]
+pub enum AnyResponse {
+    Json(Response),
+    Binary(BinaryWsUpdate),
+}
+
+impl AnyResponse {
+    pub fn cache_key(&self) -> u64 {
+        let mut hasher = DefaultHasher::new();
+        self.hash(&mut hasher);
+        hasher.finish()
+    }
+}
+impl PythLazerWSConnection {
+    /// Creates a new Lazer client instance
+    ///
+    /// # Arguments
+    /// * `endpoint` - The WebSocket URL of the Lazer service
+    /// * `access_token` - Access token for authentication
+    ///
+    /// # Returns
+    /// Returns a new client instance (not yet connected)
+    pub fn new(endpoint: Url, access_token: String) -> Result<Self> {
+        Ok(Self {
+            endpoint,
+            access_token,
+            ws_sender: None,
+        })
+    }
+
+    /// Starts the WebSocket connection
+    ///
+    /// # Returns
+    /// Returns a stream of responses from the server
+    pub async fn start(&mut self) -> Result<impl futures_util::Stream<Item = Result<AnyResponse>>> {
+        let url = self.endpoint.clone();
+        let mut request =
+            tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(url)?;
+
+        request.headers_mut().insert(
+            "Authorization",
+            format!("Bearer {}", self.access_token).parse().unwrap(),
+        );
+
+        let (ws_stream, _) = connect_async(request).await?;
+        let (ws_sender, ws_receiver) = ws_stream.split();
+
+        self.ws_sender = Some(ws_sender);
+        let response_stream =
+            ws_receiver
+                .map_err(anyhow::Error::from)
+                .try_filter_map(|msg| async {
+                    let r: Result<Option<AnyResponse>> = match msg {
+                        Message::Text(text) => {
+                            Ok(Some(serde_json::from_str::<Response>(&text)?.into()))
+                        }
+                        Message::Binary(data) => {
+                            Ok(Some(BinaryWsUpdate::deserialize_slice(&data)?.into()))
+                        }
+                        Message::Close(_) => Ok(Some(
+                            Response::Error(ErrorResponse {
+                                error: "WebSocket connection closed".to_string(),
+                            })
+                            .into(),
+                        )),
+                        _ => Ok(None),
+                    };
+                    r
+                });
+
+        Ok(response_stream)
+    }
+
+    pub async fn send_request(&mut self, request: Request) -> Result<()> {
+        if let Some(sender) = &mut self.ws_sender {
+            let msg = serde_json::to_string(&request)?;
+            sender.send(Message::Text(msg)).await?;
+            Ok(())
+        } else {
+            anyhow::bail!("WebSocket connection not started")
+        }
+    }
+
+    /// Subscribes to price feed updates
+    ///
+    /// # Arguments
+    /// * `request` - A subscription request containing feed IDs and parameters
+    pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> {
+        let request = Request::Subscribe(request);
+        self.send_request(request).await
+    }
+
+    /// Unsubscribes from a previously subscribed feed
+    ///
+    /// # Arguments
+    /// * `subscription_id` - The ID of the subscription to cancel
+    pub async fn unsubscribe(&mut self, request: UnsubscribeRequest) -> Result<()> {
+        let request = Request::Unsubscribe(request);
+        self.send_request(request).await
+    }
+
+    /// Closes the WebSocket connection
+    pub async fn close(&mut self) -> Result<()> {
+        if let Some(sender) = &mut self.ws_sender {
+            sender.send(Message::Close(None)).await?;
+            self.ws_sender = None;
+            Ok(())
+        } else {
+            anyhow::bail!("WebSocket connection not started")
+        }
+    }
+}