Explorar el Código

feat(hermes): Add `ignore_invalid_price_ids` flag to Hermes v2 REST APIs (#2091)

* feat: add ignore_invalid_price_ids parameter to v2 apis

* update readme

* undo build hack

* better names

* apply precommit

* refactor: address PR comments

* test: add tests for validate_price_ids

* docs: address PR comments

* feat: include ignoreInvalidPriceIds flag in HermesClient

* fix: bump ver, address PR comments

* docs: update to reflect change from nightly to stable rust

* fix: semver
Tejas Badadare hace 1 año
padre
commit
059b7b99ac

+ 1 - 1
apps/hermes/client/js/package.json

@@ -1,6 +1,6 @@
 {
   "name": "@pythnetwork/hermes-client",
-  "version": "1.1.0",
+  "version": "1.2.0",
   "description": "Pyth Hermes Client",
   "author": {
     "name": "Pyth Data Association"

+ 16 - 7
apps/hermes/client/js/src/HermesClient.ts

@@ -157,6 +157,7 @@ export class HermesClient {
    * @param options Optional parameters:
    *        - encoding: Encoding type. If specified, return the price update in the encoding specified by the encoding parameter. Default is hex.
    *        - parsed: Boolean to specify if the parsed price update should be included in the response. Default is false.
+   *        - ignoreInvalidPriceIds: Boolean to specify if invalid price IDs should be ignored instead of returning an error. Default is false.
    *
    * @returns PriceUpdate object containing the latest updates.
    */
@@ -165,6 +166,7 @@ export class HermesClient {
     options?: {
       encoding?: EncodingType;
       parsed?: boolean;
+      ignoreInvalidPriceIds?: boolean;
     }
   ): Promise<PriceUpdate> {
     const url = new URL("v2/updates/price/latest", this.baseURL);
@@ -173,7 +175,8 @@ export class HermesClient {
     }
 
     if (options) {
-      this.appendUrlSearchParams(url, options);
+      const transformedOptions = camelToSnakeCaseObject(options);
+      this.appendUrlSearchParams(url, transformedOptions);
     }
 
     return this.httpRequest(url.toString(), schemas.PriceUpdate);
@@ -189,6 +192,7 @@ export class HermesClient {
    * @param options Optional parameters:
    *        - encoding: Encoding type. If specified, return the price update in the encoding specified by the encoding parameter. Default is hex.
    *        - parsed: Boolean to specify if the parsed price update should be included in the response. Default is false.
+   *        - ignoreInvalidPriceIds: Boolean to specify if invalid price IDs should be ignored instead of returning an error. Default is false.
    *
    * @returns PriceUpdate object containing the updates at the specified timestamp.
    */
@@ -198,6 +202,7 @@ export class HermesClient {
     options?: {
       encoding?: EncodingType;
       parsed?: boolean;
+      ignoreInvalidPriceIds?: boolean;
     }
   ): Promise<PriceUpdate> {
     const url = new URL(`v2/updates/price/${publishTime}`, this.baseURL);
@@ -206,7 +211,8 @@ export class HermesClient {
     }
 
     if (options) {
-      this.appendUrlSearchParams(url, options);
+      const transformedOptions = camelToSnakeCaseObject(options);
+      this.appendUrlSearchParams(url, transformedOptions);
     }
 
     return this.httpRequest(url.toString(), schemas.PriceUpdate);
@@ -219,12 +225,14 @@ export class HermesClient {
    * This will return an EventSource that can be used to listen to streaming updates.
    * If an invalid hex-encoded ID is passed, it will throw an error.
    *
-   *
    * @param ids Array of hex-encoded price feed IDs for which streaming updates are requested.
-   * @param encoding Optional encoding type. If specified, updates are returned in the specified encoding. Default is hex.
-   * @param parsed Optional boolean to specify if the parsed price update should be included in the response. Default is false.
-   * @param allow_unordered Optional boolean to specify if unordered updates are allowed to be included in the stream. Default is false.
-   * @param benchmarks_only Optional boolean to specify if only benchmark prices that are the initial price updates at a given timestamp (i.e., prevPubTime != pubTime) should be returned. Default is false.
+   * @param options Optional parameters:
+   *        - encoding: Encoding type. If specified, updates are returned in the specified encoding. Default is hex.
+   *        - parsed: Boolean to specify if the parsed price update should be included in the response. Default is false.
+   *        - allowUnordered: Boolean to specify if unordered updates are allowed to be included in the stream. Default is false.
+   *        - benchmarksOnly: Boolean to specify if only benchmark prices should be returned. Default is false.
+   *        - ignoreInvalidPriceIds: Boolean to specify if invalid price IDs should be ignored instead of returning an error. Default is false.
+   *
    * @returns An EventSource instance for receiving streaming updates.
    */
   async getPriceUpdatesStream(
@@ -234,6 +242,7 @@ export class HermesClient {
       parsed?: boolean;
       allowUnordered?: boolean;
       benchmarksOnly?: boolean;
+      ignoreInvalidPriceIds?: boolean;
     }
   ): Promise<EventSource> {
     const url = new URL("v2/updates/price/stream", this.baseURL);

+ 1 - 1
apps/hermes/server/Cargo.lock

@@ -1796,7 +1796,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
 
 [[package]]
 name = "hermes"
-version = "0.6.1"
+version = "0.7.0"
 dependencies = [
  "anyhow",
  "async-trait",

+ 1 - 1
apps/hermes/server/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name        = "hermes"
-version     = "0.6.1"
+version     = "0.7.0"
 description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
 edition     = "2021"
 

+ 1 - 1
apps/hermes/server/README.md

@@ -52,7 +52,7 @@ To set up and run a Hermes node, follow the steps below:
    can interact with the node using the REST and Websocket APIs on port 33999.
 
    For local development, you can also run the node with [cargo watch](https://crates.io/crates/cargo-watch) to restart
-   it automatically when the code changes:
+   it automatically when the code changes.
 
    ```bash
    cargo watch -w src -x "run -- run --pythnet-http-addr https://pythnet-rpc/ --pythnet-ws-addr wss://pythnet-rpc/ --wormhole-spy-rpc-addr https://wormhole-spy-rpc/

+ 166 - 11
apps/hermes/server/src/api/rest.rs

@@ -94,25 +94,180 @@ impl IntoResponse for RestError {
     }
 }
 
-/// Verify that the price ids exist in the aggregate state.
-pub async fn verify_price_ids_exist<S>(
+/// Validate that the passed in price_ids exist in the aggregate state. Return a Vec of valid price ids.
+/// # Returns
+/// If `remove_invalid` is true, invalid price ids are filtered out and only valid price ids are returned.
+/// If `remove_invalid` is false and any passed in IDs are invalid, an error is returned.
+pub async fn validate_price_ids<S>(
     state: &ApiState<S>,
     price_ids: &[PriceIdentifier],
-) -> Result<(), RestError>
+    remove_invalid: bool,
+) -> Result<Vec<PriceIdentifier>, RestError>
 where
     S: Aggregates,
 {
     let state = &*state.state;
-    let all_ids = Aggregates::get_price_feed_ids(state).await;
-    let missing_ids = price_ids
+    let available_ids = Aggregates::get_price_feed_ids(state).await;
+
+    // Partition into (valid_ids, invalid_ids)
+    let (valid_ids, invalid_ids): (Vec<_>, Vec<_>) = price_ids
         .iter()
-        .filter(|id| !all_ids.contains(id))
-        .cloned()
-        .collect::<Vec<_>>();
+        .copied()
+        .partition(|id| available_ids.contains(id));
+
+    if invalid_ids.is_empty() || remove_invalid {
+        // All IDs are valid
+        Ok(valid_ids)
+    } else {
+        // Return error with list of missing IDs
+        Err(RestError::PriceIdsNotFound {
+            missing_ids: invalid_ids,
+        })
+    }
+}
+#[cfg(test)]
+mod tests {
+    use {
+        super::*,
+        crate::state::{
+            aggregate::{
+                AggregationEvent,
+                PriceFeedsWithUpdateData,
+                PublisherStakeCapsWithUpdateData,
+                ReadinessMetadata,
+                RequestTime,
+                Update,
+            },
+            benchmarks::BenchmarksState,
+            cache::CacheState,
+            metrics::MetricsState,
+            price_feeds_metadata::PriceFeedMetaState,
+        },
+        anyhow::Result,
+        std::{
+            collections::HashSet,
+            sync::Arc,
+        },
+        tokio::sync::broadcast::Receiver,
+    };
 
-    if !missing_ids.is_empty() {
-        return Err(RestError::PriceIdsNotFound { missing_ids });
+    // Simplified mock that only contains what we need
+    struct MockAggregates {
+        available_ids: HashSet<PriceIdentifier>,
+    }
+
+    // Implement all required From traits with unimplemented!()
+    impl<'a> From<&'a MockAggregates> for &'a CacheState {
+        fn from(_: &'a MockAggregates) -> Self {
+            unimplemented!("Not needed for this test")
+        }
+    }
+
+    impl<'a> From<&'a MockAggregates> for &'a BenchmarksState {
+        fn from(_: &'a MockAggregates) -> Self {
+            unimplemented!("Not needed for this test")
+        }
     }
 
-    Ok(())
+    impl<'a> From<&'a MockAggregates> for &'a PriceFeedMetaState {
+        fn from(_: &'a MockAggregates) -> Self {
+            unimplemented!("Not needed for this test")
+        }
+    }
+
+    impl<'a> From<&'a MockAggregates> for &'a MetricsState {
+        fn from(_: &'a MockAggregates) -> Self {
+            unimplemented!("Not needed for this test")
+        }
+    }
+
+    #[async_trait::async_trait]
+    impl Aggregates for MockAggregates {
+        async fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier> {
+            self.available_ids.clone()
+        }
+
+        fn subscribe(&self) -> Receiver<AggregationEvent> {
+            unimplemented!("Not needed for this test")
+        }
+
+        async fn is_ready(&self) -> (bool, ReadinessMetadata) {
+            unimplemented!("Not needed for this test")
+        }
+
+        async fn store_update(&self, _update: Update) -> Result<()> {
+            unimplemented!("Not needed for this test")
+        }
+
+        async fn get_price_feeds_with_update_data(
+            &self,
+            _price_ids: &[PriceIdentifier],
+            _request_time: RequestTime,
+        ) -> Result<PriceFeedsWithUpdateData> {
+            unimplemented!("Not needed for this test")
+        }
+
+        async fn get_latest_publisher_stake_caps_with_update_data(
+            &self,
+        ) -> Result<PublisherStakeCapsWithUpdateData> {
+            unimplemented!("Not needed for this test")
+        }
+    }
+
+    #[tokio::test]
+    async fn validate_price_ids_accepts_all_valid_ids() {
+        let id1 = PriceIdentifier::new([1; 32]);
+        let id2 = PriceIdentifier::new([2; 32]);
+
+        let mut available_ids = HashSet::new();
+        available_ids.insert(id1);
+        available_ids.insert(id2);
+
+        let mock_state = MockAggregates { available_ids };
+        let api_state = ApiState::new(Arc::new(mock_state), vec![], String::new());
+
+        let input_ids = vec![id1, id2];
+        let result = validate_price_ids(&api_state, &input_ids, false).await;
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap(), input_ids);
+    }
+
+    #[tokio::test]
+    async fn validate_price_ids_removes_invalid_ids_when_requested() {
+        let id1 = PriceIdentifier::new([1; 32]);
+        let id2 = PriceIdentifier::new([2; 32]);
+        let id3 = PriceIdentifier::new([3; 32]);
+
+        let mut available_ids = HashSet::new();
+        available_ids.insert(id1);
+        available_ids.insert(id2);
+
+        let mock_state = MockAggregates { available_ids };
+        let api_state = ApiState::new(Arc::new(mock_state), vec![], String::new());
+
+        let input_ids = vec![id1, id2, id3];
+        let result = validate_price_ids(&api_state, &input_ids, true).await;
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap(), vec![id1, id2]);
+    }
+
+    #[tokio::test]
+    async fn validate_price_ids_errors_on_invalid_ids() {
+        let id1 = PriceIdentifier::new([1; 32]);
+        let id2 = PriceIdentifier::new([2; 32]);
+        let id3 = PriceIdentifier::new([3; 32]);
+
+        let mut available_ids = HashSet::new();
+        available_ids.insert(id1);
+        available_ids.insert(id2);
+
+        let mock_state = MockAggregates { available_ids };
+        let api_state = ApiState::new(Arc::new(mock_state), vec![], String::new());
+
+        let input_ids = vec![id1, id2, id3];
+        let result = validate_price_ids(&api_state, &input_ids, false).await;
+        assert!(
+            matches!(result, Err(RestError::PriceIdsNotFound { missing_ids }) if missing_ids == vec![id3])
+        );
+    }
 }

+ 2 - 2
apps/hermes/server/src/api/rest/get_price_feed.rs

@@ -1,5 +1,5 @@
 use {
-    super::verify_price_ids_exist,
+    super::validate_price_ids,
     crate::{
         api::{
             doc_examples,
@@ -73,7 +73,7 @@ where
     S: Aggregates,
 {
     let price_id: PriceIdentifier = params.id.into();
-    verify_price_ids_exist(&state, &[price_id]).await?;
+    validate_price_ids(&state, &[price_id], false).await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(

+ 2 - 2
apps/hermes/server/src/api/rest/get_vaa.rs

@@ -1,5 +1,5 @@
 use {
-    super::verify_price_ids_exist,
+    super::validate_price_ids,
     crate::{
         api::{
             doc_examples,
@@ -80,7 +80,7 @@ where
     S: Aggregates,
 {
     let price_id: PriceIdentifier = params.id.into();
-    verify_price_ids_exist(&state, &[price_id]).await?;
+    validate_price_ids(&state, &[price_id], false).await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(

+ 2 - 2
apps/hermes/server/src/api/rest/get_vaa_ccip.rs

@@ -1,5 +1,5 @@
 use {
-    super::verify_price_ids_exist,
+    super::validate_price_ids,
     crate::{
         api::{
             rest::RestError,
@@ -75,7 +75,7 @@ where
             .try_into()
             .map_err(|_| RestError::InvalidCCIPInput)?,
     );
-    verify_price_ids_exist(&state, &[price_id]).await?;
+    validate_price_ids(&state, &[price_id], false).await?;
 
     let publish_time = UnixTimestamp::from_be_bytes(
         params.data[32..40]

+ 2 - 2
apps/hermes/server/src/api/rest/latest_price_feeds.rs

@@ -1,5 +1,5 @@
 use {
-    super::verify_price_ids_exist,
+    super::validate_price_ids,
     crate::{
         api::{
             rest::RestError,
@@ -74,7 +74,7 @@ where
     S: Aggregates,
 {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-    verify_price_ids_exist(&state, &price_ids).await?;
+    validate_price_ids(&state, &price_ids, false).await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data =

+ 2 - 2
apps/hermes/server/src/api/rest/latest_vaas.rs

@@ -1,5 +1,5 @@
 use {
-    super::verify_price_ids_exist,
+    super::validate_price_ids,
     crate::{
         api::{
             doc_examples,
@@ -69,7 +69,7 @@ where
     S: Aggregates,
 {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-    verify_price_ids_exist(&state, &price_ids).await?;
+    validate_price_ids(&state, &price_ids, false).await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data =

+ 9 - 3
apps/hermes/server/src/api/rest/v2/latest_price_updates.rs

@@ -2,7 +2,7 @@ use {
     crate::{
         api::{
             rest::{
-                verify_price_ids_exist,
+                validate_price_ids,
                 RestError,
             },
             types::{
@@ -57,6 +57,10 @@ pub struct LatestPriceUpdatesQueryParams {
     /// If true, include the parsed price update in the `parsed` field of each returned feed. Default is `true`.
     #[serde(default = "default_true")]
     parsed: bool,
+
+    /// If true, invalid price IDs in the `ids` parameter are ignored. Only applicable to the v2 APIs. Default is `false`.
+    #[serde(default)]
+    ignore_invalid_price_ids: bool,
 }
 
 fn default_true() -> bool {
@@ -84,8 +88,10 @@ pub async fn latest_price_updates<S>(
 where
     S: Aggregates,
 {
-    let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-    verify_price_ids_exist(&state, &price_ids).await?;
+    let price_id_inputs: Vec<PriceIdentifier> =
+        params.ids.into_iter().map(|id| id.into()).collect();
+    let price_ids: Vec<PriceIdentifier> =
+        validate_price_ids(&state, &price_id_inputs, params.ignore_invalid_price_ids).await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data =

+ 7 - 4
apps/hermes/server/src/api/rest/v2/sse.rs

@@ -2,7 +2,7 @@ use {
     crate::{
         api::{
             rest::{
-                verify_price_ids_exist,
+                validate_price_ids,
                 RestError,
             },
             types::{
@@ -73,6 +73,9 @@ pub struct StreamPriceUpdatesQueryParams {
     /// If true, only include benchmark prices that are the initial price updates at a given timestamp (i.e., prevPubTime != pubTime).
     #[serde(default)]
     benchmarks_only: bool,
+
+    /// If true, invalid price IDs in the `ids` parameter are ignored. Only applicable to the v2 APIs. Default is `false`.    #[serde(default)]
+    ignore_invalid_price_ids: bool,
 }
 
 fn default_true() -> bool {
@@ -97,9 +100,9 @@ where
     S: Aggregates,
     S: Send + Sync + 'static,
 {
-    let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
-
-    verify_price_ids_exist(&state, &price_ids).await?;
+    let price_id_inputs: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
+    let price_ids: Vec<PriceIdentifier> =
+        validate_price_ids(&state, &price_id_inputs, params.ignore_invalid_price_ids).await?;
 
     // Clone the update_tx receiver to listen for new price updates
     let update_rx: broadcast::Receiver<AggregationEvent> = Aggregates::subscribe(&*state.state);

+ 12 - 4
apps/hermes/server/src/api/rest/v2/timestamp_price_updates.rs

@@ -3,7 +3,7 @@ use {
         api::{
             doc_examples,
             rest::{
-                verify_price_ids_exist,
+                validate_price_ids,
                 RestError,
             },
             types::{
@@ -67,6 +67,10 @@ pub struct TimestampPriceUpdatesQueryParams {
     /// If true, include the parsed price update in the `parsed` field of each returned feed. Default is `true`.
     #[serde(default = "default_true")]
     parsed: bool,
+
+    /// If true, invalid price IDs in the `ids` parameter are ignored. Only applicable to the v2 APIs. Default is `false`.
+    #[serde(default)]
+    ignore_invalid_price_ids: bool,
 }
 
 
@@ -97,10 +101,14 @@ pub async fn timestamp_price_updates<S>(
 where
     S: Aggregates,
 {
-    let price_ids: Vec<PriceIdentifier> =
+    let price_id_inputs: Vec<PriceIdentifier> =
         query_params.ids.into_iter().map(|id| id.into()).collect();
-
-    verify_price_ids_exist(&state, &price_ids).await?;
+    let price_ids: Vec<PriceIdentifier> = validate_price_ids(
+        &state,
+        &price_id_inputs,
+        query_params.ignore_invalid_price_ids,
+    )
+    .await?;
 
     let state = &*state.state;
     let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(