ws.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. use {
  2. super::{
  3. types::{PriceIdInput, RpcPriceFeed},
  4. ApiState,
  5. },
  6. crate::state::{
  7. aggregate::{Aggregates, AggregationEvent, RequestTime},
  8. metrics::Metrics,
  9. Benchmarks, Cache, PriceFeedMeta,
  10. },
  11. anyhow::{anyhow, Result},
  12. axum::{
  13. extract::{
  14. ws::{Message, WebSocket, WebSocketUpgrade},
  15. State as AxumState,
  16. },
  17. http::HeaderMap,
  18. response::IntoResponse,
  19. },
  20. futures::{
  21. stream::{SplitSink, SplitStream},
  22. SinkExt, StreamExt,
  23. },
  24. governor::{DefaultKeyedRateLimiter, Quota, RateLimiter},
  25. ipnet::IpNet,
  26. nonzero_ext::nonzero,
  27. prometheus_client::{
  28. encoding::{EncodeLabelSet, EncodeLabelValue},
  29. metrics::{counter::Counter, family::Family},
  30. },
  31. pyth_sdk::PriceIdentifier,
  32. serde::{Deserialize, Serialize},
  33. std::{
  34. collections::HashMap,
  35. net::IpAddr,
  36. num::NonZeroU32,
  37. sync::{
  38. atomic::{AtomicUsize, Ordering},
  39. Arc,
  40. },
  41. time::Duration,
  42. },
  43. tokio::sync::{broadcast::Receiver, watch},
  44. };
  45. const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
  46. const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
  47. /// The maximum number of bytes that can be sent per second per IP address.
  48. /// If the limit is exceeded, the connection is closed.
  49. const BYTES_LIMIT_PER_IP_PER_SECOND: u32 = 256 * 1024; // 256 KiB
  50. #[derive(Clone)]
  51. pub struct PriceFeedClientConfig {
  52. verbose: bool,
  53. binary: bool,
  54. allow_out_of_order: bool,
  55. }
  56. #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelValue)]
  57. pub enum Interaction {
  58. NewConnection,
  59. CloseConnection,
  60. ClientHeartbeat,
  61. PriceUpdate,
  62. ClientMessage,
  63. RateLimit,
  64. }
  65. #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelValue)]
  66. pub enum Status {
  67. Success,
  68. Error,
  69. }
  70. #[derive(Clone, Debug, PartialEq, Eq, Hash, EncodeLabelSet)]
  71. pub struct Labels {
  72. pub interaction: Interaction,
  73. pub status: Status,
  74. }
  75. pub struct WsMetrics {
  76. pub interactions: Family<Labels, Counter>,
  77. }
  78. impl WsMetrics {
  79. pub fn new<S>(state: Arc<S>) -> Self
  80. where
  81. S: Metrics,
  82. S: Send + Sync + 'static,
  83. {
  84. let new = Self {
  85. interactions: Family::default(),
  86. };
  87. {
  88. let interactions = new.interactions.clone();
  89. tokio::spawn(async move {
  90. Metrics::register(
  91. &*state,
  92. (
  93. "ws_interactions",
  94. "Total number of websocket interactions",
  95. interactions,
  96. ),
  97. )
  98. .await;
  99. });
  100. }
  101. new
  102. }
  103. }
  104. pub struct WsState {
  105. pub subscriber_counter: AtomicUsize,
  106. pub bytes_limit_whitelist: Vec<IpNet>,
  107. pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
  108. pub requester_ip_header_name: String,
  109. pub metrics: WsMetrics,
  110. }
  111. impl WsState {
  112. pub fn new<S>(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<S>) -> Self
  113. where
  114. S: Metrics,
  115. S: Send + Sync + 'static,
  116. {
  117. Self {
  118. subscriber_counter: AtomicUsize::new(0),
  119. rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
  120. BYTES_LIMIT_PER_IP_PER_SECOND
  121. ))),
  122. bytes_limit_whitelist: whitelist,
  123. requester_ip_header_name,
  124. metrics: WsMetrics::new(state.clone()),
  125. }
  126. }
  127. }
  128. #[derive(Deserialize, Debug, Clone)]
  129. #[serde(tag = "type")]
  130. enum ClientMessage {
  131. #[serde(rename = "subscribe")]
  132. Subscribe {
  133. ids: Vec<PriceIdInput>,
  134. #[serde(default)]
  135. verbose: bool,
  136. #[serde(default)]
  137. binary: bool,
  138. #[serde(default)]
  139. allow_out_of_order: bool,
  140. },
  141. #[serde(rename = "unsubscribe")]
  142. Unsubscribe { ids: Vec<PriceIdInput> },
  143. }
  144. #[derive(Serialize, Debug, Clone)]
  145. #[serde(tag = "type")]
  146. enum ServerMessage {
  147. #[serde(rename = "response")]
  148. Response(ServerResponseMessage),
  149. #[serde(rename = "price_update")]
  150. PriceUpdate { price_feed: RpcPriceFeed },
  151. }
  152. #[derive(Serialize, Debug, Clone)]
  153. #[serde(tag = "status")]
  154. enum ServerResponseMessage {
  155. #[serde(rename = "success")]
  156. Success,
  157. #[serde(rename = "error")]
  158. Err { error: String },
  159. }
  160. pub async fn ws_route_handler<S>(
  161. ws: WebSocketUpgrade,
  162. AxumState(state): AxumState<ApiState<S>>,
  163. headers: HeaderMap,
  164. ) -> impl IntoResponse
  165. where
  166. S: Aggregates,
  167. S: Benchmarks,
  168. S: Cache,
  169. S: PriceFeedMeta,
  170. S: Send + Sync + 'static,
  171. {
  172. let requester_ip = headers
  173. .get(state.ws.requester_ip_header_name.as_str())
  174. .and_then(|value| value.to_str().ok())
  175. .and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
  176. .and_then(|value| value.parse().ok());
  177. ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
  178. .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
  179. }
  180. #[tracing::instrument(skip(stream, state, subscriber_ip))]
  181. async fn websocket_handler<S>(stream: WebSocket, state: ApiState<S>, subscriber_ip: Option<IpAddr>)
  182. where
  183. S: Aggregates,
  184. S: Send,
  185. {
  186. let ws_state = state.ws.clone();
  187. // Retain the recent rate limit data for the IP addresses to
  188. // prevent the rate limiter size from growing indefinitely.
  189. ws_state.rate_limiter.retain_recent();
  190. let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
  191. tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
  192. ws_state
  193. .metrics
  194. .interactions
  195. .get_or_create(&Labels {
  196. interaction: Interaction::NewConnection,
  197. status: Status::Success,
  198. })
  199. .inc();
  200. let notify_receiver = Aggregates::subscribe(&*state.state);
  201. let (sender, receiver) = stream.split();
  202. let mut subscriber = Subscriber::new(
  203. id,
  204. subscriber_ip,
  205. state.state.clone(),
  206. state.ws.clone(),
  207. notify_receiver,
  208. receiver,
  209. sender,
  210. );
  211. subscriber.run().await;
  212. }
  213. pub type SubscriberId = usize;
  214. /// Subscriber is an actor that handles a single websocket connection.
  215. /// It listens to the store for updates and sends them to the client.
  216. pub struct Subscriber<S> {
  217. id: SubscriberId,
  218. ip_addr: Option<IpAddr>,
  219. closed: bool,
  220. state: Arc<S>,
  221. ws_state: Arc<WsState>,
  222. notify_receiver: Receiver<AggregationEvent>,
  223. receiver: SplitStream<WebSocket>,
  224. sender: SplitSink<WebSocket, Message>,
  225. price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
  226. ping_interval: tokio::time::Interval,
  227. exit: watch::Receiver<bool>,
  228. responded_to_ping: bool,
  229. }
  230. impl<S> Subscriber<S>
  231. where
  232. S: Aggregates,
  233. {
  234. pub fn new(
  235. id: SubscriberId,
  236. ip_addr: Option<IpAddr>,
  237. state: Arc<S>,
  238. ws_state: Arc<WsState>,
  239. notify_receiver: Receiver<AggregationEvent>,
  240. receiver: SplitStream<WebSocket>,
  241. sender: SplitSink<WebSocket, Message>,
  242. ) -> Self {
  243. Self {
  244. id,
  245. ip_addr,
  246. closed: false,
  247. state,
  248. ws_state,
  249. notify_receiver,
  250. receiver,
  251. sender,
  252. price_feeds_with_config: HashMap::new(),
  253. ping_interval: tokio::time::interval(PING_INTERVAL_DURATION),
  254. exit: crate::EXIT.subscribe(),
  255. responded_to_ping: true, // We start with true so we don't close the connection immediately
  256. }
  257. }
  258. #[tracing::instrument(skip(self))]
  259. pub async fn run(&mut self) {
  260. while !self.closed {
  261. if let Err(e) = self.handle_next().await {
  262. tracing::debug!(subscriber = self.id, error = ?e, "Error Handling Subscriber Message.");
  263. break;
  264. }
  265. }
  266. }
  267. async fn handle_next(&mut self) -> Result<()> {
  268. tokio::select! {
  269. maybe_update_feeds_event = self.notify_receiver.recv() => {
  270. match maybe_update_feeds_event {
  271. Ok(event) => self.handle_price_feeds_update(event).await,
  272. Err(e) => Err(anyhow!("Failed to receive update from store: {:?}", e)),
  273. }
  274. },
  275. maybe_message_or_err = self.receiver.next() => {
  276. self.handle_client_message(
  277. maybe_message_or_err.ok_or(anyhow!("Client channel is closed"))??
  278. ).await
  279. },
  280. _ = self.ping_interval.tick() => {
  281. if !self.responded_to_ping {
  282. self.ws_state
  283. .metrics
  284. .interactions
  285. .get_or_create(&Labels {
  286. interaction: Interaction::ClientHeartbeat,
  287. status: Status::Error,
  288. })
  289. .inc();
  290. return Err(anyhow!("Subscriber did not respond to ping. Closing connection."));
  291. }
  292. self.responded_to_ping = false;
  293. self.sender.send(Message::Ping(vec![])).await?;
  294. Ok(())
  295. },
  296. _ = self.exit.changed() => {
  297. self.sender.close().await?;
  298. self.closed = true;
  299. Err(anyhow!("Application is shutting down. Closing connection."))
  300. }
  301. }
  302. }
  303. async fn handle_price_feeds_update(&mut self, event: AggregationEvent) -> Result<()> {
  304. let price_feed_ids = self
  305. .price_feeds_with_config
  306. .keys()
  307. .cloned()
  308. .collect::<Vec<_>>();
  309. let state = &*self.state;
  310. let updates = match Aggregates::get_price_feeds_with_update_data(
  311. state,
  312. &price_feed_ids,
  313. RequestTime::AtSlot(event.slot()),
  314. )
  315. .await
  316. {
  317. Ok(updates) => updates,
  318. Err(_) => {
  319. // The error can only happen when a price feed was available
  320. // and is no longer there as we check the price feed ids upon
  321. // subscription. In this case we just remove the non-existing
  322. // price feed from the list and will keep sending updates for
  323. // the rest.
  324. let available_price_feed_ids = Aggregates::get_price_feed_ids(state).await;
  325. self.price_feeds_with_config
  326. .retain(|price_feed_id, _| available_price_feed_ids.contains(price_feed_id));
  327. let price_feed_ids = self
  328. .price_feeds_with_config
  329. .keys()
  330. .cloned()
  331. .collect::<Vec<_>>();
  332. Aggregates::get_price_feeds_with_update_data(
  333. state,
  334. &price_feed_ids,
  335. RequestTime::AtSlot(event.slot()),
  336. )
  337. .await?
  338. }
  339. };
  340. for update in updates.price_feeds {
  341. let config = self
  342. .price_feeds_with_config
  343. .get(&update.price_feed.id)
  344. .ok_or(anyhow::anyhow!(
  345. "Config missing, price feed list was poisoned during iteration."
  346. ))?;
  347. if let AggregationEvent::OutOfOrder { slot: _ } = event {
  348. if !config.allow_out_of_order {
  349. continue;
  350. }
  351. }
  352. let message = serde_json::to_string(&ServerMessage::PriceUpdate {
  353. price_feed: RpcPriceFeed::from_price_feed_update(
  354. update,
  355. config.verbose,
  356. config.binary,
  357. ),
  358. })?;
  359. // Close the connection if rate limit is exceeded and the ip is not whitelisted.
  360. // If the ip address is None no rate limiting is applied.
  361. if let Some(ip_addr) = self.ip_addr {
  362. if !self
  363. .ws_state
  364. .bytes_limit_whitelist
  365. .iter()
  366. .any(|ip_net| ip_net.contains(&ip_addr))
  367. && self.ws_state.rate_limiter.check_key_n(
  368. &ip_addr,
  369. NonZeroU32::new(message.len().try_into()?)
  370. .ok_or(anyhow!("Empty message"))?,
  371. ) != Ok(Ok(()))
  372. {
  373. tracing::info!(
  374. self.id,
  375. ip = %ip_addr,
  376. "Rate limit exceeded. Closing connection.",
  377. );
  378. self.ws_state
  379. .metrics
  380. .interactions
  381. .get_or_create(&Labels {
  382. interaction: Interaction::RateLimit,
  383. status: Status::Error,
  384. })
  385. .inc();
  386. self.sender
  387. .send(
  388. serde_json::to_string(&ServerResponseMessage::Err {
  389. error: "Rate limit exceeded".to_string(),
  390. })?
  391. .into(),
  392. )
  393. .await?;
  394. self.sender.close().await?;
  395. self.closed = true;
  396. return Ok(());
  397. }
  398. }
  399. // `sender.feed` buffers a message to the client but does not flush it, so we can send
  400. // multiple messages and flush them all at once.
  401. self.sender.feed(message.into()).await?;
  402. self.ws_state
  403. .metrics
  404. .interactions
  405. .get_or_create(&Labels {
  406. interaction: Interaction::PriceUpdate,
  407. status: Status::Success,
  408. })
  409. .inc();
  410. }
  411. self.sender.flush().await?;
  412. Ok(())
  413. }
  414. #[tracing::instrument(skip(self, message))]
  415. async fn handle_client_message(&mut self, message: Message) -> Result<()> {
  416. let maybe_client_message = match message {
  417. Message::Close(_) => {
  418. // Closing the connection. We don't remove it from the subscribers
  419. // list, instead when the Subscriber struct is dropped the channel
  420. // to subscribers list will be closed and it will eventually get
  421. // removed.
  422. tracing::trace!(id = self.id, "Subscriber Closed Connection.");
  423. self.ws_state
  424. .metrics
  425. .interactions
  426. .get_or_create(&Labels {
  427. interaction: Interaction::CloseConnection,
  428. status: Status::Success,
  429. })
  430. .inc();
  431. // Send the close message to gracefully shut down the connection
  432. // Otherwise the client might get an abnormal Websocket closure
  433. // error.
  434. self.sender.close().await?;
  435. self.closed = true;
  436. return Ok(());
  437. }
  438. Message::Text(text) => serde_json::from_str::<ClientMessage>(&text),
  439. Message::Binary(data) => serde_json::from_slice::<ClientMessage>(&data),
  440. Message::Ping(_) => {
  441. // Axum will send Pong automatically
  442. return Ok(());
  443. }
  444. Message::Pong(_) => {
  445. // This metric can be used to monitor the number of active connections
  446. self.ws_state
  447. .metrics
  448. .interactions
  449. .get_or_create(&Labels {
  450. interaction: Interaction::ClientHeartbeat,
  451. status: Status::Success,
  452. })
  453. .inc();
  454. self.responded_to_ping = true;
  455. return Ok(());
  456. }
  457. };
  458. match maybe_client_message {
  459. Err(e) => {
  460. self.ws_state
  461. .metrics
  462. .interactions
  463. .get_or_create(&Labels {
  464. interaction: Interaction::ClientMessage,
  465. status: Status::Error,
  466. })
  467. .inc();
  468. self.sender
  469. .send(
  470. serde_json::to_string(&ServerMessage::Response(
  471. ServerResponseMessage::Err {
  472. error: e.to_string(),
  473. },
  474. ))?
  475. .into(),
  476. )
  477. .await?;
  478. return Ok(());
  479. }
  480. Ok(ClientMessage::Subscribe {
  481. ids,
  482. verbose,
  483. binary,
  484. allow_out_of_order,
  485. }) => {
  486. let price_ids: Vec<PriceIdentifier> = ids.into_iter().map(|id| id.into()).collect();
  487. let available_price_ids = Aggregates::get_price_feed_ids(&*self.state).await;
  488. let not_found_price_ids: Vec<&PriceIdentifier> = price_ids
  489. .iter()
  490. .filter(|price_id| !available_price_ids.contains(price_id))
  491. .collect();
  492. // If there is a single price id that is not found, we don't subscribe to any of the
  493. // asked correct price feed ids and return an error to be more explicit and clear.
  494. if !not_found_price_ids.is_empty() {
  495. self.sender
  496. .send(
  497. serde_json::to_string(&ServerMessage::Response(
  498. ServerResponseMessage::Err {
  499. error: format!(
  500. "Price feed(s) with id(s) {:?} not found",
  501. not_found_price_ids
  502. ),
  503. },
  504. ))?
  505. .into(),
  506. )
  507. .await?;
  508. return Ok(());
  509. } else {
  510. for price_id in price_ids {
  511. self.price_feeds_with_config.insert(
  512. price_id,
  513. PriceFeedClientConfig {
  514. verbose,
  515. binary,
  516. allow_out_of_order,
  517. },
  518. );
  519. }
  520. }
  521. }
  522. Ok(ClientMessage::Unsubscribe { ids }) => {
  523. for id in ids {
  524. let price_id: PriceIdentifier = id.into();
  525. self.price_feeds_with_config.remove(&price_id);
  526. }
  527. }
  528. }
  529. self.ws_state
  530. .metrics
  531. .interactions
  532. .get_or_create(&Labels {
  533. interaction: Interaction::ClientMessage,
  534. status: Status::Success,
  535. })
  536. .inc();
  537. self.sender
  538. .send(
  539. serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Success))?
  540. .into(),
  541. )
  542. .await?;
  543. Ok(())
  544. }
  545. }