Bläddra i källkod

refactor: update `certificate_limits_and_vote_types()` (#9082)

* use vote instead of vote type to remove some unwraps

* remove vec
Akhilesh Singhania 2 dagar sedan
förälder
incheckning
98c168ed6c

+ 20 - 11
votor/src/common.rs

@@ -32,21 +32,30 @@ impl VoteType {
     }
 }
 
-/// Lookup from `CertificateId` to the `VoteType`s that contribute,
-/// as well as the stake fraction required for certificate completion.
+/// For a given [`CertificateType`], returns the fractional stake, the [`Vote`], and the optional fallback [`Vote`] required to construct it.
 ///
-/// Must be in sync with `vote_to_certificate_ids`
-pub const fn certificate_limits_and_vote_types(
+/// Must be in sync with [`vote_to_certificate_ids`].
+pub(crate) fn certificate_limits_and_votes(
     cert_type: &CertificateType,
-) -> (f64, &'static [VoteType]) {
+) -> (f64, Vote, Option<Vote>) {
     match cert_type {
-        CertificateType::Notarize(_, _) => (0.6, &[VoteType::Notarize]),
-        CertificateType::NotarizeFallback(_, _) => {
-            (0.6, &[VoteType::Notarize, VoteType::NotarizeFallback])
+        CertificateType::Notarize(slot, block_id) => {
+            (0.6, Vote::new_notarization_vote(*slot, *block_id), None)
         }
-        CertificateType::FinalizeFast(_, _) => (0.8, &[VoteType::Notarize]),
-        CertificateType::Finalize(_) => (0.6, &[VoteType::Finalize]),
-        CertificateType::Skip(_) => (0.6, &[VoteType::Skip, VoteType::SkipFallback]),
+        CertificateType::NotarizeFallback(slot, block_id) => (
+            0.6,
+            Vote::new_notarization_vote(*slot, *block_id),
+            Some(Vote::new_notarization_fallback_vote(*slot, *block_id)),
+        ),
+        CertificateType::FinalizeFast(slot, block_id) => {
+            (0.8, Vote::new_notarization_vote(*slot, *block_id), None)
+        }
+        CertificateType::Finalize(slot) => (0.6, Vote::new_finalization_vote(*slot), None),
+        CertificateType::Skip(slot) => (
+            0.6,
+            Vote::new_skip_vote(*slot),
+            Some(Vote::new_skip_fallback_vote(*slot)),
+        ),
     }
 }
 

+ 14 - 18
votor/src/consensus_pool.rs

@@ -2,7 +2,7 @@
 
 use {
     crate::{
-        common::{certificate_limits_and_vote_types, vote_to_certificate_ids, Stake},
+        common::{certificate_limits_and_votes, vote_to_certificate_ids, Stake},
         consensus_pool::{
             certificate_builder::{BuildError as CertificateBuilderError, CertificateBuilder},
             parent_ready_tracker::ParentReadyTracker,
@@ -147,7 +147,9 @@ impl ConsensusPool {
         events: &mut Vec<VotorEvent>,
         total_stake: Stake,
     ) -> Result<Vec<Arc<Certificate>>, AddVoteError> {
-        let slot = vote.slot();
+        let Some(vote_pool) = self.vote_pools.get(&vote.slot()) else {
+            return Ok(vec![]);
+        };
         let mut new_certificates_to_send = Vec::new();
         for cert_type in vote_to_certificate_ids(vote) {
             // If the certificate is already complete, skip it
@@ -155,32 +157,26 @@ impl ConsensusPool {
                 continue;
             }
             // Otherwise check whether the certificate is complete
-            let (limit, vote_types) = certificate_limits_and_vote_types(&cert_type);
-            let accumulated_stake = vote_types
-                .iter()
-                .map(|vote_type| {
-                    self.vote_pools
-                        .get(&slot)
-                        .map_or(0, |p| p.get_stake(vote_type, vote.block_id()))
-                })
-                .sum::<Stake>();
+            let (limit, vote, fallback_vote) = certificate_limits_and_votes(&cert_type);
+            let accumulated_stake = vote_pool
+                .get_stake(&vote)
+                .saturating_add(fallback_vote.map_or(0, |v| vote_pool.get_stake(&v)));
 
             if accumulated_stake as f64 / (total_stake as f64) < limit {
                 continue;
             }
             let mut cert_builder = CertificateBuilder::new(cert_type);
-            for vote_type in vote_types {
-                if let Some(vote_pool) = self.vote_pools.get(&slot) {
-                    cert_builder
-                        .aggregate(&vote_pool.get_votes(vote_type, vote.block_id()))
-                        .unwrap();
-                }
+            cert_builder.aggregate(&vote_pool.get_votes(&vote)).unwrap();
+            if let Some(v) = fallback_vote {
+                cert_builder.aggregate(&vote_pool.get_votes(&v)).unwrap();
             }
             let new_cert = Arc::new(cert_builder.build()?);
-            self.insert_certificate(cert_type, new_cert.clone(), events);
             self.stats.incr_cert_type(&new_cert.cert_type, true);
             new_certificates_to_send.push(new_cert);
         }
+        for cert in &new_certificates_to_send {
+            self.insert_certificate(cert.cert_type, cert.clone(), events);
+        }
         Ok(new_certificates_to_send)
     }
 

+ 12 - 14
votor/src/consensus_pool/certificate_builder.rs

@@ -1,5 +1,5 @@
 use {
-    crate::common::{certificate_limits_and_vote_types, VoteType},
+    crate::common::certificate_limits_and_votes,
     agave_votor_messages::consensus_message::{Certificate, CertificateType, VoteMessage},
     bitvec::prelude::*,
     solana_bls_signatures::{BlsError, SignatureProjective},
@@ -146,24 +146,23 @@ impl BuilderType {
         cert_type: &CertificateType,
         msgs: &[VoteMessage],
     ) -> Result<(), AggregateError> {
-        let vote_types = certificate_limits_and_vote_types(cert_type).1;
+        let (_, vote, fallback_vote) = certificate_limits_and_votes(cert_type);
         match self {
             Self::DoubleVote {
                 signature,
                 bitmap0,
                 bitmap1,
             } => {
-                debug_assert_eq!(vote_types.len(), 2);
-                if vote_types.len() != 2 {
+                debug_assert!(fallback_vote.is_some());
+                let Some(fallback_vote) = fallback_vote else {
                     return Err(AggregateError::InvalidVoteTypes);
-                }
+                };
                 for msg in msgs {
-                    let vote_type = VoteType::get_type(&msg.vote);
-                    if vote_type == vote_types[0] {
+                    if msg.vote == vote {
                         try_set_bitmap(bitmap0, msg.rank)?;
                     } else {
-                        debug_assert_eq!(vote_type, vote_types[1]);
-                        if vote_type != vote_types[1] {
+                        debug_assert_eq!(msg.vote, fallback_vote);
+                        if msg.vote != fallback_vote {
                             return Err(AggregateError::InvalidVoteTypes);
                         }
                         match bitmap1 {
@@ -180,14 +179,13 @@ impl BuilderType {
             }
 
             Self::SingleVote { signature, bitmap } => {
-                debug_assert_eq!(vote_types.len(), 1);
-                if vote_types.len() != 1 {
+                debug_assert!(fallback_vote.is_none());
+                if fallback_vote.is_some() {
                     return Err(AggregateError::InvalidVoteTypes);
                 }
                 for msg in msgs {
-                    let vote_type = VoteType::get_type(&msg.vote);
-                    debug_assert_eq!(vote_type, vote_types[0]);
-                    if vote_type != vote_types[0] {
+                    debug_assert_eq!(msg.vote, vote);
+                    if msg.vote != vote {
                         return Err(AggregateError::InvalidVoteTypes);
                     }
                     try_set_bitmap(bitmap, msg.rank)?;

+ 39 - 60
votor/src/consensus_pool/vote_pool.rs

@@ -4,7 +4,7 @@
 //! Further detects duplicate votes which are defined as identical vote from the same sender received multiple times.
 
 use {
-    crate::common::{Stake, VoteType},
+    crate::common::Stake,
     agave_votor_messages::{consensus_message::VoteMessage, vote::Vote},
     solana_clock::Slot,
     solana_hash::Hash,
@@ -146,31 +146,29 @@ impl InternalVotePool {
         }
     }
 
-    /// Get votes for the corresponding [`VoteType`] and block id.
+    /// Get [`VoteMessage`]s for the corresponding [`Vote`].
     ///
     // TODO: figure out how to return an iterator here instead which would require `CertificateBuilder::aggregate()` to accept an iterator.
-    // TODO: instead of passing vote_type and block_id, pass in `Vote` which will remove some unwraps below.
-    fn get_votes(&self, vote_type: &VoteType, block_id: Option<&Hash>) -> Vec<VoteMessage> {
-        match vote_type {
-            VoteType::Finalize => self.finalize.values().cloned().collect(),
-            VoteType::Notarize => {
-                self.notar
-                    .values()
-                    .filter(|vote| {
-                        // unwrap on the stored vote should be safe as we should only store notar type votes here
-                        vote.vote.block_id().unwrap() == block_id.unwrap()
-                    })
-                    .cloned()
-                    .collect()
-            }
-            VoteType::NotarizeFallback => self
+    fn get_votes(&self, vote: &Vote) -> Vec<VoteMessage> {
+        match vote {
+            Vote::Finalize(_) => self.finalize.values().cloned().collect(),
+            Vote::Notarize(notar) => self
+                .notar
+                .values()
+                .filter(|vote| {
+                    // unwrap should be safe as we should only store notar votes here
+                    vote.vote.block_id().unwrap() == &notar.block_id
+                })
+                .cloned()
+                .collect(),
+            Vote::NotarizeFallback(nf) => self
                 .notar_fallback
                 .values()
-                .filter_map(|map| map.get(block_id.unwrap()))
+                .filter_map(|map| map.get(&nf.block_id))
                 .cloned()
                 .collect(),
-            VoteType::Skip => self.skip.values().cloned().collect(),
-            VoteType::SkipFallback => self.skip_fallback.values().cloned().collect(),
+            Vote::Skip(_) => self.skip.values().cloned().collect(),
+            Vote::SkipFallback(_) => self.skip_fallback.values().cloned().collect(),
         }
     }
 }
@@ -238,16 +236,14 @@ impl Stakes {
         }
     }
 
-    /// Get the stake corresponding to the [`VoteType`] and block id.
-    //
-    // TODO: instead of passing vote_type and block_id, pass in `Vote` which will remove unwraps below.
-    fn get_stake(&self, vote_type: &VoteType, block_id: Option<&Hash>) -> Stake {
-        match vote_type {
-            VoteType::Notarize => *self.notar.get(block_id.unwrap()).unwrap_or(&0),
-            VoteType::NotarizeFallback => *self.notar_fallback.get(block_id.unwrap()).unwrap_or(&0),
-            VoteType::Skip => self.skip,
-            VoteType::SkipFallback => self.skip_fallback,
-            VoteType::Finalize => self.finalize,
+    /// Get the stake corresponding to the [`Vote`].
+    fn get_stake(&self, vote: &Vote) -> Stake {
+        match vote {
+            Vote::Notarize(notar) => *self.notar.get(&notar.block_id).unwrap_or(&0),
+            Vote::NotarizeFallback(nf) => *self.notar_fallback.get(&nf.block_id).unwrap_or(&0),
+            Vote::Skip(_) => self.skip,
+            Vote::SkipFallback(_) => self.skip_fallback,
+            Vote::Finalize(_) => self.finalize,
         }
     }
 }
@@ -289,17 +285,13 @@ impl VotePool {
     }
 
     /// Returns the [`Stake`] corresponding to the specific [`Vote`].
-    pub(super) fn get_stake(&self, vote_type: &VoteType, block_id: Option<&Hash>) -> Stake {
-        self.stakes.get_stake(vote_type, block_id)
+    pub(super) fn get_stake(&self, vote: &Vote) -> Stake {
+        self.stakes.get_stake(vote)
     }
 
     /// Returns a list of votes corresponding to the specific [`Vote`].
-    pub(super) fn get_votes(
-        &self,
-        vote_type: &VoteType,
-        block_id: Option<&Hash>,
-    ) -> Vec<VoteMessage> {
-        self.votes.get_votes(vote_type, block_id)
+    pub(super) fn get_votes(&self, vote: &Vote) -> Vec<VoteMessage> {
+        self.votes.get_votes(vote)
     }
 }
 
@@ -573,17 +565,17 @@ mod test {
         let mut stakes = Stakes::new(slot);
         let vote = Vote::new_skip_vote(slot);
         assert_eq!(stakes.add_stake(stake, &vote), stake);
-        assert_eq!(stakes.get_stake(&VoteType::get_type(&vote), None), stake);
+        assert_eq!(stakes.get_stake(&vote), stake);
 
         let mut stakes = Stakes::new(slot);
         let vote = Vote::new_skip_fallback_vote(slot);
         assert_eq!(stakes.add_stake(stake, &vote), stake);
-        assert_eq!(stakes.get_stake(&VoteType::get_type(&vote), None), stake);
+        assert_eq!(stakes.get_stake(&vote), stake);
 
         let mut stakes = Stakes::new(slot);
         let vote = Vote::new_finalization_vote(slot);
         assert_eq!(stakes.add_stake(stake, &vote), stake);
-        assert_eq!(stakes.get_stake(&VoteType::get_type(&vote), None), stake);
+        assert_eq!(stakes.get_stake(&vote), stake);
 
         let mut stakes = Stakes::new(slot);
         let stake0 = 10;
@@ -594,14 +586,8 @@ mod test {
         let vote1 = Vote::new_notarization_vote(slot, hash1);
         assert_eq!(stakes.add_stake(stake0, &vote0), stake0);
         assert_eq!(stakes.add_stake(stake1, &vote1), stake1);
-        assert_eq!(
-            stakes.get_stake(&VoteType::get_type(&vote0), Some(&hash0)),
-            stake0
-        );
-        assert_eq!(
-            stakes.get_stake(&VoteType::get_type(&vote1), Some(&hash1)),
-            stake1
-        );
+        assert_eq!(stakes.get_stake(&vote0), stake0);
+        assert_eq!(stakes.get_stake(&vote1), stake1);
 
         let mut stakes = Stakes::new(slot);
         let stake0 = 10;
@@ -612,14 +598,8 @@ mod test {
         let vote1 = Vote::new_notarization_fallback_vote(slot, hash1);
         assert_eq!(stakes.add_stake(stake0, &vote0), stake0);
         assert_eq!(stakes.add_stake(stake1, &vote1), stake1);
-        assert_eq!(
-            stakes.get_stake(&VoteType::get_type(&vote0), Some(&hash0)),
-            stake0
-        );
-        assert_eq!(
-            stakes.get_stake(&VoteType::get_type(&vote1), Some(&hash1)),
-            stake1
-        );
+        assert_eq!(stakes.get_stake(&vote0), stake0);
+        assert_eq!(stakes.get_stake(&vote1), stake1);
     }
 
     #[test]
@@ -643,9 +623,8 @@ mod test {
                 .unwrap(),
             stake
         );
-        let vote_type = VoteType::get_type(&vote);
-        assert_eq!(vote_pool.get_stake(&vote_type, None), stake);
-        let returned_votes = vote_pool.get_votes(&vote_type, None);
+        assert_eq!(vote_pool.get_stake(&vote), stake);
+        let returned_votes = vote_pool.get_votes(&vote);
         assert_eq!(returned_votes.len(), 1);
         assert_eq!(returned_votes[0], vote_message);
     }