Browse Source

[hermes] Add more rest api methods (#746)

* [hermes] Add more rest api methods

Add many of the price service apis. Per David suggestion, we do
validation in parsing instead of doing it later. I didn't find
any suitable library to deserialize our hex format so I created
a macro to implement it because we use it in a couple of places.
I tried making a generic HexInput but couldn't make it working
(and I need other crates like generic_array for it which makes
the code more complex)

* Address feedbacks
Ali Behjati 2 years ago
parent
commit
32596d5d4e
6 changed files with 188 additions and 56 deletions
  1. 5 4
      hermes/Cargo.lock
  2. 4 4
      hermes/Cargo.toml
  3. 41 0
      hermes/src/macros.rs
  4. 1 0
      hermes/src/main.rs
  5. 5 3
      hermes/src/network/rpc.rs
  6. 132 45
      hermes/src/network/rpc/rest.rs

+ 5 - 4
hermes/Cargo.lock

@@ -335,6 +335,7 @@ checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967"
 dependencies = [
  "async-trait",
  "axum-core",
+ "axum-macros",
  "base64 0.21.0",
  "bitflags",
  "bytes",
@@ -3466,9 +3467,9 @@ dependencies = [
 
 [[package]]
 name = "regex"
-version = "1.7.1"
+version = "1.7.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
+checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d"
 dependencies = [
  "aho-corasick",
  "memchr",
@@ -3483,9 +3484,9 @@ checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
 
 [[package]]
 name = "regex-syntax"
-version = "0.6.28"
+version = "0.6.29"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
+checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
 
 [[package]]
 name = "remove_dir_all"

+ 4 - 4
hermes/Cargo.toml

@@ -4,14 +4,16 @@ version                        = "0.1.0"
 edition                        = "2021"
 
 [dependencies]
-axum                           = { version = "0.6.9", features = ["json", "ws"] }
+axum                           = { version = "0.6.9", features = ["json", "ws", "macros"] }
 axum-extra                     = { version = "0.7.2", features = ["query"] }
 axum-macros                    = { version = "0.3.4" }
 anyhow                         = { version = "1.0.69" }
+base64                         = { version = "0.21.0" }
 borsh                          = { version = "0.9.0" }
 bs58                           = { version = "0.4.0" }
 dashmap                        = { version = "5.4.0" }
 der                            = { version = "0.7.0" }
+derive_more                    = { version = "0.99.17" }
 env_logger                     = { version = "0.10.0" }
 futures                        = { version = "0.3.26" }
 hex                            = { version = "0.4.3" }
@@ -26,7 +28,7 @@ secp256k1                      = { version = "0.26.0", features = ["rand", "reco
 serde                          = { version = "1.0.152", features = ["derive"] }
 serde_arrays                   = { version = "0.1.0" }
 serde_cbor                     = { version = "0.11.2" }
-serde_json                      = { version = "1.0.93" }
+serde_json                     = { version = "1.0.93" }
 sha256                         = { version = "1.1.2" }
 structopt                      = { version = "0.3.26" }
 tokio                          = { version = "1.26.0", features = ["full"] }
@@ -58,5 +60,3 @@ libp2p                         = { version = "0.51.1", features = [
     "websocket",
     "yamux",
 ]}
-base64 = "0.21.0"
-derive_more = "0.99.17"

+ 41 - 0
hermes/src/macros.rs

@@ -0,0 +1,41 @@
+#[macro_export]
+/// A macro that generates Deserialize from string for a struct S that wraps [u8; N] where N is a
+/// compile-time constant. This macro deserializes a string with or without leading 0x and supports
+/// both lower case and upper case hex characters.
+macro_rules! impl_deserialize_for_hex_string_wrapper {
+    ($struct_name:ident, $array_size:expr) => {
+        impl<'de> serde::Deserialize<'de> for $struct_name {
+            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct HexVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for HexVisitor {
+                    type Value = [u8; $array_size];
+
+                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+                        write!(formatter, "a hex string of length {}", $array_size * 2)
+                    }
+
+                    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        let s = s.trim_start_matches("0x");
+                        let bytes = hex::decode(s)
+                            .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(s), &self))?;
+                        if bytes.len() != $array_size {
+                            return Err(E::invalid_length(bytes.len(), &self));
+                        }
+                        let mut array = [0_u8; $array_size];
+                        array.copy_from_slice(&bytes);
+                        Ok(array)
+                    }
+                }
+
+                deserializer.deserialize_str(HexVisitor).map($struct_name)
+            }
+        }
+    };
+}

+ 1 - 0
hermes/src/main.rs

@@ -16,6 +16,7 @@ use {
 };
 
 mod config;
+mod macros;
 mod network;
 mod store;
 

+ 5 - 3
hermes/src/network/rpc.rs

@@ -28,7 +28,7 @@ impl State {
 
 /// This method provides a background service that responds to REST requests
 ///
-/// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
+/// Currently this is based on Axum due to the simplicity and strong ecosyjtem support for the
 /// packages they are based on (tokio & hyper).
 pub async fn spawn(rpc_addr: String, store: Store) -> Result<()> {
     let state = State::new(store);
@@ -39,8 +39,10 @@ pub async fn spawn(rpc_addr: String, store: Store) -> Result<()> {
     let app = app
         .route("/", get(rest::index))
         .route("/live", get(rest::live))
-        .route("/latest_price_feeds", get(rest::latest_price_feeds))
-        .route("/latest_vaas", get(rest::latest_vaas))
+        .route("/api/latest_price_feeds", get(rest::latest_price_feeds))
+        .route("/api/latest_vaas", get(rest::latest_vaas))
+        .route("/api/get_vaa", get(rest::get_vaa))
+        .route("/api/get_vaa_ccip", get(rest::get_vaa_ccip))
         .with_state(state.clone());
 
     // Listen in the background for new VAA's from the Wormhole RPC.

+ 132 - 45
hermes/src/network/rpc/rest.rs

@@ -1,18 +1,9 @@
 use {
     crate::store::RequestTime,
-    base64::{
-        engine::general_purpose::STANDARD as base64_standard_engine,
-        Engine as _,
-    },
-    pyth_sdk::{
-        PriceFeed,
-        PriceIdentifier,
+    crate::{
+        impl_deserialize_for_hex_string_wrapper,
+        store::UnixTimestamp,
     },
-};
-// This file implements a REST service for the Price Service. This is a mostly direct copy of the
-// TypeScript implementation in the `pyth-crosschain` repo. It uses `axum` as the web framework and
-// `tokio` as the async runtime.
-use {
     anyhow::Result,
     axum::{
         extract::State,
@@ -24,19 +15,38 @@ use {
         Json,
     },
     axum_extra::extract::Query, // Axum extra Query allows us to parse multi-value query parameters.
+    base64::{
+        engine::general_purpose::STANDARD as base64_standard_engine,
+        Engine as _,
+    },
+    derive_more::{
+        Deref,
+        DerefMut,
+    },
+    pyth_sdk::{
+        PriceFeed,
+        PriceIdentifier,
+    },
 };
 
+#[derive(Debug, Clone, Deref, DerefMut)]
+pub struct PriceIdInput([u8; 32]);
+// TODO: Use const generics instead of macro.
+impl_deserialize_for_hex_string_wrapper!(PriceIdInput, 32);
+
+impl From<PriceIdInput> for PriceIdentifier {
+    fn from(id: PriceIdInput) -> Self {
+        Self::new(*id)
+    }
+}
+
 pub enum RestError {
-    InvalidPriceId,
     UpdateDataNotFound,
 }
 
 impl IntoResponse for RestError {
     fn into_response(self) -> Response {
         match self {
-            RestError::InvalidPriceId => {
-                (StatusCode::BAD_REQUEST, "Invalid Price Id").into_response()
-            }
             RestError::UpdateDataNotFound => {
                 (StatusCode::NOT_FOUND, "Update data not found").into_response()
             }
@@ -44,27 +54,18 @@ impl IntoResponse for RestError {
     }
 }
 
-#[derive(Debug, serde::Serialize, serde::Deserialize)]
-pub struct LatestVaaQueryParams {
-    ids: Vec<String>,
+
+#[derive(Debug, serde::Deserialize)]
+pub struct LatestVaasQueryParams {
+    ids: Vec<PriceIdInput>,
 }
 
-/// REST endpoint /latest_vaas?ids[]=...&ids[]=...&ids[]=...
-///
-/// TODO: This endpoint returns update data as an array of base64 encoded strings. We want
-/// to support other formats such as hex in the future.
+
 pub async fn latest_vaas(
     State(state): State<super::State>,
-    Query(params): Query<LatestVaaQueryParams>,
+    Query(params): Query<LatestVaasQueryParams>,
 ) -> Result<Json<Vec<String>>, RestError> {
-    // TODO: Find better ways to validate query parameters.
-    // FIXME: Handle ids with leading 0x
-    let price_ids: Vec<PriceIdentifier> = params
-        .ids
-        .iter()
-        .map(PriceIdentifier::from_hex)
-        .collect::<Result<Vec<PriceIdentifier>, _>>()
-        .map_err(|_| RestError::InvalidPriceId)?;
+    let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
     let price_feeds_with_update_data = state
         .store
         .get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
@@ -74,27 +75,22 @@ pub async fn latest_vaas(
             .update_data
             .batch_vaa
             .iter()
-            .map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes))
+            .map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes)) // TODO: Support multiple
+            // encoding formats
             .collect(),
     ))
 }
 
-#[derive(Debug, serde::Serialize, serde::Deserialize)]
-pub struct LatestPriceFeedParams {
-    ids: Vec<String>,
+#[derive(Debug, serde::Deserialize)]
+pub struct LatestPriceFeedsQueryParams {
+    ids: Vec<PriceIdInput>,
 }
 
-/// REST endpoint /latest_vaas?ids[]=...&ids[]=...&ids[]=...
 pub async fn latest_price_feeds(
     State(state): State<super::State>,
-    Query(params): Query<LatestPriceFeedParams>,
+    Query(params): Query<LatestPriceFeedsQueryParams>,
 ) -> Result<Json<Vec<PriceFeed>>, RestError> {
-    let price_ids: Vec<PriceIdentifier> = params
-        .ids
-        .iter()
-        .map(PriceIdentifier::from_hex)
-        .collect::<Result<Vec<PriceIdentifier>, _>>()
-        .map_err(|_| RestError::InvalidPriceId)?;
+    let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
     let price_feeds_with_update_data = state
         .store
         .get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
@@ -107,6 +103,91 @@ pub async fn latest_price_feeds(
     ))
 }
 
+#[derive(Debug, serde::Deserialize)]
+pub struct GetVaaQueryParams {
+    id:           PriceIdInput,
+    publish_time: UnixTimestamp,
+}
+
+#[derive(Debug, serde::Serialize)]
+pub struct GetVaaResponse {
+    pub vaa:          String,
+    #[serde(rename = "publishTime")]
+    pub publish_time: UnixTimestamp,
+}
+
+pub async fn get_vaa(
+    State(state): State<super::State>,
+    Query(params): Query<GetVaaQueryParams>,
+) -> Result<Json<GetVaaResponse>, RestError> {
+    let price_id: PriceIdentifier = params.id.into();
+
+    let price_feeds_with_update_data = state
+        .store
+        .get_price_feeds_with_update_data(
+            vec![price_id],
+            RequestTime::FirstAfter(params.publish_time),
+        )
+        .map_err(|_| RestError::UpdateDataNotFound)?;
+
+    let vaa = price_feeds_with_update_data
+        .update_data
+        .batch_vaa
+        .get(0)
+        .map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes))
+        .ok_or(RestError::UpdateDataNotFound)?;
+
+    let publish_time = price_feeds_with_update_data
+        .price_feeds
+        .get(&price_id)
+        .map(|price_feed| price_feed.get_price_unchecked().publish_time)
+        .ok_or(RestError::UpdateDataNotFound)?;
+    let publish_time: UnixTimestamp = publish_time
+        .try_into()
+        .map_err(|_| RestError::UpdateDataNotFound)?;
+
+    Ok(Json(GetVaaResponse { vaa, publish_time }))
+}
+
+#[derive(Debug, Clone, Deref, DerefMut)]
+pub struct GetVaaCcipInput([u8; 40]);
+impl_deserialize_for_hex_string_wrapper!(GetVaaCcipInput, 40);
+
+#[derive(Debug, serde::Deserialize)]
+pub struct GetVaaCcipQueryParams {
+    data: GetVaaCcipInput,
+}
+
+#[derive(Debug, serde::Serialize)]
+pub struct GetVaaCcipResponse {
+    data: String, // TODO: Use a typed wrapper for the hex output with leading 0x.
+}
+
+pub async fn get_vaa_ccip(
+    State(state): State<super::State>,
+    Query(params): Query<GetVaaCcipQueryParams>,
+) -> Result<Json<GetVaaCcipResponse>, RestError> {
+    let price_id: PriceIdentifier = PriceIdentifier::new(params.data[0..32].try_into().unwrap());
+    let publish_time = UnixTimestamp::from_be_bytes(params.data[32..40].try_into().unwrap());
+
+    let price_feeds_with_update_data = state
+        .store
+        .get_price_feeds_with_update_data(vec![price_id], RequestTime::FirstAfter(publish_time))
+        .map_err(|_| RestError::UpdateDataNotFound)?;
+
+    let vaa = price_feeds_with_update_data
+        .update_data
+        .batch_vaa
+        .get(0) // One price feed has only a single VAA as proof.
+        .ok_or(RestError::UpdateDataNotFound)?;
+
+    // FIXME: We should return 5xx when the vaa is not found and 4xx when the price id is not there
+
+    Ok(Json(GetVaaCcipResponse {
+        data: format!("0x{}", hex::encode(vaa)),
+    }))
+}
+
 // This function implements the `/live` endpoint. It returns a `200` status code. This endpoint is
 // used by the Kubernetes liveness probe.
 pub async fn live() -> Result<impl IntoResponse, std::convert::Infallible> {
@@ -116,5 +197,11 @@ pub async fn live() -> Result<impl IntoResponse, std::convert::Infallible> {
 // This is the index page for the REST service. It will list all the available endpoints.
 // TODO: Dynamically generate this list if possible.
 pub async fn index() -> impl IntoResponse {
-    Json(["/live", "/latest_price_feeds", "/latest_vaas"])
+    Json([
+        "/live",
+        "/api/latest_price_feeds?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&..",
+        "/api/latest_vaas?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&...",
+        "/api/get_vaa?id=<price_feed_id>&publish_time=<publish_time_in_unix_timestamp>",
+        "/api/get_vaa_ccip?data=<0x<price_feed_id_32_bytes>+<publish_time_unix_timestamp_be_8_bytes>>",
+    ])
 }