ソースを参照

Quic update identity (#33865)

Update the Quic transport layer keypair and identity when the Validator's identity keypair is updated
ryleung-solana 1 年間 前
コミット
132c910f81

+ 20 - 3
client/src/connection_cache.rs

@@ -9,7 +9,10 @@ use {
         },
     },
     solana_quic_client::{QuicConfig, QuicConnectionManager, QuicPool},
-    solana_sdk::{pubkey::Pubkey, signature::Keypair, transport::Result as TransportResult},
+    solana_sdk::{
+        pubkey::Pubkey, quic::NotifyKeyUpdate, signature::Keypair,
+        transport::Result as TransportResult,
+    },
     solana_streamer::streamer::StakedNodes,
     solana_udp_client::{UdpConfig, UdpConnectionManager, UdpPool},
     std::{
@@ -43,6 +46,15 @@ pub enum NonblockingClientConnection {
     Udp(Arc<<UdpBaseClientConnection as BaseClientConnection>::NonblockingClientConnection>),
 }
 
+impl NotifyKeyUpdate for ConnectionCache {
+    fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+        match self {
+            Self::Udp(_) => Ok(()),
+            Self::Quic(backend) => backend.update_key(key),
+        }
+    }
+}
+
 impl ConnectionCache {
     pub fn new(name: &'static str) -> Self {
         if DEFAULT_CONNECTION_CACHE_USE_QUIC {
@@ -217,7 +229,8 @@ mod tests {
         crossbeam_channel::unbounded,
         solana_sdk::{net::DEFAULT_TPU_COALESCE, signature::Keypair},
         solana_streamer::{
-            nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT, streamer::StakedNodes,
+            nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT, quic::SpawnServerResult,
+            streamer::StakedNodes,
         },
         std::{
             net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
@@ -245,7 +258,11 @@ mod tests {
 
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
 
-        let (response_recv_endpoint, response_recv_thread) = solana_streamer::quic::spawn_server(
+        let SpawnServerResult {
+            endpoint: response_recv_endpoint,
+            thread: response_recv_thread,
+            key_updater: _,
+        } = solana_streamer::quic::spawn_server(
             "quic_streamer_test",
             response_recv_socket,
             &keypair2,

+ 11 - 1
connection-cache/src/connection_cache.rs

@@ -9,7 +9,7 @@ use {
     log::*,
     rand::{thread_rng, Rng},
     solana_measure::measure::Measure,
-    solana_sdk::timing::AtomicInterval,
+    solana_sdk::{signature::Keypair, timing::AtomicInterval},
     std::{
         net::SocketAddr,
         sync::{atomic::Ordering, Arc, RwLock},
@@ -38,6 +38,7 @@ pub trait ConnectionManager: Send + Sync + 'static {
 
     fn new_connection_pool(&self) -> Self::ConnectionPool;
     fn new_connection_config(&self) -> Self::NewConnectionConfig;
+    fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>>;
 }
 
 pub struct ConnectionCache<
@@ -137,6 +138,11 @@ where
             .unwrap()
     }
 
+    pub fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+        let mut map = self.map.write().unwrap();
+        map.clear();
+        self.connection_manager.update_key(key)
+    }
     /// Create a lazy connection object under the exclusive lock of the cache map if there is not
     /// enough used connections in the connection pool for the specified address.
     /// Returns CreateConnectionResult.
@@ -636,6 +642,10 @@ mod tests {
         fn new_connection_config(&self) -> Self::NewConnectionConfig {
             MockUdpConfig::new().unwrap()
         }
+
+        fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+            Ok(())
+        }
     }
 
     impl BlockingClientConnection for MockUdpConnection {

+ 2 - 1
core/src/admin_rpc_post_init.rs

@@ -1,7 +1,7 @@
 use {
     solana_gossip::cluster_info::ClusterInfo,
     solana_runtime::bank_forks::BankForks,
-    solana_sdk::pubkey::Pubkey,
+    solana_sdk::{pubkey::Pubkey, quic::NotifyKeyUpdate},
     std::{
         collections::HashSet,
         sync::{Arc, RwLock},
@@ -14,4 +14,5 @@ pub struct AdminRpcRequestMetadataPostInit {
     pub bank_forks: Arc<RwLock<BankForks>>,
     pub vote_account: Pubkey,
     pub repair_whitelist: Arc<RwLock<HashSet<Pubkey>>>,
+    pub notifies: Vec<Arc<dyn NotifyKeyUpdate + Sync + Send>>,
 }

+ 29 - 18
core/src/tpu.rs

@@ -31,10 +31,10 @@ use {
         rpc_subscriptions::RpcSubscriptions,
     },
     solana_runtime::{bank_forks::BankForks, prioritization_fee_cache::PrioritizationFeeCache},
-    solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Keypair},
+    solana_sdk::{clock::Slot, pubkey::Pubkey, quic::NotifyKeyUpdate, signature::Keypair},
     solana_streamer::{
         nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
-        quic::{spawn_server, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS},
+        quic::{spawn_server, SpawnServerResult, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS},
         streamer::StakedNodes,
     },
     solana_turbine::broadcast_stage::{BroadcastStage, BroadcastStageType},
@@ -111,7 +111,7 @@ impl Tpu {
         prioritization_fee_cache: &Arc<PrioritizationFeeCache>,
         block_production_method: BlockProductionMethod,
         _generator_config: Option<GeneratorConfig>, /* vestigial code for replay invalidator */
-    ) -> Self {
+    ) -> (Self, Vec<Arc<dyn NotifyKeyUpdate + Sync + Send>>) {
         let TpuSockets {
             transactions: transactions_sockets,
             transaction_forwards: tpu_forwards_sockets,
@@ -148,7 +148,11 @@ impl Tpu {
 
         let (non_vote_sender, non_vote_receiver) = banking_tracer.create_channel_non_vote();
 
-        let (_, tpu_quic_t) = spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: tpu_quic_t,
+            key_updater,
+        } = spawn_server(
             "quic_streamer_tpu",
             transactions_quic_sockets,
             keypair,
@@ -168,7 +172,11 @@ impl Tpu {
         )
         .unwrap();
 
-        let (_, tpu_forwards_quic_t) = spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: tpu_forwards_quic_t,
+            key_updater: forwards_key_updater,
+        } = spawn_server(
             "quic_streamer_tpu_forwards",
             transactions_forwards_quic_sockets,
             keypair,
@@ -259,19 +267,22 @@ impl Tpu {
             turbine_quic_endpoint_sender,
         );
 
-        Self {
-            fetch_stage,
-            sigverify_stage,
-            vote_sigverify_stage,
-            banking_stage,
-            cluster_info_vote_listener,
-            broadcast_stage,
-            tpu_quic_t,
-            tpu_forwards_quic_t,
-            tpu_entry_notifier,
-            staked_nodes_updater_service,
-            tracer_thread_hdl,
-        }
+        (
+            Self {
+                fetch_stage,
+                sigverify_stage,
+                vote_sigverify_stage,
+                banking_stage,
+                cluster_info_vote_listener,
+                broadcast_stage,
+                tpu_quic_t,
+                tpu_forwards_quic_t,
+                tpu_entry_notifier,
+                staked_nodes_updater_service,
+                tracer_thread_hdl,
+            },
+            vec![key_updater, forwards_key_updater],
+        )
     }
 
     pub fn join(self) -> thread::Result<()> {

+ 11 - 8
core/src/validator.rs

@@ -1080,13 +1080,6 @@ impl Validator {
             exit.clone(),
         );
 
-        *admin_rpc_service_post_init.write().unwrap() = Some(AdminRpcRequestMetadataPostInit {
-            bank_forks: bank_forks.clone(),
-            cluster_info: cluster_info.clone(),
-            vote_account: *vote_account,
-            repair_whitelist: config.repair_whitelist.clone(),
-        });
-
         let waited_for_supermajority = wait_for_supermajority(
             config,
             Some(&mut process_blockstore),
@@ -1295,7 +1288,7 @@ impl Validator {
             };
         }
 
-        let tpu = Tpu::new(
+        let (tpu, mut key_notifies) = Tpu::new(
             &cluster_info,
             &poh_recorder,
             entry_receiver,
@@ -1346,6 +1339,16 @@ impl Validator {
         );
 
         *start_progress.write().unwrap() = ValidatorStartProgress::Running;
+        key_notifies.push(connection_cache);
+
+        *admin_rpc_service_post_init.write().unwrap() = Some(AdminRpcRequestMetadataPostInit {
+            bank_forks: bank_forks.clone(),
+            cluster_info: cluster_info.clone(),
+            vote_account: *vote_account,
+            repair_whitelist: config.repair_whitelist.clone(),
+            notifies: key_notifies,
+        });
+
         Ok(Self {
             stats_reporter_service,
             gossip_service,

+ 45 - 11
quic-client/src/lib.rs

@@ -84,39 +84,52 @@ impl ConnectionPool for QuicPool {
     }
 }
 
-#[derive(Clone)]
 pub struct QuicConfig {
-    client_certificate: Arc<QuicClientCertificate>,
+    // Arc to prevent having to copy the struct
+    client_certificate: RwLock<Arc<QuicClientCertificate>>,
     maybe_staked_nodes: Option<Arc<RwLock<StakedNodes>>>,
     maybe_client_pubkey: Option<Pubkey>,
 
     // The optional specified endpoint for the quic based client connections
     // If not specified, the connection cache will create as needed.
     client_endpoint: Option<Endpoint>,
+    addr: IpAddr,
+}
+
+impl Clone for QuicConfig {
+    fn clone(&self) -> Self {
+        let cert_guard = self.client_certificate.read().unwrap();
+        QuicConfig {
+            client_certificate: RwLock::new(cert_guard.clone()),
+            maybe_staked_nodes: self.maybe_staked_nodes.clone(),
+            maybe_client_pubkey: self.maybe_client_pubkey,
+            client_endpoint: self.client_endpoint.clone(),
+            addr: self.addr,
+        }
+    }
 }
 
 impl NewConnectionConfig for QuicConfig {
     fn new() -> Result<Self, ClientError> {
-        let (cert, priv_key) =
-            new_self_signed_tls_certificate(&Keypair::new(), IpAddr::V4(Ipv4Addr::UNSPECIFIED))?;
+        let addr = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
+        let (cert, priv_key) = new_self_signed_tls_certificate(&Keypair::new(), addr)?;
         Ok(Self {
-            client_certificate: Arc::new(QuicClientCertificate {
+            client_certificate: RwLock::new(Arc::new(QuicClientCertificate {
                 certificate: cert,
                 key: priv_key,
-            }),
+            })),
             maybe_staked_nodes: None,
             maybe_client_pubkey: None,
             client_endpoint: None,
+            addr,
         })
     }
 }
 
 impl QuicConfig {
     fn create_endpoint(&self) -> QuicLazyInitializedEndpoint {
-        QuicLazyInitializedEndpoint::new(
-            self.client_certificate.clone(),
-            self.client_endpoint.as_ref().cloned(),
-        )
+        let cert_guard = self.client_certificate.read().unwrap();
+        QuicLazyInitializedEndpoint::new(cert_guard.clone(), self.client_endpoint.as_ref().cloned())
     }
 
     fn compute_max_parallel_streams(&self) -> usize {
@@ -143,7 +156,23 @@ impl QuicConfig {
         ipaddr: IpAddr,
     ) -> Result<(), RcgenError> {
         let (cert, priv_key) = new_self_signed_tls_certificate(keypair, ipaddr)?;
-        self.client_certificate = Arc::new(QuicClientCertificate {
+        self.addr = ipaddr;
+
+        let mut cert_guard = self.client_certificate.write().unwrap();
+
+        *cert_guard = Arc::new(QuicClientCertificate {
+            certificate: cert,
+            key: priv_key,
+        });
+        Ok(())
+    }
+
+    pub fn update_keypair(&self, keypair: &Keypair) -> Result<(), RcgenError> {
+        let (cert, priv_key) = new_self_signed_tls_certificate(keypair, self.addr)?;
+
+        let mut cert_guard = self.client_certificate.write().unwrap();
+
+        *cert_guard = Arc::new(QuicClientCertificate {
             certificate: cert,
             key: priv_key,
         });
@@ -212,6 +241,11 @@ impl ConnectionManager for QuicConnectionManager {
     fn new_connection_config(&self) -> QuicConfig {
         self.connection_config.clone()
     }
+
+    fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+        self.connection_config.update_keypair(key)?;
+        Ok(())
+    }
 }
 
 impl QuicConnectionManager {

+ 17 - 5
quic-client/tests/quic_client.rs

@@ -10,8 +10,8 @@ mod tests {
         },
         solana_sdk::{net::DEFAULT_TPU_COALESCE, packet::PACKET_DATA_SIZE, signature::Keypair},
         solana_streamer::{
-            nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT, streamer::StakedNodes,
-            tls_certificates::new_self_signed_tls_certificate,
+            nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT, quic::SpawnServerResult,
+            streamer::StakedNodes, tls_certificates::new_self_signed_tls_certificate,
         },
         std::{
             net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
@@ -68,7 +68,11 @@ mod tests {
         let (sender, receiver) = unbounded();
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
         let (s, exit, keypair, ip) = server_args();
-        let (_, t) = solana_streamer::quic::spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: t,
+            key_updater: _,
+        } = solana_streamer::quic::spawn_server(
             "quic_streamer_test",
             s.try_clone().unwrap(),
             &keypair,
@@ -204,7 +208,11 @@ mod tests {
         let (sender, receiver) = unbounded();
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
         let (request_recv_socket, request_recv_exit, keypair, request_recv_ip) = server_args();
-        let (request_recv_endpoint, request_recv_thread) = solana_streamer::quic::spawn_server(
+        let SpawnServerResult {
+            endpoint: request_recv_endpoint,
+            thread: request_recv_thread,
+            key_updater: _,
+        } = solana_streamer::quic::spawn_server(
             "quic_streamer_test",
             request_recv_socket.try_clone().unwrap(),
             &keypair,
@@ -228,7 +236,11 @@ mod tests {
         let addr = response_recv_socket.local_addr().unwrap().ip();
         let port = response_recv_socket.local_addr().unwrap().port();
         let server_addr = SocketAddr::new(addr, port);
-        let (response_recv_endpoint, response_recv_thread) = solana_streamer::quic::spawn_server(
+        let SpawnServerResult {
+            endpoint: response_recv_endpoint,
+            thread: response_recv_thread,
+            key_updater: _,
+        } = solana_streamer::quic::spawn_server(
             "quic_streamer_test",
             response_recv_socket,
             &keypair2,

+ 6 - 1
sdk/src/quic.rs

@@ -1,5 +1,6 @@
+#![cfg(feature = "full")]
 //! Definitions related to Solana over QUIC.
-use std::time::Duration;
+use {crate::signer::keypair::Keypair, std::time::Duration};
 
 pub const QUIC_PORT_OFFSET: u16 = 6;
 // Empirically found max number of concurrent streams
@@ -35,3 +36,7 @@ pub const QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO: u64 = 128;
 /// The receive window for QUIC connection from maximum staked nodes is
 /// set to this ratio times [`solana_sdk::packet::PACKET_DATA_SIZE`]
 pub const QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO: u64 = 512;
+
+pub trait NotifyKeyUpdate {
+    fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>>;
+}

+ 45 - 6
streamer/src/quic.rs

@@ -10,7 +10,7 @@ use {
     solana_perf::packet::PacketBatch,
     solana_sdk::{
         packet::PACKET_DATA_SIZE,
-        quic::{QUIC_MAX_TIMEOUT, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS},
+        quic::{NotifyKeyUpdate, QUIC_MAX_TIMEOUT, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS},
         signature::Keypair,
     },
     std::{
@@ -36,6 +36,12 @@ impl SkipClientVerification {
     }
 }
 
+pub struct SpawnServerResult {
+    pub endpoint: Endpoint,
+    pub thread: thread::JoinHandle<()>,
+    pub key_updater: Arc<EndpointKeyUpdater>,
+}
+
 impl rustls::server::ClientCertVerifier for SkipClientVerification {
     fn client_auth_root_subjects(&self) -> &[DistinguishedName] {
         &[]
@@ -113,6 +119,19 @@ pub enum QuicServerError {
     TlsError(#[from] rustls::Error),
 }
 
+pub struct EndpointKeyUpdater {
+    endpoint: Endpoint,
+    gossip_host: IpAddr,
+}
+
+impl NotifyKeyUpdate for EndpointKeyUpdater {
+    fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+        let (config, _) = configure_server(key, self.gossip_host)?;
+        self.endpoint.set_server_config(Some(config));
+        Ok(())
+    }
+}
+
 #[derive(Default)]
 pub struct StreamStats {
     pub(crate) total_connections: AtomicUsize,
@@ -404,7 +423,7 @@ pub fn spawn_server(
     max_unstaked_connections: usize,
     wait_for_chunk_timeout: Duration,
     coalesce: Duration,
-) -> Result<(Endpoint, thread::JoinHandle<()>), QuicServerError> {
+) -> Result<SpawnServerResult, QuicServerError> {
     let runtime = rt();
     let (endpoint, _stats, task) = {
         let _guard = runtime.enter();
@@ -431,7 +450,15 @@ pub fn spawn_server(
             }
         })
         .unwrap();
-    Ok((endpoint, handle))
+    let updater = EndpointKeyUpdater {
+        endpoint: endpoint.clone(),
+        gossip_host,
+    };
+    Ok(SpawnServerResult {
+        endpoint,
+        thread: handle,
+        key_updater: Arc::new(updater),
+    })
 }
 
 #[cfg(test)]
@@ -457,7 +484,11 @@ mod test {
         let ip = "127.0.0.1".parse().unwrap();
         let server_address = s.local_addr().unwrap();
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
-        let (_, t) = spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: t,
+            key_updater: _,
+        } = spawn_server(
             "quic_streamer_test",
             s,
             &keypair,
@@ -513,7 +544,11 @@ mod test {
         let ip = "127.0.0.1".parse().unwrap();
         let server_address = s.local_addr().unwrap();
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
-        let (_, t) = spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: t,
+            key_updater: _,
+        } = spawn_server(
             "quic_streamer_test",
             s,
             &keypair,
@@ -556,7 +591,11 @@ mod test {
         let ip = "127.0.0.1".parse().unwrap();
         let server_address = s.local_addr().unwrap();
         let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
-        let (_, t) = spawn_server(
+        let SpawnServerResult {
+            endpoint: _,
+            thread: t,
+            key_updater: _,
+        } = spawn_server(
             "quic_streamer_test",
             s,
             &keypair,

+ 5 - 0
udp-client/src/lib.rs

@@ -15,6 +15,7 @@ use {
         },
         connection_cache_stats::ConnectionCacheStats,
     },
+    solana_sdk::signature::Keypair,
     std::{
         net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
         sync::Arc,
@@ -112,4 +113,8 @@ impl ConnectionManager for UdpConnectionManager {
     fn new_connection_config(&self) -> Self::NewConnectionConfig {
         UdpConfig::new().unwrap()
     }
+
+    fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
+        Ok(())
+    }
 }

+ 7 - 0
validator/src/admin_rpc_service.rs

@@ -682,6 +682,12 @@ impl AdminRpcImpl {
                     })?;
             }
 
+            for n in post_init.notifies.iter() {
+                if let Err(err) = n.update_key(&identity_keypair) {
+                    error!("Error updating network layer keypair: {err}");
+                }
+            }
+
             solana_metrics::set_host_id(identity_keypair.pubkey().to_string());
             post_init
                 .cluster_info
@@ -888,6 +894,7 @@ mod tests {
                     bank_forks: bank_forks.clone(),
                     vote_account,
                     repair_whitelist,
+                    notifies: Vec::new(),
                 }))),
                 staked_nodes_overrides: Arc::new(RwLock::new(HashMap::new())),
                 rpc_to_plugin_manager_sender: None,