瀏覽代碼

Use TaskTracket in streamer for graceful exit and avoiding flaky tests (#8066)

When streamer service is launched, it spawns task and returns handle for this task. This task internally spawns tasks to handle connections without tracking them.
Which means that when the root task has stopped we have no idea if it's child tasks have completed or not.
This leads to flaky tests. For example, in nonblocking::quic::test::test_quic_server_multiple_streams we stop the streamer and after that check stats.total_connections.load(Ordering::Relaxed). This counter counts how many task are handling at giving moment connections. It might happen that these tasks haven't finished before we do this check, so all the checks below task_handle.await are potentially flaky.
kirill lykov 1 月之前
父節點
當前提交
10b07bb764
共有 4 個文件被更改,包括 35 次插入14 次删除
  1. 1 0
      Cargo.lock
  2. 1 0
      programs/sbf/Cargo.lock
  3. 1 1
      streamer/Cargo.toml
  4. 32 13
      streamer/src/nonblocking/quic.rs

+ 1 - 0
Cargo.lock

@@ -12937,6 +12937,7 @@ dependencies = [
  "futures-core",
  "futures-io",
  "futures-sink",
+ "futures-util",
  "pin-project-lite",
  "tokio",
 ]

+ 1 - 0
programs/sbf/Cargo.lock

@@ -11009,6 +11009,7 @@ dependencies = [
  "futures-core",
  "futures-io",
  "futures-sink",
+ "futures-util",
  "pin-project-lite",
  "tokio",
 ]

+ 1 - 1
streamer/Cargo.toml

@@ -58,7 +58,7 @@ solana-transaction-error = { workspace = true }
 solana-transaction-metrics-tracker = { workspace = true }
 thiserror = { workspace = true }
 tokio = { workspace = true, features = ["full"] }
-tokio-util = { workspace = true }
+tokio-util = { workspace = true, features = ["rt"] }
 x509-parser = { workspace = true }
 
 [dev-dependencies]

+ 32 - 13
streamer/src/nonblocking/quic.rs

@@ -62,7 +62,7 @@ use {
         task::{self, JoinHandle},
         time::{sleep, timeout},
     },
-    tokio_util::sync::CancellationToken,
+    tokio_util::{sync::CancellationToken, task::TaskTracker},
 };
 
 pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
@@ -240,15 +240,24 @@ pub fn spawn_server_with_cancel(
     });
 
     let max_concurrent_connections = quic_server_params.max_concurrent_connections();
-    let handle = tokio::spawn(run_server(
-        name,
-        endpoints.clone(),
-        packet_batch_sender,
-        staked_nodes,
-        stats.clone(),
-        quic_server_params,
-        cancel,
-    ));
+    let handle = tokio::spawn({
+        let endpoints = endpoints.clone();
+        let stats = stats.clone();
+        async move {
+            let tasks = run_server(
+                name,
+                endpoints.clone(),
+                packet_batch_sender,
+                staked_nodes,
+                stats.clone(),
+                quic_server_params,
+                cancel,
+            )
+            .await;
+            tasks.close();
+            tasks.wait().await;
+        }
+    });
 
     Ok(SpawnNonBlockingServerResult {
         endpoints,
@@ -312,7 +321,7 @@ async fn run_server(
     stats: Arc<StreamerStats>,
     quic_server_params: QuicServerParams,
     cancel: CancellationToken,
-) {
+) -> TaskTracker {
     let rate_limiter = Arc::new(ConnectionRateLimiter::new(
         quic_server_params.max_connections_per_ipaddr_per_min,
     ));
@@ -349,6 +358,7 @@ async fn run_server(
         })
         .collect::<FuturesUnordered<_>>();
 
+    let tasks = TaskTracker::new();
     loop {
         let timeout_connection = select! {
             ready = accepts.next() => {
@@ -408,7 +418,7 @@ async fn run_server(
                 Ok(connecting) => {
                     let rate_limiter = rate_limiter.clone();
                     let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
-                    tokio::spawn(setup_connection(
+                    tasks.spawn(setup_connection(
                         connecting,
                         rate_limiter,
                         overall_connection_rate_limiter,
@@ -420,6 +430,7 @@ async fn run_server(
                         stats.clone(),
                         stream_load_ema.clone(),
                         quic_server_params.clone(),
+                        tasks.clone(),
                     ));
                 }
                 Err(err) => {
@@ -433,6 +444,7 @@ async fn run_server(
             debug!("accept(): Timed out waiting for connection");
         }
     }
+    tasks
 }
 
 fn prune_unstaked_connection_table(
@@ -567,6 +579,7 @@ fn handle_and_cache_new_connection(
     connection_table: Arc<Mutex<ConnectionTable>>,
     params: &NewConnectionHandlerParams,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
+    tasks: TaskTracker,
 ) -> Result<(), ConnectionHandlerError> {
     if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
         params.peer_type,
@@ -605,7 +618,7 @@ fn handle_and_cache_new_connection(
             }
             connection.set_max_concurrent_uni_streams(max_uni_streams);
 
-            tokio::spawn(handle_connection(
+            tasks.spawn(handle_connection(
                 connection,
                 remote_addr,
                 last_update,
@@ -642,6 +655,7 @@ async fn prune_unstaked_connections_and_add_new_connection(
     connection_table: Arc<Mutex<ConnectionTable>>,
     params: &NewConnectionHandlerParams,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
+    tasks: TaskTracker,
 ) -> Result<(), ConnectionHandlerError> {
     let stats = params.stats.clone();
     if params.max_connections > 0 {
@@ -655,6 +669,7 @@ async fn prune_unstaked_connections_and_add_new_connection(
             connection_table_clone,
             params,
             stream_load_ema,
+            tasks,
         )
     } else {
         connection.close(
@@ -721,6 +736,7 @@ async fn setup_connection(
     stats: Arc<StreamerStats>,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
     quic_server_params: QuicServerParams,
+    tasks: TaskTracker,
 ) {
     const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
     let from = connecting.remote_address();
@@ -821,6 +837,7 @@ async fn setup_connection(
                                 staked_connection_table.clone(),
                                 &params,
                                 stream_load_ema.clone(),
+                                tasks,
                             ) {
                                 stats
                                     .connection_added_from_staked_peer
@@ -836,6 +853,7 @@ async fn setup_connection(
                                 unstaked_connection_table.clone(),
                                 &params,
                                 stream_load_ema.clone(),
+                                tasks,
                             )
                             .await
                             {
@@ -859,6 +877,7 @@ async fn setup_connection(
                             unstaked_connection_table.clone(),
                             &params,
                             stream_load_ema.clone(),
+                            tasks,
                         )
                         .await
                         {