|
|
@@ -21,9 +21,9 @@ use {
|
|
|
WebSocket,
|
|
|
WebSocketUpgrade,
|
|
|
},
|
|
|
- ConnectInfo,
|
|
|
State as AxumState,
|
|
|
},
|
|
|
+ http::HeaderMap,
|
|
|
response::IntoResponse,
|
|
|
},
|
|
|
dashmap::DashMap,
|
|
|
@@ -50,10 +50,7 @@ use {
|
|
|
},
|
|
|
std::{
|
|
|
collections::HashMap,
|
|
|
- net::{
|
|
|
- IpAddr,
|
|
|
- SocketAddr,
|
|
|
- },
|
|
|
+ net::IpAddr,
|
|
|
num::NonZeroU32,
|
|
|
sync::{
|
|
|
atomic::{
|
|
|
@@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
|
|
|
}
|
|
|
|
|
|
pub struct WsState {
|
|
|
- pub subscriber_counter: AtomicUsize,
|
|
|
- pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
|
|
- pub bytes_limit_whitelist: Vec<IpNet>,
|
|
|
- pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
|
|
|
+ pub subscriber_counter: AtomicUsize,
|
|
|
+ pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
|
|
+ pub bytes_limit_whitelist: Vec<IpNet>,
|
|
|
+ pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
|
|
|
+ pub requester_ip_header_name: String,
|
|
|
}
|
|
|
|
|
|
impl WsState {
|
|
|
- pub fn new(whitelist: Vec<IpNet>) -> Self {
|
|
|
+ pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self {
|
|
|
Self {
|
|
|
- subscriber_counter: AtomicUsize::new(0),
|
|
|
- subscribers: DashMap::new(),
|
|
|
- rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
|
|
|
+ subscriber_counter: AtomicUsize::new(0),
|
|
|
+ subscribers: DashMap::new(),
|
|
|
+ rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
|
|
|
BYTES_LIMIT_PER_IP_PER_SECOND
|
|
|
))),
|
|
|
bytes_limit_whitelist: whitelist,
|
|
|
+ requester_ip_header_name,
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -142,23 +141,33 @@ enum ServerResponseMessage {
|
|
|
pub async fn ws_route_handler(
|
|
|
ws: WebSocketUpgrade,
|
|
|
AxumState(state): AxumState<super::ApiState>,
|
|
|
- ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
|
+ headers: HeaderMap,
|
|
|
) -> impl IntoResponse {
|
|
|
+ let requester_ip = headers
|
|
|
+ .get(state.ws.requester_ip_header_name.as_str())
|
|
|
+ .and_then(|value| value.to_str().ok())
|
|
|
+ .and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
|
|
|
+ .and_then(|value| value.parse().ok());
|
|
|
+
|
|
|
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
|
|
|
- .on_upgrade(move |socket| websocket_handler(socket, state, addr))
|
|
|
+ .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
|
|
|
}
|
|
|
|
|
|
-#[tracing::instrument(skip(stream, state, addr))]
|
|
|
-async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
|
|
|
+#[tracing::instrument(skip(stream, state, subscriber_ip))]
|
|
|
+async fn websocket_handler(
|
|
|
+ stream: WebSocket,
|
|
|
+ state: super::ApiState,
|
|
|
+ subscriber_ip: Option<IpAddr>,
|
|
|
+) {
|
|
|
let ws_state = state.ws.clone();
|
|
|
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
|
|
|
- tracing::debug!(id, %addr, "New Websocket Connection");
|
|
|
+ tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
|
|
|
|
|
|
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
|
|
|
let (sender, receiver) = stream.split();
|
|
|
let mut subscriber = Subscriber::new(
|
|
|
id,
|
|
|
- addr.ip(),
|
|
|
+ subscriber_ip,
|
|
|
state.state.clone(),
|
|
|
state.ws.clone(),
|
|
|
notify_receiver,
|
|
|
@@ -176,7 +185,7 @@ pub type SubscriberId = usize;
|
|
|
/// It listens to the store for updates and sends them to the client.
|
|
|
pub struct Subscriber {
|
|
|
id: SubscriberId,
|
|
|
- ip_addr: IpAddr,
|
|
|
+ ip_addr: Option<IpAddr>,
|
|
|
closed: bool,
|
|
|
store: Arc<State>,
|
|
|
ws_state: Arc<WsState>,
|
|
|
@@ -191,7 +200,7 @@ pub struct Subscriber {
|
|
|
impl Subscriber {
|
|
|
pub fn new(
|
|
|
id: SubscriberId,
|
|
|
- ip_addr: IpAddr,
|
|
|
+ ip_addr: Option<IpAddr>,
|
|
|
store: Arc<State>,
|
|
|
ws_state: Arc<WsState>,
|
|
|
notify_receiver: mpsc::Receiver<AggregationEvent>,
|
|
|
@@ -291,32 +300,36 @@ impl Subscriber {
|
|
|
})?;
|
|
|
|
|
|
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
|
|
|
- if !self
|
|
|
- .ws_state
|
|
|
- .bytes_limit_whitelist
|
|
|
- .iter()
|
|
|
- .any(|ip_net| ip_net.contains(&self.ip_addr))
|
|
|
- && self.ws_state.rate_limiter.check_key_n(
|
|
|
- &self.ip_addr,
|
|
|
- NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
|
|
|
- ) != Ok(Ok(()))
|
|
|
- {
|
|
|
- tracing::info!(
|
|
|
- self.id,
|
|
|
- ip = %self.ip_addr,
|
|
|
- "Rate limit exceeded. Closing connection.",
|
|
|
- );
|
|
|
- self.sender
|
|
|
- .send(
|
|
|
- serde_json::to_string(&ServerResponseMessage::Err {
|
|
|
- error: "Rate limit exceeded".to_string(),
|
|
|
- })?
|
|
|
- .into(),
|
|
|
- )
|
|
|
- .await?;
|
|
|
- self.sender.close().await?;
|
|
|
- self.closed = true;
|
|
|
- return Ok(());
|
|
|
+ // If the ip address is None no rate limiting is applied.
|
|
|
+ if let Some(ip_addr) = self.ip_addr {
|
|
|
+ if !self
|
|
|
+ .ws_state
|
|
|
+ .bytes_limit_whitelist
|
|
|
+ .iter()
|
|
|
+ .any(|ip_net| ip_net.contains(&ip_addr))
|
|
|
+ && self.ws_state.rate_limiter.check_key_n(
|
|
|
+ &ip_addr,
|
|
|
+ NonZeroU32::new(message.len().try_into()?)
|
|
|
+ .ok_or(anyhow!("Empty message"))?,
|
|
|
+ ) != Ok(Ok(()))
|
|
|
+ {
|
|
|
+ tracing::info!(
|
|
|
+ self.id,
|
|
|
+ ip = %ip_addr,
|
|
|
+ "Rate limit exceeded. Closing connection.",
|
|
|
+ );
|
|
|
+ self.sender
|
|
|
+ .send(
|
|
|
+ serde_json::to_string(&ServerResponseMessage::Err {
|
|
|
+ error: "Rate limit exceeded".to_string(),
|
|
|
+ })?
|
|
|
+ .into(),
|
|
|
+ )
|
|
|
+ .await?;
|
|
|
+ self.sender.close().await?;
|
|
|
+ self.closed = true;
|
|
|
+ return Ok(());
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// `sender.feed` buffers a message to the client but does not flush it, so we can send
|