浏览代码

Refactor key update notifier interface to support remove (#6196)

Lijun Wang 6 月之前
父节点
当前提交
4daf16d925
共有 4 个文件被更改,包括 113 次插入37 次删除
  1. 56 2
      core/src/admin_rpc_post_init.rs
  2. 25 25
      core/src/tpu.rs
  3. 10 5
      core/src/validator.rs
  4. 22 5
      validator/src/admin_rpc_service.rs

+ 56 - 2
core/src/admin_rpc_post_init.rs

@@ -8,19 +8,73 @@ use {
     solana_quic_definitions::NotifyKeyUpdate,
     solana_runtime::bank_forks::BankForks,
     std::{
-        collections::HashSet,
+        collections::{HashMap, HashSet},
         net::UdpSocket,
         sync::{Arc, RwLock},
     },
 };
 
+/// Key updaters:
+#[derive(PartialEq, Eq, Hash, Clone, Debug)]
+pub enum KeyUpdaterType {
+    /// TPU key updater
+    Tpu,
+    /// TPU forwards key updater
+    TpuForwards,
+    /// TPU vote key updater
+    TpuVote,
+    /// Forward key updater
+    Forward,
+    /// For the RPC service
+    RpcService,
+}
+
+/// Responsible for managing the updaters for identity key change
+#[derive(Default)]
+pub struct KeyUpdaters {
+    updaters: HashMap<KeyUpdaterType, Arc<dyn NotifyKeyUpdate + Sync + Send>>,
+}
+
+impl KeyUpdaters {
+    /// Add a new key updater to the list
+    pub fn add(
+        &mut self,
+        updater_type: KeyUpdaterType,
+        updater: Arc<dyn NotifyKeyUpdate + Sync + Send>,
+    ) {
+        self.updaters.insert(updater_type, updater);
+    }
+
+    /// Remove a key updater by its key
+    pub fn remove(&mut self, updater_type: &KeyUpdaterType) {
+        self.updaters.remove(updater_type);
+    }
+}
+
+/// Implement the Iterator trait for KeyUpdaters
+impl<'a> IntoIterator for &'a KeyUpdaters {
+    type Item = (
+        &'a KeyUpdaterType,
+        &'a Arc<dyn NotifyKeyUpdate + Sync + Send>,
+    );
+    type IntoIter = std::collections::hash_map::Iter<
+        'a,
+        KeyUpdaterType,
+        Arc<dyn NotifyKeyUpdate + Sync + Send>,
+    >;
+
+    fn into_iter(self) -> Self::IntoIter {
+        self.updaters.iter()
+    }
+}
+
 #[derive(Clone)]
 pub struct AdminRpcRequestMetadataPostInit {
     pub cluster_info: Arc<ClusterInfo>,
     pub bank_forks: Arc<RwLock<BankForks>>,
     pub vote_account: Pubkey,
     pub repair_whitelist: Arc<RwLock<HashSet<Pubkey>>>,
-    pub notifies: Vec<Arc<dyn NotifyKeyUpdate + Sync + Send>>,
+    pub notifies: Arc<RwLock<KeyUpdaters>>,
     pub repair_socket: Arc<UdpSocket>,
     pub outstanding_repair_requests: Arc<RwLock<OutstandingRequests<ShredRepairType>>>,
     pub cluster_slots: Arc<ClusterSlots>,

+ 25 - 25
core/src/tpu.rs

@@ -12,6 +12,7 @@ pub use {
 };
 use {
     crate::{
+        admin_rpc_post_init::{KeyUpdaterType, KeyUpdaters},
         banking_stage::BankingStage,
         banking_trace::{Channels, TracerThread},
         cluster_info_vote_listener::{
@@ -44,7 +45,6 @@ use {
         transaction_recorder::TransactionRecorder,
     },
     solana_pubkey::Pubkey,
-    solana_quic_definitions::NotifyKeyUpdate,
     solana_rpc::{
         optimistically_confirmed_bank_tracker::BankNotificationSender,
         rpc_subscriptions::RpcSubscriptions,
@@ -156,7 +156,8 @@ impl Tpu {
         transaction_struct: TransactionStructure,
         enable_block_production_forwarding: bool,
         _generator_config: Option<GeneratorConfig>, /* vestigial code for replay invalidator */
-    ) -> (Self, Vec<Arc<dyn NotifyKeyUpdate + Sync + Send>>) {
+        key_notifiers: Arc<RwLock<KeyUpdaters>>,
+    ) -> Self {
         let TpuSockets {
             transactions: transactions_sockets,
             transaction_forwards: tpu_forwards_sockets,
@@ -373,33 +374,32 @@ impl Tpu {
             turbine_quic_endpoint_sender,
         );
 
-        let mut key_updaters: Vec<Arc<dyn NotifyKeyUpdate + Send + Sync>> = Vec::new();
+        let mut key_notifiers = key_notifiers.write().unwrap();
         if let Some(key_updater) = key_updater {
-            key_updaters.push(key_updater);
+            key_notifiers.add(KeyUpdaterType::Tpu, key_updater);
         }
         if let Some(forwards_key_updater) = forwards_key_updater {
-            key_updaters.push(forwards_key_updater);
+            key_notifiers.add(KeyUpdaterType::TpuForwards, forwards_key_updater);
+        }
+        key_notifiers.add(KeyUpdaterType::TpuVote, vote_streamer_key_updater);
+
+        key_notifiers.add(KeyUpdaterType::Forward, client_updater);
+
+        Self {
+            fetch_stage,
+            sig_verifier,
+            vote_sigverify_stage,
+            banking_stage,
+            forwarding_stage,
+            cluster_info_vote_listener,
+            broadcast_stage,
+            tpu_quic_t,
+            tpu_forwards_quic_t,
+            tpu_entry_notifier,
+            staked_nodes_updater_service,
+            tracer_thread_hdl,
+            tpu_vote_quic_t,
         }
-        key_updaters.push(vote_streamer_key_updater);
-        key_updaters.push(client_updater);
-        (
-            Self {
-                fetch_stage,
-                sig_verifier,
-                vote_sigverify_stage,
-                banking_stage,
-                forwarding_stage,
-                cluster_info_vote_listener,
-                broadcast_stage,
-                tpu_quic_t,
-                tpu_forwards_quic_t,
-                tpu_entry_notifier,
-                staked_nodes_updater_service,
-                tracer_thread_hdl,
-                tpu_vote_quic_t,
-            },
-            key_updaters,
-        )
     }
 
     pub fn join(self) -> thread::Result<()> {

+ 10 - 5
core/src/validator.rs

@@ -4,7 +4,7 @@ pub use solana_perf::report_target_features;
 use {
     crate::{
         accounts_hash_verifier::AccountsHashVerifier,
-        admin_rpc_post_init::AdminRpcRequestMetadataPostInit,
+        admin_rpc_post_init::{AdminRpcRequestMetadataPostInit, KeyUpdaterType, KeyUpdaters},
         banking_trace::{self, BankingTracer, TraceError},
         cluster_info_vote_listener::VoteTracker,
         completed_data_sets_service::CompletedDataSetsService,
@@ -1581,7 +1581,8 @@ impl Validator {
             return Err(ValidatorError::WenRestartFinished.into());
         }
 
-        let forwarding_tpu_client = if let Some(connection_cache) = connection_cache {
+        let key_notifiers = Arc::new(RwLock::new(KeyUpdaters::default()));
+        let forwarding_tpu_client = if let Some(connection_cache) = &connection_cache {
             ForwardingClientOption::ConnectionCache(connection_cache.clone())
         } else {
             let runtime_handle = tpu_client_next_runtime
@@ -1596,7 +1597,7 @@ impl Validator {
                 runtime_handle.clone(),
             ))
         };
-        let (tpu, mut key_notifies) = Tpu::new_with_client(
+        let tpu = Tpu::new_with_client(
             &cluster_info,
             &poh_recorder,
             transaction_recorder,
@@ -1646,6 +1647,7 @@ impl Validator {
             config.transaction_struct.clone(),
             config.enable_block_production_forwarding,
             config.generator_config.clone(),
+            key_notifiers.clone(),
         );
 
         datapoint_info!(
@@ -1661,7 +1663,10 @@ impl Validator {
         *start_progress.write().unwrap() = ValidatorStartProgress::Running;
         if config.use_tpu_client_next {
             if let Some(json_rpc_service) = &json_rpc_service {
-                key_notifies.push(json_rpc_service.get_client_key_updater())
+                key_notifiers.write().unwrap().add(
+                    KeyUpdaterType::RpcService,
+                    json_rpc_service.get_client_key_updater(),
+                );
             }
             // note, that we don't need to add ConnectionClient to key_notifiers
             // because it is added inside Tpu.
@@ -1672,7 +1677,7 @@ impl Validator {
             cluster_info: cluster_info.clone(),
             vote_account: *vote_account,
             repair_whitelist: config.repair_whitelist.clone(),
-            notifies: key_notifies,
+            notifies: key_notifiers,
             repair_socket: Arc::new(node.sockets.repair),
             outstanding_repair_requests,
             cluster_slots,

+ 22 - 5
validator/src/admin_rpc_service.rs

@@ -750,9 +750,9 @@ impl AdminRpcImpl {
                     })?;
             }
 
-            for n in post_init.notifies.iter() {
-                if let Err(err) = n.update_key(&identity_keypair) {
-                    error!("Error updating network layer keypair: {err}");
+            for (key, notifier) in &*post_init.notifies.read().unwrap() {
+                if let Err(err) = notifier.update_key(&identity_keypair) {
+                    error!("Error updating network layer keypair: {err} on {key:?}");
                 }
             }
 
@@ -903,6 +903,7 @@ mod tests {
             accounts_index::AccountSecondaryIndexes,
         },
         solana_core::{
+            admin_rpc_post_init::{KeyUpdaterType, KeyUpdaters},
             consensus::tower_storage::NullTowerStorage,
             validator::{Validator, ValidatorConfig, ValidatorTpuConfig},
         },
@@ -981,7 +982,7 @@ mod tests {
                     bank_forks: bank_forks.clone(),
                     vote_account,
                     repair_whitelist,
-                    notifies: Vec::new(),
+                    notifies: Arc::new(RwLock::new(KeyUpdaters::default())),
                     repair_socket: Arc::new(bind_to_unspecified().unwrap()),
                     outstanding_repair_requests: Arc::<
                         RwLock<repair_service::OutstandingShredRepairs>,
@@ -1427,13 +1428,29 @@ mod tests {
                 start_progress.clone(),
                 SocketAddrSpace::Unspecified,
                 ValidatorTpuConfig::new_for_tests(DEFAULT_TPU_ENABLE_UDP),
-                post_init,
+                post_init.clone(),
             )
             .expect("assume successful validator start");
             assert_eq!(
                 *start_progress.read().unwrap(),
                 ValidatorStartProgress::Running
             );
+            let post_init = post_init.read().unwrap();
+
+            assert!(post_init.is_some());
+            let post_init = post_init.as_ref().unwrap();
+            let notifies = post_init.notifies.read().unwrap();
+            let updater_keys: HashSet<KeyUpdaterType> =
+                notifies.into_iter().map(|(key, _)| key.clone()).collect();
+            assert_eq!(
+                updater_keys,
+                HashSet::from_iter(vec![
+                    KeyUpdaterType::Tpu,
+                    KeyUpdaterType::TpuForwards,
+                    KeyUpdaterType::TpuVote,
+                    KeyUpdaterType::Forward,
+                ])
+            );
             let mut io = MetaIoHandler::default();
             io.extend_with(AdminRpcImpl.to_delegate());
             Self {