quic.rs 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020
  1. use {
  2. crate::{
  3. nonblocking::{
  4. connection_rate_limiter::ConnectionRateLimiter,
  5. qos::{ConnectionContext, OpaqueStreamerCounter, QosController},
  6. },
  7. quic::{configure_server, QuicServerError, QuicStreamerConfig, StreamerStats},
  8. streamer::StakedNodes,
  9. },
  10. bytes::{BufMut, Bytes, BytesMut},
  11. crossbeam_channel::{Sender, TrySendError},
  12. futures::{stream::FuturesUnordered, Future, StreamExt as _},
  13. indexmap::map::{Entry, IndexMap},
  14. quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime},
  15. rand::{thread_rng, Rng},
  16. smallvec::SmallVec,
  17. solana_keypair::Keypair,
  18. solana_measure::measure::Measure,
  19. solana_net_utils::token_bucket::TokenBucket,
  20. solana_packet::{Meta, PACKET_DATA_SIZE},
  21. solana_perf::packet::{BytesPacket, PacketBatch},
  22. solana_pubkey::Pubkey,
  23. solana_signature::Signature,
  24. solana_tls_utils::get_pubkey_from_tls_certificate,
  25. solana_transaction_metrics_tracker::signature_if_should_track_packet,
  26. std::{
  27. array, fmt,
  28. iter::repeat_with,
  29. net::{IpAddr, SocketAddr, UdpSocket},
  30. pin::Pin,
  31. sync::{
  32. atomic::{AtomicU64, Ordering},
  33. Arc, RwLock,
  34. },
  35. task::Poll,
  36. time::{Duration, Instant},
  37. },
  38. tokio::{
  39. // CAUTION: It's kind of sketch that we're mixing async and sync locks (see the RwLock above).
  40. // This is done so that sync code can also access the stake table.
  41. // Make sure we don't hold a sync lock across an await - including the await to
  42. // lock an async Mutex. This does not happen now and should not happen as long as we
  43. // don't hold an async Mutex and sync RwLock at the same time (currently true)
  44. // but if we do, the scope of the RwLock must always be a subset of the async Mutex
  45. // (i.e. lock order is always async Mutex -> RwLock). Also, be careful not to
  46. // introduce any other awaits while holding the RwLock.
  47. select,
  48. task::JoinHandle,
  49. time::timeout,
  50. },
  51. tokio_util::{sync::CancellationToken, task::TaskTracker},
  52. };
  53. pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
  54. pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
  55. const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
  56. const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
  57. pub(crate) const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
  58. pub(crate) const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
  59. pub(crate) const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
  60. pub(crate) const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] =
  61. b"exceed_max_stream_count";
  62. const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
  63. const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
  64. const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
  65. const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
  66. /// Total new connection counts per second. Heuristically taken from
  67. /// the default staked and unstaked connection limits. Might be adjusted
  68. /// later.
  69. const TOTAL_CONNECTIONS_PER_SECOND: f64 = 2500.0;
  70. /// Max burst of connections above sustained rate to pass through
  71. const MAX_CONNECTION_BURST: u64 = 1000;
  72. /// Timeout for connection handshake. Timer starts once we get Initial from the
  73. /// peer, and is canceled when we get a Handshake packet from them.
  74. const QUIC_CONNECTION_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(2);
  75. // A struct to accumulate the bytes making up
  76. // a packet, along with their offsets, and the
  77. // packet metadata. We use this accumulator to avoid
  78. // multiple copies of the Bytes (when building up
  79. // the Packet and then when copying the Packet into a PacketBatch)
  80. #[derive(Clone)]
  81. struct PacketAccumulator {
  82. pub meta: Meta,
  83. pub chunks: SmallVec<[Bytes; 2]>,
  84. pub start_time: Instant,
  85. }
  86. impl PacketAccumulator {
  87. fn new(meta: Meta) -> Self {
  88. Self {
  89. meta,
  90. chunks: SmallVec::default(),
  91. start_time: Instant::now(),
  92. }
  93. }
  94. }
  95. #[derive(Copy, Clone, Debug)]
  96. pub enum ConnectionPeerType {
  97. Unstaked,
  98. Staked(u64),
  99. }
  100. impl ConnectionPeerType {
  101. pub(crate) fn is_staked(&self) -> bool {
  102. matches!(self, ConnectionPeerType::Staked(_))
  103. }
  104. }
  105. pub struct SpawnNonBlockingServerResult {
  106. pub endpoints: Vec<Endpoint>,
  107. pub stats: Arc<StreamerStats>,
  108. pub thread: JoinHandle<()>,
  109. pub max_concurrent_connections: usize,
  110. }
  111. /// Spawn a streamer instance in the current tokio runtime.
  112. pub(crate) fn spawn_server<Q, C>(
  113. name: &'static str,
  114. stats: Arc<StreamerStats>,
  115. sockets: impl IntoIterator<Item = UdpSocket>,
  116. keypair: &Keypair,
  117. packet_sender: Sender<PacketBatch>,
  118. quic_server_params: QuicStreamerConfig,
  119. qos: Arc<Q>,
  120. cancel: CancellationToken,
  121. ) -> Result<SpawnNonBlockingServerResult, QuicServerError>
  122. where
  123. Q: QosController<C> + Send + Sync + 'static,
  124. C: ConnectionContext + Send + Sync + 'static,
  125. {
  126. let sockets: Vec<_> = sockets.into_iter().collect();
  127. info!("Start {name} quic server on {sockets:?}");
  128. let (config, _) = configure_server(keypair)?;
  129. let endpoints = sockets
  130. .into_iter()
  131. .map(|sock| {
  132. Endpoint::new(
  133. EndpointConfig::default(),
  134. Some(config.clone()),
  135. sock,
  136. Arc::new(TokioRuntime),
  137. )
  138. .map_err(QuicServerError::EndpointFailed)
  139. })
  140. .collect::<Result<Vec<_>, _>>()?;
  141. let max_concurrent_connections = qos.max_concurrent_connections();
  142. let handle = tokio::spawn({
  143. let endpoints = endpoints.clone();
  144. let stats = stats.clone();
  145. async move {
  146. let tasks = run_server(
  147. name,
  148. endpoints.clone(),
  149. packet_sender,
  150. stats.clone(),
  151. quic_server_params,
  152. cancel,
  153. qos,
  154. )
  155. .await;
  156. tasks.close();
  157. tasks.wait().await;
  158. }
  159. });
  160. Ok(SpawnNonBlockingServerResult {
  161. endpoints,
  162. stats,
  163. thread: handle,
  164. max_concurrent_connections,
  165. })
  166. }
  167. /// struct ease tracking connections of all stages, so that we do not have to
  168. /// litter the code with open connection tracking. This is added into the
  169. /// connection table as part of the ConnectionEntry. The reference is auto
  170. /// reduced when it is dropped.
  171. pub struct ClientConnectionTracker {
  172. pub(crate) stats: Arc<StreamerStats>,
  173. }
  174. /// This is required by ConnectionEntry for supporting debug format.
  175. impl fmt::Debug for ClientConnectionTracker {
  176. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  177. f.debug_struct("StreamerClientConnection")
  178. .field(
  179. "open_connections:",
  180. &self.stats.open_connections.load(Ordering::Relaxed),
  181. )
  182. .finish()
  183. }
  184. }
  185. impl Drop for ClientConnectionTracker {
  186. /// When this is dropped, reduce the open connection count.
  187. fn drop(&mut self) {
  188. self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
  189. }
  190. }
  191. impl ClientConnectionTracker {
  192. /// Check the max_concurrent_connections limit and if it is within the limit
  193. /// create ClientConnectionTracker and increment open connection count. Otherwise returns Err
  194. fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
  195. let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
  196. if open_connections >= max_concurrent_connections {
  197. stats.open_connections.fetch_sub(1, Ordering::Relaxed);
  198. debug!(
  199. "There are too many concurrent connections opened already: open: \
  200. {open_connections}, max: {max_concurrent_connections}"
  201. );
  202. return Err(());
  203. }
  204. Ok(Self { stats })
  205. }
  206. }
  207. #[allow(clippy::too_many_arguments)]
  208. async fn run_server<Q, C>(
  209. name: &'static str,
  210. endpoints: Vec<Endpoint>,
  211. packet_batch_sender: Sender<PacketBatch>,
  212. stats: Arc<StreamerStats>,
  213. quic_server_params: QuicStreamerConfig,
  214. cancel: CancellationToken,
  215. qos: Arc<Q>,
  216. ) -> TaskTracker
  217. where
  218. Q: QosController<C> + Send + Sync + 'static,
  219. C: ConnectionContext + Send + Sync + 'static,
  220. {
  221. let quic_server_params = Arc::new(quic_server_params);
  222. let rate_limiter = Arc::new(ConnectionRateLimiter::new(
  223. quic_server_params.max_connections_per_ipaddr_per_min,
  224. // allow for 10x burst to make sure we can accommodate legitimate
  225. // bursts from container environments running multiple pods on same IP
  226. quic_server_params.max_connections_per_ipaddr_per_min * 10,
  227. quic_server_params.num_threads.get() * 2,
  228. ));
  229. let overall_connection_rate_limiter = Arc::new(TokenBucket::new(
  230. MAX_CONNECTION_BURST,
  231. MAX_CONNECTION_BURST,
  232. TOTAL_CONNECTIONS_PER_SECOND,
  233. ));
  234. const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
  235. debug!("spawn quic server");
  236. let mut last_datapoint = Instant::now();
  237. stats
  238. .quic_endpoints_count
  239. .store(endpoints.len(), Ordering::Relaxed);
  240. let mut accepts = endpoints
  241. .iter()
  242. .enumerate()
  243. .map(|(i, incoming)| {
  244. Box::pin(EndpointAccept {
  245. accept: incoming.accept(),
  246. endpoint: i,
  247. })
  248. })
  249. .collect::<FuturesUnordered<_>>();
  250. let tasks = TaskTracker::new();
  251. loop {
  252. let timeout_connection = select! {
  253. ready = accepts.next() => {
  254. if let Some((connecting, i)) = ready {
  255. accepts.push(
  256. Box::pin(EndpointAccept {
  257. accept: endpoints[i].accept(),
  258. endpoint: i,
  259. }
  260. ));
  261. Ok(connecting)
  262. } else {
  263. // we can't really get here - we never poll an empty FuturesUnordered
  264. continue
  265. }
  266. }
  267. _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
  268. Err(())
  269. }
  270. _ = cancel.cancelled() => break,
  271. };
  272. if last_datapoint.elapsed().as_secs() >= 5 {
  273. stats.report(name);
  274. last_datapoint = Instant::now();
  275. }
  276. if let Ok(Some(incoming)) = timeout_connection {
  277. stats
  278. .total_incoming_connection_attempts
  279. .fetch_add(1, Ordering::Relaxed);
  280. // check overall connection request rate limiter
  281. if overall_connection_rate_limiter.current_tokens() == 0 {
  282. stats
  283. .connection_rate_limited_across_all
  284. .fetch_add(1, Ordering::Relaxed);
  285. debug!(
  286. "Ignoring incoming connection from {} due to overall rate limit.",
  287. incoming.remote_address()
  288. );
  289. incoming.ignore();
  290. continue;
  291. }
  292. // then perform per IpAddr rate limiting
  293. if !rate_limiter.is_allowed(&incoming.remote_address().ip()) {
  294. stats
  295. .connection_rate_limited_per_ipaddr
  296. .fetch_add(1, Ordering::Relaxed);
  297. debug!(
  298. "Ignoring incoming connection from {} due to per-IP rate limiting.",
  299. incoming.remote_address()
  300. );
  301. incoming.ignore();
  302. continue;
  303. }
  304. let Ok(client_connection_tracker) =
  305. ClientConnectionTracker::new(stats.clone(), qos.max_concurrent_connections())
  306. else {
  307. stats
  308. .refused_connections_too_many_open_connections
  309. .fetch_add(1, Ordering::Relaxed);
  310. incoming.refuse();
  311. continue;
  312. };
  313. stats
  314. .outstanding_incoming_connection_attempts
  315. .fetch_add(1, Ordering::Relaxed);
  316. let connecting = incoming.accept();
  317. match connecting {
  318. Ok(connecting) => {
  319. let rate_limiter = rate_limiter.clone();
  320. let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
  321. tasks.spawn(setup_connection(
  322. connecting,
  323. rate_limiter,
  324. overall_connection_rate_limiter,
  325. client_connection_tracker,
  326. packet_batch_sender.clone(),
  327. stats.clone(),
  328. quic_server_params.clone(),
  329. qos.clone(),
  330. tasks.clone(),
  331. ));
  332. }
  333. Err(err) => {
  334. stats
  335. .outstanding_incoming_connection_attempts
  336. .fetch_sub(1, Ordering::Relaxed);
  337. debug!("Incoming::accept(): error {err:?}");
  338. }
  339. }
  340. } else {
  341. debug!("accept(): Timed out waiting for connection");
  342. }
  343. }
  344. tasks
  345. }
  346. pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
  347. // Use the client cert only if it is self signed and the chain length is 1.
  348. connection
  349. .peer_identity()?
  350. .downcast::<Vec<rustls::pki_types::CertificateDer>>()
  351. .ok()
  352. .filter(|certs| certs.len() == 1)?
  353. .first()
  354. .and_then(get_pubkey_from_tls_certificate)
  355. }
  356. pub fn get_connection_stake(
  357. connection: &Connection,
  358. staked_nodes: &RwLock<StakedNodes>,
  359. ) -> Option<(Pubkey, u64, u64, u64, u64)> {
  360. let pubkey = get_remote_pubkey(connection)?;
  361. debug!("Peer public key is {pubkey:?}");
  362. let staked_nodes = staked_nodes.read().unwrap();
  363. Some((
  364. pubkey,
  365. staked_nodes.get_node_stake(&pubkey)?,
  366. staked_nodes.total_stake(),
  367. staked_nodes.max_stake(),
  368. staked_nodes.min_stake(),
  369. ))
  370. }
  371. #[derive(Debug)]
  372. pub(crate) enum ConnectionHandlerError {
  373. ConnectionAddError,
  374. MaxStreamError,
  375. }
  376. pub(crate) fn update_open_connections_stat<S: OpaqueStreamerCounter>(
  377. stats: &StreamerStats,
  378. connection_table: &ConnectionTable<S>,
  379. ) {
  380. if connection_table.is_staked() {
  381. stats
  382. .open_staked_connections
  383. .store(connection_table.table_size(), Ordering::Relaxed);
  384. stats
  385. .peak_open_staked_connections
  386. .fetch_max(connection_table.table_size(), Ordering::Relaxed);
  387. } else {
  388. stats
  389. .open_unstaked_connections
  390. .store(connection_table.table_size(), Ordering::Relaxed);
  391. stats
  392. .peak_open_unstaked_connections
  393. .fetch_max(connection_table.table_size(), Ordering::Relaxed);
  394. }
  395. }
  396. #[allow(clippy::too_many_arguments)]
  397. async fn setup_connection<Q, C>(
  398. connecting: Connecting,
  399. rate_limiter: Arc<ConnectionRateLimiter>,
  400. overall_connection_rate_limiter: Arc<TokenBucket>,
  401. client_connection_tracker: ClientConnectionTracker,
  402. packet_sender: Sender<PacketBatch>,
  403. stats: Arc<StreamerStats>,
  404. server_params: Arc<QuicStreamerConfig>,
  405. qos: Arc<Q>,
  406. tasks: TaskTracker,
  407. ) where
  408. Q: QosController<C> + Send + Sync + 'static,
  409. C: ConnectionContext + Send + Sync + 'static,
  410. {
  411. let from = connecting.remote_address();
  412. let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
  413. stats
  414. .outstanding_incoming_connection_attempts
  415. .fetch_sub(1, Ordering::Relaxed);
  416. if let Ok(connecting_result) = res {
  417. match connecting_result {
  418. Ok(new_connection) => {
  419. debug!("Got a connection {from:?}");
  420. // now that we have observed the handshake we can be certain
  421. // that the initiator owns an IP address, we can update rate
  422. // limiters on the server
  423. if !rate_limiter.register_connection(&from.ip()) {
  424. debug!("Reject connection from {from:?} -- rate limiting exceeded");
  425. stats
  426. .connection_rate_limited_per_ipaddr
  427. .fetch_add(1, Ordering::Relaxed);
  428. new_connection.close(
  429. CONNECTION_CLOSE_CODE_DISALLOWED.into(),
  430. CONNECTION_CLOSE_REASON_DISALLOWED,
  431. );
  432. return;
  433. }
  434. if overall_connection_rate_limiter.consume_tokens(1).is_err() {
  435. debug!(
  436. "Reject connection from {:?} -- total rate limiting exceeded",
  437. from.ip()
  438. );
  439. stats
  440. .connection_rate_limited_across_all
  441. .fetch_add(1, Ordering::Relaxed);
  442. new_connection.close(
  443. CONNECTION_CLOSE_CODE_DISALLOWED.into(),
  444. CONNECTION_CLOSE_REASON_DISALLOWED,
  445. );
  446. return;
  447. }
  448. stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
  449. let mut conn_context = qos.build_connection_context(&new_connection);
  450. if let Some(cancel_connection) = qos
  451. .try_add_connection(
  452. client_connection_tracker,
  453. &new_connection,
  454. &mut conn_context,
  455. )
  456. .await
  457. {
  458. tasks.spawn(handle_connection(
  459. packet_sender.clone(),
  460. new_connection,
  461. stats,
  462. server_params.wait_for_chunk_timeout,
  463. conn_context.clone(),
  464. qos,
  465. cancel_connection,
  466. ));
  467. }
  468. }
  469. Err(e) => {
  470. handle_connection_error(e, &stats, from);
  471. }
  472. }
  473. } else {
  474. stats
  475. .connection_setup_timeout
  476. .fetch_add(1, Ordering::Relaxed);
  477. }
  478. }
  479. fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
  480. debug!("error: {e:?} from: {from:?}");
  481. stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
  482. match e {
  483. quinn::ConnectionError::TimedOut => {
  484. stats
  485. .connection_setup_error_timed_out
  486. .fetch_add(1, Ordering::Relaxed);
  487. }
  488. quinn::ConnectionError::ConnectionClosed(_) => {
  489. stats
  490. .connection_setup_error_closed
  491. .fetch_add(1, Ordering::Relaxed);
  492. }
  493. quinn::ConnectionError::TransportError(_) => {
  494. stats
  495. .connection_setup_error_transport
  496. .fetch_add(1, Ordering::Relaxed);
  497. }
  498. quinn::ConnectionError::ApplicationClosed(_) => {
  499. stats
  500. .connection_setup_error_app_closed
  501. .fetch_add(1, Ordering::Relaxed);
  502. }
  503. quinn::ConnectionError::Reset => {
  504. stats
  505. .connection_setup_error_reset
  506. .fetch_add(1, Ordering::Relaxed);
  507. }
  508. quinn::ConnectionError::LocallyClosed => {
  509. stats
  510. .connection_setup_error_locally_closed
  511. .fetch_add(1, Ordering::Relaxed);
  512. }
  513. _ => {}
  514. }
  515. }
  516. fn track_streamer_fetch_packet_performance(
  517. packet_perf_measure: &[([u8; 64], Instant)],
  518. stats: &StreamerStats,
  519. ) {
  520. if packet_perf_measure.is_empty() {
  521. return;
  522. }
  523. let mut measure = Measure::start("track_perf");
  524. let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
  525. let now = Instant::now();
  526. for (signature, start_time) in packet_perf_measure {
  527. let duration = now.duration_since(*start_time);
  528. debug!(
  529. "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
  530. Signature::from(*signature)
  531. );
  532. process_sampled_packets_us_hist
  533. .increment(duration.as_micros() as u64)
  534. .unwrap();
  535. }
  536. drop(process_sampled_packets_us_hist);
  537. measure.stop();
  538. stats
  539. .perf_track_overhead_us
  540. .fetch_add(measure.as_us(), Ordering::Relaxed);
  541. }
  542. async fn handle_connection<Q, C>(
  543. packet_sender: Sender<PacketBatch>,
  544. connection: Connection,
  545. stats: Arc<StreamerStats>,
  546. wait_for_chunk_timeout: Duration,
  547. context: C,
  548. qos: Arc<Q>,
  549. cancel: CancellationToken,
  550. ) where
  551. Q: QosController<C> + Send + Sync + 'static,
  552. C: ConnectionContext + Send + Sync + 'static,
  553. {
  554. let peer_type = context.peer_type();
  555. let remote_addr = connection.remote_address();
  556. debug!(
  557. "quic new connection {} streams: {} connections: {}",
  558. remote_addr,
  559. stats.active_streams.load(Ordering::Relaxed),
  560. stats.total_connections.load(Ordering::Relaxed),
  561. );
  562. stats.total_connections.fetch_add(1, Ordering::Relaxed);
  563. 'conn: loop {
  564. // Wait for new streams. If the peer is disconnected we get a cancellation signal and stop
  565. // the connection task.
  566. let mut stream = select! {
  567. stream = connection.accept_uni() => match stream {
  568. Ok(stream) => stream,
  569. Err(e) => {
  570. debug!("stream error: {e:?}");
  571. break;
  572. }
  573. },
  574. _ = cancel.cancelled() => break,
  575. };
  576. qos.on_new_stream(&context).await;
  577. qos.on_stream_accepted(&context);
  578. stats.active_streams.fetch_add(1, Ordering::Relaxed);
  579. stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
  580. let mut meta = Meta::default();
  581. meta.set_socket_addr(&remote_addr);
  582. meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
  583. if let Some(pubkey) = context.remote_pubkey() {
  584. meta.set_remote_pubkey(pubkey);
  585. }
  586. let mut accum = PacketAccumulator::new(meta);
  587. // Virtually all small transactions will fit in 1 chunk. Larger transactions will fit in 1
  588. // or 2 chunks if the first chunk starts towards the end of a datagram. A small number of
  589. // transaction will have other protocol frames inserted in the middle. Empirically it's been
  590. // observed that 4 is the maximum number of chunks txs get split into.
  591. //
  592. // Bytes values are small, so overall the array takes only 128 bytes, and the "cost" of
  593. // overallocating a few bytes is negligible compared to the cost of having to do multiple
  594. // read_chunks() calls.
  595. let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
  596. loop {
  597. // Read the next chunks, waiting up to `wait_for_chunk_timeout`. If we don't get chunks
  598. // before then, we assume the stream is dead. This can only happen if there's severe
  599. // packet loss or the peer stops sending for whatever reason.
  600. let n_chunks = match tokio::select! {
  601. chunk = tokio::time::timeout(
  602. wait_for_chunk_timeout,
  603. stream.read_chunks(&mut chunks)) => chunk,
  604. // If the peer gets disconnected stop the task right away.
  605. _ = cancel.cancelled() => break,
  606. } {
  607. // read_chunk returned success
  608. Ok(Ok(chunk)) => chunk.unwrap_or(0),
  609. // read_chunk returned error
  610. Ok(Err(e)) => {
  611. debug!("Received stream error: {e:?}");
  612. stats
  613. .total_stream_read_errors
  614. .fetch_add(1, Ordering::Relaxed);
  615. break;
  616. }
  617. // timeout elapsed
  618. Err(_) => {
  619. debug!("Timeout in receiving on stream");
  620. stats
  621. .total_stream_read_timeouts
  622. .fetch_add(1, Ordering::Relaxed);
  623. break;
  624. }
  625. };
  626. match handle_chunks(
  627. // Bytes::clone() is a cheap atomic inc
  628. chunks.iter().take(n_chunks).cloned(),
  629. &mut accum,
  630. &packet_sender,
  631. &stats,
  632. peer_type,
  633. ) {
  634. // The stream is finished, break out of the loop and close the stream.
  635. Ok(StreamState::Finished) => {
  636. qos.on_stream_finished(&context);
  637. break;
  638. }
  639. // The stream is still active, continue reading.
  640. Ok(StreamState::Receiving) => {}
  641. Err(_) => {
  642. // Disconnect peers that send invalid streams.
  643. connection.close(
  644. CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
  645. CONNECTION_CLOSE_REASON_INVALID_STREAM,
  646. );
  647. stats.active_streams.fetch_sub(1, Ordering::Relaxed);
  648. qos.on_stream_error(&context);
  649. break 'conn;
  650. }
  651. }
  652. }
  653. stats.active_streams.fetch_sub(1, Ordering::Relaxed);
  654. qos.on_stream_closed(&context);
  655. }
  656. let removed_connection_count = qos.remove_connection(&context, connection).await;
  657. if removed_connection_count > 0 {
  658. stats
  659. .connection_removed
  660. .fetch_add(removed_connection_count, Ordering::Relaxed);
  661. } else {
  662. stats
  663. .connection_remove_failed
  664. .fetch_add(1, Ordering::Relaxed);
  665. }
  666. stats.total_connections.fetch_sub(1, Ordering::Relaxed);
  667. }
  668. enum StreamState {
  669. // Stream is not finished, keep receiving chunks
  670. Receiving,
  671. // Stream is finished
  672. Finished,
  673. }
  674. // Handle the chunks received from the stream. If the stream is finished, send the packet to the
  675. // packet sender.
  676. //
  677. // Returns Err(()) if the stream is invalid.
  678. fn handle_chunks(
  679. chunks: impl ExactSizeIterator<Item = Bytes>,
  680. accum: &mut PacketAccumulator,
  681. packet_sender: &Sender<PacketBatch>,
  682. stats: &StreamerStats,
  683. peer_type: ConnectionPeerType,
  684. ) -> Result<StreamState, ()> {
  685. let n_chunks = chunks.len();
  686. for chunk in chunks {
  687. accum.meta.size += chunk.len();
  688. if accum.meta.size > PACKET_DATA_SIZE {
  689. // The stream window size is set to PACKET_DATA_SIZE, so one individual chunk can
  690. // never exceed this size. A peer can send two chunks that together exceed the size
  691. // tho, in which case we report the error.
  692. stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
  693. debug!("invalid stream size {}", accum.meta.size);
  694. return Err(());
  695. }
  696. accum.chunks.push(chunk);
  697. if peer_type.is_staked() {
  698. stats
  699. .total_staked_chunks_received
  700. .fetch_add(1, Ordering::Relaxed);
  701. } else {
  702. stats
  703. .total_unstaked_chunks_received
  704. .fetch_add(1, Ordering::Relaxed);
  705. }
  706. }
  707. // n_chunks == 0 marks the end of a stream
  708. if n_chunks != 0 {
  709. return Ok(StreamState::Receiving);
  710. }
  711. if accum.chunks.is_empty() {
  712. debug!("stream is empty");
  713. stats
  714. .total_packet_batches_none
  715. .fetch_add(1, Ordering::Relaxed);
  716. return Err(());
  717. }
  718. // done receiving chunks
  719. let bytes_sent = accum.meta.size;
  720. //
  721. // 86% of transactions/packets come in one chunk. In that case,
  722. // we can just move the chunk to the `Packet` and no copy is
  723. // made.
  724. // 14% of them come in multiple chunks. In that case, we copy
  725. // them into one `Bytes` buffer. We make a copy once, with
  726. // intention to not do it again.
  727. let mut packet = if accum.chunks.len() == 1 {
  728. BytesPacket::new(
  729. accum.chunks.pop().expect("expected one chunk"),
  730. accum.meta.clone(),
  731. )
  732. } else {
  733. let size: usize = accum.chunks.iter().map(Bytes::len).sum();
  734. let mut buf = BytesMut::with_capacity(size);
  735. for chunk in &accum.chunks {
  736. buf.put_slice(chunk);
  737. }
  738. BytesPacket::new(buf.freeze(), accum.meta.clone())
  739. };
  740. let packet_size = packet.meta().size;
  741. let mut packet_perf_measure = None;
  742. if let Some(signature) = signature_if_should_track_packet(&packet).ok().flatten() {
  743. packet_perf_measure = Some((*signature, accum.start_time));
  744. // we set the PERF_TRACK_PACKET on
  745. packet.meta_mut().set_track_performance(true);
  746. }
  747. let packet_batch = PacketBatch::Single(packet);
  748. if let Err(err) = packet_sender.try_send(packet_batch) {
  749. stats
  750. .total_handle_chunk_to_packet_send_err
  751. .fetch_add(1, Ordering::Relaxed);
  752. match err {
  753. TrySendError::Full(_) => {
  754. stats
  755. .total_handle_chunk_to_packet_send_full_err
  756. .fetch_add(1, Ordering::Relaxed);
  757. }
  758. TrySendError::Disconnected(_) => {
  759. stats
  760. .total_handle_chunk_to_packet_send_disconnected_err
  761. .fetch_add(1, Ordering::Relaxed);
  762. }
  763. }
  764. trace!("packet batch send error {err:?}");
  765. } else {
  766. if let Some(ppm) = &packet_perf_measure {
  767. track_streamer_fetch_packet_performance(core::array::from_ref(ppm), stats);
  768. }
  769. stats
  770. .total_bytes_sent_to_consumer
  771. .fetch_add(packet_size, Ordering::Relaxed);
  772. stats
  773. .total_packets_sent_to_consumer
  774. .fetch_add(1, Ordering::Relaxed);
  775. match peer_type {
  776. ConnectionPeerType::Unstaked => {
  777. stats
  778. .total_unstaked_packets_sent_for_batching
  779. .fetch_add(1, Ordering::Relaxed);
  780. }
  781. ConnectionPeerType::Staked(_) => {
  782. stats
  783. .total_staked_packets_sent_for_batching
  784. .fetch_add(1, Ordering::Relaxed);
  785. }
  786. }
  787. trace!("sent {bytes_sent} byte packet for batching");
  788. }
  789. Ok(StreamState::Finished)
  790. }
  791. struct ConnectionEntry<S: OpaqueStreamerCounter> {
  792. cancel: CancellationToken,
  793. peer_type: ConnectionPeerType,
  794. last_update: Arc<AtomicU64>,
  795. port: u16,
  796. // We do not explicitly use it, but its drop is triggered when ConnectionEntry is dropped.
  797. _client_connection_tracker: ClientConnectionTracker,
  798. connection: Option<Connection>,
  799. stream_counter: Arc<S>,
  800. }
  801. impl<S: OpaqueStreamerCounter> ConnectionEntry<S> {
  802. fn new(
  803. cancel: CancellationToken,
  804. peer_type: ConnectionPeerType,
  805. last_update: Arc<AtomicU64>,
  806. port: u16,
  807. client_connection_tracker: ClientConnectionTracker,
  808. connection: Option<Connection>,
  809. stream_counter: Arc<S>,
  810. ) -> Self {
  811. Self {
  812. cancel,
  813. peer_type,
  814. last_update,
  815. port,
  816. _client_connection_tracker: client_connection_tracker,
  817. connection,
  818. stream_counter,
  819. }
  820. }
  821. fn last_update(&self) -> u64 {
  822. self.last_update.load(Ordering::Relaxed)
  823. }
  824. fn stake(&self) -> u64 {
  825. match self.peer_type {
  826. ConnectionPeerType::Unstaked => 0,
  827. ConnectionPeerType::Staked(stake) => stake,
  828. }
  829. }
  830. }
  831. impl<S: OpaqueStreamerCounter> Drop for ConnectionEntry<S> {
  832. fn drop(&mut self) {
  833. if let Some(conn) = self.connection.take() {
  834. conn.close(
  835. CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
  836. CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
  837. );
  838. }
  839. self.cancel.cancel();
  840. }
  841. }
  842. #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
  843. pub(crate) enum ConnectionTableKey {
  844. IP(IpAddr),
  845. Pubkey(Pubkey),
  846. }
  847. impl ConnectionTableKey {
  848. pub(crate) fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
  849. maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
  850. ConnectionTableKey::Pubkey(pubkey)
  851. })
  852. }
  853. }
  854. pub(crate) enum ConnectionTableType {
  855. Staked,
  856. Unstaked,
  857. }
  858. // Map of IP to list of connection entries
  859. pub(crate) struct ConnectionTable<S: OpaqueStreamerCounter> {
  860. table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry<S>>>,
  861. pub(crate) total_size: usize,
  862. table_type: ConnectionTableType,
  863. cancel: CancellationToken,
  864. }
  865. /// Prune the connection which has the oldest update
  866. ///
  867. /// Return number pruned
  868. impl<S: OpaqueStreamerCounter> ConnectionTable<S> {
  869. pub(crate) fn new(table_type: ConnectionTableType, cancel: CancellationToken) -> Self {
  870. Self {
  871. table: IndexMap::default(),
  872. total_size: 0,
  873. table_type,
  874. cancel,
  875. }
  876. }
  877. fn table_size(&self) -> usize {
  878. self.total_size
  879. }
  880. fn is_staked(&self) -> bool {
  881. matches!(self.table_type, ConnectionTableType::Staked)
  882. }
  883. pub(crate) fn prune_oldest(&mut self, max_size: usize) -> usize {
  884. let mut num_pruned = 0;
  885. let key = |(_, connections): &(_, &Vec<_>)| {
  886. connections.iter().map(ConnectionEntry::last_update).min()
  887. };
  888. while self.total_size.saturating_sub(num_pruned) > max_size {
  889. match self.table.values().enumerate().min_by_key(key) {
  890. None => break,
  891. Some((index, connections)) => {
  892. num_pruned += connections.len();
  893. self.table.swap_remove_index(index);
  894. }
  895. }
  896. }
  897. self.total_size = self.total_size.saturating_sub(num_pruned);
  898. num_pruned
  899. }
  900. // Randomly selects sample_size many connections, evicts the one with the
  901. // lowest stake, and returns the number of pruned connections.
  902. // If the stakes of all the sampled connections are higher than the
  903. // threshold_stake, rejects the pruning attempt, and returns 0.
  904. pub(crate) fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
  905. let num_pruned = std::iter::once(self.table.len())
  906. .filter(|&size| size > 0)
  907. .flat_map(|size| {
  908. let mut rng = thread_rng();
  909. repeat_with(move || rng.gen_range(0..size))
  910. })
  911. .map(|index| {
  912. let connection = self.table[index].first();
  913. let stake = connection.map(|connection: &ConnectionEntry<S>| connection.stake());
  914. (index, stake)
  915. })
  916. .take(sample_size)
  917. .min_by_key(|&(_, stake)| stake)
  918. .filter(|&(_, stake)| stake < Some(threshold_stake))
  919. .and_then(|(index, _)| self.table.swap_remove_index(index))
  920. .map(|(_, connections)| connections.len())
  921. .unwrap_or_default();
  922. self.total_size = self.total_size.saturating_sub(num_pruned);
  923. num_pruned
  924. }
  925. pub(crate) fn try_add_connection<F: FnOnce() -> Arc<S>>(
  926. &mut self,
  927. key: ConnectionTableKey,
  928. port: u16,
  929. client_connection_tracker: ClientConnectionTracker,
  930. connection: Option<Connection>,
  931. peer_type: ConnectionPeerType,
  932. last_update: Arc<AtomicU64>,
  933. max_connections_per_peer: usize,
  934. stream_counter_factory: F,
  935. ) -> Option<(Arc<AtomicU64>, CancellationToken, Arc<S>)> {
  936. let connection_entry = self.table.entry(key).or_default();
  937. let has_connection_capacity = connection_entry
  938. .len()
  939. .checked_add(1)
  940. .map(|c| c <= max_connections_per_peer)
  941. .unwrap_or(false);
  942. if has_connection_capacity {
  943. let cancel = self.cancel.child_token();
  944. let stream_counter = connection_entry
  945. .first()
  946. .map(|entry| entry.stream_counter.clone())
  947. .unwrap_or_else(stream_counter_factory);
  948. connection_entry.push(ConnectionEntry::new(
  949. cancel.clone(),
  950. peer_type,
  951. last_update.clone(),
  952. port,
  953. client_connection_tracker,
  954. connection,
  955. stream_counter.clone(),
  956. ));
  957. self.total_size += 1;
  958. Some((last_update, cancel, stream_counter))
  959. } else {
  960. if let Some(connection) = connection {
  961. connection.close(
  962. CONNECTION_CLOSE_CODE_TOO_MANY.into(),
  963. CONNECTION_CLOSE_REASON_TOO_MANY,
  964. );
  965. }
  966. None
  967. }
  968. }
  969. // Returns number of connections that were removed
  970. pub(crate) fn remove_connection(
  971. &mut self,
  972. key: ConnectionTableKey,
  973. port: u16,
  974. stable_id: usize,
  975. ) -> usize {
  976. if let Entry::Occupied(mut e) = self.table.entry(key) {
  977. let e_ref = e.get_mut();
  978. let old_size = e_ref.len();
  979. e_ref.retain(|connection_entry| {
  980. // Retain the connection entry if the port is different, or if the connection's
  981. // stable_id doesn't match the provided stable_id.
  982. // (Some unit tests do not fill in a valid connection in the table. To support that,
  983. // if the connection is none, the stable_id check is ignored. i.e. if the port matches,
  984. // the connection gets removed)
  985. connection_entry.port != port
  986. || connection_entry
  987. .connection
  988. .as_ref()
  989. .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
  990. .is_some()
  991. });
  992. let new_size = e_ref.len();
  993. if e_ref.is_empty() {
  994. e.swap_remove_entry();
  995. }
  996. let connections_removed = old_size.saturating_sub(new_size);
  997. self.total_size = self.total_size.saturating_sub(connections_removed);
  998. connections_removed
  999. } else {
  1000. 0
  1001. }
  1002. }
  1003. }
  1004. struct EndpointAccept<'a> {
  1005. endpoint: usize,
  1006. accept: Accept<'a>,
  1007. }
  1008. impl Future for EndpointAccept<'_> {
  1009. type Output = (Option<quinn::Incoming>, usize);
  1010. fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
  1011. let i = self.endpoint;
  1012. // Safety:
  1013. // self is pinned and accept is a field so it can't get moved out. See safety docs of
  1014. // map_unchecked_mut.
  1015. unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
  1016. .poll(cx)
  1017. .map(|r| (r, i))
  1018. }
  1019. }
  1020. #[cfg(test)]
  1021. pub mod test {
  1022. use {
  1023. super::*,
  1024. crate::nonblocking::{
  1025. qos::NullStreamerCounter,
  1026. swqos::SwQosConfig,
  1027. testing_utilities::{
  1028. check_multiple_streams, get_client_config, make_client_endpoint, setup_quic_server,
  1029. spawn_stake_weighted_qos_server, SpawnTestServerResult,
  1030. },
  1031. },
  1032. assert_matches::assert_matches,
  1033. crossbeam_channel::{unbounded, Receiver},
  1034. quinn::{ApplicationClose, ConnectionError},
  1035. solana_keypair::Keypair,
  1036. solana_net_utils::sockets::bind_to_localhost_unique,
  1037. solana_signer::Signer,
  1038. std::collections::HashMap,
  1039. tokio::time::sleep,
  1040. };
  1041. pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
  1042. let conn1 = make_client_endpoint(&server_address, None).await;
  1043. let total = 30;
  1044. for i in 0..total {
  1045. let mut s1 = conn1.open_uni().await.unwrap();
  1046. s1.write_all(&[0u8]).await.unwrap();
  1047. s1.finish().unwrap();
  1048. info!("done {i}");
  1049. sleep(Duration::from_millis(1000)).await;
  1050. }
  1051. let mut received = 0;
  1052. loop {
  1053. if let Ok(_x) = receiver.try_recv() {
  1054. received += 1;
  1055. info!("got {received}");
  1056. } else {
  1057. sleep(Duration::from_millis(500)).await;
  1058. }
  1059. if received >= total {
  1060. break;
  1061. }
  1062. }
  1063. }
  1064. pub async fn check_block_multiple_connections(server_address: SocketAddr) {
  1065. let conn1 = make_client_endpoint(&server_address, None).await;
  1066. let conn2 = make_client_endpoint(&server_address, None).await;
  1067. let mut s1 = conn1.open_uni().await.unwrap();
  1068. let s2 = conn2.open_uni().await;
  1069. if let Ok(mut s2) = s2 {
  1070. s1.write_all(&[0u8]).await.unwrap();
  1071. s1.finish().unwrap();
  1072. // Send enough data to create more than 1 chunks.
  1073. // The first will try to open the connection (which should fail).
  1074. // The following chunks will enable the detection of connection failure.
  1075. let data = vec![1u8; PACKET_DATA_SIZE * 2];
  1076. s2.write_all(&data)
  1077. .await
  1078. .expect_err("shouldn't be able to open 2 connections");
  1079. } else {
  1080. // It has been noticed if there is already connection open against the server, this open_uni can fail
  1081. // with ApplicationClosed(ApplicationClose) error due to CONNECTION_CLOSE_CODE_TOO_MANY before writing to
  1082. // the stream -- expect it.
  1083. assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
  1084. }
  1085. }
  1086. pub async fn check_multiple_writes(
  1087. receiver: Receiver<PacketBatch>,
  1088. server_address: SocketAddr,
  1089. client_keypair: Option<&Keypair>,
  1090. ) {
  1091. let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
  1092. // Send a full size packet with single byte writes.
  1093. let num_bytes = PACKET_DATA_SIZE;
  1094. let num_expected_packets = 1;
  1095. let mut s1 = conn1.open_uni().await.unwrap();
  1096. for _ in 0..num_bytes {
  1097. s1.write_all(&[0u8]).await.unwrap();
  1098. }
  1099. s1.finish().unwrap();
  1100. check_received_packets(receiver, num_expected_packets, num_bytes).await;
  1101. }
  1102. pub async fn check_multiple_packets(
  1103. receiver: Receiver<PacketBatch>,
  1104. server_address: SocketAddr,
  1105. client_keypair: Option<&Keypair>,
  1106. num_expected_packets: usize,
  1107. ) {
  1108. let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
  1109. // Send a full size packet with single byte writes.
  1110. let num_bytes = PACKET_DATA_SIZE;
  1111. let packet = vec![1u8; num_bytes];
  1112. for _ in 0..num_expected_packets {
  1113. let mut s1 = conn1.open_uni().await.unwrap();
  1114. s1.write_all(&packet).await.unwrap();
  1115. s1.finish().unwrap();
  1116. }
  1117. check_received_packets(receiver, num_expected_packets, num_bytes).await;
  1118. }
  1119. async fn check_received_packets(
  1120. receiver: Receiver<PacketBatch>,
  1121. num_expected_packets: usize,
  1122. num_bytes: usize,
  1123. ) {
  1124. let mut all_packets = vec![];
  1125. let now = Instant::now();
  1126. let mut total_packets = 0;
  1127. while now.elapsed().as_secs() < 5 {
  1128. // We're running in an async environment, we (almost) never
  1129. // want to block
  1130. if let Ok(packets) = receiver.try_recv() {
  1131. total_packets += packets.len();
  1132. all_packets.push(packets)
  1133. } else {
  1134. sleep(Duration::from_secs(1)).await;
  1135. }
  1136. if total_packets >= num_expected_packets {
  1137. break;
  1138. }
  1139. }
  1140. for batch in all_packets {
  1141. for p in batch.iter() {
  1142. assert_eq!(p.meta().size, num_bytes);
  1143. }
  1144. }
  1145. assert_eq!(total_packets, num_expected_packets);
  1146. }
  1147. pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
  1148. let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
  1149. // Send a full size packet with single byte writes.
  1150. if let Ok(mut s1) = conn1.open_uni().await {
  1151. for _ in 0..PACKET_DATA_SIZE {
  1152. // Ignoring any errors here. s1.finish() will test the error condition
  1153. s1.write_all(&[0u8]).await.unwrap_or_default();
  1154. }
  1155. s1.finish().unwrap_or_default();
  1156. s1.stopped().await.unwrap_err();
  1157. }
  1158. }
  1159. #[tokio::test(flavor = "multi_thread")]
  1160. async fn test_quic_server_exit_on_cancel() {
  1161. let SpawnTestServerResult {
  1162. join_handle,
  1163. receiver,
  1164. server_address: _,
  1165. stats: _,
  1166. cancel,
  1167. } = setup_quic_server(
  1168. None,
  1169. QuicStreamerConfig::default_for_tests(),
  1170. SwQosConfig::default(),
  1171. );
  1172. cancel.cancel();
  1173. join_handle.await.unwrap();
  1174. // test that it is stopped by cancel, not due to receiver
  1175. // dropped.
  1176. drop(receiver);
  1177. }
  1178. #[tokio::test(flavor = "multi_thread")]
  1179. async fn test_quic_timeout() {
  1180. agave_logger::setup();
  1181. let SpawnTestServerResult {
  1182. join_handle,
  1183. receiver,
  1184. server_address,
  1185. stats: _,
  1186. cancel,
  1187. } = setup_quic_server(
  1188. None,
  1189. QuicStreamerConfig::default_for_tests(),
  1190. SwQosConfig::default(),
  1191. );
  1192. check_timeout(receiver, server_address).await;
  1193. cancel.cancel();
  1194. join_handle.await.unwrap();
  1195. }
  1196. #[tokio::test(flavor = "multi_thread")]
  1197. async fn test_quic_stream_timeout() {
  1198. agave_logger::setup();
  1199. let SpawnTestServerResult {
  1200. join_handle,
  1201. receiver,
  1202. server_address,
  1203. stats,
  1204. cancel,
  1205. } = setup_quic_server(
  1206. None,
  1207. QuicStreamerConfig::default_for_tests(),
  1208. SwQosConfig::default(),
  1209. );
  1210. let conn1 = make_client_endpoint(&server_address, None).await;
  1211. assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
  1212. assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
  1213. // Send one byte to start the stream
  1214. let mut s1 = conn1.open_uni().await.unwrap();
  1215. s1.write_all(&[0u8]).await.unwrap_or_default();
  1216. // Wait long enough for the stream to timeout in receiving chunks
  1217. let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
  1218. sleep(sleep_time).await;
  1219. // Test that the stream was created, but timed out in read
  1220. assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
  1221. assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
  1222. // Test that more writes to the stream will fail (i.e. the stream is no longer writable
  1223. // after the timeouts)
  1224. assert!(s1.write_all(&[0u8]).await.is_err());
  1225. cancel.cancel();
  1226. drop(receiver);
  1227. join_handle.await.unwrap();
  1228. }
  1229. #[tokio::test(flavor = "multi_thread")]
  1230. async fn test_quic_server_block_multiple_connections() {
  1231. agave_logger::setup();
  1232. let SpawnTestServerResult {
  1233. join_handle,
  1234. receiver,
  1235. server_address,
  1236. stats: _,
  1237. cancel,
  1238. } = setup_quic_server(
  1239. None,
  1240. QuicStreamerConfig::default_for_tests(),
  1241. SwQosConfig::default_for_tests(),
  1242. );
  1243. check_block_multiple_connections(server_address).await;
  1244. cancel.cancel();
  1245. drop(receiver);
  1246. join_handle.await.unwrap();
  1247. }
  1248. #[tokio::test(flavor = "multi_thread")]
  1249. async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
  1250. agave_logger::setup();
  1251. let SpawnTestServerResult {
  1252. join_handle,
  1253. receiver,
  1254. server_address,
  1255. stats,
  1256. cancel,
  1257. } = setup_quic_server(
  1258. None,
  1259. QuicStreamerConfig {
  1260. ..QuicStreamerConfig::default_for_tests()
  1261. },
  1262. SwQosConfig {
  1263. max_connections_per_unstaked_peer: 2,
  1264. ..SwQosConfig::default_for_tests()
  1265. },
  1266. );
  1267. let client_socket = bind_to_localhost_unique().expect("should bind - client");
  1268. let mut endpoint = quinn::Endpoint::new(
  1269. EndpointConfig::default(),
  1270. None,
  1271. client_socket,
  1272. Arc::new(TokioRuntime),
  1273. )
  1274. .unwrap();
  1275. let default_keypair = Keypair::new();
  1276. endpoint.set_default_client_config(get_client_config(&default_keypair));
  1277. let conn1 = endpoint
  1278. .connect(server_address, "localhost")
  1279. .expect("Failed in connecting")
  1280. .await
  1281. .expect("Failed in waiting");
  1282. let conn2 = endpoint
  1283. .connect(server_address, "localhost")
  1284. .expect("Failed in connecting")
  1285. .await
  1286. .expect("Failed in waiting");
  1287. let mut s1 = conn1.open_uni().await.unwrap();
  1288. s1.write_all(&[0u8]).await.unwrap();
  1289. s1.finish().unwrap();
  1290. let mut s2 = conn2.open_uni().await.unwrap();
  1291. conn1.close(
  1292. CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
  1293. CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
  1294. );
  1295. let start = Instant::now();
  1296. while stats.connection_removed.load(Ordering::Relaxed) != 1 && start.elapsed().as_secs() < 1
  1297. {
  1298. debug!("First connection not removed yet");
  1299. sleep(Duration::from_millis(10)).await;
  1300. }
  1301. assert!(start.elapsed().as_secs() < 1);
  1302. s2.write_all(&[0u8]).await.unwrap();
  1303. s2.finish().unwrap();
  1304. conn2.close(
  1305. CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
  1306. CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
  1307. );
  1308. let start = Instant::now();
  1309. while stats.connection_removed.load(Ordering::Relaxed) != 2 && start.elapsed().as_secs() < 1
  1310. {
  1311. debug!("Second connection not removed yet");
  1312. sleep(Duration::from_millis(10)).await;
  1313. }
  1314. assert!(start.elapsed().as_secs() < 1);
  1315. cancel.cancel();
  1316. // Explicitly drop receiver here so that it doesn't get implicitly
  1317. // dropped earlier. This is necessary to ensure the server stays alive
  1318. // and doesn't issue a cancel to kill the connection earlier than
  1319. // expected.
  1320. drop(receiver);
  1321. join_handle.await.unwrap();
  1322. }
  1323. #[tokio::test(flavor = "multi_thread")]
  1324. async fn test_quic_server_multiple_writes() {
  1325. agave_logger::setup();
  1326. let SpawnTestServerResult {
  1327. join_handle,
  1328. receiver,
  1329. server_address,
  1330. stats: _,
  1331. cancel,
  1332. } = setup_quic_server(
  1333. None,
  1334. QuicStreamerConfig::default_for_tests(),
  1335. SwQosConfig::default(),
  1336. );
  1337. check_multiple_writes(receiver, server_address, None).await;
  1338. cancel.cancel();
  1339. join_handle.await.unwrap();
  1340. }
  1341. #[tokio::test(flavor = "multi_thread")]
  1342. async fn test_quic_server_staked_connection_removal() {
  1343. agave_logger::setup();
  1344. let client_keypair = Keypair::new();
  1345. let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
  1346. let staked_nodes = StakedNodes::new(
  1347. Arc::new(stakes),
  1348. HashMap::<Pubkey, u64>::default(), // overrides
  1349. );
  1350. let SpawnTestServerResult {
  1351. join_handle,
  1352. receiver,
  1353. server_address,
  1354. stats,
  1355. cancel,
  1356. } = setup_quic_server(
  1357. Some(staked_nodes),
  1358. QuicStreamerConfig::default_for_tests(),
  1359. SwQosConfig::default(),
  1360. );
  1361. check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
  1362. cancel.cancel();
  1363. join_handle.await.unwrap();
  1364. assert_eq!(
  1365. stats
  1366. .connection_added_from_staked_peer
  1367. .load(Ordering::Relaxed),
  1368. 1
  1369. );
  1370. assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
  1371. assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
  1372. }
  1373. #[tokio::test(flavor = "multi_thread")]
  1374. async fn test_quic_server_zero_staked_connection_removal() {
  1375. // In this test, the client has a pubkey, but is not in stake table.
  1376. agave_logger::setup();
  1377. let client_keypair = Keypair::new();
  1378. let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
  1379. let staked_nodes = StakedNodes::new(
  1380. Arc::new(stakes),
  1381. HashMap::<Pubkey, u64>::default(), // overrides
  1382. );
  1383. let SpawnTestServerResult {
  1384. join_handle,
  1385. receiver,
  1386. server_address,
  1387. stats,
  1388. cancel,
  1389. } = setup_quic_server(
  1390. Some(staked_nodes),
  1391. QuicStreamerConfig::default_for_tests(),
  1392. SwQosConfig::default(),
  1393. );
  1394. check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
  1395. cancel.cancel();
  1396. join_handle.await.unwrap();
  1397. assert_eq!(
  1398. stats
  1399. .connection_added_from_staked_peer
  1400. .load(Ordering::Relaxed),
  1401. 0
  1402. );
  1403. assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
  1404. assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
  1405. }
  1406. #[tokio::test(flavor = "multi_thread")]
  1407. async fn test_quic_server_unstaked_connection_removal() {
  1408. agave_logger::setup();
  1409. let SpawnTestServerResult {
  1410. join_handle,
  1411. receiver,
  1412. server_address,
  1413. stats,
  1414. cancel,
  1415. } = setup_quic_server(
  1416. None,
  1417. QuicStreamerConfig::default_for_tests(),
  1418. SwQosConfig::default(),
  1419. );
  1420. check_multiple_writes(receiver, server_address, None).await;
  1421. cancel.cancel();
  1422. join_handle.await.unwrap();
  1423. assert_eq!(
  1424. stats
  1425. .connection_added_from_staked_peer
  1426. .load(Ordering::Relaxed),
  1427. 0
  1428. );
  1429. assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
  1430. assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
  1431. }
  1432. #[tokio::test(flavor = "multi_thread")]
  1433. async fn test_quic_server_unstaked_node_connect_failure() {
  1434. agave_logger::setup();
  1435. let s = bind_to_localhost_unique().expect("should bind");
  1436. let (sender, _) = unbounded();
  1437. let keypair = Keypair::new();
  1438. let server_address = s.local_addr().unwrap();
  1439. let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
  1440. let cancel = CancellationToken::new();
  1441. let SpawnNonBlockingServerResult {
  1442. endpoints: _,
  1443. stats: _,
  1444. thread: t,
  1445. max_concurrent_connections: _,
  1446. } = spawn_stake_weighted_qos_server(
  1447. "quic_streamer_test",
  1448. [s],
  1449. &keypair,
  1450. sender,
  1451. staked_nodes,
  1452. QuicStreamerConfig {
  1453. ..QuicStreamerConfig::default_for_tests()
  1454. },
  1455. SwQosConfig {
  1456. max_unstaked_connections: 0, // Do not allow any connection from unstaked clients/nodes
  1457. ..Default::default()
  1458. },
  1459. cancel.clone(),
  1460. )
  1461. .unwrap();
  1462. check_unstaked_node_connect_failure(server_address).await;
  1463. cancel.cancel();
  1464. t.await.unwrap();
  1465. }
  1466. #[tokio::test(flavor = "multi_thread")]
  1467. async fn test_quic_server_multiple_streams() {
  1468. agave_logger::setup();
  1469. let s = bind_to_localhost_unique().expect("should bind");
  1470. let (sender, receiver) = unbounded();
  1471. let keypair = Keypair::new();
  1472. let server_address = s.local_addr().unwrap();
  1473. let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
  1474. let cancel = CancellationToken::new();
  1475. let SpawnNonBlockingServerResult {
  1476. endpoints: _,
  1477. stats,
  1478. thread: t,
  1479. max_concurrent_connections: _,
  1480. } = spawn_stake_weighted_qos_server(
  1481. "quic_streamer_test",
  1482. [s],
  1483. &keypair,
  1484. sender,
  1485. staked_nodes,
  1486. QuicStreamerConfig {
  1487. ..QuicStreamerConfig::default_for_tests()
  1488. },
  1489. SwQosConfig {
  1490. max_connections_per_unstaked_peer: 2,
  1491. ..Default::default()
  1492. },
  1493. cancel.clone(),
  1494. )
  1495. .unwrap();
  1496. check_multiple_streams(receiver, server_address, None).await;
  1497. assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
  1498. assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
  1499. assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
  1500. assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
  1501. cancel.cancel();
  1502. t.await.unwrap();
  1503. assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
  1504. assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
  1505. }
  1506. #[test]
  1507. fn test_prune_table_with_ip() {
  1508. use std::net::Ipv4Addr;
  1509. agave_logger::setup();
  1510. let cancel = CancellationToken::new();
  1511. let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
  1512. let mut num_entries = 5;
  1513. let max_connections_per_peer = 10;
  1514. let sockets: Vec<_> = (0..num_entries)
  1515. .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
  1516. .collect();
  1517. let stats = Arc::new(StreamerStats::default());
  1518. for (i, socket) in sockets.iter().enumerate() {
  1519. table
  1520. .try_add_connection(
  1521. ConnectionTableKey::IP(socket.ip()),
  1522. socket.port(),
  1523. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1524. None,
  1525. ConnectionPeerType::Unstaked,
  1526. Arc::new(AtomicU64::new(i as u64)),
  1527. max_connections_per_peer,
  1528. || Arc::new(NullStreamerCounter {}),
  1529. )
  1530. .unwrap();
  1531. }
  1532. num_entries += 1;
  1533. table
  1534. .try_add_connection(
  1535. ConnectionTableKey::IP(sockets[0].ip()),
  1536. sockets[0].port(),
  1537. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1538. None,
  1539. ConnectionPeerType::Unstaked,
  1540. Arc::new(AtomicU64::new(5)),
  1541. max_connections_per_peer,
  1542. || Arc::new(NullStreamerCounter {}),
  1543. )
  1544. .unwrap();
  1545. let new_size = 3;
  1546. let pruned = table.prune_oldest(new_size);
  1547. assert_eq!(pruned, num_entries as usize - new_size);
  1548. for v in table.table.values() {
  1549. for x in v {
  1550. assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
  1551. }
  1552. }
  1553. assert_eq!(table.table.len(), new_size);
  1554. assert_eq!(table.total_size, new_size);
  1555. for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
  1556. table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
  1557. }
  1558. assert_eq!(table.total_size, 0);
  1559. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
  1560. }
  1561. #[test]
  1562. fn test_prune_table_with_unique_pubkeys() {
  1563. agave_logger::setup();
  1564. let cancel = CancellationToken::new();
  1565. let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
  1566. // We should be able to add more entries than max_connections_per_peer, since each entry is
  1567. // from a different peer pubkey.
  1568. let num_entries = 15;
  1569. let max_connections_per_peer = 10;
  1570. let stats = Arc::new(StreamerStats::default());
  1571. let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
  1572. for (i, pubkey) in pubkeys.iter().enumerate() {
  1573. table
  1574. .try_add_connection(
  1575. ConnectionTableKey::Pubkey(*pubkey),
  1576. 0,
  1577. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1578. None,
  1579. ConnectionPeerType::Unstaked,
  1580. Arc::new(AtomicU64::new(i as u64)),
  1581. max_connections_per_peer,
  1582. || Arc::new(NullStreamerCounter {}),
  1583. )
  1584. .unwrap();
  1585. }
  1586. let new_size = 3;
  1587. let pruned = table.prune_oldest(new_size);
  1588. assert_eq!(pruned, num_entries as usize - new_size);
  1589. assert_eq!(table.table.len(), new_size);
  1590. assert_eq!(table.total_size, new_size);
  1591. for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
  1592. table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
  1593. }
  1594. assert_eq!(table.total_size, 0);
  1595. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
  1596. }
  1597. #[test]
  1598. fn test_prune_table_with_non_unique_pubkeys() {
  1599. agave_logger::setup();
  1600. let cancel = CancellationToken::new();
  1601. let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
  1602. let max_connections_per_peer = 10;
  1603. let pubkey = Pubkey::new_unique();
  1604. let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
  1605. (0..max_connections_per_peer).for_each(|i| {
  1606. table
  1607. .try_add_connection(
  1608. ConnectionTableKey::Pubkey(pubkey),
  1609. 0,
  1610. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1611. None,
  1612. ConnectionPeerType::Unstaked,
  1613. Arc::new(AtomicU64::new(i as u64)),
  1614. max_connections_per_peer,
  1615. || Arc::new(NullStreamerCounter {}),
  1616. )
  1617. .unwrap();
  1618. });
  1619. // We should NOT be able to add more entries than max_connections_per_peer, since we are
  1620. // using the same peer pubkey.
  1621. assert!(table
  1622. .try_add_connection(
  1623. ConnectionTableKey::Pubkey(pubkey),
  1624. 0,
  1625. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1626. None,
  1627. ConnectionPeerType::Unstaked,
  1628. Arc::new(AtomicU64::new(10)),
  1629. max_connections_per_peer,
  1630. || Arc::new(NullStreamerCounter {})
  1631. )
  1632. .is_none());
  1633. // We should be able to add an entry from another peer pubkey
  1634. let num_entries = max_connections_per_peer + 1;
  1635. let pubkey2 = Pubkey::new_unique();
  1636. assert!(table
  1637. .try_add_connection(
  1638. ConnectionTableKey::Pubkey(pubkey2),
  1639. 0,
  1640. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1641. None,
  1642. ConnectionPeerType::Unstaked,
  1643. Arc::new(AtomicU64::new(10)),
  1644. max_connections_per_peer,
  1645. || Arc::new(NullStreamerCounter {})
  1646. )
  1647. .is_some());
  1648. assert_eq!(table.total_size, num_entries);
  1649. let new_max_size = 3;
  1650. let pruned = table.prune_oldest(new_max_size);
  1651. assert!(pruned >= num_entries - new_max_size);
  1652. assert!(table.table.len() <= new_max_size);
  1653. assert!(table.total_size <= new_max_size);
  1654. table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
  1655. assert_eq!(table.total_size, 0);
  1656. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
  1657. }
  1658. #[test]
  1659. fn test_prune_table_random() {
  1660. use std::net::Ipv4Addr;
  1661. agave_logger::setup();
  1662. let cancel = CancellationToken::new();
  1663. let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
  1664. let num_entries = 5;
  1665. let max_connections_per_peer = 10;
  1666. let sockets: Vec<_> = (0..num_entries)
  1667. .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
  1668. .collect();
  1669. let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
  1670. for (i, socket) in sockets.iter().enumerate() {
  1671. table
  1672. .try_add_connection(
  1673. ConnectionTableKey::IP(socket.ip()),
  1674. socket.port(),
  1675. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1676. None,
  1677. ConnectionPeerType::Staked((i + 1) as u64),
  1678. Arc::new(AtomicU64::new(i as u64)),
  1679. max_connections_per_peer,
  1680. || Arc::new(NullStreamerCounter {}),
  1681. )
  1682. .unwrap();
  1683. }
  1684. // Try pruninng with threshold stake less than all the entries in the table
  1685. // It should fail to prune (i.e. return 0 number of pruned entries)
  1686. let pruned = table.prune_random(/*sample_size:*/ 2, /*threshold_stake:*/ 0);
  1687. assert_eq!(pruned, 0);
  1688. // Try pruninng with threshold stake higher than all the entries in the table
  1689. // It should succeed to prune (i.e. return 1 number of pruned entries)
  1690. let pruned = table.prune_random(
  1691. 2, // sample_size
  1692. num_entries as u64 + 1, // threshold_stake
  1693. );
  1694. assert_eq!(pruned, 1);
  1695. // We had 5 connections and pruned 1, we should have 4 left
  1696. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
  1697. }
  1698. #[test]
  1699. fn test_remove_connections() {
  1700. use std::net::Ipv4Addr;
  1701. agave_logger::setup();
  1702. let cancel = CancellationToken::new();
  1703. let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
  1704. let num_ips = 5;
  1705. let max_connections_per_peer = 10;
  1706. let mut sockets: Vec<_> = (0..num_ips)
  1707. .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
  1708. .collect();
  1709. let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
  1710. for (i, socket) in sockets.iter().enumerate() {
  1711. table
  1712. .try_add_connection(
  1713. ConnectionTableKey::IP(socket.ip()),
  1714. socket.port(),
  1715. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1716. None,
  1717. ConnectionPeerType::Unstaked,
  1718. Arc::new(AtomicU64::new((i * 2) as u64)),
  1719. max_connections_per_peer,
  1720. || Arc::new(NullStreamerCounter {}),
  1721. )
  1722. .unwrap();
  1723. table
  1724. .try_add_connection(
  1725. ConnectionTableKey::IP(socket.ip()),
  1726. socket.port(),
  1727. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1728. None,
  1729. ConnectionPeerType::Unstaked,
  1730. Arc::new(AtomicU64::new((i * 2 + 1) as u64)),
  1731. max_connections_per_peer,
  1732. || Arc::new(NullStreamerCounter {}),
  1733. )
  1734. .unwrap();
  1735. }
  1736. let single_connection_addr =
  1737. SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
  1738. table
  1739. .try_add_connection(
  1740. ConnectionTableKey::IP(single_connection_addr.ip()),
  1741. single_connection_addr.port(),
  1742. ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
  1743. None,
  1744. ConnectionPeerType::Unstaked,
  1745. Arc::new(AtomicU64::new((num_ips * 2) as u64)),
  1746. max_connections_per_peer,
  1747. || Arc::new(NullStreamerCounter {}),
  1748. )
  1749. .unwrap();
  1750. let zero_connection_addr =
  1751. SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
  1752. sockets.push(single_connection_addr);
  1753. sockets.push(zero_connection_addr);
  1754. for socket in sockets.iter() {
  1755. table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
  1756. }
  1757. assert_eq!(table.total_size, 0);
  1758. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
  1759. }
  1760. #[tokio::test(flavor = "multi_thread")]
  1761. async fn test_throttling_check_no_packet_drop() {
  1762. agave_logger::setup_with_default_filter();
  1763. let SpawnTestServerResult {
  1764. join_handle,
  1765. receiver,
  1766. server_address,
  1767. stats,
  1768. cancel,
  1769. } = setup_quic_server(
  1770. None,
  1771. QuicStreamerConfig::default_for_tests(),
  1772. SwQosConfig::default(),
  1773. );
  1774. let client_connection = make_client_endpoint(&server_address, None).await;
  1775. // unstaked connection can handle up to 100tps, so we should send in ~1s.
  1776. let expected_num_txs = 100;
  1777. let start_time = tokio::time::Instant::now();
  1778. for i in 0..expected_num_txs {
  1779. let mut send_stream = client_connection.open_uni().await.unwrap();
  1780. let data = format!("{i}").into_bytes();
  1781. send_stream.write_all(&data).await.unwrap();
  1782. send_stream.finish().unwrap();
  1783. }
  1784. let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
  1785. info!("Elapsed sending: {elapsed_sending}");
  1786. // check that delivered all of them
  1787. let start_time = tokio::time::Instant::now();
  1788. let mut num_txs_received = 0;
  1789. while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
  1790. if let Ok(packets) = receiver.try_recv() {
  1791. num_txs_received += packets.len();
  1792. } else {
  1793. sleep(Duration::from_millis(100)).await;
  1794. }
  1795. }
  1796. assert_eq!(expected_num_txs, num_txs_received);
  1797. cancel.cancel();
  1798. join_handle.await.unwrap();
  1799. assert_eq!(
  1800. stats.total_new_streams.load(Ordering::Relaxed),
  1801. expected_num_txs
  1802. );
  1803. assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
  1804. }
  1805. #[test]
  1806. fn test_client_connection_tracker() {
  1807. let stats = Arc::new(StreamerStats::default());
  1808. let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
  1809. assert!(tracker_1.is_ok());
  1810. assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
  1811. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
  1812. // dropping the connection, concurrent connections should become 0
  1813. drop(tracker_1);
  1814. assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
  1815. }
  1816. #[tokio::test(flavor = "multi_thread")]
  1817. async fn test_client_connection_close_invalid_stream() {
  1818. let SpawnTestServerResult {
  1819. join_handle,
  1820. server_address,
  1821. stats,
  1822. cancel,
  1823. ..
  1824. } = setup_quic_server(
  1825. None,
  1826. QuicStreamerConfig::default_for_tests(),
  1827. SwQosConfig::default(),
  1828. );
  1829. let client_connection = make_client_endpoint(&server_address, None).await;
  1830. let mut send_stream = client_connection.open_uni().await.unwrap();
  1831. send_stream
  1832. .write_all(&[42; PACKET_DATA_SIZE + 1])
  1833. .await
  1834. .unwrap();
  1835. match client_connection.closed().await {
  1836. ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
  1837. assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
  1838. assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
  1839. }
  1840. _ => panic!("unexpected close"),
  1841. }
  1842. assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
  1843. cancel.cancel();
  1844. join_handle.await.unwrap();
  1845. }
  1846. }