jrpc_handle.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. use crate::config::Config;
  2. use crate::lazer_publisher::LazerPublisher;
  3. use crate::websocket_utils::{handle_websocket_error, send_text};
  4. use anyhow::Error;
  5. use futures::{AsyncRead, AsyncWrite};
  6. use futures_util::io::{BufReader, BufWriter};
  7. use hyper_util::rt::TokioIo;
  8. use pyth_lazer_protocol::jrpc::{
  9. GetMetadataParams, JrpcCall, JrpcError, JrpcErrorResponse, JrpcResponse, JrpcSuccessResponse,
  10. JsonRpcVersion, PythLazerAgentJrpcV1, SymbolMetadata,
  11. };
  12. use soketto::Sender;
  13. use soketto::handshake::http::Server;
  14. use std::str::FromStr;
  15. use tokio::{pin, select};
  16. use tokio_util::compat::TokioAsyncReadCompatExt;
  17. use tracing::{debug, error, instrument};
  18. use url::Url;
  19. const DEFAULT_HISTORY_SERVICE_URL: &str =
  20. "https://history.pyth-lazer.dourolabs.app/history/v1/symbols";
  21. pub struct JrpcConnectionContext {}
  22. #[instrument(
  23. skip(server, request, lazer_publisher, context),
  24. fields(component = "jrpc_ws")
  25. )]
  26. pub async fn handle_jrpc(
  27. config: Config,
  28. server: Server,
  29. request: hyper::Request<hyper::body::Incoming>,
  30. context: JrpcConnectionContext,
  31. lazer_publisher: LazerPublisher,
  32. ) {
  33. if let Err(err) = try_handle_jrpc(config, server, request, context, lazer_publisher).await {
  34. handle_websocket_error(err);
  35. }
  36. }
  37. #[instrument(
  38. skip(server, request, lazer_publisher, _context),
  39. fields(component = "jrpc_ws")
  40. )]
  41. async fn try_handle_jrpc(
  42. config: Config,
  43. server: Server,
  44. request: hyper::Request<hyper::body::Incoming>,
  45. _context: JrpcConnectionContext,
  46. lazer_publisher: LazerPublisher,
  47. ) -> anyhow::Result<()> {
  48. let stream = hyper::upgrade::on(request).await?;
  49. let io = TokioIo::new(stream);
  50. let stream = BufReader::new(BufWriter::new(io.compat()));
  51. let (mut ws_sender, mut ws_receiver) = server.into_builder(stream).finish();
  52. let mut receive_buf = Vec::new();
  53. loop {
  54. receive_buf.clear();
  55. {
  56. // soketto is not cancel-safe, so we need to store the future and poll it
  57. // in the inner loop.
  58. let receive = async { ws_receiver.receive(&mut receive_buf).await };
  59. pin!(receive);
  60. #[allow(clippy::never_loop, reason = "false positive")] // false positive
  61. loop {
  62. select! {
  63. _result = &mut receive => {
  64. break
  65. }
  66. }
  67. }
  68. }
  69. match handle_jrpc_inner(&config, &mut ws_sender, &mut receive_buf, &lazer_publisher).await {
  70. Ok(_) => {}
  71. Err(err) => {
  72. debug!("Error handling JRPC request: {}", err);
  73. send_text(
  74. &mut ws_sender,
  75. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  76. JrpcErrorResponse {
  77. jsonrpc: JsonRpcVersion::V2,
  78. error: JrpcError::InternalError.into(),
  79. id: None,
  80. },
  81. ))?
  82. .as_str(),
  83. )
  84. .await?;
  85. }
  86. }
  87. }
  88. }
  89. async fn handle_jrpc_inner<T: AsyncRead + AsyncWrite + Unpin>(
  90. config: &Config,
  91. sender: &mut Sender<T>,
  92. receive_buf: &mut Vec<u8>,
  93. lazer_publisher: &LazerPublisher,
  94. ) -> anyhow::Result<()> {
  95. match serde_json::from_slice::<PythLazerAgentJrpcV1>(receive_buf.as_slice()) {
  96. Ok(jrpc_request) => match jrpc_request.params {
  97. JrpcCall::PushUpdate(request_params) => {
  98. match lazer_publisher
  99. .push_feed_update(request_params.into())
  100. .await
  101. {
  102. Ok(_) => {
  103. send_text(
  104. sender,
  105. serde_json::to_string::<JrpcResponse<String>>(&JrpcResponse::Success(
  106. JrpcSuccessResponse::<String> {
  107. jsonrpc: JsonRpcVersion::V2,
  108. result: "success".to_string(),
  109. id: jrpc_request.id,
  110. },
  111. ))?
  112. .as_str(),
  113. )
  114. .await?;
  115. }
  116. Err(err) => {
  117. debug!("error while sending updates: {:?}", err);
  118. send_text(
  119. sender,
  120. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  121. JrpcErrorResponse {
  122. jsonrpc: JsonRpcVersion::V2,
  123. error: JrpcError::InternalError.into(),
  124. id: Some(jrpc_request.id),
  125. },
  126. ))?
  127. .as_str(),
  128. )
  129. .await?;
  130. }
  131. }
  132. }
  133. JrpcCall::GetMetadata(request_params) => match get_metadata(config.clone()).await {
  134. Ok(symbols) => {
  135. let symbols = filter_symbols(symbols.clone(), request_params);
  136. send_text(
  137. sender,
  138. serde_json::to_string::<JrpcResponse<Vec<SymbolMetadata>>>(
  139. &JrpcResponse::Success(JrpcSuccessResponse::<Vec<SymbolMetadata>> {
  140. jsonrpc: JsonRpcVersion::V2,
  141. result: symbols,
  142. id: jrpc_request.id,
  143. }),
  144. )?
  145. .as_str(),
  146. )
  147. .await?;
  148. }
  149. Err(err) => {
  150. error!("error while retrieving metadata: {:?}", err);
  151. send_text(
  152. sender,
  153. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  154. JrpcErrorResponse {
  155. jsonrpc: JsonRpcVersion::V2,
  156. // note: right now specifying an invalid method results in a parse error
  157. error: JrpcError::InternalError.into(),
  158. id: None,
  159. },
  160. ))?
  161. .as_str(),
  162. )
  163. .await?;
  164. }
  165. },
  166. },
  167. Err(err) => {
  168. debug!("Error parsing JRPC request: {}", err);
  169. send_text(
  170. sender,
  171. serde_json::to_string::<JrpcResponse<()>>(&JrpcResponse::Error(
  172. JrpcErrorResponse {
  173. jsonrpc: JsonRpcVersion::V2,
  174. error: JrpcError::ParseError(err.to_string()).into(),
  175. id: None,
  176. },
  177. ))?
  178. .as_str(),
  179. )
  180. .await?;
  181. }
  182. }
  183. Ok(())
  184. }
  185. async fn get_metadata(config: Config) -> Result<Vec<SymbolMetadata>, Error> {
  186. let result = reqwest::get(
  187. config
  188. .history_service_url
  189. .unwrap_or(Url::from_str(DEFAULT_HISTORY_SERVICE_URL)?),
  190. )
  191. .await?;
  192. if result.status().is_success() {
  193. Ok(serde_json::from_str::<Vec<SymbolMetadata>>(
  194. &result.text().await?,
  195. )?)
  196. } else {
  197. Err(anyhow::anyhow!(
  198. "Error getting metadata (status_code={}, body={})",
  199. result.status(),
  200. result.text().await.unwrap_or("none".to_string())
  201. ))
  202. }
  203. }
  204. fn filter_symbols(
  205. symbols: Vec<SymbolMetadata>,
  206. get_metadata_params: GetMetadataParams,
  207. ) -> Vec<SymbolMetadata> {
  208. let names = &get_metadata_params.names.clone();
  209. let asset_types = &get_metadata_params.asset_types.clone();
  210. let res: Vec<SymbolMetadata> = symbols
  211. .into_iter()
  212. .filter(|symbol| {
  213. if let Some(names) = names {
  214. if !names.contains(&symbol.name) {
  215. return false;
  216. }
  217. }
  218. if let Some(asset_types) = asset_types {
  219. if !asset_types.contains(&symbol.asset_type) {
  220. return false;
  221. }
  222. }
  223. true
  224. })
  225. .collect();
  226. res
  227. }
  228. #[cfg(test)]
  229. pub mod tests {
  230. use super::*;
  231. use pyth_lazer_protocol::router::{Channel, FixedRate, PriceFeedId};
  232. use pyth_lazer_protocol::symbol_state::SymbolState;
  233. use std::net::SocketAddr;
  234. fn gen_test_symbol(name: String, asset_type: String) -> SymbolMetadata {
  235. SymbolMetadata {
  236. pyth_lazer_id: PriceFeedId(1),
  237. name,
  238. symbol: "".to_string(),
  239. description: "".to_string(),
  240. asset_type,
  241. exponent: 0,
  242. cmc_id: None,
  243. funding_rate_interval: None,
  244. min_publishers: 0,
  245. min_channel: Channel::FixedRate(FixedRate::MIN),
  246. state: SymbolState::Stable,
  247. hermes_id: None,
  248. quote_currency: None,
  249. }
  250. }
  251. #[tokio::test]
  252. #[ignore]
  253. async fn test_try_get_metadata() {
  254. let config = Config {
  255. listen_address: SocketAddr::from(([127, 0, 0, 1], 0)),
  256. relayer_urls: vec![],
  257. authorization_token: None,
  258. publish_keypair_path: Default::default(),
  259. publish_interval_duration: Default::default(),
  260. history_service_url: None,
  261. };
  262. println!("{:?}", get_metadata(config).await.unwrap());
  263. }
  264. #[test]
  265. fn test_filter_symbols() {
  266. let symbol1 = gen_test_symbol("BTC".to_string(), "crypto".to_string());
  267. let symbol2 = gen_test_symbol("XMR".to_string(), "crypto".to_string());
  268. let symbol3 = gen_test_symbol("BTCUSDT".to_string(), "funding-rate".to_string());
  269. let symbols = vec![symbol1.clone(), symbol2.clone(), symbol3.clone()];
  270. // just a name filter
  271. assert_eq!(
  272. filter_symbols(
  273. symbols.clone(),
  274. GetMetadataParams {
  275. names: Some(vec!["XMR".to_string()]),
  276. asset_types: None,
  277. },
  278. ),
  279. vec![symbol2.clone()]
  280. );
  281. // just an asset type filter
  282. assert_eq!(
  283. filter_symbols(
  284. symbols.clone(),
  285. GetMetadataParams {
  286. names: None,
  287. asset_types: Some(vec!["crypto".to_string()]),
  288. },
  289. ),
  290. vec![symbol1.clone(), symbol2.clone()]
  291. );
  292. // name and asset type
  293. assert_eq!(
  294. filter_symbols(
  295. symbols.clone(),
  296. GetMetadataParams {
  297. names: Some(vec!["BTC".to_string()]),
  298. asset_types: Some(vec!["crypto".to_string()]),
  299. },
  300. ),
  301. vec![symbol1.clone()]
  302. );
  303. }
  304. }