|
|
@@ -1,6 +1,7 @@
|
|
|
-use anyhow::{Result, bail};
|
|
|
+use anyhow::{Context, Result, bail};
|
|
|
use backoff::ExponentialBackoffBuilder;
|
|
|
use backoff::backoff::Backoff;
|
|
|
+use base64::Engine;
|
|
|
use futures_util::stream::{SplitSink, SplitStream};
|
|
|
use futures_util::{SinkExt, StreamExt};
|
|
|
use http::HeaderValue;
|
|
|
@@ -9,12 +10,13 @@ use pyth_lazer_publisher_sdk::transaction::SignedLazerTransaction;
|
|
|
use std::sync::Arc;
|
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
|
use std::time::{Duration, Instant};
|
|
|
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
use tokio::net::TcpStream;
|
|
|
use tokio::select;
|
|
|
use tokio::sync::broadcast;
|
|
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
|
|
use tokio_tungstenite::{
|
|
|
- MaybeTlsStream, WebSocketStream, connect_async_with_config,
|
|
|
+ MaybeTlsStream, WebSocketStream, client_async, connect_async_with_config,
|
|
|
tungstenite::Message as TungsteniteMessage,
|
|
|
};
|
|
|
use url::Url;
|
|
|
@@ -22,19 +24,180 @@ use url::Url;
|
|
|
type RelayerWsSender = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, TungsteniteMessage>;
|
|
|
type RelayerWsReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
|
|
|
|
|
-async fn connect_to_relayer(url: Url, token: &str) -> Result<(RelayerWsSender, RelayerWsReceiver)> {
|
|
|
- tracing::info!("connecting to the relayer at {}", url);
|
|
|
- let mut req = url.clone().into_client_request()?;
|
|
|
+async fn connect_through_proxy(
|
|
|
+ proxy_url: &Url,
|
|
|
+ target_url: &Url,
|
|
|
+ token: &str,
|
|
|
+) -> Result<(RelayerWsSender, RelayerWsReceiver)> {
|
|
|
+ tracing::info!(
|
|
|
+ "connecting to the relayer at {} via proxy {}",
|
|
|
+ target_url,
|
|
|
+ proxy_url
|
|
|
+ );
|
|
|
+
|
|
|
+ let proxy_host = proxy_url.host_str().context("Proxy URL must have a host")?;
|
|
|
+ let proxy_port = proxy_url
|
|
|
+ .port()
|
|
|
+ .unwrap_or(if proxy_url.scheme() == "https" {
|
|
|
+ 443
|
|
|
+ } else {
|
|
|
+ 80
|
|
|
+ });
|
|
|
+
|
|
|
+ let proxy_addr = format!("{proxy_host}:{proxy_port}");
|
|
|
+ let mut stream = TcpStream::connect(&proxy_addr)
|
|
|
+ .await
|
|
|
+ .context(format!("Failed to connect to proxy at {proxy_addr}"))?;
|
|
|
+
|
|
|
+ let target_host = target_url
|
|
|
+ .host_str()
|
|
|
+ .context("Target URL must have a host")?;
|
|
|
+ let target_port = target_url
|
|
|
+ .port()
|
|
|
+ .unwrap_or(if target_url.scheme() == "wss" {
|
|
|
+ 443
|
|
|
+ } else {
|
|
|
+ 80
|
|
|
+ });
|
|
|
+
|
|
|
+ let target_authority = format!("{target_host}:{target_port}");
|
|
|
+ let mut request_parts = vec![format!("CONNECT {target_authority} HTTP/1.1")];
|
|
|
+ request_parts.push(format!("Host: {target_authority}"));
|
|
|
+
|
|
|
+ let username = proxy_url.username();
|
|
|
+ if !username.is_empty() {
|
|
|
+ let password = proxy_url.password().unwrap_or("");
|
|
|
+ let credentials = format!("{username}:{password}");
|
|
|
+ let encoded = base64::engine::general_purpose::STANDARD.encode(credentials.as_bytes());
|
|
|
+ request_parts.push(format!("Proxy-Authorization: Basic {encoded}"));
|
|
|
+ }
|
|
|
+
|
|
|
+ request_parts.push("Proxy-Connection: Keep-Alive".to_string());
|
|
|
+ request_parts.push(String::new()); // Empty line to end headers
|
|
|
+ request_parts.push(String::new()); // CRLF to end request
|
|
|
+
|
|
|
+ let connect_request = request_parts.join("\r\n");
|
|
|
+
|
|
|
+ stream
|
|
|
+ .write_all(connect_request.as_bytes())
|
|
|
+ .await
|
|
|
+ .context(format!(
|
|
|
+ "Failed to send CONNECT request to proxy at {proxy_url}"
|
|
|
+ ))?;
|
|
|
+
|
|
|
+ let mut response_buffer = Vec::new();
|
|
|
+ let mut temp_buf = [0u8; 1024];
|
|
|
+ let mut headers_complete = false;
|
|
|
+
|
|
|
+ while !headers_complete {
|
|
|
+ let n = stream.read(&mut temp_buf).await.context(format!(
|
|
|
+ "Failed to read CONNECT response from proxy at {proxy_url}"
|
|
|
+ ))?;
|
|
|
+
|
|
|
+ if n == 0 {
|
|
|
+ bail!("Proxy closed connection before sending complete response");
|
|
|
+ }
|
|
|
+
|
|
|
+ response_buffer.extend_from_slice(temp_buf.get(..n).context("Invalid buffer slice")?);
|
|
|
+
|
|
|
+ if response_buffer.windows(4).any(|w| w == b"\r\n\r\n") {
|
|
|
+ headers_complete = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ let response_str = String::from_utf8_lossy(&response_buffer);
|
|
|
+
|
|
|
+ let status_line = response_str
|
|
|
+ .lines()
|
|
|
+ .next()
|
|
|
+ .context("Empty response from proxy")?;
|
|
|
+
|
|
|
+ let parts: Vec<&str> = status_line.split_whitespace().collect();
|
|
|
+ if parts.len() < 2 {
|
|
|
+ bail!(
|
|
|
+ "Invalid HTTP response from proxy at {}: {}",
|
|
|
+ proxy_url,
|
|
|
+ status_line
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ let status_code = parts
|
|
|
+ .get(1)
|
|
|
+ .context("Missing status code in proxy response")?
|
|
|
+ .parse::<u16>()
|
|
|
+ .context("Invalid status code in proxy response")?;
|
|
|
+
|
|
|
+ if status_code != 200 {
|
|
|
+ let status_text = parts
|
|
|
+ .get(2..)
|
|
|
+ .map(|s| s.join(" "))
|
|
|
+ .unwrap_or_else(|| "Unknown".to_string());
|
|
|
+ bail!(
|
|
|
+ "Proxy CONNECT failed with status {} {}: {}",
|
|
|
+ status_code,
|
|
|
+ status_text,
|
|
|
+ status_line
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ tracing::info!("Successfully connected through proxy at {}", proxy_url);
|
|
|
+
|
|
|
+ let mut req = target_url.clone().into_client_request()?;
|
|
|
let headers = req.headers_mut();
|
|
|
headers.insert(
|
|
|
"Authorization",
|
|
|
HeaderValue::from_str(&format!("Bearer {token}"))?,
|
|
|
);
|
|
|
- let (ws_stream, _) = connect_async_with_config(req, None, true).await?;
|
|
|
- tracing::info!("connected to the relayer at {}", url);
|
|
|
+
|
|
|
+ let maybe_tls_stream = if target_url.scheme() == "wss" {
|
|
|
+ let tls_connector = tokio_native_tls::native_tls::TlsConnector::builder()
|
|
|
+ .build()
|
|
|
+ .context("Failed to build TLS connector")?;
|
|
|
+ let tokio_connector = tokio_native_tls::TlsConnector::from(tls_connector);
|
|
|
+ let domain = target_host;
|
|
|
+ let tls_stream = tokio_connector
|
|
|
+ .connect(domain, stream)
|
|
|
+ .await
|
|
|
+ .context("Failed to establish TLS connection")?;
|
|
|
+
|
|
|
+ MaybeTlsStream::NativeTls(tls_stream)
|
|
|
+ } else {
|
|
|
+ MaybeTlsStream::Plain(stream)
|
|
|
+ };
|
|
|
+
|
|
|
+ let (ws_stream, _) = client_async(req, maybe_tls_stream)
|
|
|
+ .await
|
|
|
+ .context("Failed to complete WebSocket handshake")?;
|
|
|
+
|
|
|
+ tracing::info!(
|
|
|
+ "WebSocket connection established to relayer at {} via proxy {}",
|
|
|
+ target_url,
|
|
|
+ proxy_url
|
|
|
+ );
|
|
|
Ok(ws_stream.split())
|
|
|
}
|
|
|
|
|
|
+async fn connect_to_relayer(
|
|
|
+ url: Url,
|
|
|
+ token: &str,
|
|
|
+ proxy_url: Option<&Url>,
|
|
|
+) -> Result<(RelayerWsSender, RelayerWsReceiver)> {
|
|
|
+ if let Some(proxy) = proxy_url {
|
|
|
+ connect_through_proxy(proxy, &url, token).await
|
|
|
+ } else {
|
|
|
+ tracing::info!("connecting to the relayer at {}", url);
|
|
|
+ let mut req = url.clone().into_client_request()?;
|
|
|
+ let headers = req.headers_mut();
|
|
|
+ headers.insert(
|
|
|
+ "Authorization",
|
|
|
+ HeaderValue::from_str(&format!("Bearer {token}"))?,
|
|
|
+ );
|
|
|
+ let (ws_stream, _) = connect_async_with_config(req, None, true).await?;
|
|
|
+ tracing::info!("connected to the relayer at {}", url);
|
|
|
+ Ok(ws_stream.split())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
struct RelayerWsSession {
|
|
|
ws_sender: RelayerWsSender,
|
|
|
}
|
|
|
@@ -58,11 +221,11 @@ impl RelayerWsSession {
|
|
|
}
|
|
|
|
|
|
pub struct RelayerSessionTask {
|
|
|
- // connection state
|
|
|
pub url: Url,
|
|
|
pub token: String,
|
|
|
pub receiver: broadcast::Receiver<SignedLazerTransaction>,
|
|
|
pub is_ready: Arc<AtomicBool>,
|
|
|
+ pub proxy_url: Option<Url>,
|
|
|
}
|
|
|
|
|
|
impl RelayerSessionTask {
|
|
|
@@ -108,10 +271,8 @@ impl RelayerSessionTask {
|
|
|
}
|
|
|
|
|
|
pub async fn run_relayer_connection(&mut self) -> Result<()> {
|
|
|
- // Establish relayer connection
|
|
|
- // Relayer will drop the connection if no data received in 5s
|
|
|
let (relayer_ws_sender, mut relayer_ws_receiver) =
|
|
|
- connect_to_relayer(self.url.clone(), &self.token).await?;
|
|
|
+ connect_to_relayer(self.url.clone(), &self.token, self.proxy_url.as_ref()).await?;
|
|
|
let mut relayer_ws_session = RelayerWsSession {
|
|
|
ws_sender: relayer_ws_sender,
|
|
|
};
|
|
|
@@ -236,11 +397,11 @@ mod tests {
|
|
|
let (relayer_sender, relayer_receiver) = broadcast::channel(RELAYER_CHANNEL_CAPACITY);
|
|
|
|
|
|
let mut relayer_session_task = RelayerSessionTask {
|
|
|
- // connection state
|
|
|
url: Url::parse("ws://127.0.0.1:12346").unwrap(),
|
|
|
token: "token1".to_string(),
|
|
|
receiver: relayer_receiver,
|
|
|
is_ready: Arc::new(AtomicBool::new(false)),
|
|
|
+ proxy_url: None,
|
|
|
};
|
|
|
tokio::spawn(async move { relayer_session_task.run().await });
|
|
|
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
|