jrpc_handle.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. use crate::config::Config;
  2. use crate::lazer_publisher::LazerPublisher;
  3. use crate::websocket_utils::{handle_websocket_error, send_json, send_text};
  4. use anyhow::Error;
  5. use futures::{AsyncRead, AsyncWrite};
  6. use futures_util::io::{BufReader, BufWriter};
  7. use hyper_util::rt::TokioIo;
  8. use pyth_lazer_protocol::jrpc::{
  9. FeedUpdateParams, GetMetadataParams, JrpcCall, JrpcError, JrpcErrorResponse, JrpcId,
  10. JrpcResponse, JrpcSuccessResponse, JsonRpcVersion, PythLazerAgentJrpcV1, SymbolMetadata,
  11. };
  12. use soketto::Sender;
  13. use soketto::handshake::http::Server;
  14. use std::str::FromStr;
  15. use tokio::{pin, select};
  16. use tokio_util::compat::TokioAsyncReadCompatExt;
  17. use tracing::{debug, error, instrument};
  18. use url::Url;
  19. const DEFAULT_HISTORY_SERVICE_URL: &str =
  20. "https://history.pyth-lazer.dourolabs.app/history/v1/symbols";
  21. pub struct JrpcConnectionContext {}
  22. #[instrument(
  23. skip(server, request, lazer_publisher, context),
  24. fields(component = "jrpc_ws")
  25. )]
  26. pub async fn handle_jrpc(
  27. config: Config,
  28. server: Server,
  29. request: hyper::Request<hyper::body::Incoming>,
  30. context: JrpcConnectionContext,
  31. lazer_publisher: LazerPublisher,
  32. ) {
  33. if let Err(err) = try_handle_jrpc(config, server, request, context, lazer_publisher).await {
  34. handle_websocket_error(err);
  35. }
  36. }
  37. #[instrument(
  38. skip(server, request, lazer_publisher, _context),
  39. fields(component = "jrpc_ws")
  40. )]
  41. async fn try_handle_jrpc(
  42. config: Config,
  43. server: Server,
  44. request: hyper::Request<hyper::body::Incoming>,
  45. _context: JrpcConnectionContext,
  46. lazer_publisher: LazerPublisher,
  47. ) -> anyhow::Result<()> {
  48. let stream = hyper::upgrade::on(request).await?;
  49. let io = TokioIo::new(stream);
  50. let stream = BufReader::new(BufWriter::new(io.compat()));
  51. let (mut ws_sender, mut ws_receiver) = server.into_builder(stream).finish();
  52. let mut receive_buf = Vec::new();
  53. loop {
  54. receive_buf.clear();
  55. {
  56. // soketto is not cancel-safe, so we need to store the future and poll it
  57. // in the inner loop.
  58. let receive = async { ws_receiver.receive(&mut receive_buf).await };
  59. pin!(receive);
  60. #[allow(clippy::never_loop, reason = "false positive")] // false positive
  61. loop {
  62. select! {
  63. _result = &mut receive => {
  64. break
  65. }
  66. }
  67. }
  68. }
  69. match handle_jrpc_inner(&config, &mut ws_sender, &mut receive_buf, &lazer_publisher).await {
  70. Ok(_) => {}
  71. Err(err) => {
  72. debug!("Error handling JRPC request: {}", err);
  73. send_text(
  74. &mut ws_sender,
  75. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  76. JrpcErrorResponse {
  77. jsonrpc: JsonRpcVersion::V2,
  78. error: JrpcError::InternalError(err.to_string()).into(),
  79. id: JrpcId::Null,
  80. },
  81. ))?
  82. .as_str(),
  83. )
  84. .await?;
  85. }
  86. }
  87. }
  88. }
  89. async fn handle_jrpc_inner<T: AsyncRead + AsyncWrite + Unpin>(
  90. config: &Config,
  91. sender: &mut Sender<T>,
  92. receive_buf: &mut Vec<u8>,
  93. lazer_publisher: &LazerPublisher,
  94. ) -> anyhow::Result<()> {
  95. match serde_json::from_slice::<PythLazerAgentJrpcV1>(receive_buf.as_slice()) {
  96. Ok(jrpc_request) => match jrpc_request.params {
  97. JrpcCall::PushUpdate(request_params) => {
  98. handle_push_update(
  99. sender,
  100. lazer_publisher,
  101. request_params,
  102. jrpc_request.id.clone(),
  103. )
  104. .await
  105. }
  106. JrpcCall::PushUpdates(request_params) => {
  107. for feed in request_params {
  108. handle_push_update(sender, lazer_publisher, feed, jrpc_request.id.clone())
  109. .await?;
  110. }
  111. Ok(())
  112. }
  113. JrpcCall::GetMetadata(request_params) => match jrpc_request.id {
  114. JrpcId::Null => {
  115. send_json(
  116. sender,
  117. &JrpcErrorResponse {
  118. jsonrpc: JsonRpcVersion::V2,
  119. error: JrpcError::ParseError(
  120. "The request to method 'get_metadata' requires an 'id'".to_string(),
  121. )
  122. .into(),
  123. id: JrpcId::Null,
  124. },
  125. )
  126. .await
  127. }
  128. _ => handle_get_metadata(sender, config, request_params, jrpc_request.id).await,
  129. },
  130. },
  131. Err(err) => {
  132. debug!("Error parsing JRPC request: {}", err);
  133. send_json(
  134. sender,
  135. &JrpcErrorResponse {
  136. jsonrpc: JsonRpcVersion::V2,
  137. error: JrpcError::ParseError(err.to_string()).into(),
  138. id: JrpcId::Null,
  139. },
  140. )
  141. .await
  142. }
  143. }
  144. }
  145. async fn get_metadata(config: Config) -> Result<Vec<SymbolMetadata>, Error> {
  146. let result = reqwest::get(
  147. config
  148. .history_service_url
  149. .unwrap_or(Url::from_str(DEFAULT_HISTORY_SERVICE_URL)?),
  150. )
  151. .await?;
  152. if result.status().is_success() {
  153. Ok(serde_json::from_str::<Vec<SymbolMetadata>>(
  154. &result.text().await?,
  155. )?)
  156. } else {
  157. Err(anyhow::anyhow!(
  158. "Error getting metadata (status_code={}, body={})",
  159. result.status(),
  160. result.text().await.unwrap_or("none".to_string())
  161. ))
  162. }
  163. }
  164. fn filter_symbols(
  165. symbols: Vec<SymbolMetadata>,
  166. get_metadata_params: GetMetadataParams,
  167. ) -> Vec<SymbolMetadata> {
  168. let names = &get_metadata_params.names.clone();
  169. let asset_types = &get_metadata_params.asset_types.clone();
  170. let res: Vec<SymbolMetadata> = symbols
  171. .into_iter()
  172. .filter(|symbol| {
  173. if let Some(names) = names {
  174. if !names.contains(&symbol.name) {
  175. return false;
  176. }
  177. }
  178. if let Some(asset_types) = asset_types {
  179. if !asset_types.contains(&symbol.asset_type) {
  180. return false;
  181. }
  182. }
  183. true
  184. })
  185. .collect();
  186. res
  187. }
  188. async fn handle_push_update<T: AsyncRead + AsyncWrite + Unpin>(
  189. sender: &mut Sender<T>,
  190. lazer_publisher: &LazerPublisher,
  191. request_params: FeedUpdateParams,
  192. request_id: JrpcId,
  193. ) -> anyhow::Result<()> {
  194. match lazer_publisher
  195. .push_feed_update(request_params.clone().into())
  196. .await
  197. {
  198. Ok(_) => match request_id {
  199. JrpcId::Null => Ok(()),
  200. _ => {
  201. send_json(
  202. sender,
  203. &JrpcSuccessResponse::<String> {
  204. jsonrpc: JsonRpcVersion::V2,
  205. result: "success".to_string(),
  206. id: request_id,
  207. },
  208. )
  209. .await
  210. }
  211. },
  212. Err(err) => {
  213. debug!("error while sending updates: {:?}", err);
  214. send_json(
  215. sender,
  216. &JrpcErrorResponse {
  217. jsonrpc: JsonRpcVersion::V2,
  218. error: JrpcError::SendUpdateError(request_params).into(),
  219. id: request_id,
  220. },
  221. )
  222. .await
  223. }
  224. }
  225. }
  226. async fn handle_get_metadata<T: AsyncRead + AsyncWrite + Unpin>(
  227. sender: &mut Sender<T>,
  228. config: &Config,
  229. request_params: GetMetadataParams,
  230. request_id: JrpcId,
  231. ) -> anyhow::Result<()> {
  232. match get_metadata(config.clone()).await {
  233. Ok(symbols) => {
  234. let symbols = filter_symbols(symbols.clone(), request_params);
  235. send_json(
  236. sender,
  237. &JrpcSuccessResponse::<Vec<SymbolMetadata>> {
  238. jsonrpc: JsonRpcVersion::V2,
  239. result: symbols,
  240. id: request_id,
  241. },
  242. )
  243. .await
  244. }
  245. Err(err) => {
  246. error!("error while retrieving metadata: {:?}", err);
  247. send_json(
  248. sender,
  249. &JrpcErrorResponse {
  250. jsonrpc: JsonRpcVersion::V2,
  251. error: JrpcError::InternalError(err.to_string()).into(),
  252. id: request_id,
  253. },
  254. )
  255. .await
  256. }
  257. }
  258. }
  259. #[cfg(test)]
  260. pub mod tests {
  261. use pyth_lazer_protocol::{PriceFeedId, SymbolState, api::Channel, time::FixedRate};
  262. use super::*;
  263. use std::net::SocketAddr;
  264. fn gen_test_symbol(name: String, asset_type: String) -> SymbolMetadata {
  265. SymbolMetadata {
  266. pyth_lazer_id: PriceFeedId(1),
  267. name,
  268. symbol: "".to_string(),
  269. description: "".to_string(),
  270. asset_type,
  271. exponent: 0,
  272. cmc_id: None,
  273. funding_rate_interval: None,
  274. min_publishers: 0,
  275. min_channel: Channel::FixedRate(FixedRate::MIN),
  276. state: SymbolState::Stable,
  277. hermes_id: None,
  278. quote_currency: None,
  279. }
  280. }
  281. #[tokio::test]
  282. #[ignore]
  283. async fn test_try_get_metadata() {
  284. let config = Config {
  285. listen_address: SocketAddr::from(([127, 0, 0, 1], 0)),
  286. relayer_urls: vec![],
  287. authorization_token: None,
  288. publish_keypair_path: Default::default(),
  289. publish_interval_duration: Default::default(),
  290. history_service_url: None,
  291. enable_update_deduplication: false,
  292. update_deduplication_ttl: Default::default(),
  293. };
  294. println!("{:?}", get_metadata(config).await.unwrap());
  295. }
  296. #[test]
  297. fn test_filter_symbols() {
  298. let symbol1 = gen_test_symbol("BTC".to_string(), "crypto".to_string());
  299. let symbol2 = gen_test_symbol("XMR".to_string(), "crypto".to_string());
  300. let symbol3 = gen_test_symbol("BTCUSDT".to_string(), "funding-rate".to_string());
  301. let symbols = vec![symbol1.clone(), symbol2.clone(), symbol3.clone()];
  302. // just a name filter
  303. assert_eq!(
  304. filter_symbols(
  305. symbols.clone(),
  306. GetMetadataParams {
  307. names: Some(vec!["XMR".to_string()]),
  308. asset_types: None,
  309. },
  310. ),
  311. vec![symbol2.clone()]
  312. );
  313. // just an asset type filter
  314. assert_eq!(
  315. filter_symbols(
  316. symbols.clone(),
  317. GetMetadataParams {
  318. names: None,
  319. asset_types: Some(vec!["crypto".to_string()]),
  320. },
  321. ),
  322. vec![symbol1.clone(), symbol2.clone()]
  323. );
  324. // name and asset type
  325. assert_eq!(
  326. filter_symbols(
  327. symbols.clone(),
  328. GetMetadataParams {
  329. names: Some(vec!["BTC".to_string()]),
  330. asset_types: Some(vec!["crypto".to_string()]),
  331. },
  332. ),
  333. vec![symbol1.clone()]
  334. );
  335. }
  336. }