|
|
@@ -1,15 +1,12 @@
|
|
|
use {
|
|
|
crate::{
|
|
|
nonblocking::{
|
|
|
- qos::{ConnectionContext, QosController},
|
|
|
+ qos::{ConnectionContext, OpaqueStreamerCounter, QosController},
|
|
|
quic::{
|
|
|
get_connection_stake, update_open_connections_stat, ClientConnectionTracker,
|
|
|
ConnectionHandlerError, ConnectionPeerType, ConnectionTable, ConnectionTableKey,
|
|
|
ConnectionTableType,
|
|
|
},
|
|
|
- stream_throttle::{
|
|
|
- throttle_stream, ConnectionStreamCounter, STREAM_THROTTLING_INTERVAL,
|
|
|
- },
|
|
|
},
|
|
|
quic::{
|
|
|
StreamerStats, DEFAULT_MAX_QUIC_CONNECTIONS_PER_STAKED_PEER,
|
|
|
@@ -18,6 +15,7 @@ use {
|
|
|
streamer::StakedNodes,
|
|
|
},
|
|
|
quinn::Connection,
|
|
|
+ solana_net_utils::token_bucket::TokenBucket,
|
|
|
solana_time_utils as timing,
|
|
|
std::{
|
|
|
future::Future,
|
|
|
@@ -25,8 +23,12 @@ use {
|
|
|
atomic::{AtomicU64, Ordering},
|
|
|
Arc, RwLock,
|
|
|
},
|
|
|
+ time::Duration,
|
|
|
+ },
|
|
|
+ tokio::{
|
|
|
+ sync::{Mutex, MutexGuard},
|
|
|
+ time::sleep,
|
|
|
},
|
|
|
- tokio::sync::{Mutex, MutexGuard},
|
|
|
tokio_util::sync::CancellationToken,
|
|
|
};
|
|
|
|
|
|
@@ -47,10 +49,12 @@ impl Default for SimpleQosConfig {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+impl OpaqueStreamerCounter for TokenBucket {}
|
|
|
+
|
|
|
pub struct SimpleQos {
|
|
|
config: SimpleQosConfig,
|
|
|
stats: Arc<StreamerStats>,
|
|
|
- staked_connection_table: Arc<Mutex<ConnectionTable>>,
|
|
|
+ staked_connection_table: Arc<Mutex<ConnectionTable<TokenBucket>>>,
|
|
|
staked_nodes: Arc<RwLock<StakedNodes>>,
|
|
|
}
|
|
|
|
|
|
@@ -76,16 +80,9 @@ impl SimpleQos {
|
|
|
&self,
|
|
|
client_connection_tracker: ClientConnectionTracker,
|
|
|
connection: &Connection,
|
|
|
- mut connection_table_l: MutexGuard<ConnectionTable>,
|
|
|
+ mut connection_table_l: MutexGuard<ConnectionTable<TokenBucket>>,
|
|
|
conn_context: &SimpleQosConnectionContext,
|
|
|
- ) -> Result<
|
|
|
- (
|
|
|
- Arc<AtomicU64>,
|
|
|
- CancellationToken,
|
|
|
- Arc<ConnectionStreamCounter>,
|
|
|
- ),
|
|
|
- ConnectionHandlerError,
|
|
|
- > {
|
|
|
+ ) -> Result<(Arc<AtomicU64>, CancellationToken, Arc<TokenBucket>), ConnectionHandlerError> {
|
|
|
let remote_addr = connection.remote_address();
|
|
|
|
|
|
debug!(
|
|
|
@@ -93,21 +90,27 @@ impl SimpleQos {
|
|
|
conn_context.peer_type(),
|
|
|
remote_addr,
|
|
|
);
|
|
|
-
|
|
|
+ let key = ConnectionTableKey::new(remote_addr.ip(), conn_context.remote_pubkey);
|
|
|
if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l
|
|
|
.try_add_connection(
|
|
|
- ConnectionTableKey::new(remote_addr.ip(), conn_context.remote_pubkey),
|
|
|
+ key,
|
|
|
remote_addr.port(),
|
|
|
client_connection_tracker,
|
|
|
Some(connection.clone()),
|
|
|
conn_context.peer_type(),
|
|
|
conn_context.last_update.clone(),
|
|
|
self.config.max_connections_per_peer,
|
|
|
+ || {
|
|
|
+ Arc::new(TokenBucket::new(
|
|
|
+ self.config.max_streams_per_second,
|
|
|
+ self.config.max_streams_per_second,
|
|
|
+ self.config.max_streams_per_second as f64,
|
|
|
+ ))
|
|
|
+ },
|
|
|
)
|
|
|
{
|
|
|
update_open_connections_stat(&self.stats, &connection_table_l);
|
|
|
drop(connection_table_l);
|
|
|
-
|
|
|
Ok((last_update, cancel_connection, stream_counter))
|
|
|
} else {
|
|
|
self.stats
|
|
|
@@ -116,11 +119,6 @@ impl SimpleQos {
|
|
|
Err(ConnectionHandlerError::ConnectionAddError)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- fn max_streams_per_throttling_interval(&self, _context: &SimpleQosConnectionContext) -> u64 {
|
|
|
- let interval_ms = STREAM_THROTTLING_INTERVAL.as_millis() as u64;
|
|
|
- (self.config.max_streams_per_second * interval_ms / 1000).max(1)
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
@@ -129,7 +127,7 @@ pub struct SimpleQosConnectionContext {
|
|
|
remote_pubkey: Option<solana_pubkey::Pubkey>,
|
|
|
remote_address: std::net::SocketAddr,
|
|
|
last_update: Arc<AtomicU64>,
|
|
|
- stream_counter: Option<Arc<ConnectionStreamCounter>>,
|
|
|
+ stream_counter: Option<Arc<TokenBucket>>,
|
|
|
}
|
|
|
|
|
|
impl ConnectionContext for SimpleQosConnectionContext {
|
|
|
@@ -214,14 +212,7 @@ impl QosController<SimpleQosConnectionContext> for SimpleQos {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- fn on_stream_accepted(&self, conn_context: &SimpleQosConnectionContext) {
|
|
|
- conn_context
|
|
|
- .stream_counter
|
|
|
- .as_ref()
|
|
|
- .unwrap()
|
|
|
- .stream_count
|
|
|
- .fetch_add(1, Ordering::Relaxed);
|
|
|
- }
|
|
|
+ fn on_stream_accepted(&self, _conn_context: &SimpleQosConnectionContext) {}
|
|
|
|
|
|
fn on_stream_error(&self, _conn_context: &SimpleQosConnectionContext) {}
|
|
|
|
|
|
@@ -262,20 +253,32 @@ impl QosController<SimpleQosConnectionContext> for SimpleQos {
|
|
|
async move {
|
|
|
let peer_type = context.peer_type();
|
|
|
let remote_addr = context.remote_address;
|
|
|
- let stream_counter: &Arc<ConnectionStreamCounter> =
|
|
|
- context.stream_counter.as_ref().unwrap();
|
|
|
-
|
|
|
- let max_streams_per_throttling_interval =
|
|
|
- self.max_streams_per_throttling_interval(context);
|
|
|
-
|
|
|
- throttle_stream(
|
|
|
- &self.stats,
|
|
|
- peer_type,
|
|
|
- remote_addr,
|
|
|
- stream_counter,
|
|
|
- max_streams_per_throttling_interval,
|
|
|
- )
|
|
|
- .await;
|
|
|
+ let stream_counter = context
|
|
|
+ .stream_counter
|
|
|
+ .as_ref()
|
|
|
+ .expect("This will always be populated before streams are opened");
|
|
|
+
|
|
|
+ while stream_counter.consume_tokens(1).is_err() {
|
|
|
+ debug!("Throttling stream from {remote_addr:?}");
|
|
|
+ self.stats.throttled_streams.fetch_add(1, Ordering::Relaxed);
|
|
|
+ match peer_type {
|
|
|
+ ConnectionPeerType::Unstaked => {
|
|
|
+ self.stats
|
|
|
+ .throttled_unstaked_streams
|
|
|
+ .fetch_add(1, Ordering::Relaxed);
|
|
|
+ }
|
|
|
+ ConnectionPeerType::Staked(_) => {
|
|
|
+ self.stats
|
|
|
+ .throttled_staked_streams
|
|
|
+ .fetch_add(1, Ordering::Relaxed);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ let min_sleep = stream_counter.us_to_have_tokens(1).expect(
|
|
|
+ "Valid QoS configurations guarantee enough token bucket fits at least one \
|
|
|
+ token",
|
|
|
+ );
|
|
|
+ sleep(Duration::from_micros(min_sleep)).await;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -423,9 +426,8 @@ mod tests {
|
|
|
|
|
|
// Verify success
|
|
|
assert!(result.is_ok());
|
|
|
- let (_last_update, cancel_token, stream_counter) = result.unwrap();
|
|
|
+ let (_last_update, cancel_token, _stream_counter) = result.unwrap();
|
|
|
assert!(!cancel_token.is_cancelled());
|
|
|
- assert_eq!(stream_counter.stream_count.load(Ordering::Relaxed), 0);
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
@@ -467,6 +469,7 @@ mod tests {
|
|
|
ConnectionPeerType::Staked(1000),
|
|
|
Arc::new(AtomicU64::new(0)),
|
|
|
1, // max_connections_per_peer
|
|
|
+ || Arc::new(TokenBucket::new(1, 1, 1.0)),
|
|
|
);
|
|
|
|
|
|
let connection_table_guard = tokio::sync::Mutex::new(connection_table);
|
|
|
@@ -497,10 +500,6 @@ mod tests {
|
|
|
|
|
|
// Verify failure due to connection limit
|
|
|
assert!(result.is_err());
|
|
|
- assert!(matches!(
|
|
|
- result.unwrap_err(),
|
|
|
- ConnectionHandlerError::ConnectionAddError
|
|
|
- ));
|
|
|
|
|
|
// Verify stats were updated
|
|
|
assert_eq!(stats.connection_add_failed.load(Ordering::Relaxed), 1);
|
|
|
@@ -674,15 +673,6 @@ mod tests {
|
|
|
|
|
|
// Verify context was updated with stream counter
|
|
|
assert!(conn_context.stream_counter.is_some());
|
|
|
- assert_eq!(
|
|
|
- conn_context
|
|
|
- .stream_counter
|
|
|
- .as_ref()
|
|
|
- .unwrap()
|
|
|
- .stream_count
|
|
|
- .load(Ordering::Relaxed),
|
|
|
- 0
|
|
|
- );
|
|
|
|
|
|
// Verify stats were updated
|
|
|
assert_eq!(
|
|
|
@@ -926,17 +916,6 @@ mod tests {
|
|
|
// Verify last_update was updated (should be same or newer)
|
|
|
let updated_last_update = conn_context.last_update.load(Ordering::Relaxed);
|
|
|
assert!(updated_last_update >= initial_last_update);
|
|
|
-
|
|
|
- // Verify stream counter starts at 0
|
|
|
- assert_eq!(
|
|
|
- conn_context
|
|
|
- .stream_counter
|
|
|
- .as_ref()
|
|
|
- .unwrap()
|
|
|
- .stream_count
|
|
|
- .load(Ordering::Relaxed),
|
|
|
- 0
|
|
|
- );
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
@@ -976,27 +955,6 @@ mod tests {
|
|
|
|
|
|
assert!(result.is_some()); // Connection should be added successfully
|
|
|
assert!(conn_context.stream_counter.is_some()); // Stream counter should be set
|
|
|
-
|
|
|
- // Record initial stream count
|
|
|
- let initial_stream_count = conn_context
|
|
|
- .stream_counter
|
|
|
- .as_ref()
|
|
|
- .unwrap()
|
|
|
- .stream_count
|
|
|
- .load(Ordering::Relaxed);
|
|
|
- assert_eq!(initial_stream_count, 0);
|
|
|
-
|
|
|
- // Test - call on_stream_accepted
|
|
|
- simple_qos.on_stream_accepted(&conn_context);
|
|
|
-
|
|
|
- // Verify stream count was incremented
|
|
|
- let updated_stream_count = conn_context
|
|
|
- .stream_counter
|
|
|
- .as_ref()
|
|
|
- .unwrap()
|
|
|
- .stream_count
|
|
|
- .load(Ordering::Relaxed);
|
|
|
- assert_eq!(updated_stream_count, initial_stream_count + 1);
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
@@ -1069,8 +1027,9 @@ mod tests {
|
|
|
create_staked_nodes_with_keypairs(&server_keypair, &client_keypair, stake_amount);
|
|
|
|
|
|
// Set a specific max_streams_per_second for testing
|
|
|
+ let max_streams_per_second = 10;
|
|
|
let qos_config = SimpleQosConfig {
|
|
|
- max_streams_per_second: 10,
|
|
|
+ max_streams_per_second,
|
|
|
max_staked_connections: 100,
|
|
|
max_connections_per_peer: 10,
|
|
|
};
|
|
|
@@ -1097,12 +1056,16 @@ mod tests {
|
|
|
// Test - call on_new_stream and measure timing
|
|
|
let start_time = std::time::Instant::now();
|
|
|
|
|
|
- simple_qos.on_new_stream(&conn_context).await;
|
|
|
+ // This should take roughly 1 second to complete
|
|
|
+ // due to rate limit (since we allow initial burst)
|
|
|
+ for _ in 0..max_streams_per_second * 2 {
|
|
|
+ simple_qos.on_new_stream(&conn_context).await;
|
|
|
+ }
|
|
|
|
|
|
let elapsed = start_time.elapsed();
|
|
|
|
|
|
- // The function should complete (may or may not sleep depending on current throttling state)
|
|
|
- // We just verify it doesn't panic and completes successfully
|
|
|
- assert!(elapsed < std::time::Duration::from_secs(1)); // Should not take too long
|
|
|
+ // we can not verify precisely so we check rough bounds
|
|
|
+ assert!(elapsed > std::time::Duration::from_millis(950)); // Should not take too little time!
|
|
|
+ assert!(elapsed < std::time::Duration::from_millis(1200)); // Should not take too long!
|
|
|
}
|
|
|
}
|