Bläddra i källkod

refactor(hermes): watch channel to simplify shutdowns

Reisen 1 år sedan
förälder
incheckning
df585e440e

+ 2 - 8
hermes/src/api.rs

@@ -13,10 +13,7 @@ use {
     },
     ipnet::IpNet,
     serde_qs::axum::QsQueryConfig,
-    std::sync::{
-        atomic::Ordering,
-        Arc,
-    },
+    std::sync::Arc,
     tokio::sync::broadcast::Sender,
     tower_http::cors::CorsLayer,
     utoipa::OpenApi,
@@ -159,10 +156,7 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
     axum::Server::try_bind(&opts.rpc.listen_addr)?
         .serve(app.into_make_service())
         .with_graceful_shutdown(async {
-            while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
-            }
-
+            let _ = crate::EXIT.subscribe().changed().await;
             tracing::info!("Shutting down RPC server...");
         })
         .await?;

+ 10 - 10
hermes/src/api/ws.rs

@@ -69,7 +69,10 @@ use {
         },
         time::Duration,
     },
-    tokio::sync::broadcast::Receiver,
+    tokio::sync::{
+        broadcast::Receiver,
+        watch,
+    },
 };
 
 const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
@@ -262,7 +265,7 @@ pub struct Subscriber {
     sender:                  SplitSink<WebSocket, Message>,
     price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
     ping_interval:           tokio::time::Interval,
-    exit_check_interval:     tokio::time::Interval,
+    exit:                    watch::Receiver<bool>,
     responded_to_ping:       bool,
 }
 
@@ -287,7 +290,7 @@ impl Subscriber {
             sender,
             price_feeds_with_config: HashMap::new(),
             ping_interval: tokio::time::interval(PING_INTERVAL_DURATION),
-            exit_check_interval: tokio::time::interval(Duration::from_secs(5)),
+            exit: crate::EXIT.subscribe(),
             responded_to_ping: true, // We start with true so we don't close the connection immediately
         }
     }
@@ -332,13 +335,10 @@ impl Subscriber {
                 self.sender.send(Message::Ping(vec![])).await?;
                 Ok(())
             },
-            _ = self.exit_check_interval.tick() => {
-                if crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                    self.sender.close().await?;
-                    self.closed = true;
-                    return Err(anyhow!("Application is shutting down. Closing connection."));
-                }
-                Ok(())
+            _ = self.exit.changed() => {
+                self.sender.close().await?;
+                self.closed = true;
+                return Err(anyhow!("Application is shutting down. Closing connection."));
             }
         }
     }

+ 18 - 12
hermes/src/main.rs

@@ -8,12 +8,13 @@ use {
         Parser,
     },
     futures::future::join_all,
+    lazy_static::lazy_static,
     state::State,
-    std::{
-        io::IsTerminal,
-        sync::atomic::AtomicBool,
+    std::io::IsTerminal,
+    tokio::{
+        spawn,
+        sync::watch,
     },
-    tokio::spawn,
 };
 
 mod aggregate;
@@ -25,13 +26,18 @@ mod price_feeds_metadata;
 mod serde;
 mod state;
 
-// A static exit flag to indicate to running threads that we're shutting down. This is used to
-// gracefully shutdown the application.
-//
-// NOTE: A more idiomatic approach would be to use a tokio::sync::broadcast channel, and to send a
-// shutdown signal to all running tasks. However, this is a bit more complicated to implement and
-// we don't rely on global state for anything else.
-pub(crate) static SHOULD_EXIT: AtomicBool = AtomicBool::new(false);
+lazy_static! {
+    /// A static exit flag to indicate to running threads that we're shutting down. This is used to
+    /// gracefully shutdown the application.
+    ///
+    /// We make this global based on the fact the:
+    /// - The `Sender` side does not rely on any async runtime.
+    /// - Exit logic doesn't really require carefully threading this value through the app.
+    /// - The `Receiver` side of a watch channel performs the detection based on if the change
+    ///   happened after the subscribe, so it means all listeners should always be notified
+    ///   currectly.
+    pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
+}
 
 /// Initialize the Application. This can be invoked either by real main, or by the Geyser plugin.
 #[tracing::instrument]
@@ -55,7 +61,7 @@ async fn init() -> Result<()> {
                 tracing::info!("Registered shutdown signal handler...");
                 tokio::signal::ctrl_c().await.unwrap();
                 tracing::info!("Shut down signal received, waiting for tasks...");
-                SHOULD_EXIT.store(true, std::sync::atomic::Ordering::Release);
+                let _ = EXIT.send(true);
             });
 
             // Spawn all worker tasks, and wait for all to complete (which will happen if a shutdown

+ 2 - 8
hermes/src/metrics_server.rs

@@ -16,10 +16,7 @@ use {
         Router,
     },
     prometheus_client::encoding::text::encode,
-    std::sync::{
-        atomic::Ordering,
-        Arc,
-    },
+    std::sync::Arc,
 };
 
 
@@ -37,10 +34,7 @@ pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
     axum::Server::try_bind(&opts.metrics.server_listen_addr)?
         .serve(app.into_make_service())
         .with_graceful_shutdown(async {
-            while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
-            }
-
+            let _ = crate::EXIT.subscribe().changed().await;
             tracing::info!("Shutting down metrics server...");
         })
         .await?;

+ 40 - 55
hermes/src/network/pythnet.rs

@@ -58,10 +58,7 @@ use {
     },
     std::{
         collections::BTreeMap,
-        sync::{
-            atomic::Ordering,
-            Arc,
-        },
+        sync::Arc,
         time::Duration,
     },
     tokio::time::Instant,
@@ -160,7 +157,7 @@ pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<()> {
         .program_subscribe(&system_program::id(), Some(config))
         .await?;
 
-    while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
+    loop {
         match notif.next().await {
             Some(update) => {
                 let account: Account = match update.value.account.decode() {
@@ -213,8 +210,6 @@ pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<()> {
             }
         }
     }
-
-    Ok(())
 }
 
 /// Fetch existing GuardianSet accounts from Wormhole.
@@ -281,19 +276,21 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
     let task_listener = {
         let store = state.clone();
         let pythnet_ws_endpoint = opts.pythnet.ws_addr.clone();
+        let mut exit = crate::EXIT.subscribe();
         tokio::spawn(async move {
-            while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
+            loop {
                 let current_time = Instant::now();
-
-                if let Err(ref e) = run(store.clone(), pythnet_ws_endpoint.clone()).await {
-                    tracing::error!(error = ?e, "Error in Pythnet network listener.");
-                    if current_time.elapsed() < Duration::from_secs(30) {
-                        tracing::error!("Pythnet listener restarting too quickly. Sleep 1s.");
-                        tokio::time::sleep(Duration::from_secs(1)).await;
+                tokio::select! {
+                    _ = exit.changed() => break,
+                    Err(err) = run(store.clone(), pythnet_ws_endpoint.clone()) => {
+                        tracing::error!(error = ?err, "Error in Pythnet network listener.");
+                        if current_time.elapsed() < Duration::from_secs(30) {
+                            tracing::error!("Pythnet listener restarting too quickly. Sleep 1s.");
+                            tokio::time::sleep(Duration::from_secs(1)).await;
+                        }
                     }
                 }
             }
-
             tracing::info!("Shutting down Pythnet listener...");
         })
     };
@@ -301,32 +298,24 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
     let task_guardian_watcher = {
         let store = state.clone();
         let pythnet_http_endpoint = opts.pythnet.http_addr.clone();
+        let mut exit = crate::EXIT.subscribe();
         tokio::spawn(async move {
-            while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                // Poll for new guardian sets every 60 seconds. We use a short wait time so we can
-                // properly exit if a quit signal was received. This isn't a perfect solution, but
-                // it's good enough for now.
-                for _ in 0..60 {
-                    if crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                        break;
-                    }
-                    tokio::time::sleep(Duration::from_secs(1)).await;
-                }
-
-                match fetch_existing_guardian_sets(
-                    store.clone(),
-                    pythnet_http_endpoint.clone(),
-                    opts.wormhole.contract_addr,
-                )
-                .await
-                {
-                    Ok(_) => {}
-                    Err(err) => {
-                        tracing::error!(error = ?err, "Failed to poll for new guardian sets.")
+            loop {
+                tokio::select! {
+                    _ = exit.changed() => break,
+                    _ = tokio::time::sleep(Duration::from_secs(60)) => {
+                        if let Err(err) = fetch_existing_guardian_sets(
+                            store.clone(),
+                            pythnet_http_endpoint.clone(),
+                            opts.wormhole.contract_addr,
+                        )
+                        .await
+                        {
+                            tracing::error!(error = ?err, "Failed to poll for new guardian sets.")
+                        }
                     }
                 }
             }
-
             tracing::info!("Shutting down Pythnet guardian set poller...");
         })
     };
@@ -334,26 +323,22 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
 
     let task_price_feeds_metadata_updater = {
         let price_feeds_state = state.clone();
+        let mut exit = crate::EXIT.subscribe();
         tokio::spawn(async move {
-            while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                if let Err(e) = fetch_and_store_price_feeds_metadata(
-                    price_feeds_state.as_ref(),
-                    &opts.pythnet.mapping_addr,
-                    &rpc_client,
-                )
-                .await
-                {
-                    tracing::error!("Error in fetching and storing price feeds metadata: {}", e);
-                }
-                // This loop with a sleep interval of 1 second allows the task to check for an exit signal at a
-                // fine-grained interval. Instead of sleeping directly for the entire `price_feeds_update_interval`,
-                // which could delay the response to an exit signal, this approach ensures the task can exit promptly
-                // if `crate::SHOULD_EXIT` is set, enhancing the responsiveness of the service to shutdown requests.
-                for _ in 0..DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL {
-                    if crate::SHOULD_EXIT.load(Ordering::Acquire) {
-                        break;
+            loop {
+                tokio::select! {
+                    _ = exit.changed() => break,
+                    _ = tokio::time::sleep(Duration::from_secs(DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL)) => {
+                        if let Err(e) = fetch_and_store_price_feeds_metadata(
+                            price_feeds_state.as_ref(),
+                            &opts.pythnet.mapping_addr,
+                            &rpc_client,
+                        )
+                        .await
+                        {
+                            tracing::error!("Error in fetching and storing price feeds metadata: {}", e);
+                        }
                     }
-                    tokio::time::sleep(Duration::from_secs(1)).await;
                 }
             }
         })

+ 8 - 15
hermes/src/network/wormhole.rs

@@ -43,10 +43,7 @@ use {
         Digest,
         Keccak256,
     },
-    std::sync::{
-        atomic::Ordering,
-        Arc,
-    },
+    std::sync::Arc,
     tonic::Request,
     wormhole_sdk::{
         vaa::{
@@ -153,16 +150,16 @@ mod proto {
 // Launches the Wormhole gRPC service.
 #[tracing::instrument(skip(opts, state))]
 pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
-    while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
-        if let Err(e) = run(opts.clone(), state.clone()).await {
-            tracing::error!(error = ?e, "Wormhole gRPC service failed.");
+    let mut exit = crate::EXIT.subscribe();
+    loop {
+        tokio::select! {
+            _ = exit.changed() => break,
+            Err(err) = run(opts.clone(), state.clone()) => {
+                tracing::error!(error = ?err, "Wormhole gRPC service failed.");
+            }
         }
-
-        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
     }
-
     tracing::info!("Shutting down Wormhole gRPC service...");
-
     Ok(())
 }
 
@@ -182,10 +179,6 @@ async fn run(opts: RunOptions, state: Arc<State>) -> Result<()> {
         .into_inner();
 
     while let Some(Ok(message)) = stream.next().await {
-        if crate::SHOULD_EXIT.load(Ordering::Acquire) {
-            return Ok(());
-        }
-
         if let Err(e) = process_message(state.clone(), message.vaa_bytes).await {
             tracing::debug!(error = ?e, "Skipped VAA.");
         }

+ 1 - 1
hermes/src/price_feeds_metadata.rs

@@ -9,7 +9,7 @@ use {
     anyhow::Result,
 };
 
-pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u16 = 600;
+pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600;
 
 pub async fn retrieve_price_feeds_metadata(state: &State) -> Result<Vec<PriceFeedMetadata>> {
     let price_feeds_metadata = state.price_feeds_metadata.read().await;