Forráskód Böngészése

Remove generic type from WeightedShuffle forcing it to u64 (#9147)

Kamil Skalski 1 napja
szülő
commit
ba3d045d65

+ 1 - 1
gossip/benches/weighted_shuffle.rs

@@ -17,7 +17,7 @@ fn bench_weighted_shuffle_new(c: &mut Criterion) {
     c.bench_function("bench_weighted_shuffle_new", |b| {
         b.iter(|| {
             let weights = make_weights(&mut rng);
-            black_box(WeightedShuffle::<u64>::new("", &weights));
+            black_box(WeightedShuffle::new("", &weights));
         })
     });
 }

+ 1 - 1
gossip/src/push_active_set.rs

@@ -143,7 +143,7 @@ impl PushActiveSetEntry {
     ) {
         debug_assert_eq!(nodes.len(), weights.len());
         debug_assert!(weights.iter().all(|&weight| weight != 0u64));
-        let mut weighted_shuffle = WeightedShuffle::<u64>::new("rotate-active-set", weights);
+        let mut weighted_shuffle = WeightedShuffle::new("rotate-active-set", weights);
         for node in weighted_shuffle.shuffle(rng).map(|k| &nodes[k]) {
             // We intend to discard the oldest/first entry in the index-map.
             if self.0.len() > size {

+ 30 - 61
gossip/src/weighted_shuffle.rs

@@ -1,7 +1,7 @@
 //! The `weighted_shuffle` module provides an iterator over shuffled weights.
 
 use {
-    num_traits::{CheckedAdd, ConstZero},
+    num_traits::CheckedAdd,
     rand::{
         distributions::uniform::{SampleUniform, UniformSampler},
         Rng,
@@ -30,53 +30,37 @@ const BIT_MASK: usize = FANOUT - 1;
 ///     weight.
 ///   - Zero weighted indices are shuffled and appear only at the end, after
 ///     non-zero weighted indices.
-pub struct WeightedShuffle<T> {
+pub struct WeightedShuffle {
     // Number of "internal" nodes of the tree.
     num_nodes: usize,
     // Underlying array implementing the tree.
     // Nodes without children are never accessed and don't need to be
     // allocated, so tree.len() < num_nodes.
     // tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
-    tree: Vec<[T; FANOUT]>,
+    tree: Vec<[u64; FANOUT]>,
     // Current sum of all weights, excluding already sampled ones.
-    weight: T,
+    weight: u64,
     // Indices of zero weighted entries.
     zeros: Vec<usize>,
 }
 
-impl<T: ConstZero> WeightedShuffle<T> {
-    const ZERO: T = <T as ConstZero>::ZERO;
-}
-
-impl<T> WeightedShuffle<T>
-where
-    T: Copy + ConstZero + PartialOrd + AddAssign + CheckedAdd,
-{
-    /// If weights are negative or overflow the total sum
-    /// they are treated as zero.
+impl WeightedShuffle {
+    /// If weights overflow the total sum they are treated as zero.
     pub fn new<I>(name: &'static str, weights: I) -> Self
     where
-        I: IntoIterator<Item: Borrow<T>>,
+        I: IntoIterator<Item: Borrow<u64>>,
         <I as IntoIterator>::IntoIter: ExactSizeIterator,
     {
         let weights = weights.into_iter();
         let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
         debug_assert!(size <= num_nodes);
-        let mut tree = vec![[Self::ZERO; FANOUT]; size];
-        let mut sum = Self::ZERO;
+        let mut tree = vec![[0; FANOUT]; size];
+        let mut sum = 0;
         let mut zeros = Vec::default();
-        let mut num_negative: usize = 0;
         let mut num_overflow: usize = 0;
         for (k, weight) in weights.enumerate() {
             let weight = *weight.borrow();
-            #[allow(clippy::neg_cmp_op_on_partial_ord)]
-            // weight < zero does not work for NaNs.
-            if !(weight >= Self::ZERO) {
-                zeros.push(k);
-                num_negative += 1;
-                continue;
-            }
-            if weight == Self::ZERO {
+            if weight == 0 {
                 zeros.push(k);
                 continue;
             }
@@ -104,9 +88,6 @@ where
                     .add_assign(weight);
             }
         }
-        if num_negative > 0 {
-            datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64));
-        }
         if num_overflow > 0 {
             datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
         }
@@ -119,12 +100,9 @@ where
     }
 }
 
-impl<T> WeightedShuffle<T>
-where
-    T: Copy + ConstZero + PartialOrd + SubAssign,
-{
+impl WeightedShuffle {
     // Removes given weight at index k.
-    fn remove(&mut self, k: usize, weight: T) {
+    fn remove(&mut self, k: usize, weight: u64) {
         debug_assert!(self.weight >= weight);
         self.weight -= weight;
         // Traverse the tree from the leaf node upwards to the root,
@@ -145,9 +123,7 @@ where
 
     // Returns smallest index such that sum of weights[..=k] > val,
     // along with its respective weight.
-    fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
-        debug_assert!(val >= Self::ZERO);
-        debug_assert!(val < self.weight);
+    fn search(&self, mut val: u64) -> (/*index:*/ usize, /*weight:*/ u64) {
         debug_assert!(!self.tree.is_empty());
         // Traverse the tree downwards from the root to the target leaf node.
         let mut index = 0; // root
@@ -181,7 +157,7 @@ where
             error!("WeightedShuffle::remove_index: Invalid index {k}");
             return;
         };
-        if weight == Self::ZERO {
+        if weight == 0 {
             self.remove_zero(k);
         } else {
             self.remove(k, weight);
@@ -195,34 +171,27 @@ where
     }
 }
 
-impl<T> WeightedShuffle<T>
-where
-    T: Copy + ConstZero + PartialOrd + SampleUniform + SubAssign,
-{
+impl WeightedShuffle {
     // Equivalent to weighted_shuffle.shuffle(&mut rng).next()
     pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
-        if self.weight > Self::ZERO {
-            let sample = <T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
+        if self.weight > 0 {
+            let sample = <u64 as SampleUniform>::Sampler::sample_single(0, self.weight, rng);
             let (index, _) = self.search(sample);
             return Some(index);
         }
         if self.zeros.is_empty() {
             return None;
         }
-        let index = <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
-        self.zeros.get(index).copied()
+        let index = <u64 as SampleUniform>::Sampler::sample_single(0, self.zeros.len() as u64, rng);
+        self.zeros.get(index as usize).copied()
     }
 }
 
-impl<T> WeightedShuffle<T>
-where
-    T: Copy + ConstZero + PartialOrd + SampleUniform + SubAssign,
-{
+impl WeightedShuffle {
     pub fn shuffle<'a, R: Rng>(&'a mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
         std::iter::from_fn(move || {
-            if self.weight > Self::ZERO {
-                let sample =
-                    <T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
+            if self.weight > 0 {
+                let sample = <u64 as SampleUniform>::Sampler::sample_single(0, self.weight, rng);
                 let (index, weight) = self.search(sample);
                 self.remove(index, weight);
                 return Some(index);
@@ -231,8 +200,8 @@ where
                 return None;
             }
             let index =
-                <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
-            Some(self.zeros.swap_remove(index))
+                <u64 as SampleUniform>::Sampler::sample_single(0, self.zeros.len() as u64, rng);
+            Some(self.zeros.swap_remove(index as usize))
         })
     }
 }
@@ -253,13 +222,13 @@ fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_si
 
 // #[derive(Clone)] does not overwrite clone_from which is used in
 // retransmit-stage to minimize allocations.
-impl<T: Clone> Clone for WeightedShuffle<T> {
+impl Clone for WeightedShuffle {
     #[inline]
     fn clone(&self) -> Self {
         Self {
             num_nodes: self.num_nodes,
             tree: self.tree.clone(),
-            weight: self.weight.clone(),
+            weight: self.weight,
             zeros: self.zeros.clone(),
         }
     }
@@ -268,7 +237,7 @@ impl<T: Clone> Clone for WeightedShuffle<T> {
     fn clone_from(&mut self, other: &Self) {
         self.num_nodes = other.num_nodes;
         self.tree.clone_from(&other.tree);
-        self.weight = other.weight.clone();
+        self.weight = other.weight;
         self.zeros.clone_from(&other.zeros);
     }
 }
@@ -559,7 +528,7 @@ mod tests {
             weights.iter().fold(0u64, |a, &b| a.checked_add(b).unwrap()),
             weights.iter().sum::<u64>()
         );
-        let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
+        let mut shuffle = WeightedShuffle::new("", &weights);
         let shuffle1 = shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>();
         // Assert that all indices appear in the shuffle.
         assert_eq!(shuffle1.len(), num_weights);
@@ -604,13 +573,13 @@ mod tests {
                 let mut seed = [0u8; 32];
                 rng.fill(&mut seed[..]);
                 let mut rng = R::from_seed(seed);
-                let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
+                let mut shuffle = WeightedShuffle::new("", &weights);
                 let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
                 let mut rng = R::from_seed(seed);
                 let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
                 assert_eq!(shuffle, shuffle_slow);
                 let mut rng = R::from_seed(seed);
-                let shuffle = WeightedShuffle::<u64>::new("", &weights);
+                let shuffle = WeightedShuffle::new("", &weights);
                 assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
             }
         }

+ 3 - 2
turbine/src/cluster_nodes.rs

@@ -39,7 +39,7 @@ use {
 };
 
 thread_local! {
-    static THREAD_LOCAL_WEIGHTED_SHUFFLE: RefCell<WeightedShuffle<u64>> = RefCell::new(
+    static THREAD_LOCAL_WEIGHTED_SHUFFLE: RefCell<WeightedShuffle> = RefCell::new(
         WeightedShuffle::new::<[u64; 0]>("get_retransmit_addrs", []),
     );
 }
@@ -84,7 +84,8 @@ pub struct ClusterNodes<T> {
     nodes: Vec<Node>,
     // Reverse index from nodes pubkey to their index in self.nodes.
     index: HashMap<Pubkey, /*index:*/ usize>,
-    weighted_shuffle: WeightedShuffle</*stake:*/ u64>,
+    // Shuffles by weights = stakes
+    weighted_shuffle: WeightedShuffle,
     use_cha_cha_8: bool,
     _phantom: PhantomData<T>,
 }