Explorar el Código

feat(hermes): use ip from request headers for ratelimiting

Ali Behjati hace 2 años
padre
commit
d11216f309
Se han modificado 5 ficheros con 91 adiciones y 66 borrados
  1. 8 8
      hermes/Cargo.lock
  2. 1 1
      hermes/Cargo.toml
  3. 15 10
      hermes/src/api.rs
  4. 60 47
      hermes/src/api/ws.rs
  5. 7 0
      hermes/src/config/rpc.rs

+ 8 - 8
hermes/Cargo.lock

@@ -481,7 +481,7 @@ dependencies = [
  "async-trait",
  "axum-core",
  "axum-macros",
- "base64 0.21.2",
+ "base64 0.21.4",
  "bitflags 1.3.2",
  "bytes",
  "futures-util",
@@ -566,9 +566,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
 
 [[package]]
 name = "base64"
-version = "0.21.2"
+version = "0.21.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
+checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2"
 
 [[package]]
 name = "base64ct"
@@ -1898,13 +1898,13 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
 
 [[package]]
 name = "hermes"
-version = "0.3.2"
+version = "0.3.3"
 dependencies = [
  "anyhow",
  "async-trait",
  "axum",
  "axum-macros",
- "base64 0.21.2",
+ "base64 0.21.4",
  "borsh 0.10.3",
  "byteorder",
  "chrono",
@@ -4503,7 +4503,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
 dependencies = [
  "async-compression",
- "base64 0.21.2",
+ "base64 0.21.4",
  "bytes",
  "encoding_rs",
  "futures-core",
@@ -4744,7 +4744,7 @@ version = "1.0.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
 dependencies = [
- "base64 0.21.2",
+ "base64 0.21.4",
 ]
 
 [[package]]
@@ -6277,7 +6277,7 @@ dependencies = [
  "async-stream",
  "async-trait",
  "axum",
- "base64 0.21.2",
+ "base64 0.21.4",
  "bytes",
  "h2",
  "http",

+ 1 - 1
hermes/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name        = "hermes"
-version     = "0.3.2"
+version     = "0.3.3"
 description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
 edition     = "2021"
 

+ 15 - 10
hermes/src/api.rs

@@ -13,12 +13,9 @@ use {
     },
     ipnet::IpNet,
     serde_qs::axum::QsQueryConfig,
-    std::{
-        net::SocketAddr,
-        sync::{
-            atomic::Ordering,
-            Arc,
-        },
+    std::sync::{
+        atomic::Ordering,
+        Arc,
     },
     tokio::{
         signal,
@@ -40,10 +37,14 @@ pub struct ApiState {
 }
 
 impl ApiState {
-    pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
+    pub fn new(
+        state: Arc<State>,
+        ws_whitelist: Vec<IpNet>,
+        requester_ip_header_name: String,
+    ) -> Self {
         Self {
             state,
-            ws: Arc::new(ws::WsState::new(ws_whitelist)),
+            ws: Arc::new(ws::WsState::new(ws_whitelist, requester_ip_header_name)),
         }
     }
 }
@@ -88,7 +89,11 @@ pub async fn run(
     )]
     struct ApiDoc;
 
-    let state = ApiState::new(state, opts.rpc.ws_whitelist);
+    let state = ApiState::new(
+        state,
+        opts.rpc.ws_whitelist,
+        opts.rpc.requester_ip_header_name,
+    );
 
     // Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
     // `with_state` method which replaces `Body` with `State` in the type signature.
@@ -135,7 +140,7 @@ pub async fn run(
     // Binds the axum's server to the configured address and port. This is a blocking call and will
     // not return until the server is shutdown.
     axum::Server::try_bind(&opts.rpc.addr)?
-        .serve(app.into_make_service_with_connect_info::<SocketAddr>())
+        .serve(app.into_make_service())
         .with_graceful_shutdown(async {
             // Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
             // should also have triggered so we will let that one print the shutdown warning.

+ 60 - 47
hermes/src/api/ws.rs

@@ -21,9 +21,9 @@ use {
                 WebSocket,
                 WebSocketUpgrade,
             },
-            ConnectInfo,
             State as AxumState,
         },
+        http::HeaderMap,
         response::IntoResponse,
     },
     dashmap::DashMap,
@@ -50,10 +50,7 @@ use {
     },
     std::{
         collections::HashMap,
-        net::{
-            IpAddr,
-            SocketAddr,
-        },
+        net::IpAddr,
         num::NonZeroU32,
         sync::{
             atomic::{
@@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
 }
 
 pub struct WsState {
-    pub subscriber_counter:    AtomicUsize,
-    pub subscribers:           DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
-    pub bytes_limit_whitelist: Vec<IpNet>,
-    pub rate_limiter:          DefaultKeyedRateLimiter<IpAddr>,
+    pub subscriber_counter:       AtomicUsize,
+    pub subscribers:              DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
+    pub bytes_limit_whitelist:    Vec<IpNet>,
+    pub rate_limiter:             DefaultKeyedRateLimiter<IpAddr>,
+    pub requester_ip_header_name: String,
 }
 
 impl WsState {
-    pub fn new(whitelist: Vec<IpNet>) -> Self {
+    pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self {
         Self {
-            subscriber_counter:    AtomicUsize::new(0),
-            subscribers:           DashMap::new(),
-            rate_limiter:          RateLimiter::dashmap(Quota::per_second(nonzero!(
+            subscriber_counter: AtomicUsize::new(0),
+            subscribers: DashMap::new(),
+            rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
                 BYTES_LIMIT_PER_IP_PER_SECOND
             ))),
             bytes_limit_whitelist: whitelist,
+            requester_ip_header_name,
         }
     }
 }
@@ -142,23 +141,33 @@ enum ServerResponseMessage {
 pub async fn ws_route_handler(
     ws: WebSocketUpgrade,
     AxumState(state): AxumState<super::ApiState>,
-    ConnectInfo(addr): ConnectInfo<SocketAddr>,
+    headers: HeaderMap,
 ) -> impl IntoResponse {
+    let requester_ip = headers
+        .get(state.ws.requester_ip_header_name.as_str())
+        .and_then(|value| value.to_str().ok())
+        .and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
+        .and_then(|value| value.parse().ok());
+
     ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
-        .on_upgrade(move |socket| websocket_handler(socket, state, addr))
+        .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
 }
 
-#[tracing::instrument(skip(stream, state, addr))]
-async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
+#[tracing::instrument(skip(stream, state, subscriber_ip))]
+async fn websocket_handler(
+    stream: WebSocket,
+    state: super::ApiState,
+    subscriber_ip: Option<IpAddr>,
+) {
     let ws_state = state.ws.clone();
     let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
-    tracing::debug!(id, %addr, "New Websocket Connection");
+    tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
 
     let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
     let (sender, receiver) = stream.split();
     let mut subscriber = Subscriber::new(
         id,
-        addr.ip(),
+        subscriber_ip,
         state.state.clone(),
         state.ws.clone(),
         notify_receiver,
@@ -176,7 +185,7 @@ pub type SubscriberId = usize;
 /// It listens to the store for updates and sends them to the client.
 pub struct Subscriber {
     id:                      SubscriberId,
-    ip_addr:                 IpAddr,
+    ip_addr:                 Option<IpAddr>,
     closed:                  bool,
     store:                   Arc<State>,
     ws_state:                Arc<WsState>,
@@ -191,7 +200,7 @@ pub struct Subscriber {
 impl Subscriber {
     pub fn new(
         id: SubscriberId,
-        ip_addr: IpAddr,
+        ip_addr: Option<IpAddr>,
         store: Arc<State>,
         ws_state: Arc<WsState>,
         notify_receiver: mpsc::Receiver<AggregationEvent>,
@@ -291,32 +300,36 @@ impl Subscriber {
             })?;
 
             // Close the connection if rate limit is exceeded and the ip is not whitelisted.
-            if !self
-                .ws_state
-                .bytes_limit_whitelist
-                .iter()
-                .any(|ip_net| ip_net.contains(&self.ip_addr))
-                && self.ws_state.rate_limiter.check_key_n(
-                    &self.ip_addr,
-                    NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
-                ) != Ok(Ok(()))
-            {
-                tracing::info!(
-                    self.id,
-                    ip = %self.ip_addr,
-                    "Rate limit exceeded. Closing connection.",
-                );
-                self.sender
-                    .send(
-                        serde_json::to_string(&ServerResponseMessage::Err {
-                            error: "Rate limit exceeded".to_string(),
-                        })?
-                        .into(),
-                    )
-                    .await?;
-                self.sender.close().await?;
-                self.closed = true;
-                return Ok(());
+            // If the ip address is None no rate limiting is applied.
+            if let Some(ip_addr) = self.ip_addr {
+                if !self
+                    .ws_state
+                    .bytes_limit_whitelist
+                    .iter()
+                    .any(|ip_net| ip_net.contains(&ip_addr))
+                    && self.ws_state.rate_limiter.check_key_n(
+                        &ip_addr,
+                        NonZeroU32::new(message.len().try_into()?)
+                            .ok_or(anyhow!("Empty message"))?,
+                    ) != Ok(Ok(()))
+                {
+                    tracing::info!(
+                        self.id,
+                        ip = %ip_addr,
+                        "Rate limit exceeded. Closing connection.",
+                    );
+                    self.sender
+                        .send(
+                            serde_json::to_string(&ServerResponseMessage::Err {
+                                error: "Rate limit exceeded".to_string(),
+                            })?
+                            .into(),
+                        )
+                        .await?;
+                    self.sender.close().await?;
+                    self.closed = true;
+                    return Ok(());
+                }
             }
 
             // `sender.feed` buffers a message to the client but does not flush it, so we can send

+ 7 - 0
hermes/src/config/rpc.rs

@@ -5,6 +5,7 @@ use {
 };
 
 const DEFAULT_RPC_ADDR: &str = "127.0.0.1:33999";
+const DEFAULT_RPC_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For";
 
 #[derive(Args, Clone, Debug)]
 #[command(next_help_heading = "RPC Options")]
@@ -21,4 +22,10 @@ pub struct Options {
     #[arg(value_delimiter = ',')]
     #[arg(env = "RPC_WS_WHITELIST")]
     pub ws_whitelist: Vec<IpNet>,
+
+    /// Header name (case insensitive) to fetch requester IP from.
+    #[arg(long = "rpc-requester-ip-header-name")]
+    #[arg(default_value = DEFAULT_RPC_REQUESTER_IP_HEADER_NAME)]
+    #[arg(env = "RPC_REQUESTER_IP_HEADER_NAME")]
+    pub requester_ip_header_name: String,
 }