|
|
@@ -1,7 +1,7 @@
|
|
|
//! The `weighted_shuffle` module provides an iterator over shuffled weights.
|
|
|
|
|
|
use {
|
|
|
- num_traits::{CheckedAdd, ConstZero},
|
|
|
+ num_traits::CheckedAdd,
|
|
|
rand::{
|
|
|
distributions::uniform::{SampleUniform, UniformSampler},
|
|
|
Rng,
|
|
|
@@ -30,53 +30,37 @@ const BIT_MASK: usize = FANOUT - 1;
|
|
|
/// weight.
|
|
|
/// - Zero weighted indices are shuffled and appear only at the end, after
|
|
|
/// non-zero weighted indices.
|
|
|
-pub struct WeightedShuffle<T> {
|
|
|
+pub struct WeightedShuffle {
|
|
|
// Number of "internal" nodes of the tree.
|
|
|
num_nodes: usize,
|
|
|
// Underlying array implementing the tree.
|
|
|
// Nodes without children are never accessed and don't need to be
|
|
|
// allocated, so tree.len() < num_nodes.
|
|
|
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
|
|
|
- tree: Vec<[T; FANOUT]>,
|
|
|
+ tree: Vec<[u64; FANOUT]>,
|
|
|
// Current sum of all weights, excluding already sampled ones.
|
|
|
- weight: T,
|
|
|
+ weight: u64,
|
|
|
// Indices of zero weighted entries.
|
|
|
zeros: Vec<usize>,
|
|
|
}
|
|
|
|
|
|
-impl<T: ConstZero> WeightedShuffle<T> {
|
|
|
- const ZERO: T = <T as ConstZero>::ZERO;
|
|
|
-}
|
|
|
-
|
|
|
-impl<T> WeightedShuffle<T>
|
|
|
-where
|
|
|
- T: Copy + ConstZero + PartialOrd + AddAssign + CheckedAdd,
|
|
|
-{
|
|
|
- /// If weights are negative or overflow the total sum
|
|
|
- /// they are treated as zero.
|
|
|
+impl WeightedShuffle {
|
|
|
+ /// If weights overflow the total sum they are treated as zero.
|
|
|
pub fn new<I>(name: &'static str, weights: I) -> Self
|
|
|
where
|
|
|
- I: IntoIterator<Item: Borrow<T>>,
|
|
|
+ I: IntoIterator<Item: Borrow<u64>>,
|
|
|
<I as IntoIterator>::IntoIter: ExactSizeIterator,
|
|
|
{
|
|
|
let weights = weights.into_iter();
|
|
|
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
|
|
|
debug_assert!(size <= num_nodes);
|
|
|
- let mut tree = vec![[Self::ZERO; FANOUT]; size];
|
|
|
- let mut sum = Self::ZERO;
|
|
|
+ let mut tree = vec![[0; FANOUT]; size];
|
|
|
+ let mut sum = 0;
|
|
|
let mut zeros = Vec::default();
|
|
|
- let mut num_negative: usize = 0;
|
|
|
let mut num_overflow: usize = 0;
|
|
|
for (k, weight) in weights.enumerate() {
|
|
|
let weight = *weight.borrow();
|
|
|
- #[allow(clippy::neg_cmp_op_on_partial_ord)]
|
|
|
- // weight < zero does not work for NaNs.
|
|
|
- if !(weight >= Self::ZERO) {
|
|
|
- zeros.push(k);
|
|
|
- num_negative += 1;
|
|
|
- continue;
|
|
|
- }
|
|
|
- if weight == Self::ZERO {
|
|
|
+ if weight == 0 {
|
|
|
zeros.push(k);
|
|
|
continue;
|
|
|
}
|
|
|
@@ -104,9 +88,6 @@ where
|
|
|
.add_assign(weight);
|
|
|
}
|
|
|
}
|
|
|
- if num_negative > 0 {
|
|
|
- datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64));
|
|
|
- }
|
|
|
if num_overflow > 0 {
|
|
|
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
|
|
|
}
|
|
|
@@ -119,12 +100,9 @@ where
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-impl<T> WeightedShuffle<T>
|
|
|
-where
|
|
|
- T: Copy + ConstZero + PartialOrd + SubAssign,
|
|
|
-{
|
|
|
+impl WeightedShuffle {
|
|
|
// Removes given weight at index k.
|
|
|
- fn remove(&mut self, k: usize, weight: T) {
|
|
|
+ fn remove(&mut self, k: usize, weight: u64) {
|
|
|
debug_assert!(self.weight >= weight);
|
|
|
self.weight -= weight;
|
|
|
// Traverse the tree from the leaf node upwards to the root,
|
|
|
@@ -145,9 +123,7 @@ where
|
|
|
|
|
|
// Returns smallest index such that sum of weights[..=k] > val,
|
|
|
// along with its respective weight.
|
|
|
- fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
|
|
|
- debug_assert!(val >= Self::ZERO);
|
|
|
- debug_assert!(val < self.weight);
|
|
|
+ fn search(&self, mut val: u64) -> (/*index:*/ usize, /*weight:*/ u64) {
|
|
|
debug_assert!(!self.tree.is_empty());
|
|
|
// Traverse the tree downwards from the root to the target leaf node.
|
|
|
let mut index = 0; // root
|
|
|
@@ -181,7 +157,7 @@ where
|
|
|
error!("WeightedShuffle::remove_index: Invalid index {k}");
|
|
|
return;
|
|
|
};
|
|
|
- if weight == Self::ZERO {
|
|
|
+ if weight == 0 {
|
|
|
self.remove_zero(k);
|
|
|
} else {
|
|
|
self.remove(k, weight);
|
|
|
@@ -195,34 +171,27 @@ where
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-impl<T> WeightedShuffle<T>
|
|
|
-where
|
|
|
- T: Copy + ConstZero + PartialOrd + SampleUniform + SubAssign,
|
|
|
-{
|
|
|
+impl WeightedShuffle {
|
|
|
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
|
|
|
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
|
|
|
- if self.weight > Self::ZERO {
|
|
|
- let sample = <T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
|
|
|
+ if self.weight > 0 {
|
|
|
+ let sample = <u64 as SampleUniform>::Sampler::sample_single(0, self.weight, rng);
|
|
|
let (index, _) = self.search(sample);
|
|
|
return Some(index);
|
|
|
}
|
|
|
if self.zeros.is_empty() {
|
|
|
return None;
|
|
|
}
|
|
|
- let index = <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
|
|
|
- self.zeros.get(index).copied()
|
|
|
+ let index = <u64 as SampleUniform>::Sampler::sample_single(0, self.zeros.len() as u64, rng);
|
|
|
+ self.zeros.get(index as usize).copied()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-impl<T> WeightedShuffle<T>
|
|
|
-where
|
|
|
- T: Copy + ConstZero + PartialOrd + SampleUniform + SubAssign,
|
|
|
-{
|
|
|
+impl WeightedShuffle {
|
|
|
pub fn shuffle<'a, R: Rng>(&'a mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
|
|
|
std::iter::from_fn(move || {
|
|
|
- if self.weight > Self::ZERO {
|
|
|
- let sample =
|
|
|
- <T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
|
|
|
+ if self.weight > 0 {
|
|
|
+ let sample = <u64 as SampleUniform>::Sampler::sample_single(0, self.weight, rng);
|
|
|
let (index, weight) = self.search(sample);
|
|
|
self.remove(index, weight);
|
|
|
return Some(index);
|
|
|
@@ -231,8 +200,8 @@ where
|
|
|
return None;
|
|
|
}
|
|
|
let index =
|
|
|
- <usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
|
|
|
- Some(self.zeros.swap_remove(index))
|
|
|
+ <u64 as SampleUniform>::Sampler::sample_single(0, self.zeros.len() as u64, rng);
|
|
|
+ Some(self.zeros.swap_remove(index as usize))
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
@@ -253,13 +222,13 @@ fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_si
|
|
|
|
|
|
// #[derive(Clone)] does not overwrite clone_from which is used in
|
|
|
// retransmit-stage to minimize allocations.
|
|
|
-impl<T: Clone> Clone for WeightedShuffle<T> {
|
|
|
+impl Clone for WeightedShuffle {
|
|
|
#[inline]
|
|
|
fn clone(&self) -> Self {
|
|
|
Self {
|
|
|
num_nodes: self.num_nodes,
|
|
|
tree: self.tree.clone(),
|
|
|
- weight: self.weight.clone(),
|
|
|
+ weight: self.weight,
|
|
|
zeros: self.zeros.clone(),
|
|
|
}
|
|
|
}
|
|
|
@@ -268,7 +237,7 @@ impl<T: Clone> Clone for WeightedShuffle<T> {
|
|
|
fn clone_from(&mut self, other: &Self) {
|
|
|
self.num_nodes = other.num_nodes;
|
|
|
self.tree.clone_from(&other.tree);
|
|
|
- self.weight = other.weight.clone();
|
|
|
+ self.weight = other.weight;
|
|
|
self.zeros.clone_from(&other.zeros);
|
|
|
}
|
|
|
}
|
|
|
@@ -559,7 +528,7 @@ mod tests {
|
|
|
weights.iter().fold(0u64, |a, &b| a.checked_add(b).unwrap()),
|
|
|
weights.iter().sum::<u64>()
|
|
|
);
|
|
|
- let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
+ let mut shuffle = WeightedShuffle::new("", &weights);
|
|
|
let shuffle1 = shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>();
|
|
|
// Assert that all indices appear in the shuffle.
|
|
|
assert_eq!(shuffle1.len(), num_weights);
|
|
|
@@ -604,13 +573,13 @@ mod tests {
|
|
|
let mut seed = [0u8; 32];
|
|
|
rng.fill(&mut seed[..]);
|
|
|
let mut rng = R::from_seed(seed);
|
|
|
- let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
+ let mut shuffle = WeightedShuffle::new("", &weights);
|
|
|
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
|
|
let mut rng = R::from_seed(seed);
|
|
|
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
|
|
assert_eq!(shuffle, shuffle_slow);
|
|
|
let mut rng = R::from_seed(seed);
|
|
|
- let shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
+ let shuffle = WeightedShuffle::new("", &weights);
|
|
|
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
|
|
|
}
|
|
|
}
|