Procházet zdrojové kódy

refactor(hermes): remove benchmarks_url from state

Reisen před 1 rokem
rodič
revize
96c973b8e6

+ 15 - 1
hermes/src/aggregate.rs

@@ -55,6 +55,7 @@ use {
             },
         },
     },
+    reqwest::Url,
     std::{
         collections::HashSet,
         time::Duration,
@@ -394,6 +395,7 @@ pub async fn get_price_feeds_with_update_data<S>(
     state: &S,
     price_ids: &[PriceIdentifier],
     request_time: RequestTime,
+    benchmarks_url: Option<Url>,
 ) -> Result<PriceFeedsWithUpdateData>
 where
     S: AggregateCache,
@@ -403,7 +405,13 @@ where
         Ok(price_feeds_with_update_data) => Ok(price_feeds_with_update_data),
         Err(e) => {
             if let RequestTime::FirstAfter(publish_time) = request_time {
-                return Benchmarks::get_verified_price_feeds(state, price_ids, publish_time).await;
+                return Benchmarks::get_verified_price_feeds(
+                    state,
+                    price_ids,
+                    publish_time,
+                    benchmarks_url,
+                )
+                .await;
             }
             Err(e)
         }
@@ -595,6 +603,7 @@ mod test {
             &*state,
             &[PriceIdentifier::new([100; 32])],
             RequestTime::Latest,
+            None,
         )
         .await
         .unwrap();
@@ -724,6 +733,7 @@ mod test {
             &*state,
             &[PriceIdentifier::new([200; 32])],
             RequestTime::Latest,
+            None,
         )
         .await
         .is_ok());
@@ -755,6 +765,7 @@ mod test {
             &*state,
             &[PriceIdentifier::new([200; 32])],
             RequestTime::Latest,
+            None,
         )
         .await
         .is_err());
@@ -797,6 +808,7 @@ mod test {
             &*state,
             &[PriceIdentifier::new([100; 32])],
             RequestTime::Latest,
+            None,
         )
         .await
         .unwrap();
@@ -877,6 +889,7 @@ mod test {
                     PriceIdentifier::new([200; 32]),
                 ],
                 RequestTime::FirstAfter(slot as i64),
+                None,
             )
             .await
             .unwrap();
@@ -894,6 +907,7 @@ mod test {
                     PriceIdentifier::new([200; 32])
                 ],
                 RequestTime::FirstAfter(slot as i64),
+                None,
             )
             .await
             .is_err());

+ 2 - 1
hermes/src/api.rs

@@ -149,7 +149,8 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
         .layer(CorsLayer::permissive())
         // Non-strict mode permits escaped [] in URL parameters. 5 is the allowed depth (also the
         // default value for this parameter).
-        .layer(Extension(QsQueryConfig::new(5, false)));
+        .layer(Extension(QsQueryConfig::new(5, false)))
+        .layer(Extension(opts.benchmarks.endpoint));
 
     // Binds the axum's server to the configured address and port. This is a blocking call and will
     // not return until the server is shutdown.

+ 4 - 0
hermes/src/api/rest/get_price_feed.rs

@@ -17,9 +17,11 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
 };
@@ -63,6 +65,7 @@ pub struct GetPriceFeedQueryParams {
 pub async fn get_price_feed(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<GetPriceFeedQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<RpcPriceFeed>, RestError> {
     let price_id: PriceIdentifier = params.id.into();
 
@@ -72,6 +75,7 @@ pub async fn get_price_feed(
         &*state.state,
         &[price_id],
         RequestTime::FirstAfter(params.publish_time),
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/get_vaa.rs

@@ -15,6 +15,7 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     base64::{
@@ -22,6 +23,7 @@ use {
         Engine as _,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde_qs::axum::QsQuery,
     utoipa::{
         IntoParams,
@@ -71,6 +73,7 @@ pub struct GetVaaResponse {
 pub async fn get_vaa(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<GetVaaQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<GetVaaResponse>, RestError> {
     let price_id: PriceIdentifier = params.id.into();
 
@@ -80,6 +83,7 @@ pub async fn get_vaa(
         &*state.state,
         &[price_id],
         RequestTime::FirstAfter(params.publish_time),
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/get_vaa_ccip.rs

@@ -10,6 +10,7 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     derive_more::{
@@ -17,6 +18,7 @@ use {
         DerefMut,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde::{
         Deserialize,
         Serialize,
@@ -59,6 +61,7 @@ pub struct GetVaaCcipResponse {
 pub async fn get_vaa_ccip(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<GetVaaCcipQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<GetVaaCcipResponse>, RestError> {
     let price_id: PriceIdentifier = PriceIdentifier::new(
         params.data[0..32]
@@ -77,6 +80,7 @@ pub async fn get_vaa_ccip(
         &*state.state,
         &[price_id],
         RequestTime::FirstAfter(publish_time),
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/latest_price_feeds.rs

@@ -13,9 +13,11 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
 };
@@ -62,6 +64,7 @@ pub struct LatestPriceFeedsQueryParams {
 pub async fn latest_price_feeds(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<LatestPriceFeedsQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<Vec<RpcPriceFeed>>, RestError> {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
 
@@ -71,6 +74,7 @@ pub async fn latest_price_feeds(
         &*state.state,
         &price_ids,
         RequestTime::Latest,
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/latest_vaas.rs

@@ -11,6 +11,7 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     base64::{
@@ -18,6 +19,7 @@ use {
         Engine as _,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
 };
@@ -57,6 +59,7 @@ pub struct LatestVaasQueryParams {
 pub async fn latest_vaas(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<LatestVaasQueryParams>,
+    Extension(benchmark_url): Extension<Option<Url>>,
 ) -> Result<Json<Vec<String>>, RestError> {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
 
@@ -66,6 +69,7 @@ pub async fn latest_vaas(
         &*state.state,
         &price_ids,
         RequestTime::Latest,
+        benchmark_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/v2/latest_price_updates.rs

@@ -18,6 +18,7 @@ use {
     anyhow::Result,
     axum::{
         extract::State,
+        Extension,
         Json,
     },
     base64::{
@@ -25,6 +26,7 @@ use {
         Engine as _,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde::Deserialize,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
@@ -76,6 +78,7 @@ fn default_true() -> bool {
 pub async fn latest_price_updates(
     State(state): State<crate::api::ApiState>,
     QsQuery(params): QsQuery<LatestPriceUpdatesQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<PriceUpdate>, RestError> {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
 
@@ -85,6 +88,7 @@ pub async fn latest_price_updates(
         &*state.state,
         &price_ids,
         RequestTime::Latest,
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 4 - 0
hermes/src/api/rest/v2/timestamp_price_updates.rs

@@ -25,9 +25,11 @@ use {
             Path,
             State,
         },
+        Extension,
         Json,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde::Deserialize,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
@@ -91,6 +93,7 @@ pub async fn timestamp_price_updates(
     State(state): State<crate::api::ApiState>,
     Path(path_params): Path<TimestampPriceUpdatesPathParams>,
     QsQuery(query_params): QsQuery<TimestampPriceUpdatesQueryParams>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
 ) -> Result<Json<PriceUpdate>, RestError> {
     let price_ids: Vec<PriceIdentifier> =
         query_params.ids.into_iter().map(|id| id.into()).collect();
@@ -101,6 +104,7 @@ pub async fn timestamp_price_updates(
         &*state.state,
         &price_ids,
         RequestTime::FirstAfter(path_params.publish_time),
+        benchmarks_url,
     )
     .await
     .map_err(|e| {

+ 17 - 7
hermes/src/api/ws.rs

@@ -25,6 +25,7 @@ use {
         },
         http::HeaderMap,
         response::IntoResponse,
+        Extension,
     },
     futures::{
         stream::{
@@ -52,6 +53,7 @@ use {
         },
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde::{
         Deserialize,
         Serialize,
@@ -199,6 +201,7 @@ enum ServerResponseMessage {
 pub async fn ws_route_handler(
     ws: WebSocketUpgrade,
     AxumState(state): AxumState<super::ApiState>,
+    Extension(benchmarks_url): Extension<Option<Url>>,
     headers: HeaderMap,
 ) -> impl IntoResponse {
     let requester_ip = headers
@@ -208,7 +211,7 @@ pub async fn ws_route_handler(
         .and_then(|value| value.parse().ok());
 
     ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
-        .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
+        .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip, benchmarks_url))
 }
 
 #[tracing::instrument(skip(stream, state, subscriber_ip))]
@@ -216,6 +219,7 @@ async fn websocket_handler(
     stream: WebSocket,
     state: super::ApiState,
     subscriber_ip: Option<IpAddr>,
+    benchmarks_url: Option<Url>,
 ) {
     let ws_state = state.ws.clone();
 
@@ -247,7 +251,7 @@ async fn websocket_handler(
         sender,
     );
 
-    subscriber.run().await;
+    subscriber.run(benchmarks_url).await;
 }
 
 pub type SubscriberId = usize;
@@ -296,20 +300,20 @@ impl Subscriber {
     }
 
     #[tracing::instrument(skip(self))]
-    pub async fn run(&mut self) {
+    pub async fn run(&mut self, benchmarks_url: Option<Url>) {
         while !self.closed {
-            if let Err(e) = self.handle_next().await {
+            if let Err(e) = self.handle_next(benchmarks_url.clone()).await {
                 tracing::debug!(subscriber = self.id, error = ?e, "Error Handling Subscriber Message.");
                 break;
             }
         }
     }
 
-    async fn handle_next(&mut self) -> Result<()> {
+    async fn handle_next(&mut self, benchmarks_url: Option<Url>) -> Result<()> {
         tokio::select! {
             maybe_update_feeds_event = self.notify_receiver.recv() => {
                 match maybe_update_feeds_event {
-                    Ok(event) => self.handle_price_feeds_update(event).await,
+                    Ok(event) => self.handle_price_feeds_update(event, benchmarks_url).await,
                     Err(e) => Err(anyhow!("Failed to receive update from store: {:?}", e)),
                 }
             },
@@ -343,7 +347,11 @@ impl Subscriber {
         }
     }
 
-    async fn handle_price_feeds_update(&mut self, event: AggregationEvent) -> Result<()> {
+    async fn handle_price_feeds_update(
+        &mut self,
+        event: AggregationEvent,
+        benchmarks_url: Option<Url>,
+    ) -> Result<()> {
         let price_feed_ids = self
             .price_feeds_with_config
             .keys()
@@ -354,6 +362,7 @@ impl Subscriber {
             &*self.store,
             &price_feed_ids,
             RequestTime::AtSlot(event.slot()),
+            benchmarks_url.clone(),
         )
         .await
         {
@@ -380,6 +389,7 @@ impl Subscriber {
                     &*self.store,
                     &price_feed_ids,
                     RequestTime::AtSlot(event.slot()),
+                    benchmarks_url,
                 )
                 .await?
             }

+ 1 - 1
hermes/src/main.rs

@@ -54,7 +54,7 @@ async fn init() -> Result<()> {
             let (update_tx, _) = tokio::sync::broadcast::channel(1000);
 
             // Initialize a cache store with a 1000 element circular buffer.
-            let store = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
+            let store = State::new(update_tx.clone(), 1000);
 
             // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
             spawn(async move {

+ 8 - 17
hermes/src/state.rs

@@ -11,7 +11,6 @@ use {
         network::wormhole::GuardianSet,
     },
     prometheus_client::registry::Registry,
-    reqwest::Url,
     std::{
         collections::{
             BTreeMap,
@@ -46,9 +45,6 @@ pub struct State {
     /// The aggregate module state.
     pub aggregate_state: RwLock<AggregateState>,
 
-    /// Benchmarks endpoint
-    pub benchmarks_endpoint: Option<Url>,
-
     /// Metrics registry
     pub metrics_registry: RwLock<Registry>,
 
@@ -57,20 +53,15 @@ pub struct State {
 }
 
 impl State {
-    pub fn new(
-        update_tx: Sender<AggregationEvent>,
-        cache_size: u64,
-        benchmarks_endpoint: Option<Url>,
-    ) -> Arc<Self> {
+    pub fn new(update_tx: Sender<AggregationEvent>, cache_size: u64) -> Arc<Self> {
         let mut metrics_registry = Registry::default();
         Arc::new(Self {
-            cache: Cache::new(cache_size),
-            observed_vaa_seqs: RwLock::new(Default::default()),
-            guardian_set: RwLock::new(Default::default()),
-            api_update_tx: update_tx,
-            aggregate_state: RwLock::new(AggregateState::new(&mut metrics_registry)),
-            benchmarks_endpoint,
-            metrics_registry: RwLock::new(metrics_registry),
+            cache:                Cache::new(cache_size),
+            observed_vaa_seqs:    RwLock::new(Default::default()),
+            guardian_set:         RwLock::new(Default::default()),
+            api_update_tx:        update_tx,
+            aggregate_state:      RwLock::new(AggregateState::new(&mut metrics_registry)),
+            metrics_registry:     RwLock::new(metrics_registry),
             price_feeds_metadata: RwLock::new(Default::default()),
         })
     }
@@ -86,7 +77,7 @@ pub mod test {
 
     pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<AggregationEvent>) {
         let (update_tx, update_rx) = tokio::sync::broadcast::channel(1000);
-        let state = State::new(update_tx, cache_size, None);
+        let state = State::new(update_tx, cache_size);
 
         // Add an initial guardian set with public key 0
         update_guardian_set(

+ 4 - 2
hermes/src/state/benchmarks.rs

@@ -14,6 +14,7 @@ use {
         Engine as _,
     },
     pyth_sdk::PriceIdentifier,
+    reqwest::Url,
     serde::Deserialize,
 };
 
@@ -56,6 +57,7 @@ pub trait Benchmarks {
         &self,
         price_ids: &[PriceIdentifier],
         publish_time: UnixTimestamp,
+        benchmarks_endpoint: Option<Url>,
     ) -> Result<PriceFeedsWithUpdateData>;
 }
 
@@ -65,9 +67,9 @@ impl Benchmarks for crate::state::State {
         &self,
         price_ids: &[PriceIdentifier],
         publish_time: UnixTimestamp,
+        benchmarks_endpoint: Option<Url>,
     ) -> Result<PriceFeedsWithUpdateData> {
-        let endpoint = self
-            .benchmarks_endpoint
+        let endpoint = benchmarks_endpoint
             .as_ref()
             .ok_or_else(|| anyhow::anyhow!("Benchmarks endpoint is not set"))?
             .join(&format!("/v1/updates/price/{}", publish_time))