jrpc_handle.rs 12 KB


  1. use crate::config::Config;
  2. use crate::lazer_publisher::LazerPublisher;
  3. use crate::websocket_utils::{handle_websocket_error, send_json, send_text};
  4. use anyhow::Error;
  5. use futures::{AsyncRead, AsyncWrite};
  6. use futures_util::io::{BufReader, BufWriter};
  7. use hyper_util::rt::TokioIo;
  8. use pyth_lazer_protocol::jrpc::{
  9. FeedUpdateParams, GetMetadataParams, JrpcCall, JrpcError, JrpcErrorResponse, JrpcId,
  10. JrpcResponse, JrpcSuccessResponse, JsonRpcVersion, PythLazerAgentJrpcV1, SymbolMetadata,
  11. };
  12. use soketto::Sender;
  13. use soketto::handshake::http::Server;
  14. use std::str::FromStr;
  15. use tokio::{pin, select};
  16. use tokio_util::compat::TokioAsyncReadCompatExt;
  17. use tracing::{debug, error, instrument};
  18. use url::Url;
  19. const DEFAULT_HISTORY_SERVICE_URL: &str =
  20. "https://history.pyth-lazer.dourolabs.app/history/v1/symbols";
  21. pub struct JrpcConnectionContext {}
  22. #[instrument(
  23. skip(server, request, lazer_publisher, context),
  24. fields(component = "jrpc_ws")
  25. )]
  26. pub async fn handle_jrpc(
  27. config: Config,
  28. server: Server,
  29. request: hyper::Request<hyper::body::Incoming>,
  30. context: JrpcConnectionContext,
  31. lazer_publisher: LazerPublisher,
  32. ) {
  33. if let Err(err) = try_handle_jrpc(config, server, request, context, lazer_publisher).await {
  34. handle_websocket_error(err);
  35. }
  36. }
  37. #[instrument(
  38. skip(server, request, lazer_publisher, _context),
  39. fields(component = "jrpc_ws")
  40. )]
  41. async fn try_handle_jrpc(
  42. config: Config,
  43. server: Server,
  44. request: hyper::Request<hyper::body::Incoming>,
  45. _context: JrpcConnectionContext,
  46. lazer_publisher: LazerPublisher,
  47. ) -> anyhow::Result<()> {
  48. let stream = hyper::upgrade::on(request).await?;
  49. let io = TokioIo::new(stream);
  50. let stream = BufReader::new(BufWriter::new(io.compat()));
  51. let (mut ws_sender, mut ws_receiver) = server.into_builder(stream).finish();
  52. let mut receive_buf = Vec::new();
  53. loop {
  54. receive_buf.clear();
  55. {
  56. // soketto is not cancel-safe, so we need to store the future and poll it
  57. // in the inner loop.
  58. let receive = async { ws_receiver.receive(&mut receive_buf).await };
  59. pin!(receive);
  60. #[allow(clippy::never_loop, reason = "false positive")] // false positive
  61. loop {
  62. select! {
  63. _result = &mut receive => {
  64. break
  65. }
  66. }
  67. }
  68. }
  69. match handle_jrpc_inner(&config, &mut ws_sender, &mut receive_buf, &lazer_publisher).await {
  70. Ok(_) => {}
  71. Err(err) => {
  72. debug!("Error handling JRPC request: {}", err);
  73. send_text(
  74. &mut ws_sender,
  75. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  76. JrpcErrorResponse {
  77. jsonrpc: JsonRpcVersion::V2,
  78. error: JrpcError::InternalError(err.to_string()).into(),
  79. id: JrpcId::Null,
  80. },
  81. ))?
  82. .as_str(),
  83. )
  84. .await?;
  85. }
  86. }
  87. }
  88. }
  89. async fn handle_jrpc_inner<T: AsyncRead + AsyncWrite + Unpin>(
  90. config: &Config,
  91. sender: &mut Sender<T>,
  92. receive_buf: &mut Vec<u8>,
  93. lazer_publisher: &LazerPublisher,
  94. ) -> anyhow::Result<()> {
  95. match serde_json::from_slice::<PythLazerAgentJrpcV1>(receive_buf.as_slice()) {
  96. Ok(jrpc_request) => match jrpc_request.params {
  97. JrpcCall::PushUpdate(request_params) => {
  98. match lazer_publisher
  99. .push_feed_update(request_params.clone().into())
  100. .await
  101. {
  102. Ok(()) => send_update_success_response(sender, jrpc_request.id).await,
  103. Err(err) => {
  104. send_update_failure_response(sender, request_params, jrpc_request.id, err)
  105. .await
  106. }
  107. }
  108. }
  109. JrpcCall::PushUpdates(request_params_batch) => {
  110. for request_params in request_params_batch {
  111. match lazer_publisher
  112. .push_feed_update(request_params.clone().into())
  113. .await
  114. {
  115. Ok(()) => (),
  116. Err(err) => {
  117. return send_update_failure_response(
  118. sender,
  119. request_params,
  120. jrpc_request.id,
  121. err,
  122. )
  123. .await;
  124. }
  125. }
  126. }
  127. send_update_success_response(sender, jrpc_request.id).await
  128. }
  129. JrpcCall::GetMetadata(request_params) => match jrpc_request.id {
  130. JrpcId::Null => {
  131. send_json(
  132. sender,
  133. &JrpcErrorResponse {
  134. jsonrpc: JsonRpcVersion::V2,
  135. error: JrpcError::ParseError(
  136. "The request to method 'get_metadata' requires an 'id'".to_string(),
  137. )
  138. .into(),
  139. id: JrpcId::Null,
  140. },
  141. )
  142. .await
  143. }
  144. _ => handle_get_metadata(sender, config, request_params, jrpc_request.id).await,
  145. },
  146. },
  147. Err(err) => {
  148. debug!("Error parsing JRPC request: {}", err);
  149. send_json(
  150. sender,
  151. &JrpcErrorResponse {
  152. jsonrpc: JsonRpcVersion::V2,
  153. error: JrpcError::ParseError(err.to_string()).into(),
  154. id: JrpcId::Null,
  155. },
  156. )
  157. .await
  158. }
  159. }
  160. }
  161. async fn get_metadata(config: Config) -> Result<Vec<SymbolMetadata>, Error> {
  162. let result = reqwest::get(
  163. config
  164. .history_service_url
  165. .unwrap_or(Url::from_str(DEFAULT_HISTORY_SERVICE_URL)?),
  166. )
  167. .await?;
  168. if result.status().is_success() {
  169. Ok(serde_json::from_str::<Vec<SymbolMetadata>>(
  170. &result.text().await?,
  171. )?)
  172. } else {
  173. Err(anyhow::anyhow!(
  174. "Error getting metadata (status_code={}, body={})",
  175. result.status(),
  176. result.text().await.unwrap_or("none".to_string())
  177. ))
  178. }
  179. }
  180. fn filter_symbols(
  181. symbols: Vec<SymbolMetadata>,
  182. get_metadata_params: GetMetadataParams,
  183. ) -> Vec<SymbolMetadata> {
  184. let names = &get_metadata_params.names.clone();
  185. let asset_types = &get_metadata_params.asset_types.clone();
  186. let res: Vec<SymbolMetadata> = symbols
  187. .into_iter()
  188. .filter(|symbol| {
  189. if let Some(names) = names {
  190. if !names.contains(&symbol.name) {
  191. return false;
  192. }
  193. }
  194. if let Some(asset_types) = asset_types {
  195. if !asset_types.contains(&symbol.asset_type) {
  196. return false;
  197. }
  198. }
  199. true
  200. })
  201. .collect();
  202. res
  203. }
  204. async fn send_update_success_response<T: AsyncRead + AsyncWrite + Unpin>(
  205. sender: &mut Sender<T>,
  206. request_id: JrpcId,
  207. ) -> anyhow::Result<()> {
  208. match request_id {
  209. JrpcId::Null => Ok(()),
  210. _ => {
  211. send_json(
  212. sender,
  213. &JrpcSuccessResponse::<String> {
  214. jsonrpc: JsonRpcVersion::V2,
  215. result: "success".to_string(),
  216. id: request_id,
  217. },
  218. )
  219. .await
  220. }
  221. }
  222. }
  223. async fn send_update_failure_response<T: AsyncRead + AsyncWrite + Unpin>(
  224. sender: &mut Sender<T>,
  225. request_params: FeedUpdateParams,
  226. request_id: JrpcId,
  227. err: Error,
  228. ) -> anyhow::Result<()> {
  229. debug!("error while sending updates: {:?}", err);
  230. send_json(
  231. sender,
  232. &JrpcErrorResponse {
  233. jsonrpc: JsonRpcVersion::V2,
  234. error: JrpcError::SendUpdateError(request_params).into(),
  235. id: request_id,
  236. },
  237. )
  238. .await
  239. }
  240. async fn handle_get_metadata<T: AsyncRead + AsyncWrite + Unpin>(
  241. sender: &mut Sender<T>,
  242. config: &Config,
  243. request_params: GetMetadataParams,
  244. request_id: JrpcId,
  245. ) -> anyhow::Result<()> {
  246. match get_metadata(config.clone()).await {
  247. Ok(symbols) => {
  248. let symbols = filter_symbols(symbols.clone(), request_params);
  249. send_json(
  250. sender,
  251. &JrpcSuccessResponse::<Vec<SymbolMetadata>> {
  252. jsonrpc: JsonRpcVersion::V2,
  253. result: symbols,
  254. id: request_id,
  255. },
  256. )
  257. .await
  258. }
  259. Err(err) => {
  260. error!("error while retrieving metadata: {:?}", err);
  261. send_json(
  262. sender,
  263. &JrpcErrorResponse {
  264. jsonrpc: JsonRpcVersion::V2,
  265. error: JrpcError::InternalError(err.to_string()).into(),
  266. id: request_id,
  267. },
  268. )
  269. .await
  270. }
  271. }
  272. }
  273. #[cfg(test)]
  274. pub mod tests {
  275. use pyth_lazer_protocol::{PriceFeedId, SymbolState, api::Channel, time::FixedRate};
  276. use super::*;
  277. use std::net::SocketAddr;
  278. fn gen_test_symbol(name: String, asset_type: String) -> SymbolMetadata {
  279. SymbolMetadata {
  280. pyth_lazer_id: PriceFeedId(1),
  281. name,
  282. symbol: "".to_string(),
  283. description: "".to_string(),
  284. asset_type,
  285. exponent: 0,
  286. cmc_id: None,
  287. funding_rate_interval: None,
  288. min_publishers: 0,
  289. min_channel: Channel::FixedRate(FixedRate::MIN),
  290. state: SymbolState::Stable,
  291. hermes_id: None,
  292. quote_currency: None,
  293. nasdaq_symbol: None,
  294. }
  295. }
  296. #[tokio::test]
  297. #[ignore]
  298. async fn test_try_get_metadata() {
  299. let config = Config {
  300. listen_address: SocketAddr::from(([127, 0, 0, 1], 0)),
  301. relayer_urls: vec![],
  302. authorization_token: None,
  303. publish_keypair_path: Default::default(),
  304. publish_interval_duration: Default::default(),
  305. history_service_url: None,
  306. enable_update_deduplication: false,
  307. update_deduplication_ttl: Default::default(),
  308. proxy_url: None,
  309. };
  310. println!("{:?}", get_metadata(config).await.unwrap());
  311. }
  312. #[test]
  313. fn test_filter_symbols() {
  314. let symbol1 = gen_test_symbol("BTC".to_string(), "crypto".to_string());
  315. let symbol2 = gen_test_symbol("XMR".to_string(), "crypto".to_string());
  316. let symbol3 = gen_test_symbol("BTCUSDT".to_string(), "funding-rate".to_string());
  317. let symbols = vec![symbol1.clone(), symbol2.clone(), symbol3.clone()];
  318. // just a name filter
  319. assert_eq!(
  320. filter_symbols(
  321. symbols.clone(),
  322. GetMetadataParams {
  323. names: Some(vec!["XMR".to_string()]),
  324. asset_types: None,
  325. },
  326. ),
  327. vec![symbol2.clone()]
  328. );
  329. // just an asset type filter
  330. assert_eq!(
  331. filter_symbols(
  332. symbols.clone(),
  333. GetMetadataParams {
  334. names: None,
  335. asset_types: Some(vec!["crypto".to_string()]),
  336. },
  337. ),
  338. vec![symbol1.clone(), symbol2.clone()]
  339. );
  340. // name and asset type
  341. assert_eq!(
  342. filter_symbols(
  343. symbols.clone(),
  344. GetMetadataParams {
  345. names: Some(vec!["BTC".to_string()]),
  346. asset_types: Some(vec!["crypto".to_string()]),
  347. },
  348. ),
  349. vec![symbol1.clone()]
  350. );
  351. }
  352. }