Explorar o código

feat(hermes): Check price ids exist before each request

This check will make rejects faster (and block invalid requests
to benchmarks). The other benefit is that we can log the
errors from the get_price_feeds_with_update_data since it should
not fail anymore.
Ali Behjati %!s(int64=2) %!d(string=hai) anos
pai
achega
9714a851eb

+ 1 - 1
hermes/Cargo.lock

@@ -1858,7 +1858,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
 
 [[package]]
 name = "hermes"
-version = "0.1.21"
+version = "0.1.22"
 dependencies = [
  "anyhow",
  "async-trait",

+ 1 - 1
hermes/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name        = "hermes"
-version     = "0.1.21"
+version     = "0.1.22"
 description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
 edition     = "2021"
 

+ 8 - 8
hermes/src/aggregate.rs

@@ -313,7 +313,7 @@ async fn build_message_states(
 
 async fn get_verified_price_feeds<S>(
     state: &S,
-    price_ids: Vec<PriceIdentifier>,
+    price_ids: &[PriceIdentifier],
     request_time: RequestTime,
 ) -> Result<PriceFeedsWithUpdateData>
 where
@@ -373,14 +373,14 @@ where
 
 pub async fn get_price_feeds_with_update_data<S>(
     state: &S,
-    price_ids: Vec<PriceIdentifier>,
+    price_ids: &[PriceIdentifier],
     request_time: RequestTime,
 ) -> Result<PriceFeedsWithUpdateData>
 where
     S: AggregateCache,
     S: Benchmarks,
 {
-    match get_verified_price_feeds(state, price_ids.clone(), request_time.clone()).await {
+    match get_verified_price_feeds(state, price_ids, request_time.clone()).await {
         Ok(price_feeds_with_update_data) => Ok(price_feeds_with_update_data),
         Err(e) => {
             if let RequestTime::FirstAfter(publish_time) = request_time {
@@ -567,7 +567,7 @@ mod test {
         // price feed with correct update data.
         let price_feeds_with_update_data = get_price_feeds_with_update_data(
             &*state,
-            vec![PriceIdentifier::new([100; 32])],
+            &[PriceIdentifier::new([100; 32])],
             RequestTime::Latest,
         )
         .await
@@ -688,7 +688,7 @@ mod test {
         // Get the price feeds with update data
         let price_feeds_with_update_data = get_price_feeds_with_update_data(
             &*state,
-            vec![PriceIdentifier::new([100; 32])],
+            &[PriceIdentifier::new([100; 32])],
             RequestTime::Latest,
         )
         .await
@@ -753,7 +753,7 @@ mod test {
         for slot in 900..1000 {
             let price_feeds_with_update_data = get_price_feeds_with_update_data(
                 &*state,
-                vec![
+                &[
                     PriceIdentifier::new([100; 32]),
                     PriceIdentifier::new([200; 32]),
                 ],
@@ -770,9 +770,9 @@ mod test {
         for slot in 0..900 {
             assert!(get_price_feeds_with_update_data(
                 &*state,
-                vec![
+                &[
                     PriceIdentifier::new([100; 32]),
-                    PriceIdentifier::new([200; 32]),
+                    PriceIdentifier::new([200; 32])
                 ],
                 RequestTime::FirstAfter(slot as i64),
             )

+ 42 - 5
hermes/src/api/rest.rs

@@ -1,9 +1,13 @@
-use axum::{
-    http::StatusCode,
-    response::{
-        IntoResponse,
-        Response,
+use {
+    super::ApiState,
+    axum::{
+        http::StatusCode,
+        response::{
+            IntoResponse,
+            Response,
+        },
     },
+    pyth_sdk::PriceIdentifier,
 };
 
 mod get_price_feed;
@@ -32,6 +36,7 @@ pub enum RestError {
     UpdateDataNotFound,
     CcipUpdateDataNotFound,
     InvalidCCIPInput,
+    PriceIdsNotFound { missing_ids: Vec<PriceIdentifier> },
 }
 
 impl IntoResponse for RestError {
@@ -53,6 +58,38 @@ impl IntoResponse for RestError {
             RestError::InvalidCCIPInput => {
                 (StatusCode::BAD_REQUEST, "Invalid CCIP input").into_response()
             }
+            RestError::PriceIdsNotFound { missing_ids } => {
+                let missing_ids = missing_ids
+                    .into_iter()
+                    .map(|id| id.to_string())
+                    .collect::<Vec<_>>()
+                    .join(", ");
+
+                (
+                    StatusCode::NOT_FOUND,
+                    format!("Price ids not found: {}", missing_ids),
+                )
+                    .into_response()
+            }
         }
     }
 }
+
+/// Verify that the price ids exist in the aggregate state.
+pub async fn verify_price_ids_exist(
+    state: &ApiState,
+    price_ids: &[PriceIdentifier],
+) -> Result<(), RestError> {
+    let all_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;
+    let missing_ids = price_ids
+        .iter()
+        .filter(|id| !all_ids.contains(id))
+        .cloned()
+        .collect::<Vec<_>>();
+
+    if !missing_ids.is_empty() {
+        return Err(RestError::PriceIdsNotFound { missing_ids });
+    }
+
+    Ok(())
+}

+ 12 - 2
hermes/src/api/rest/get_price_feed.rs

@@ -1,4 +1,5 @@
 use {
+    super::verify_price_ids_exist,
     crate::{
         aggregate::{
             RequestTime,
@@ -65,13 +66,22 @@ pub async fn get_price_feed(
 ) -> Result<Json<RpcPriceFeed>, RestError> {
     let price_id: PriceIdentifier = params.id.into();
 
+    verify_price_ids_exist(&state, &[price_id]).await?;
+
     let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
         &*state.state,
-        vec![price_id],
+        &[price_id],
         RequestTime::FirstAfter(params.publish_time),
     )
     .await
-    .map_err(|_| RestError::UpdateDataNotFound)?;
+    .map_err(|e| {
+        tracing::warn!(
+            "Error getting price feed {:?} with update data: {:?}",
+            price_id,
+            e
+        );
+        RestError::UpdateDataNotFound
+    })?;
 
     let mut price_feed = price_feeds_with_update_data
         .price_feeds

+ 12 - 2
hermes/src/api/rest/get_vaa.rs

@@ -1,4 +1,5 @@
 use {
+    super::verify_price_ids_exist,
     crate::{
         aggregate::{
             get_price_feeds_with_update_data,
@@ -73,13 +74,22 @@ pub async fn get_vaa(
 ) -> Result<Json<GetVaaResponse>, RestError> {
     let price_id: PriceIdentifier = params.id.into();
 
+    verify_price_ids_exist(&state, &[price_id]).await?;
+
     let price_feeds_with_update_data = get_price_feeds_with_update_data(
         &*state.state,
-        vec![price_id],
+        &[price_id],
         RequestTime::FirstAfter(params.publish_time),
     )
     .await
-    .map_err(|_| RestError::UpdateDataNotFound)?;
+    .map_err(|e| {
+        tracing::warn!(
+            "Error getting price feed {:?} with update data: {:?}",
+            price_id,
+            e
+        );
+        RestError::UpdateDataNotFound
+    })?;
 
     let vaa = price_feeds_with_update_data
         .update_data

+ 12 - 2
hermes/src/api/rest/get_vaa_ccip.rs

@@ -1,4 +1,5 @@
 use {
+    super::verify_price_ids_exist,
     crate::{
         aggregate::{
             RequestTime,
@@ -70,13 +71,22 @@ pub async fn get_vaa_ccip(
             .map_err(|_| RestError::InvalidCCIPInput)?,
     );
 
+    verify_price_ids_exist(&state, &[price_id]).await?;
+
     let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
         &*state.state,
-        vec![price_id],
+        &[price_id],
         RequestTime::FirstAfter(publish_time),
     )
     .await
-    .map_err(|_| RestError::CcipUpdateDataNotFound)?;
+    .map_err(|e| {
+        tracing::warn!(
+            "Error getting price feed {:?} with update data: {:?}",
+            price_id,
+            e
+        );
+        RestError::CcipUpdateDataNotFound
+    })?;
 
     let bytes = price_feeds_with_update_data
         .update_data

+ 13 - 2
hermes/src/api/rest/latest_price_feeds.rs

@@ -1,4 +1,5 @@
 use {
+    super::verify_price_ids_exist,
     crate::{
         aggregate::RequestTime,
         api::{
@@ -63,13 +64,23 @@ pub async fn latest_price_feeds(
     QsQuery(params): QsQuery<LatestPriceFeedsQueryParams>,
 ) -> Result<Json<Vec<RpcPriceFeed>>, RestError> {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
+
+    verify_price_ids_exist(&state, &price_ids).await?;
+
     let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
         &*state.state,
-        price_ids,
+        &price_ids,
         RequestTime::Latest,
     )
     .await
-    .map_err(|_| RestError::UpdateDataNotFound)?;
+    .map_err(|e| {
+        tracing::warn!(
+            "Error getting price feeds {:?} with update data: {:?}",
+            price_ids,
+            e
+        );
+        RestError::UpdateDataNotFound
+    })?;
 
     Ok(Json(
         price_feeds_with_update_data

+ 13 - 2
hermes/src/api/rest/latest_vaas.rs

@@ -1,4 +1,5 @@
 use {
+    super::verify_price_ids_exist,
     crate::{
         aggregate::RequestTime,
         api::{
@@ -58,13 +59,23 @@ pub async fn latest_vaas(
     QsQuery(params): QsQuery<LatestVaasQueryParams>,
 ) -> Result<Json<Vec<String>>, RestError> {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
+
+    verify_price_ids_exist(&state, &price_ids).await?;
+
     let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
         &*state.state,
-        price_ids,
+        &price_ids,
         RequestTime::Latest,
     )
     .await
-    .map_err(|_| RestError::UpdateDataNotFound)?;
+    .map_err(|e| {
+        tracing::warn!(
+            "Error getting price feeds {:?} with update data: {:?}",
+            price_ids,
+            e
+        );
+        RestError::UpdateDataNotFound
+    })?;
 
     Ok(Json(
         price_feeds_with_update_data

+ 15 - 3
hermes/src/api/ws.rs

@@ -209,13 +209,25 @@ impl Subscriber {
     }
 
     async fn handle_price_feeds_update(&mut self, event: AggregationEvent) -> Result<()> {
-        let price_feed_ids = self.price_feeds_with_config.keys().cloned().collect();
+        let price_feed_ids = self
+            .price_feeds_with_config
+            .keys()
+            .cloned()
+            .collect::<Vec<_>>();
         for update in crate::aggregate::get_price_feeds_with_update_data(
             &*self.store,
-            price_feed_ids,
+            &price_feed_ids,
             RequestTime::AtSlot(event.slot()),
         )
-        .await?
+        .await
+        .map_err(|e| {
+            tracing::warn!(
+                "Failed to get price feeds {:?} with update data: {:?}",
+                price_feed_ids,
+                e
+            );
+            e
+        })?
         .price_feeds
         {
             let config = self

+ 2 - 2
hermes/src/state/benchmarks.rs

@@ -80,7 +80,7 @@ impl TryFrom<BenchmarkUpdates> for PriceFeedsWithUpdateData {
 pub trait Benchmarks {
     async fn get_verified_price_feeds(
         &self,
-        price_ids: Vec<PriceIdentifier>,
+        price_ids: &[PriceIdentifier],
         publish_time: UnixTimestamp,
     ) -> Result<PriceFeedsWithUpdateData>;
 }
@@ -89,7 +89,7 @@ pub trait Benchmarks {
 impl Benchmarks for crate::state::State {
     async fn get_verified_price_feeds(
         &self,
-        price_ids: Vec<PriceIdentifier>,
+        price_ids: &[PriceIdentifier],
         publish_time: UnixTimestamp,
     ) -> Result<PriceFeedsWithUpdateData> {
         let endpoint = self