Explorar o código

refactor(hermes): make `State` hidden to force APIs (#1537)

Reisen hai 1 ano
pai
achega
f0adce06c7

+ 29 - 10
apps/hermes/src/api.rs

@@ -1,7 +1,12 @@
 use {
     crate::{
         config::RunOptions,
-        state::State,
+        state::{
+            Aggregates,
+            Benchmarks,
+            Cache,
+            Metrics,
+        },
     },
     anyhow::Result,
     axum::{
@@ -24,7 +29,7 @@ mod rest;
 pub mod types;
 mod ws;
 
-pub struct ApiState<S = State> {
+pub struct ApiState<S> {
     pub state:   Arc<S>,
     pub ws:      Arc<ws::WsState>,
     pub metrics: Arc<metrics_middleware::ApiMetrics>,
@@ -42,12 +47,12 @@ impl<S> Clone for ApiState<S> {
     }
 }
 
-impl ApiState<State> {
-    pub fn new(
-        state: Arc<State>,
-        ws_whitelist: Vec<IpNet>,
-        requester_ip_header_name: String,
-    ) -> Self {
+impl<S> ApiState<S> {
+    pub fn new(state: Arc<S>, ws_whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self
+    where
+        S: Metrics,
+        S: Send + Sync + 'static,
+    {
         Self {
             metrics: Arc::new(metrics_middleware::ApiMetrics::new(state.clone())),
             ws: Arc::new(ws::WsState::new(
@@ -61,7 +66,14 @@ impl ApiState<State> {
 }
 
 #[tracing::instrument(skip(opts, state))]
-pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
+pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
+where
+    S: Aggregates,
+    S: Benchmarks,
+    S: Cache,
+    S: Metrics,
+    S: Send + Sync + 'static,
+{
     let state = {
         let opts = opts.clone();
         ApiState::new(
@@ -79,7 +91,14 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
 /// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
 /// packages they are based on (tokio & hyper).
 #[tracing::instrument(skip(opts, state))]
-pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
+pub async fn run<S>(opts: RunOptions, state: ApiState<S>) -> Result<()>
+where
+    S: Aggregates,
+    S: Benchmarks,
+    S: Cache,
+    S: Metrics,
+    S: Send + Sync + 'static,
+{
     tracing::info!(endpoint = %opts.rpc.listen_addr, "Starting RPC Server.");
 
     #[derive(OpenApi)]

+ 3 - 5
apps/hermes/src/api/metrics_middleware.rs

@@ -31,9 +31,7 @@ impl ApiMetrics {
     pub fn new<S>(state: Arc<S>) -> Self
     where
         S: Metrics,
-        S: Send,
-        S: Sync,
-        S: 'static,
+        S: Send + Sync + 'static,
     {
         let new = Self {
             requests:  Family::default(),
@@ -81,8 +79,8 @@ pub struct Labels {
     pub status: u16,
 }
 
-pub async fn track_metrics<B>(
-    State(api_state): State<ApiState>,
+pub async fn track_metrics<B, S>(
+    State(api_state): State<ApiState<S>>,
     req: Request<B>,
     next: Next<B>,
 ) -> impl IntoResponse {

+ 1 - 1
apps/hermes/src/api/rest/v2/price_feeds_metadata.rs

@@ -8,7 +8,7 @@ use {
             },
             ApiState,
         },
-        price_feeds_metadata::PriceFeedMeta,
+        state::price_feeds_metadata::PriceFeedMeta,
     },
     anyhow::Result,
     axum::{

+ 1 - 3
apps/hermes/src/api/rest/v2/sse.rs

@@ -95,9 +95,7 @@ pub async fn price_stream_sse_handler<S>(
 ) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError>
 where
     S: Aggregates,
-    S: Sync,
-    S: Send,
-    S: 'static,
+    S: Send + Sync + 'static,
 {
     let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
 

+ 20 - 8
apps/hermes/src/api/ws.rs

@@ -13,7 +13,9 @@ use {
             RequestTime,
         },
         metrics::Metrics,
-        State,
+        Benchmarks,
+        Cache,
+        PriceFeedMeta,
     },
     anyhow::{
         anyhow,
@@ -124,9 +126,7 @@ impl WsMetrics {
     pub fn new<S>(state: Arc<S>) -> Self
     where
         S: Metrics,
-        S: Send,
-        S: Sync,
-        S: 'static,
+        S: Send + Sync + 'static,
     {
         let new = Self {
             interactions: Family::default(),
@@ -161,7 +161,11 @@ pub struct WsState {
 }
 
 impl WsState {
-    pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<State>) -> Self {
+    pub fn new<S>(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<S>) -> Self
+    where
+        S: Metrics,
+        S: Send + Sync + 'static,
+    {
         Self {
             subscriber_counter: AtomicUsize::new(0),
             rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
@@ -211,11 +215,18 @@ enum ServerResponseMessage {
     Err { error: String },
 }
 
-pub async fn ws_route_handler(
+pub async fn ws_route_handler<S>(
     ws: WebSocketUpgrade,
-    AxumState(state): AxumState<super::ApiState>,
+    AxumState(state): AxumState<ApiState<S>>,
     headers: HeaderMap,
-) -> impl IntoResponse {
+) -> impl IntoResponse
+where
+    S: Aggregates,
+    S: Benchmarks,
+    S: Cache,
+    S: PriceFeedMeta,
+    S: Send + Sync + 'static,
+{
     let requester_ip = headers
         .get(state.ws.requester_ip_header_name.as_str())
         .and_then(|value| value.to_str().ok())
@@ -230,6 +241,7 @@ pub async fn ws_route_handler(
 async fn websocket_handler<S>(stream: WebSocket, state: ApiState<S>, subscriber_ip: Option<IpAddr>)
 where
     S: Aggregates,
+    S: Send,
 {
     let ws_state = state.ws.clone();
 

+ 1 - 3
apps/hermes/src/main.rs

@@ -9,7 +9,6 @@ use {
     },
     futures::future::join_all,
     lazy_static::lazy_static,
-    state::State,
     std::io::IsTerminal,
     tokio::{
         spawn,
@@ -21,7 +20,6 @@ mod api;
 mod config;
 mod metrics_server;
 mod network;
-mod price_feeds_metadata;
 mod serde;
 mod state;
 
@@ -53,7 +51,7 @@ async fn init() -> Result<()> {
             let (update_tx, _) = tokio::sync::broadcast::channel(1000);
 
             // Initialize a cache store with a 1000 element circular buffer.
-            let state = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
+            let state = state::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
 
             // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
             spawn(async move {

+ 10 - 6
apps/hermes/src/metrics_server.rs

@@ -5,10 +5,7 @@
 use {
     crate::{
         config::RunOptions,
-        state::{
-            metrics::Metrics,
-            State as AppState,
-        },
+        state::metrics::Metrics,
     },
     anyhow::Result,
     axum::{
@@ -23,7 +20,11 @@ use {
 
 
 #[tracing::instrument(skip(opts, state))]
-pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
+pub async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
+where
+    S: Metrics,
+    S: Send + Sync + 'static,
+{
     tracing::info!(endpoint = %opts.metrics.server_listen_addr, "Starting Metrics Server.");
 
     let app = Router::new();
@@ -44,7 +45,10 @@ pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
     Ok(())
 }
 
-pub async fn metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
+pub async fn metrics<S>(State(state): State<Arc<S>>) -> impl IntoResponse
+where
+    S: Metrics,
+{
     let buffer = Metrics::encode(&*state).await;
     (
         [(

+ 16 - 7
apps/hermes/src/network/pythnet.rs

@@ -11,18 +11,17 @@ use {
             GuardianSet,
             GuardianSetData,
         },
-        price_feeds_metadata::{
-            PriceFeedMeta,
-            DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
-        },
         state::{
             aggregate::{
                 AccumulatorMessages,
                 Aggregates,
                 Update,
             },
+            price_feeds_metadata::{
+                PriceFeedMeta,
+                DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
+            },
             wormhole::Wormhole,
-            State,
         },
     },
     anyhow::{
@@ -139,7 +138,12 @@ async fn fetch_bridge_data(
     }
 }
 
-pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<!> {
+pub async fn run<S>(store: Arc<S>, pythnet_ws_endpoint: String) -> Result<!>
+where
+    S: Aggregates,
+    S: Wormhole,
+    S: Send + Sync + 'static,
+{
     let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?;
 
     let config = RpcProgramAccountsConfig {
@@ -222,6 +226,7 @@ async fn fetch_existing_guardian_sets<S>(
 ) -> Result<()>
 where
     S: Wormhole,
+    S: Send + Sync + 'static,
 {
     let client = RpcClient::new(pythnet_http_endpoint.to_string());
     let bridge = fetch_bridge_data(&client, &wormhole_contract_addr).await?;
@@ -261,7 +266,11 @@ where
 }
 
 #[tracing::instrument(skip(opts, state))]
-pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
+pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
+where
+    S: Wormhole,
+    S: Send + Sync + 'static,
+{
     tracing::info!(endpoint = opts.pythnet.ws_addr, "Started Pythnet Listener.");
 
     // Create RpcClient instance here

+ 7 - 8
apps/hermes/src/network/wormhole.rs

@@ -7,10 +7,7 @@
 use {
     crate::{
         config::RunOptions,
-        state::{
-            wormhole::Wormhole,
-            State,
-        },
+        state::wormhole::Wormhole,
     },
     anyhow::{
         anyhow,
@@ -118,7 +115,11 @@ mod proto {
 
 // Launches the Wormhole gRPC service.
 #[tracing::instrument(skip(opts, state))]
-pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
+pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
+where
+    S: Wormhole,
+    S: Send + Sync + 'static,
+{
     let mut exit = crate::EXIT.subscribe();
     loop {
         let current_time = Instant::now();
@@ -142,9 +143,7 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
 async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<!>
 where
     S: Wormhole,
-    S: Sync,
-    S: Send,
-    S: 'static,
+    S: Send + Sync + 'static,
 {
     let mut client = SpyRpcServiceClient::connect(opts.wormhole.spy_rpc_addr).await?;
     let mut stream = client

+ 42 - 22
apps/hermes/src/state.rs

@@ -9,9 +9,9 @@ use {
         benchmarks::BenchmarksState,
         cache::CacheState,
         metrics::MetricsState,
+        price_feeds_metadata::PriceFeedMetaState,
         wormhole::WormholeState,
     },
-    crate::price_feeds_metadata::PriceFeedMetaState,
     prometheus_client::registry::Registry,
     reqwest::Url,
     std::sync::Arc,
@@ -22,9 +22,25 @@ pub mod aggregate;
 pub mod benchmarks;
 pub mod cache;
 pub mod metrics;
+pub mod price_feeds_metadata;
 pub mod wormhole;
 
-pub struct State {
+// Expose State interfaces and types for other modules.
+pub use {
+    aggregate::Aggregates,
+    benchmarks::Benchmarks,
+    cache::Cache,
+    metrics::Metrics,
+    price_feeds_metadata::PriceFeedMeta,
+    wormhole::Wormhole,
+};
+
+/// State contains all relevant shared application state.
+///
+/// This type is intentionally not exposed, forcing modules to interface with the
+/// various API's using the provided traits. This is done to enforce separation of
+/// concerns and to avoid direct manipulation of state.
+struct State {
     /// State for the `Cache` service for short-lived storage of updates.
     pub cache: CacheState,
 
@@ -44,36 +60,40 @@ pub struct State {
     pub metrics: MetricsState,
 }
 
-impl State {
-    pub fn new(
-        update_tx: Sender<AggregationEvent>,
-        cache_size: u64,
-        benchmarks_endpoint: Option<Url>,
-    ) -> Arc<Self> {
-        let mut metrics_registry = Registry::default();
-        Arc::new(Self {
-            cache:           CacheState::new(cache_size),
-            benchmarks:      BenchmarksState::new(benchmarks_endpoint),
-            price_feed_meta: PriceFeedMetaState::new(),
-            aggregates:      AggregateState::new(update_tx, &mut metrics_registry),
-            wormhole:        WormholeState::new(),
-            metrics:         MetricsState::new(metrics_registry),
-        })
-    }
+pub fn new(
+    update_tx: Sender<AggregationEvent>,
+    cache_size: u64,
+    benchmarks_endpoint: Option<Url>,
+) -> Arc<impl Metrics + Wormhole> {
+    let mut metrics_registry = Registry::default();
+    Arc::new(State {
+        cache:           CacheState::new(cache_size),
+        benchmarks:      BenchmarksState::new(benchmarks_endpoint),
+        price_feed_meta: PriceFeedMetaState::new(),
+        aggregates:      AggregateState::new(update_tx, &mut metrics_registry),
+        wormhole:        WormholeState::new(),
+        metrics:         MetricsState::new(metrics_registry),
+    })
 }
 
 #[cfg(test)]
 pub mod test {
     use {
-        self::wormhole::Wormhole,
-        super::*,
+        super::{
+            aggregate::AggregationEvent,
+            Aggregates,
+            Wormhole,
+        },
         crate::network::wormhole::GuardianSet,
+        std::sync::Arc,
         tokio::sync::broadcast::Receiver,
     };
 
-    pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<AggregationEvent>) {
+    pub async fn setup_state(
+        cache_size: u64,
+    ) -> (Arc<impl Aggregates>, Receiver<AggregationEvent>) {
         let (update_tx, update_rx) = tokio::sync::broadcast::channel(1000);
-        let state = State::new(update_tx, cache_size, None);
+        let state = super::new(update_tx, cache_size, None);
 
         // Add an initial guardian set with public key 0
         Wormhole::update_guardian_set(

+ 6 - 2
apps/hermes/src/state/aggregate.rs

@@ -20,7 +20,6 @@ use {
     },
     crate::{
         network::wormhole::VaaBytes,
-        price_feeds_metadata::PriceFeedMeta,
         state::{
             benchmarks::Benchmarks,
             cache::{
@@ -28,6 +27,7 @@ use {
                 MessageState,
                 MessageStateFilter,
             },
+            price_feeds_metadata::PriceFeedMeta,
             State,
         },
     },
@@ -612,7 +612,11 @@ mod test {
         }
     }
 
-    pub async fn store_multiple_concurrent_valid_updates(state: Arc<State>, updates: Vec<Update>) {
+    pub async fn store_multiple_concurrent_valid_updates<S>(state: Arc<S>, updates: Vec<Update>)
+    where
+        S: Aggregates,
+        S: Send + Sync + 'static,
+    {
         let res = join_all(updates.into_iter().map(|u| state.store_update(u))).await;
         // Check that all store_update calls succeeded
         assert!(res.into_iter().all(|r| r.is_ok()));

+ 1 - 6
apps/hermes/src/price_feeds_metadata.rs → apps/hermes/src/state/price_feeds_metadata.rs

@@ -12,16 +12,11 @@ use {
 
 pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600;
 
+#[derive(Default)]
 pub struct PriceFeedMetaState {
     pub data: RwLock<Vec<PriceFeedMetadata>>,
 }
 
-impl Default for PriceFeedMetaState {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
 impl PriceFeedMetaState {
     pub fn new() -> Self {
         Self {