Bladeren bron

Use TaskTracket in streamer for graceful exit and avoiding flaky tests (#8066)

When streamer service is launched, it spawns task and returns handle for this task. This task internally spawns tasks to handle connections without tracking them.
Which means that when the root task has stopped we have no idea if it's child tasks have completed or not.
This leads to flaky tests. For example, in nonblocking::quic::test::test_quic_server_multiple_streams we stop the streamer and after that check stats.total_connections.load(Ordering::Relaxed). This counter counts how many task are handling at giving moment connections. It might happen that these tasks haven't finished before we do this check, so all the checks below task_handle.await are potentially flaky.
kirill lykov 1 maand geleden
bovenliggende
commit
10b07bb764
4 gewijzigde bestanden met toevoegingen van 35 en 14 verwijderingen
  1. 1 0
      Cargo.lock
  2. 1 0
      programs/sbf/Cargo.lock
  3. 1 1
      streamer/Cargo.toml
  4. 32 13
      streamer/src/nonblocking/quic.rs

+ 1 - 0
Cargo.lock

@@ -12937,6 +12937,7 @@ dependencies = [
  "futures-core",
  "futures-core",
  "futures-io",
  "futures-io",
  "futures-sink",
  "futures-sink",
+ "futures-util",
  "pin-project-lite",
  "pin-project-lite",
  "tokio",
  "tokio",
 ]
 ]

+ 1 - 0
programs/sbf/Cargo.lock

@@ -11009,6 +11009,7 @@ dependencies = [
  "futures-core",
  "futures-core",
  "futures-io",
  "futures-io",
  "futures-sink",
  "futures-sink",
+ "futures-util",
  "pin-project-lite",
  "pin-project-lite",
  "tokio",
  "tokio",
 ]
 ]

+ 1 - 1
streamer/Cargo.toml

@@ -58,7 +58,7 @@ solana-transaction-error = { workspace = true }
 solana-transaction-metrics-tracker = { workspace = true }
 solana-transaction-metrics-tracker = { workspace = true }
 thiserror = { workspace = true }
 thiserror = { workspace = true }
 tokio = { workspace = true, features = ["full"] }
 tokio = { workspace = true, features = ["full"] }
-tokio-util = { workspace = true }
+tokio-util = { workspace = true, features = ["rt"] }
 x509-parser = { workspace = true }
 x509-parser = { workspace = true }
 
 
 [dev-dependencies]
 [dev-dependencies]

+ 32 - 13
streamer/src/nonblocking/quic.rs

@@ -62,7 +62,7 @@ use {
         task::{self, JoinHandle},
         task::{self, JoinHandle},
         time::{sleep, timeout},
         time::{sleep, timeout},
     },
     },
-    tokio_util::sync::CancellationToken,
+    tokio_util::{sync::CancellationToken, task::TaskTracker},
 };
 };
 
 
 pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
 pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
@@ -240,15 +240,24 @@ pub fn spawn_server_with_cancel(
     });
     });
 
 
     let max_concurrent_connections = quic_server_params.max_concurrent_connections();
     let max_concurrent_connections = quic_server_params.max_concurrent_connections();
-    let handle = tokio::spawn(run_server(
-        name,
-        endpoints.clone(),
-        packet_batch_sender,
-        staked_nodes,
-        stats.clone(),
-        quic_server_params,
-        cancel,
-    ));
+    let handle = tokio::spawn({
+        let endpoints = endpoints.clone();
+        let stats = stats.clone();
+        async move {
+            let tasks = run_server(
+                name,
+                endpoints.clone(),
+                packet_batch_sender,
+                staked_nodes,
+                stats.clone(),
+                quic_server_params,
+                cancel,
+            )
+            .await;
+            tasks.close();
+            tasks.wait().await;
+        }
+    });
 
 
     Ok(SpawnNonBlockingServerResult {
     Ok(SpawnNonBlockingServerResult {
         endpoints,
         endpoints,
@@ -312,7 +321,7 @@ async fn run_server(
     stats: Arc<StreamerStats>,
     stats: Arc<StreamerStats>,
     quic_server_params: QuicServerParams,
     quic_server_params: QuicServerParams,
     cancel: CancellationToken,
     cancel: CancellationToken,
-) {
+) -> TaskTracker {
     let rate_limiter = Arc::new(ConnectionRateLimiter::new(
     let rate_limiter = Arc::new(ConnectionRateLimiter::new(
         quic_server_params.max_connections_per_ipaddr_per_min,
         quic_server_params.max_connections_per_ipaddr_per_min,
     ));
     ));
@@ -349,6 +358,7 @@ async fn run_server(
         })
         })
         .collect::<FuturesUnordered<_>>();
         .collect::<FuturesUnordered<_>>();
 
 
+    let tasks = TaskTracker::new();
     loop {
     loop {
         let timeout_connection = select! {
         let timeout_connection = select! {
             ready = accepts.next() => {
             ready = accepts.next() => {
@@ -408,7 +418,7 @@ async fn run_server(
                 Ok(connecting) => {
                 Ok(connecting) => {
                     let rate_limiter = rate_limiter.clone();
                     let rate_limiter = rate_limiter.clone();
                     let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
                     let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
-                    tokio::spawn(setup_connection(
+                    tasks.spawn(setup_connection(
                         connecting,
                         connecting,
                         rate_limiter,
                         rate_limiter,
                         overall_connection_rate_limiter,
                         overall_connection_rate_limiter,
@@ -420,6 +430,7 @@ async fn run_server(
                         stats.clone(),
                         stats.clone(),
                         stream_load_ema.clone(),
                         stream_load_ema.clone(),
                         quic_server_params.clone(),
                         quic_server_params.clone(),
+                        tasks.clone(),
                     ));
                     ));
                 }
                 }
                 Err(err) => {
                 Err(err) => {
@@ -433,6 +444,7 @@ async fn run_server(
             debug!("accept(): Timed out waiting for connection");
             debug!("accept(): Timed out waiting for connection");
         }
         }
     }
     }
+    tasks
 }
 }
 
 
 fn prune_unstaked_connection_table(
 fn prune_unstaked_connection_table(
@@ -567,6 +579,7 @@ fn handle_and_cache_new_connection(
     connection_table: Arc<Mutex<ConnectionTable>>,
     connection_table: Arc<Mutex<ConnectionTable>>,
     params: &NewConnectionHandlerParams,
     params: &NewConnectionHandlerParams,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
+    tasks: TaskTracker,
 ) -> Result<(), ConnectionHandlerError> {
 ) -> Result<(), ConnectionHandlerError> {
     if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
     if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
         params.peer_type,
         params.peer_type,
@@ -605,7 +618,7 @@ fn handle_and_cache_new_connection(
             }
             }
             connection.set_max_concurrent_uni_streams(max_uni_streams);
             connection.set_max_concurrent_uni_streams(max_uni_streams);
 
 
-            tokio::spawn(handle_connection(
+            tasks.spawn(handle_connection(
                 connection,
                 connection,
                 remote_addr,
                 remote_addr,
                 last_update,
                 last_update,
@@ -642,6 +655,7 @@ async fn prune_unstaked_connections_and_add_new_connection(
     connection_table: Arc<Mutex<ConnectionTable>>,
     connection_table: Arc<Mutex<ConnectionTable>>,
     params: &NewConnectionHandlerParams,
     params: &NewConnectionHandlerParams,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
+    tasks: TaskTracker,
 ) -> Result<(), ConnectionHandlerError> {
 ) -> Result<(), ConnectionHandlerError> {
     let stats = params.stats.clone();
     let stats = params.stats.clone();
     if params.max_connections > 0 {
     if params.max_connections > 0 {
@@ -655,6 +669,7 @@ async fn prune_unstaked_connections_and_add_new_connection(
             connection_table_clone,
             connection_table_clone,
             params,
             params,
             stream_load_ema,
             stream_load_ema,
+            tasks,
         )
         )
     } else {
     } else {
         connection.close(
         connection.close(
@@ -721,6 +736,7 @@ async fn setup_connection(
     stats: Arc<StreamerStats>,
     stats: Arc<StreamerStats>,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
     stream_load_ema: Arc<StakedStreamLoadEMA>,
     quic_server_params: QuicServerParams,
     quic_server_params: QuicServerParams,
+    tasks: TaskTracker,
 ) {
 ) {
     const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
     const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
     let from = connecting.remote_address();
     let from = connecting.remote_address();
@@ -821,6 +837,7 @@ async fn setup_connection(
                                 staked_connection_table.clone(),
                                 staked_connection_table.clone(),
                                 &params,
                                 &params,
                                 stream_load_ema.clone(),
                                 stream_load_ema.clone(),
+                                tasks,
                             ) {
                             ) {
                                 stats
                                 stats
                                     .connection_added_from_staked_peer
                                     .connection_added_from_staked_peer
@@ -836,6 +853,7 @@ async fn setup_connection(
                                 unstaked_connection_table.clone(),
                                 unstaked_connection_table.clone(),
                                 &params,
                                 &params,
                                 stream_load_ema.clone(),
                                 stream_load_ema.clone(),
+                                tasks,
                             )
                             )
                             .await
                             .await
                             {
                             {
@@ -859,6 +877,7 @@ async fn setup_connection(
                             unstaked_connection_table.clone(),
                             unstaked_connection_table.clone(),
                             &params,
                             &params,
                             stream_load_ema.clone(),
                             stream_load_ema.clone(),
+                            tasks,
                         )
                         )
                         .await
                         .await
                         {
                         {