Ver Fonte

refactor(hermes): state->aggregate downcasting (#1479)

Reisen há 1 ano atrás
pai
commit
8b76d8c19a

+ 8 - 19
apps/hermes/src/api.rs

@@ -1,6 +1,5 @@
 use {
     crate::{
-        aggregate::AggregationEvent,
         config::RunOptions,
         state::State,
     },
@@ -14,7 +13,6 @@ use {
     ipnet::IpNet,
     serde_qs::axum::QsQueryConfig,
     std::sync::Arc,
-    tokio::sync::broadcast::Sender,
     tower_http::cors::CorsLayer,
     utoipa::OpenApi,
     utoipa_swagger_ui::SwaggerUi,
@@ -27,10 +25,9 @@ pub mod types;
 mod ws;
 
 pub struct ApiState<S = State> {
-    pub state:     Arc<S>,
-    pub ws:        Arc<ws::WsState>,
-    pub metrics:   Arc<metrics_middleware::Metrics>,
-    pub update_tx: Sender<AggregationEvent>,
+    pub state:   Arc<S>,
+    pub ws:      Arc<ws::WsState>,
+    pub metrics: Arc<metrics_middleware::Metrics>,
 }
 
 /// Manually implement `Clone` as the derive macro will try and slap `Clone` on
@@ -38,10 +35,9 @@ pub struct ApiState<S = State> {
 impl<S> Clone for ApiState<S> {
     fn clone(&self) -> Self {
         Self {
-            state:     self.state.clone(),
-            ws:        self.ws.clone(),
-            metrics:   self.metrics.clone(),
-            update_tx: self.update_tx.clone(),
+            state:   self.state.clone(),
+            ws:      self.ws.clone(),
+            metrics: self.metrics.clone(),
         }
     }
 }
@@ -51,7 +47,6 @@ impl ApiState<State> {
         state: Arc<State>,
         ws_whitelist: Vec<IpNet>,
         requester_ip_header_name: String,
-        update_tx: Sender<AggregationEvent>,
     ) -> Self {
         Self {
             metrics: Arc::new(metrics_middleware::Metrics::new(state.clone())),
@@ -61,24 +56,18 @@ impl ApiState<State> {
                 state.clone(),
             )),
             state,
-            update_tx,
         }
     }
 }
 
-#[tracing::instrument(skip(opts, state, update_tx))]
-pub async fn spawn(
-    opts: RunOptions,
-    state: Arc<State>,
-    update_tx: Sender<AggregationEvent>,
-) -> Result<()> {
+#[tracing::instrument(skip(opts, state))]
+pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
     let state = {
         let opts = opts.clone();
         ApiState::new(
             state,
             opts.rpc.ws_whitelist,
             opts.rpc.requester_ip_header_name,
-            update_tx,
         )
     };
 

+ 1 - 1
apps/hermes/src/api/doc_examples.rs

@@ -1,4 +1,4 @@
-use crate::aggregate::UnixTimestamp;
+use crate::state::aggregate::UnixTimestamp;
 
 // Example values for the utoipa API docs.
 // Note that each of these expressions is only evaluated once when the documentation is created,

+ 9 - 4
apps/hermes/src/api/rest.rs

@@ -1,5 +1,6 @@
 use {
     super::ApiState,
+    crate::state::aggregate::Aggregates,
     axum::{
         http::StatusCode,
         response::{
@@ -93,11 +94,15 @@ impl IntoResponse for RestError {
 }
 
 /// Verify that the price ids exist in the aggregate state.
-pub async fn verify_price_ids_exist(
-    state: &ApiState,
+pub async fn verify_price_ids_exist<S>(
+    state: &ApiState<S>,
     price_ids: &[PriceIdentifier],
-) -> Result<(), RestError> {
-    let all_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;
+) -> Result<(), RestError>
+where
+    S: Aggregates,
+{
+    let state = &*state.state;
+    let all_ids = Aggregates::get_price_feed_ids(state).await;
     let missing_ids = price_ids
         .iter()
         .filter(|id| !all_ids.contains(id))

+ 15 - 10
apps/hermes/src/api/rest/get_price_feed.rs

@@ -1,10 +1,6 @@
 use {
     super::verify_price_ids_exist,
     crate::{
-        aggregate::{
-            RequestTime,
-            UnixTimestamp,
-        },
         api::{
             doc_examples,
             rest::RestError,
@@ -12,6 +8,12 @@ use {
                 PriceIdInput,
                 RpcPriceFeed,
             },
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
+            UnixTimestamp,
         },
     },
     anyhow::Result,
@@ -60,16 +62,19 @@ pub struct GetPriceFeedQueryParams {
         GetPriceFeedQueryParams
     )
 )]
-pub async fn get_price_feed(
-    State(state): State<crate::api::ApiState>,
+pub async fn get_price_feed<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<GetPriceFeedQueryParams>,
-) -> Result<Json<RpcPriceFeed>, RestError> {
+) -> Result<Json<RpcPriceFeed>, RestError>
+where
+    S: Aggregates,
+{
     let price_id: PriceIdentifier = params.id.into();
-
     verify_price_ids_exist(&state, &[price_id]).await?;
 
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
+    let state = &*state.state;
+    let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
+        state,
         &[price_id],
         RequestTime::FirstAfter(params.publish_time),
     )

+ 15 - 11
apps/hermes/src/api/rest/get_vaa.rs

@@ -1,15 +1,16 @@
 use {
     super::verify_price_ids_exist,
     crate::{
-        aggregate::{
-            get_price_feeds_with_update_data,
-            RequestTime,
-            UnixTimestamp,
-        },
         api::{
             doc_examples,
             rest::RestError,
             types::PriceIdInput,
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
+            UnixTimestamp,
         },
     },
     anyhow::Result,
@@ -68,16 +69,19 @@ pub struct GetVaaResponse {
         GetVaaQueryParams
     )
 )]
-pub async fn get_vaa(
-    State(state): State<crate::api::ApiState>,
+pub async fn get_vaa<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<GetVaaQueryParams>,
-) -> Result<Json<GetVaaResponse>, RestError> {
+) -> Result<Json<GetVaaResponse>, RestError>
+where
+    S: Aggregates,
+{
     let price_id: PriceIdentifier = params.id.into();
-
     verify_price_ids_exist(&state, &[price_id]).await?;
 
-    let price_feeds_with_update_data = get_price_feeds_with_update_data(
-        &*state.state,
+    let state = &*state.state;
+    let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
+        state,
         &[price_id],
         RequestTime::FirstAfter(params.publish_time),
     )

+ 17 - 9
apps/hermes/src/api/rest/get_vaa_ccip.rs

@@ -1,11 +1,15 @@
 use {
     super::verify_price_ids_exist,
     crate::{
-        aggregate::{
+        api::{
+            rest::RestError,
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
             RequestTime,
             UnixTimestamp,
         },
-        api::rest::RestError,
     },
     anyhow::Result,
     axum::{
@@ -56,25 +60,29 @@ pub struct GetVaaCcipResponse {
         GetVaaCcipQueryParams
     )
 )]
-pub async fn get_vaa_ccip(
-    State(state): State<crate::api::ApiState>,
+pub async fn get_vaa_ccip<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<GetVaaCcipQueryParams>,
-) -> Result<Json<GetVaaCcipResponse>, RestError> {
+) -> Result<Json<GetVaaCcipResponse>, RestError>
+where
+    S: Aggregates,
+{
     let price_id: PriceIdentifier = PriceIdentifier::new(
         params.data[0..32]
             .try_into()
             .map_err(|_| RestError::InvalidCCIPInput)?,
     );
+    verify_price_ids_exist(&state, &[price_id]).await?;
+
     let publish_time = UnixTimestamp::from_be_bytes(
         params.data[32..40]
             .try_into()
             .map_err(|_| RestError::InvalidCCIPInput)?,
     );
 
-    verify_price_ids_exist(&state, &[price_id]).await?;
-
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
+    let state = &*state.state;
+    let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
+        state,
         &[price_id],
         RequestTime::FirstAfter(publish_time),
     )

+ 23 - 19
apps/hermes/src/api/rest/latest_price_feeds.rs

@@ -1,13 +1,17 @@
 use {
     super::verify_price_ids_exist,
     crate::{
-        aggregate::RequestTime,
         api::{
             rest::RestError,
             types::{
                 PriceIdInput,
                 RpcPriceFeed,
             },
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
         },
     },
     anyhow::Result,
@@ -59,28 +63,28 @@ pub struct LatestPriceFeedsQueryParams {
         LatestPriceFeedsQueryParams
     )
 )]
-pub async fn latest_price_feeds(
-    State(state): State<crate::api::ApiState>,
+pub async fn latest_price_feeds<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<LatestPriceFeedsQueryParams>,
-) -> Result<Json<Vec<RpcPriceFeed>>, RestError> {
+) -> Result<Json<Vec<RpcPriceFeed>>, RestError>
+where
+    S: Aggregates,
+{
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-
     verify_price_ids_exist(&state, &price_ids).await?;
 
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
-        &price_ids,
-        RequestTime::Latest,
-    )
-    .await
-    .map_err(|e| {
-        tracing::warn!(
-            "Error getting price feeds {:?} with update data: {:?}",
-            price_ids,
-            e
-        );
-        RestError::UpdateDataNotFound
-    })?;
+    let state = &*state.state;
+    let price_feeds_with_update_data =
+        Aggregates::get_price_feeds_with_update_data(state, &price_ids, RequestTime::Latest)
+            .await
+            .map_err(|e| {
+                tracing::warn!(
+                    "Error getting price feeds {:?} with update data: {:?}",
+                    price_ids,
+                    e
+                );
+                RestError::UpdateDataNotFound
+            })?;
 
     Ok(Json(
         price_feeds_with_update_data

+ 23 - 19
apps/hermes/src/api/rest/latest_vaas.rs

@@ -1,11 +1,15 @@
 use {
     super::verify_price_ids_exist,
     crate::{
-        aggregate::RequestTime,
         api::{
             doc_examples,
             rest::RestError,
             types::PriceIdInput,
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
         },
     },
     anyhow::Result,
@@ -54,28 +58,28 @@ pub struct LatestVaasQueryParams {
         (status = 200, description = "VAAs retrieved successfully", body = Vec<String>, example=json!([doc_examples::vaa_example()]))
     ),
 )]
-pub async fn latest_vaas(
-    State(state): State<crate::api::ApiState>,
+pub async fn latest_vaas<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<LatestVaasQueryParams>,
-) -> Result<Json<Vec<String>>, RestError> {
+) -> Result<Json<Vec<String>>, RestError>
+where
+    S: Aggregates,
+{
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-
     verify_price_ids_exist(&state, &price_ids).await?;
 
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
-        &price_ids,
-        RequestTime::Latest,
-    )
-    .await
-    .map_err(|e| {
-        tracing::warn!(
-            "Error getting price feeds {:?} with update data: {:?}",
-            price_ids,
-            e
-        );
-        RestError::UpdateDataNotFound
-    })?;
+    let state = &*state.state;
+    let price_feeds_with_update_data =
+        Aggregates::get_price_feeds_with_update_data(state, &price_ids, RequestTime::Latest)
+            .await
+            .map_err(|e| {
+                tracing::warn!(
+                    "Error getting price feeds {:?} with update data: {:?}",
+                    price_ids,
+                    e
+                );
+                RestError::UpdateDataNotFound
+            })?;
 
     Ok(Json(
         price_feeds_with_update_data

+ 15 - 7
apps/hermes/src/api/rest/price_feed_ids.rs

@@ -1,7 +1,11 @@
 use {
-    crate::api::{
-        rest::RestError,
-        types::RpcPriceIdentifier,
+    crate::{
+        api::{
+            rest::RestError,
+            types::RpcPriceIdentifier,
+            ApiState,
+        },
+        state::aggregate::Aggregates,
     },
     anyhow::Result,
     axum::{
@@ -21,10 +25,14 @@ use {
         (status = 200, description = "Price feed ids retrieved successfully", body = Vec<RpcPriceIdentifier>)
     ),
 )]
-pub async fn price_feed_ids(
-    State(state): State<crate::api::ApiState>,
-) -> Result<Json<Vec<RpcPriceIdentifier>>, RestError> {
-    let price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state)
+pub async fn price_feed_ids<S>(
+    State(state): State<ApiState<S>>,
+) -> Result<Json<Vec<RpcPriceIdentifier>>, RestError>
+where
+    S: Aggregates,
+{
+    let state = &*state.state;
+    let price_feed_ids = Aggregates::get_price_feed_ids(state)
         .await
         .into_iter()
         .map(RpcPriceIdentifier::from)

+ 18 - 8
apps/hermes/src/api/rest/ready.rs

@@ -1,14 +1,24 @@
-use axum::{
-    extract::State,
-    http::StatusCode,
-    response::{
-        IntoResponse,
-        Response,
+use {
+    crate::{
+        api::ApiState,
+        state::aggregate::Aggregates,
+    },
+    axum::{
+        extract::State,
+        http::StatusCode,
+        response::{
+            IntoResponse,
+            Response,
+        },
     },
 };
 
-pub async fn ready(State(state): State<crate::api::ApiState>) -> Response {
-    match crate::aggregate::is_ready(&state.state).await {
+pub async fn ready<S>(State(state): State<ApiState<S>>) -> Response
+where
+    S: Aggregates,
+{
+    let state = &*state.state;
+    match Aggregates::is_ready(state).await {
         true => (StatusCode::OK, "OK").into_response(),
         false => (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable").into_response(),
     }

+ 23 - 19
apps/hermes/src/api/rest/v2/latest_price_updates.rs

@@ -1,6 +1,5 @@
 use {
     crate::{
-        aggregate::RequestTime,
         api::{
             rest::{
                 verify_price_ids_exist,
@@ -13,6 +12,11 @@ use {
                 PriceIdInput,
                 PriceUpdate,
             },
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
         },
     },
     anyhow::Result,
@@ -73,28 +77,28 @@ fn default_true() -> bool {
         LatestPriceUpdatesQueryParams
     )
 )]
-pub async fn latest_price_updates(
-    State(state): State<crate::api::ApiState>,
+pub async fn latest_price_updates<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<LatestPriceUpdatesQueryParams>,
-) -> Result<Json<PriceUpdate>, RestError> {
+) -> Result<Json<PriceUpdate>, RestError>
+where
+    S: Aggregates,
+{
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
-
     verify_price_ids_exist(&state, &price_ids).await?;
 
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
-        &price_ids,
-        RequestTime::Latest,
-    )
-    .await
-    .map_err(|e| {
-        tracing::warn!(
-            "Error getting price feeds {:?} with update data: {:?}",
-            price_ids,
-            e
-        );
-        RestError::UpdateDataNotFound
-    })?;
+    let state = &*state.state;
+    let price_feeds_with_update_data =
+        Aggregates::get_price_feeds_with_update_data(state, &price_ids, RequestTime::Latest)
+            .await
+            .map_err(|e| {
+                tracing::warn!(
+                    "Error getting price feeds {:?} with update data: {:?}",
+                    price_ids,
+                    e
+                );
+                RestError::UpdateDataNotFound
+            })?;
 
     let price_update_data = price_feeds_with_update_data.update_data;
     let encoded_data: Vec<String> = price_update_data

+ 24 - 14
apps/hermes/src/api/rest/v2/sse.rs

@@ -1,9 +1,5 @@
 use {
     crate::{
-        aggregate::{
-            AggregationEvent,
-            RequestTime,
-        },
         api::{
             rest::{
                 verify_price_ids_exist,
@@ -19,6 +15,11 @@ use {
             },
             ApiState,
         },
+        state::aggregate::{
+            Aggregates,
+            AggregationEvent,
+            RequestTime,
+        },
     },
     anyhow::Result,
     axum::{
@@ -88,16 +89,22 @@ fn default_true() -> bool {
     params(StreamPriceUpdatesQueryParams)
 )]
 /// SSE route handler for streaming price updates.
-pub async fn price_stream_sse_handler(
-    State(state): State<ApiState>,
+pub async fn price_stream_sse_handler<S>(
+    State(state): State<ApiState<S>>,
     QsQuery(params): QsQuery<StreamPriceUpdatesQueryParams>,
-) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError> {
+) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError>
+where
+    S: Aggregates,
+    S: Sync,
+    S: Send,
+    S: 'static,
+{
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
 
     verify_price_ids_exist(&state, &price_ids).await?;
 
     // Clone the update_tx receiver to listen for new price updates
-    let update_rx: broadcast::Receiver<AggregationEvent> = state.update_tx.subscribe();
+    let update_rx: broadcast::Receiver<AggregationEvent> = Aggregates::subscribe(&*state.state);
 
     // Convert the broadcast receiver into a Stream
     let stream = BroadcastStream::new(update_rx);
@@ -134,15 +141,18 @@ pub async fn price_stream_sse_handler(
     Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
 }
 
-async fn handle_aggregation_event(
+async fn handle_aggregation_event<S>(
     event: AggregationEvent,
-    state: ApiState,
+    state: ApiState<S>,
     mut price_ids: Vec<PriceIdentifier>,
     encoding: EncodingType,
     parsed: bool,
     benchmarks_only: bool,
     allow_unordered: bool,
-) -> Result<Option<PriceUpdate>> {
+) -> Result<Option<PriceUpdate>>
+where
+    S: Aggregates,
+{
     // Handle out-of-order events
     if let AggregationEvent::OutOfOrder { .. } = event {
         if !allow_unordered {
@@ -151,11 +161,11 @@ async fn handle_aggregation_event(
     }
 
     // We check for available price feed ids to ensure that the price feed ids provided exists since price feeds can be removed.
-    let available_price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;
+    let available_price_feed_ids = Aggregates::get_price_feed_ids(&*state.state).await;
 
     price_ids.retain(|price_feed_id| available_price_feed_ids.contains(price_feed_id));
 
-    let mut price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
+    let mut price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
         &*state.state,
         &price_ids,
         RequestTime::AtSlot(event.slot()),
@@ -185,7 +195,7 @@ async fn handle_aggregation_event(
                 .iter()
                 .any(|price_feed| price_feed.id == RpcPriceIdentifier::from(*price_id))
         });
-        price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
+        price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
             &*state.state,
             &price_ids,
             RequestTime::AtSlot(event.slot()),

+ 15 - 9
apps/hermes/src/api/rest/v2/timestamp_price_updates.rs

@@ -1,9 +1,5 @@
 use {
     crate::{
-        aggregate::{
-            RequestTime,
-            UnixTimestamp,
-        },
         api::{
             doc_examples,
             rest::{
@@ -17,6 +13,12 @@ use {
                 PriceIdInput,
                 PriceUpdate,
             },
+            ApiState,
+        },
+        state::aggregate::{
+            Aggregates,
+            RequestTime,
+            UnixTimestamp,
         },
     },
     anyhow::Result,
@@ -87,18 +89,22 @@ fn default_true() -> bool {
         TimestampPriceUpdatesQueryParams
     )
 )]
-pub async fn timestamp_price_updates(
-    State(state): State<crate::api::ApiState>,
+pub async fn timestamp_price_updates<S>(
+    State(state): State<ApiState<S>>,
     Path(path_params): Path<TimestampPriceUpdatesPathParams>,
     QsQuery(query_params): QsQuery<TimestampPriceUpdatesQueryParams>,
-) -> Result<Json<PriceUpdate>, RestError> {
+) -> Result<Json<PriceUpdate>, RestError>
+where
+    S: Aggregates,
+{
     let price_ids: Vec<PriceIdentifier> =
         query_params.ids.into_iter().map(|id| id.into()).collect();
 
     verify_price_ids_exist(&state, &price_ids).await?;
 
-    let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
-        &*state.state,
+    let state = &*state.state;
+    let price_feeds_with_update_data = Aggregates::get_price_feeds_with_update_data(
+        state,
         &price_ids,
         RequestTime::FirstAfter(path_params.publish_time),
     )

+ 1 - 1
apps/hermes/src/api/types.rs

@@ -1,6 +1,6 @@
 use {
     super::doc_examples,
-    crate::aggregate::{
+    crate::state::aggregate::{
         PriceFeedUpdate,
         PriceFeedsWithUpdateData,
         Slot,

+ 29 - 23
apps/hermes/src/api/ws.rs

@@ -1,14 +1,18 @@
 use {
-    super::types::{
-        PriceIdInput,
-        RpcPriceFeed,
+    super::{
+        types::{
+            PriceIdInput,
+            RpcPriceFeed,
+        },
+        ApiState,
     },
-    crate::{
+    crate::state::{
         aggregate::{
+            Aggregates,
             AggregationEvent,
             RequestTime,
         },
-        state::State,
+        State,
     },
     anyhow::{
         anyhow,
@@ -212,11 +216,10 @@ pub async fn ws_route_handler(
 }
 
 #[tracing::instrument(skip(stream, state, subscriber_ip))]
-async fn websocket_handler(
-    stream: WebSocket,
-    state: super::ApiState,
-    subscriber_ip: Option<IpAddr>,
-) {
+async fn websocket_handler<S>(stream: WebSocket, state: ApiState<S>, subscriber_ip: Option<IpAddr>)
+where
+    S: Aggregates,
+{
     let ws_state = state.ws.clone();
 
     // Retain the recent rate limit data for the IP addresses to
@@ -235,7 +238,7 @@ async fn websocket_handler(
         })
         .inc();
 
-    let notify_receiver = state.update_tx.subscribe();
+    let notify_receiver = Aggregates::subscribe(&*state.state);
     let (sender, receiver) = stream.split();
     let mut subscriber = Subscriber::new(
         id,
@@ -254,11 +257,11 @@ pub type SubscriberId = usize;
 
 /// Subscriber is an actor that handles a single websocket connection.
 /// It listens to the store for updates and sends them to the client.
-pub struct Subscriber {
+pub struct Subscriber<S> {
     id:                      SubscriberId,
     ip_addr:                 Option<IpAddr>,
     closed:                  bool,
-    store:                   Arc<State>,
+    state:                   Arc<S>,
     ws_state:                Arc<WsState>,
     notify_receiver:         Receiver<AggregationEvent>,
     receiver:                SplitStream<WebSocket>,
@@ -269,11 +272,14 @@ pub struct Subscriber {
     responded_to_ping:       bool,
 }
 
-impl Subscriber {
+impl<S> Subscriber<S>
+where
+    S: Aggregates,
+{
     pub fn new(
         id: SubscriberId,
         ip_addr: Option<IpAddr>,
-        store: Arc<State>,
+        state: Arc<S>,
         ws_state: Arc<WsState>,
         notify_receiver: Receiver<AggregationEvent>,
         receiver: SplitStream<WebSocket>,
@@ -283,7 +289,7 @@ impl Subscriber {
             id,
             ip_addr,
             closed: false,
-            store,
+            state,
             ws_state,
             notify_receiver,
             receiver,
@@ -350,8 +356,9 @@ impl Subscriber {
             .cloned()
             .collect::<Vec<_>>();
 
-        let updates = match crate::aggregate::get_price_feeds_with_update_data(
-            &*self.store,
+        let state = &*self.state;
+        let updates = match Aggregates::get_price_feeds_with_update_data(
+            state,
             &price_feed_ids,
             RequestTime::AtSlot(event.slot()),
         )
@@ -364,8 +371,7 @@ impl Subscriber {
                 // subscription. In this case we just remove the non-existing
                 // price feed from the list and will keep sending updates for
                 // the rest.
-                let available_price_feed_ids =
-                    crate::aggregate::get_price_feed_ids(&*self.store).await;
+                let available_price_feed_ids = Aggregates::get_price_feed_ids(state).await;
 
                 self.price_feeds_with_config
                     .retain(|price_feed_id, _| available_price_feed_ids.contains(price_feed_id));
@@ -376,8 +382,8 @@ impl Subscriber {
                     .cloned()
                     .collect::<Vec<_>>();
 
-                crate::aggregate::get_price_feeds_with_update_data(
-                    &*self.store,
+                Aggregates::get_price_feeds_with_update_data(
+                    state,
                     &price_feed_ids,
                     RequestTime::AtSlot(event.slot()),
                 )
@@ -545,7 +551,7 @@ impl Subscriber {
                 allow_out_of_order,
             }) => {
                 let price_ids: Vec<PriceIdentifier> = ids.into_iter().map(|id| id.into()).collect();
-                let available_price_ids = crate::aggregate::get_price_feed_ids(&*self.store).await;
+                let available_price_ids = Aggregates::get_price_feed_ids(&*self.state).await;
 
                 let not_found_price_ids: Vec<&PriceIdentifier> = price_ids
                     .iter()

+ 5 - 6
apps/hermes/src/main.rs

@@ -17,7 +17,6 @@ use {
     },
 };
 
-mod aggregate;
 mod api;
 mod config;
 mod metrics_server;
@@ -54,7 +53,7 @@ async fn init() -> Result<()> {
             let (update_tx, _) = tokio::sync::broadcast::channel(1000);
 
             // Initialize a cache store with a 1000 element circular buffer.
-            let store = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
+            let state = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
 
             // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
             spawn(async move {
@@ -67,10 +66,10 @@ async fn init() -> Result<()> {
             // Spawn all worker tasks, and wait for all to complete (which will happen if a shutdown
             // signal has been observed).
             let tasks = join_all(vec![
-                spawn(network::wormhole::spawn(opts.clone(), store.clone())),
-                spawn(network::pythnet::spawn(opts.clone(), store.clone())),
-                spawn(metrics_server::run(opts.clone(), store.clone())),
-                spawn(api::spawn(opts.clone(), store.clone(), update_tx)),
+                spawn(network::wormhole::spawn(opts.clone(), state.clone())),
+                spawn(network::pythnet::spawn(opts.clone(), state.clone())),
+                spawn(metrics_server::run(opts.clone(), state.clone())),
+                spawn(api::spawn(opts.clone(), state.clone())),
             ])
             .await;
 

+ 10 - 7
apps/hermes/src/network/pythnet.rs

@@ -4,10 +4,6 @@
 
 use {
     crate::{
-        aggregate::{
-            AccumulatorMessages,
-            Update,
-        },
         api::types::PriceFeedMetadata,
         config::RunOptions,
         network::wormhole::{
@@ -20,7 +16,14 @@ use {
             PriceFeedMeta,
             DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
         },
-        state::State,
+        state::{
+            aggregate::{
+                AccumulatorMessages,
+                Aggregates,
+                Update,
+            },
+            State,
+        },
     },
     anyhow::{
         anyhow,
@@ -182,8 +185,8 @@ pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<()> {
                         if candidate.to_string() == update.value.pubkey {
                             let store = store.clone();
                             tokio::spawn(async move {
-                                if let Err(err) = crate::aggregate::store_update(
-                                    &store,
+                                if let Err(err) = Aggregates::store_update(
+                                    &*store,
                                     Update::AccumulatorMessages(accumulator_messages),
                                 )
                                 .await

+ 12 - 8
apps/hermes/src/network/wormhole.rs

@@ -7,7 +7,13 @@
 use {
     crate::{
         config::RunOptions,
-        state::State,
+        state::{
+            aggregate::{
+                Aggregates,
+                Update,
+            },
+            State,
+        },
     },
     anyhow::{
         anyhow,
@@ -100,10 +106,10 @@ pub struct BridgeConfig {
 /// GuardianSetData extracted from wormhole bridge account, due to no API.
 #[derive(borsh::BorshDeserialize)]
 pub struct GuardianSetData {
-    pub index:           u32,
-    pub keys:            Vec<[u8; 20]>,
-    pub creation_time:   u32,
-    pub expiration_time: u32,
+    pub _index:           u32,
+    pub keys:             Vec<[u8; 20]>,
+    pub _creation_time:   u32,
+    pub _expiration_time: u32,
 }
 
 /// Update the guardian set with the given ID in the state.
@@ -352,9 +358,7 @@ pub async fn store_vaa(state: Arc<State>, sequence: u64, vaa_bytes: Vec<u8>) {
     }
 
     // Hand the VAA to the aggregate store.
-    if let Err(e) =
-        crate::aggregate::store_update(&state, crate::aggregate::Update::Vaa(vaa_bytes)).await
-    {
+    if let Err(e) = Aggregates::store_update(&*state, Update::Vaa(vaa_bytes)).await {
         tracing::error!(error = ?e, "Failed to store VAA in aggregate store.");
     }
 }

+ 2 - 0
apps/hermes/src/price_feeds_metadata.rs

@@ -31,6 +31,7 @@ impl<'a> From<&'a State> for &'a PriceFeedMetaState {
     }
 }
 
+#[async_trait::async_trait]
 pub trait PriceFeedMeta {
     async fn retrieve_price_feeds_metadata(&self) -> Result<Vec<PriceFeedMetadata>>;
     async fn store_price_feeds_metadata(
@@ -44,6 +45,7 @@ pub trait PriceFeedMeta {
     ) -> Result<Vec<PriceFeedMetadata>>;
 }
 
+#[async_trait::async_trait]
 impl<T> PriceFeedMeta for T
 where
     for<'a> &'a T: Into<&'a PriceFeedMetaState>,

+ 9 - 12
apps/hermes/src/state.rs

@@ -2,14 +2,14 @@
 
 use {
     self::{
-        benchmarks::BenchmarksState,
-        cache::CacheState,
-    },
-    crate::{
         aggregate::{
             AggregateState,
             AggregationEvent,
         },
+        benchmarks::BenchmarksState,
+        cache::CacheState,
+    },
+    crate::{
         network::wormhole::GuardianSet,
         price_feeds_metadata::PriceFeedMetaState,
     },
@@ -28,6 +28,7 @@ use {
     },
 };
 
+pub mod aggregate;
 pub mod benchmarks;
 pub mod cache;
 
@@ -41,6 +42,9 @@ pub struct State {
     /// State for the `PriceFeedMeta` service for looking up metadata related to Pyth price feeds.
     pub price_feed_meta: PriceFeedMetaState,
 
+    /// State for accessing/storing Pyth price aggregates.
+    pub aggregates: AggregateState,
+
     /// Sequence numbers of lately observed Vaas. Store uses this set
     /// to ignore the previously observed Vaas as a performance boost.
     pub observed_vaa_seqs: RwLock<BTreeSet<u64>>,
@@ -48,12 +52,6 @@ pub struct State {
     /// Wormhole guardian sets. It is used to verify Vaas before using them.
     pub guardian_set: RwLock<BTreeMap<u32, GuardianSet>>,
 
-    /// The sender to the channel between Store and Api to notify completed updates.
-    pub api_update_tx: Sender<AggregationEvent>,
-
-    /// The aggregate module state.
-    pub aggregate_state: RwLock<AggregateState>,
-
     /// Metrics registry
     pub metrics_registry: RwLock<Registry>,
 }
@@ -69,10 +67,9 @@ impl State {
             cache:             CacheState::new(cache_size),
             benchmarks:        BenchmarksState::new(benchmarks_endpoint),
             price_feed_meta:   PriceFeedMetaState::new(),
+            aggregates:        AggregateState::new(update_tx, &mut metrics_registry),
             observed_vaa_seqs: RwLock::new(Default::default()),
             guardian_set:      RwLock::new(Default::default()),
-            api_update_tx:     update_tx,
-            aggregate_state:   RwLock::new(AggregateState::new(&mut metrics_registry)),
             metrics_registry:  RwLock::new(metrics_registry),
         })
     }

+ 271 - 215
apps/hermes/src/aggregate.rs → apps/hermes/src/state/aggregate.rs

@@ -20,6 +20,7 @@ use {
     },
     crate::{
         network::wormhole::VaaBytes,
+        price_feeds_metadata::PriceFeedMeta,
         state::{
             benchmarks::Benchmarks,
             cache::{
@@ -59,6 +60,13 @@ use {
         collections::HashSet,
         time::Duration,
     },
+    tokio::sync::{
+        broadcast::{
+            Receiver,
+            Sender,
+        },
+        RwLock,
+    },
     wormhole_sdk::Vaa,
 };
 
@@ -102,8 +110,7 @@ impl AggregationEvent {
     }
 }
 
-#[derive(Clone, Debug)]
-pub struct AggregateState {
+pub struct AggregateStateData {
     /// The latest completed slot. This is used to check whether a completed state is new or out of
     /// order.
     pub latest_completed_slot: Option<Slot>,
@@ -119,7 +126,7 @@ pub struct AggregateState {
     pub metrics: metrics::Metrics,
 }
 
-impl AggregateState {
+impl AggregateStateData {
     pub fn new(metrics_registry: &mut Registry) -> Self {
         Self {
             latest_completed_slot:      None,
@@ -130,6 +137,20 @@ impl AggregateState {
     }
 }
 
+pub struct AggregateState {
+    pub data:          RwLock<AggregateStateData>,
+    pub api_update_tx: Sender<AggregationEvent>,
+}
+
+impl AggregateState {
+    pub fn new(update_tx: Sender<AggregationEvent>, metrics_registry: &mut Registry) -> Self {
+        Self {
+            data:          RwLock::new(AggregateStateData::new(metrics_registry)),
+            api_update_tx: update_tx,
+        }
+    }
+}
+
 /// Accumulator messages coming from Pythnet validators.
 ///
 /// The validators writes the accumulator messages using Borsh with
@@ -177,124 +198,220 @@ const READINESS_STALENESS_THRESHOLD: Duration = Duration::from_secs(30);
 /// 10 slots is almost 5 seconds.
 const READINESS_MAX_ALLOWED_SLOT_LAG: Slot = 10;
 
-/// Stores the update data in the store
-#[tracing::instrument(skip(state, update))]
-pub async fn store_update(state: &State, update: Update) -> Result<()> {
-    // The slot that the update is originating from. It should be available
-    // in all the updates.
-    let slot = match update {
-        Update::Vaa(update_vaa) => {
-            let vaa = serde_wormhole::from_slice::<Vaa<&serde_wormhole::RawMessage>>(
-                update_vaa.as_ref(),
-            )?;
-            match WormholeMessage::try_from_bytes(vaa.payload)?.payload {
-                WormholePayload::Merkle(proof) => {
-                    tracing::info!(slot = proof.slot, "Storing VAA Merkle Proof.");
-
-                    store_wormhole_merkle_verified_message(
-                        state,
-                        proof.clone(),
-                        update_vaa.to_owned(),
-                    )
-                    .await?;
+#[async_trait::async_trait]
+pub trait Aggregates
+where
+    Self: Cache,
+    Self: Benchmarks,
+    Self: PriceFeedMeta,
+{
+    fn subscribe(&self) -> Receiver<AggregationEvent>;
+    async fn is_ready(&self) -> bool;
+    async fn store_update(&self, update: Update) -> Result<()>;
+    async fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier>;
+    async fn get_price_feeds_with_update_data(
+        &self,
+        price_ids: &[PriceIdentifier],
+        request_time: RequestTime,
+    ) -> Result<PriceFeedsWithUpdateData>;
+}
+
+/// Allow downcasting State into CacheState for functions that depend on the `Cache` service.
+impl<'a> From<&'a State> for &'a AggregateState {
+    fn from(state: &'a State) -> &'a AggregateState {
+        &state.aggregates
+    }
+}
 
-                    state
-                        .aggregate_state
-                        .write()
-                        .await
-                        .metrics
-                        .observe(proof.slot, metrics::Event::Vaa);
+#[async_trait::async_trait]
+impl<T> Aggregates for T
+where
+    for<'a> &'a T: Into<&'a AggregateState>,
+    T: Sync,
+    T: Send,
+    T: Cache,
+    T: Benchmarks,
+    T: PriceFeedMeta,
+{
+    fn subscribe(&self) -> Receiver<AggregationEvent> {
+        self.into().api_update_tx.subscribe()
+    }
 
-                    proof.slot
+    /// Stores the update data in the store
+    #[tracing::instrument(skip(self, update))]
+    async fn store_update(&self, update: Update) -> Result<()> {
+        // The slot that the update is originating from. It should be available
+        // in all the updates.
+        let slot = match update {
+            Update::Vaa(update_vaa) => {
+                let vaa = serde_wormhole::from_slice::<Vaa<&serde_wormhole::RawMessage>>(
+                    update_vaa.as_ref(),
+                )?;
+                match WormholeMessage::try_from_bytes(vaa.payload)?.payload {
+                    WormholePayload::Merkle(proof) => {
+                        tracing::info!(slot = proof.slot, "Storing VAA Merkle Proof.");
+
+                        store_wormhole_merkle_verified_message(
+                            self,
+                            proof.clone(),
+                            update_vaa.to_owned(),
+                        )
+                        .await?;
+
+                        self.into()
+                            .data
+                            .write()
+                            .await
+                            .metrics
+                            .observe(proof.slot, metrics::Event::Vaa);
+
+                        proof.slot
+                    }
                 }
             }
-        }
-        Update::AccumulatorMessages(accumulator_messages) => {
-            let slot = accumulator_messages.slot;
-            tracing::info!(slot = slot, "Storing Accumulator Messages.");
+            Update::AccumulatorMessages(accumulator_messages) => {
+                let slot = accumulator_messages.slot;
+                tracing::info!(slot = slot, "Storing Accumulator Messages.");
 
-            state
-                .store_accumulator_messages(accumulator_messages)
-                .await?;
+                self.store_accumulator_messages(accumulator_messages)
+                    .await?;
 
-            state
-                .aggregate_state
-                .write()
-                .await
-                .metrics
-                .observe(slot, metrics::Event::AccumulatorMessages);
-            slot
-        }
-    };
+                self.into()
+                    .data
+                    .write()
+                    .await
+                    .metrics
+                    .observe(slot, metrics::Event::AccumulatorMessages);
+                slot
+            }
+        };
 
-    // Update the aggregate state with the latest observed slot
-    {
-        let mut aggregate_state = state.aggregate_state.write().await;
-        aggregate_state.latest_observed_slot = aggregate_state
-            .latest_observed_slot
-            .map(|latest| latest.max(slot))
-            .or(Some(slot));
-    }
+        // Update the aggregate state with the latest observed slot
+        {
+            let mut aggregate_state = self.into().data.write().await;
+            aggregate_state.latest_observed_slot = aggregate_state
+                .latest_observed_slot
+                .map(|latest| latest.max(slot))
+                .or(Some(slot));
+        }
 
-    let accumulator_messages = state.fetch_accumulator_messages(slot).await?;
-    let wormhole_merkle_state = state.fetch_wormhole_merkle_state(slot).await?;
+        let accumulator_messages = self.fetch_accumulator_messages(slot).await?;
+        let wormhole_merkle_state = self.fetch_wormhole_merkle_state(slot).await?;
 
-    let (accumulator_messages, wormhole_merkle_state) =
-        match (accumulator_messages, wormhole_merkle_state) {
-            (Some(accumulator_messages), Some(wormhole_merkle_state)) => {
-                (accumulator_messages, wormhole_merkle_state)
-            }
-            _ => return Ok(()),
-        };
+        let (accumulator_messages, wormhole_merkle_state) =
+            match (accumulator_messages, wormhole_merkle_state) {
+                (Some(accumulator_messages), Some(wormhole_merkle_state)) => {
+                    (accumulator_messages, wormhole_merkle_state)
+                }
+                _ => return Ok(()),
+            };
 
-    tracing::info!(slot = wormhole_merkle_state.root.slot, "Completed Update.");
+        tracing::info!(slot = wormhole_merkle_state.root.slot, "Completed Update.");
 
-    // Once the accumulator reaches a complete state for a specific slot
-    // we can build the message states
-    let message_states = build_message_states(accumulator_messages, wormhole_merkle_state)?;
+        // Once the accumulator reaches a complete state for a specific slot
+        // we can build the message states
+        let message_states = build_message_states(accumulator_messages, wormhole_merkle_state)?;
 
-    let message_state_keys = message_states
-        .iter()
-        .map(|message_state| message_state.key())
-        .collect::<HashSet<_>>();
+        let message_state_keys = message_states
+            .iter()
+            .map(|message_state| message_state.key())
+            .collect::<HashSet<_>>();
 
-    tracing::info!(len = message_states.len(), "Storing Message States.");
-    state.store_message_states(message_states).await?;
+        tracing::info!(len = message_states.len(), "Storing Message States.");
+        self.store_message_states(message_states).await?;
 
-    // Update the aggregate state
-    let mut aggregate_state = state.aggregate_state.write().await;
+        // Update the aggregate state
+        let mut aggregate_state = self.into().data.write().await;
 
-    // Send update event to subscribers. We are purposefully ignoring the result
-    // because there might be no subscribers.
-    let _ = match aggregate_state.latest_completed_slot {
-        None => {
-            aggregate_state.latest_completed_slot.replace(slot);
-            state.api_update_tx.send(AggregationEvent::New { slot })
+        // Check if the update is new or out of order
+        match aggregate_state.latest_completed_slot {
+            None => {
+                aggregate_state.latest_completed_slot.replace(slot);
+                self.into()
+                    .api_update_tx
+                    .send(AggregationEvent::New { slot })?;
+            }
+            Some(latest) if slot > latest => {
+                self.prune_removed_keys(message_state_keys).await;
+                aggregate_state.latest_completed_slot.replace(slot);
+                self.into()
+                    .api_update_tx
+                    .send(AggregationEvent::New { slot })?;
+            }
+            _ => {
+                self.into()
+                    .api_update_tx
+                    .send(AggregationEvent::OutOfOrder { slot })?;
+            }
         }
-        Some(latest) if slot > latest => {
-            state.prune_removed_keys(message_state_keys).await;
-            aggregate_state.latest_completed_slot.replace(slot);
-            state.api_update_tx.send(AggregationEvent::New { slot })
+
+        aggregate_state.latest_completed_slot = aggregate_state
+            .latest_completed_slot
+            .map(|latest| latest.max(slot))
+            .or(Some(slot));
+
+        aggregate_state
+            .latest_completed_update_at
+            .replace(Instant::now());
+
+        aggregate_state
+            .metrics
+            .observe(slot, metrics::Event::CompletedUpdate);
+
+        Ok(())
+    }
+
+    async fn get_price_feeds_with_update_data(
+        &self,
+        price_ids: &[PriceIdentifier],
+        request_time: RequestTime,
+    ) -> Result<PriceFeedsWithUpdateData> {
+        match get_verified_price_feeds(self, price_ids, request_time.clone()).await {
+            Ok(price_feeds_with_update_data) => Ok(price_feeds_with_update_data),
+            Err(e) => {
+                if let RequestTime::FirstAfter(publish_time) = request_time {
+                    return Benchmarks::get_verified_price_feeds(self, price_ids, publish_time)
+                        .await;
+                }
+                Err(e)
+            }
         }
-        _ => state
-            .api_update_tx
-            .send(AggregationEvent::OutOfOrder { slot }),
-    };
+    }
 
-    aggregate_state.latest_completed_slot = aggregate_state
-        .latest_completed_slot
-        .map(|latest| latest.max(slot))
-        .or(Some(slot));
+    async fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier> {
+        Cache::message_state_keys(self)
+            .await
+            .iter()
+            .map(|key| PriceIdentifier::new(key.feed_id))
+            .collect()
+    }
+
+    async fn is_ready(&self) -> bool {
+        let metadata = self.into().data.read().await;
+        let price_feeds_metadata = PriceFeedMeta::retrieve_price_feeds_metadata(self)
+            .await
+            .unwrap();
 
-    aggregate_state
-        .latest_completed_update_at
-        .replace(Instant::now());
+        let has_completed_recently = match metadata.latest_completed_update_at.as_ref() {
+            Some(latest_completed_update_time) => {
+                latest_completed_update_time.elapsed() < READINESS_STALENESS_THRESHOLD
+            }
+            None => false,
+        };
 
-    aggregate_state
-        .metrics
-        .observe(slot, metrics::Event::CompletedUpdate);
+        let is_not_behind = match (
+            metadata.latest_completed_slot,
+            metadata.latest_observed_slot,
+        ) {
+            (Some(latest_completed_slot), Some(latest_observed_slot)) => {
+                latest_observed_slot - latest_completed_slot <= READINESS_MAX_ALLOWED_SLOT_LAG
+            }
+            _ => false,
+        };
 
-    Ok(())
+        let is_metadata_loaded = !price_feeds_metadata.is_empty();
+        has_completed_recently && is_not_behind && is_metadata_loaded
+    }
 }
 
 #[tracing::instrument(skip(accumulator_messages, wormhole_merkle_state))]
@@ -389,73 +506,12 @@ where
     })
 }
 
-pub async fn get_price_feeds_with_update_data<S>(
-    state: &S,
-    price_ids: &[PriceIdentifier],
-    request_time: RequestTime,
-) -> Result<PriceFeedsWithUpdateData>
-where
-    S: Cache,
-    S: Benchmarks,
-{
-    match get_verified_price_feeds(state, price_ids, request_time.clone()).await {
-        Ok(price_feeds_with_update_data) => Ok(price_feeds_with_update_data),
-        Err(e) => {
-            if let RequestTime::FirstAfter(publish_time) = request_time {
-                return state
-                    .get_verified_price_feeds(price_ids, publish_time)
-                    .await;
-            }
-            Err(e)
-        }
-    }
-}
-
-pub async fn get_price_feed_ids<S>(state: &S) -> HashSet<PriceIdentifier>
-where
-    S: Cache,
-{
-    state
-        .message_state_keys()
-        .await
-        .iter()
-        .map(|key| PriceIdentifier::new(key.feed_id))
-        .collect()
-}
-
-pub async fn is_ready(state: &State) -> bool {
-    let metadata = state.aggregate_state.read().await;
-    let price_feeds_metadata = state.price_feed_meta.data.read().await;
-
-    let has_completed_recently = match metadata.latest_completed_update_at.as_ref() {
-        Some(latest_completed_update_time) => {
-            latest_completed_update_time.elapsed() < READINESS_STALENESS_THRESHOLD
-        }
-        None => false,
-    };
-
-    let is_not_behind = match (
-        metadata.latest_completed_slot,
-        metadata.latest_observed_slot,
-    ) {
-        (Some(latest_completed_slot), Some(latest_observed_slot)) => {
-            latest_observed_slot - latest_completed_slot <= READINESS_MAX_ALLOWED_SLOT_LAG
-        }
-        _ => false,
-    };
-
-    let is_metadata_loaded = !price_feeds_metadata.is_empty();
-
-    has_completed_recently && is_not_behind && is_metadata_loaded
-}
-
 #[cfg(test)]
 mod test {
     use {
         super::*,
         crate::{
             api::types::PriceFeedMetadata,
-            price_feeds_metadata::PriceFeedMeta,
             state::test::setup_state,
         },
         futures::future::join_all,
@@ -557,7 +613,7 @@ mod test {
     }
 
     pub async fn store_multiple_concurrent_valid_updates(state: Arc<State>, updates: Vec<Update>) {
-        let res = join_all(updates.into_iter().map(|u| store_update(&state, u))).await;
+        let res = join_all(updates.into_iter().map(|u| (&state).store_update(u))).await;
         // Check that all store_update calls succeeded
         assert!(res.into_iter().all(|r| r.is_ok()));
     }
@@ -583,19 +639,19 @@ mod test {
 
         // Check the price ids are stored correctly
         assert_eq!(
-            get_price_feed_ids(&*state).await,
+            (&*state).get_price_feed_ids().await,
             vec![PriceIdentifier::new([100; 32])].into_iter().collect()
         );
 
         // Check get_price_feeds_with_update_data retrieves the correct
         // price feed with correct update data.
-        let price_feeds_with_update_data = get_price_feeds_with_update_data(
-            &*state,
-            &[PriceIdentifier::new([100; 32])],
-            RequestTime::Latest,
-        )
-        .await
-        .unwrap();
+        let price_feeds_with_update_data = (&*state)
+            .get_price_feeds_with_update_data(
+                &[PriceIdentifier::new([100; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .unwrap();
 
         assert_eq!(
             price_feeds_with_update_data.price_feeds,
@@ -708,7 +764,7 @@ mod test {
 
         // Check the price ids are stored correctly
         assert_eq!(
-            get_price_feed_ids(&*state).await,
+            (&*state).get_price_feed_ids().await,
             vec![
                 PriceIdentifier::new([100; 32]),
                 PriceIdentifier::new([200; 32])
@@ -718,13 +774,13 @@ mod test {
         );
 
         // Check that price feed 2 exists
-        assert!(get_price_feeds_with_update_data(
-            &*state,
-            &[PriceIdentifier::new([200; 32])],
-            RequestTime::Latest,
-        )
-        .await
-        .is_ok());
+        assert!((&*state)
+            .get_price_feeds_with_update_data(
+                &[PriceIdentifier::new([200; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .is_ok());
 
         // Now send an update with only price feed 1 (without price feed 2)
         // and make sure that price feed 2 is not stored anymore.
@@ -745,17 +801,17 @@ mod test {
 
         // Check that price feed 2 does not exist anymore
         assert_eq!(
-            get_price_feed_ids(&*state).await,
+            (&*state).get_price_feed_ids().await,
             vec![PriceIdentifier::new([100; 32]),].into_iter().collect()
         );
 
-        assert!(get_price_feeds_with_update_data(
-            &*state,
-            &[PriceIdentifier::new([200; 32])],
-            RequestTime::Latest,
-        )
-        .await
-        .is_err());
+        assert!((&*state)
+            .get_price_feeds_with_update_data(
+                &[PriceIdentifier::new([200; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .is_err());
     }
 
     #[tokio::test]
@@ -791,13 +847,13 @@ mod test {
         MockClock::advance(Duration::from_secs(1));
 
         // Get the price feeds with update data
-        let price_feeds_with_update_data = get_price_feeds_with_update_data(
-            &*state,
-            &[PriceIdentifier::new([100; 32])],
-            RequestTime::Latest,
-        )
-        .await
-        .unwrap();
+        let price_feeds_with_update_data = (&*state)
+            .get_price_feeds_with_update_data(
+                &[PriceIdentifier::new([100; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .unwrap();
 
         // check received_at is correct
         assert_eq!(price_feeds_with_update_data.price_feeds.len(), 1);
@@ -817,13 +873,13 @@ mod test {
             .unwrap();
 
         // Check the state is ready
-        assert!(is_ready(&state).await);
+        assert!((&state).is_ready().await);
 
         // Advance the clock to make the prices stale
         MockClock::advance_system_time(READINESS_STALENESS_THRESHOLD);
         MockClock::advance(READINESS_STALENESS_THRESHOLD);
         // Check the state is not ready
-        assert!(!is_ready(&state).await);
+        assert!(!(&state).is_ready().await);
     }
 
     /// Test that the state retains the latest slots upon cache eviction.
@@ -866,16 +922,16 @@ mod test {
 
         // Check the last 100 slots are retained
         for slot in 900..1000 {
-            let price_feeds_with_update_data = get_price_feeds_with_update_data(
-                &*state,
-                &[
-                    PriceIdentifier::new([100; 32]),
-                    PriceIdentifier::new([200; 32]),
-                ],
-                RequestTime::FirstAfter(slot as i64),
-            )
-            .await
-            .unwrap();
+            let price_feeds_with_update_data = (&*state)
+                .get_price_feeds_with_update_data(
+                    &[
+                        PriceIdentifier::new([100; 32]),
+                        PriceIdentifier::new([200; 32]),
+                    ],
+                    RequestTime::FirstAfter(slot as i64),
+                )
+                .await
+                .unwrap();
             assert_eq!(price_feeds_with_update_data.price_feeds.len(), 2);
             assert_eq!(price_feeds_with_update_data.price_feeds[0].slot, Some(slot));
             assert_eq!(price_feeds_with_update_data.price_feeds[1].slot, Some(slot));
@@ -883,16 +939,16 @@ mod test {
 
         // Check nothing else is retained
         for slot in 0..900 {
-            assert!(get_price_feeds_with_update_data(
-                &*state,
-                &[
-                    PriceIdentifier::new([100; 32]),
-                    PriceIdentifier::new([200; 32])
-                ],
-                RequestTime::FirstAfter(slot as i64),
-            )
-            .await
-            .is_err());
+            assert!((&*state)
+                .get_price_feeds_with_update_data(
+                    &[
+                        PriceIdentifier::new([100; 32]),
+                        PriceIdentifier::new([200; 32])
+                    ],
+                    RequestTime::FirstAfter(slot as i64),
+                )
+                .await
+                .is_err());
         }
     }
 }

+ 0 - 0
apps/hermes/src/aggregate/metrics.rs → apps/hermes/src/state/aggregate/metrics.rs


+ 0 - 0
apps/hermes/src/aggregate/wormhole_merkle.rs → apps/hermes/src/state/aggregate/wormhole_merkle.rs


+ 5 - 3
apps/hermes/src/state/benchmarks.rs

@@ -1,14 +1,14 @@
 //! This module communicates with Pyth Benchmarks, an API for historical price feeds and their updates.
 
 use {
-    super::State,
-    crate::{
+    super::{
         aggregate::{
             PriceFeedsWithUpdateData,
             UnixTimestamp,
         },
-        api::types::PriceUpdate,
+        State,
     },
+    crate::api::types::PriceUpdate,
     anyhow::Result,
     base64::{
         engine::general_purpose::STANDARD as base64_standard_engine,
@@ -69,6 +69,7 @@ impl<'a> From<&'a State> for &'a BenchmarksState {
     }
 }
 
+#[async_trait::async_trait]
 pub trait Benchmarks {
     async fn get_verified_price_feeds(
         &self,
@@ -77,6 +78,7 @@ pub trait Benchmarks {
     ) -> Result<PriceFeedsWithUpdateData>;
 }
 
+#[async_trait::async_trait]
 impl<T> Benchmarks for T
 where
     for<'a> &'a T: Into<&'a BenchmarksState>,

+ 12 - 10
apps/hermes/src/state/cache.rs

@@ -1,6 +1,6 @@
 use {
     super::State,
-    crate::aggregate::{
+    crate::state::aggregate::{
         wormhole_merkle::WormholeMerkleState,
         AccumulatorMessages,
         ProofSet,
@@ -132,16 +132,10 @@ impl<'a> From<&'a State> for &'a CacheState {
     }
 }
 
+#[async_trait::async_trait]
 pub trait Cache {
-    async fn message_state_keys(&self) -> Vec<MessageStateKey>;
     async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
     async fn prune_removed_keys(&self, current_keys: HashSet<MessageStateKey>);
-    async fn fetch_message_states(
-        &self,
-        ids: Vec<FeedId>,
-        request_time: RequestTime,
-        filter: MessageStateFilter,
-    ) -> Result<Vec<MessageState>>;
     async fn store_accumulator_messages(
         &self,
         accumulator_messages: AccumulatorMessages,
@@ -152,8 +146,16 @@ pub trait Cache {
         wormhole_merkle_state: WormholeMerkleState,
     ) -> Result<()>;
     async fn fetch_wormhole_merkle_state(&self, slot: Slot) -> Result<Option<WormholeMerkleState>>;
+    async fn message_state_keys(&self) -> Vec<MessageStateKey>;
+    async fn fetch_message_states(
+        &self,
+        ids: Vec<FeedId>,
+        request_time: RequestTime,
+        filter: MessageStateFilter,
+    ) -> Result<Vec<MessageState>>;
 }
 
+#[async_trait::async_trait]
 impl<T> Cache for T
 where
     for<'a> &'a T: Into<&'a CacheState>,
@@ -322,9 +324,9 @@ async fn retrieve_message_state(
 mod test {
     use {
         super::*,
-        crate::{
+        crate::state::{
             aggregate::wormhole_merkle::WormholeMerkleMessageProof,
-            state::test::setup_state,
+            test::setup_state,
         },
         pyth_sdk::UnixTimestamp,
         pythnet_sdk::{