|
|
@@ -279,7 +279,7 @@ mod tests {
|
|
|
super::*,
|
|
|
itertools::Itertools,
|
|
|
rand::SeedableRng,
|
|
|
- rand_chacha::ChaChaRng,
|
|
|
+ rand_chacha::{ChaCha8Rng, ChaChaRng},
|
|
|
solana_hash::Hash,
|
|
|
std::{
|
|
|
convert::TryInto,
|
|
|
@@ -384,18 +384,33 @@ mod tests {
|
|
|
}
|
|
|
|
|
|
// Asserts that zero weights will be shuffled.
|
|
|
- #[test]
|
|
|
- fn test_weighted_shuffle_zero_weights() {
|
|
|
+ #[test_case(8)]
|
|
|
+ #[test_case(20)]
|
|
|
+ fn test_weighted_shuffle_zero_weights(cha_cha_variant: u8) {
|
|
|
let weights = vec![0u64; 5];
|
|
|
let seed = [37u8; 32];
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
let shuffle = WeightedShuffle::new("", weights);
|
|
|
- assert_eq!(
|
|
|
- shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
|
|
- [1, 4, 2, 3, 0]
|
|
|
- );
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- assert_eq!(shuffle.first(&mut rng), Some(1));
|
|
|
+ match cha_cha_variant {
|
|
|
+ 8 => {
|
|
|
+ let mut rng = ChaCha8Rng::from_seed(seed);
|
|
|
+ assert_eq!(
|
|
|
+ shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
|
|
+ [4, 3, 1, 2, 0],
|
|
|
+ );
|
|
|
+ let mut rng = ChaCha8Rng::from_seed(seed);
|
|
|
+ assert_eq!(shuffle.first(&mut rng), Some(4));
|
|
|
+ }
|
|
|
+ 20 => {
|
|
|
+ let mut rng = ChaChaRng::from_seed(seed);
|
|
|
+ assert_eq!(
|
|
|
+ shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
|
|
|
+ [1, 4, 2, 3, 0],
|
|
|
+ );
|
|
|
+ let mut rng = ChaChaRng::from_seed(seed);
|
|
|
+ assert_eq!(shuffle.first(&mut rng), Some(1));
|
|
|
+ }
|
|
|
+ _ => unreachable!(),
|
|
|
+ };
|
|
|
}
|
|
|
|
|
|
// Asserts that each index is selected proportional to its weight.
|
|
|
@@ -404,46 +419,70 @@ mod tests {
|
|
|
let seed: Vec<_> = (1..).step_by(3).take(32).collect();
|
|
|
let seed: [u8; 32] = seed.try_into().unwrap();
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
|
|
|
- let mut counts = [0; 8];
|
|
|
- for _ in 0..100000 {
|
|
|
- let mut weighted_shuffle = WeightedShuffle::new("", weights);
|
|
|
- let mut shuffle = weighted_shuffle.shuffle(&mut rng);
|
|
|
- counts[shuffle.next().unwrap()] += 1;
|
|
|
- let _ = shuffle.count(); // consume the rest.
|
|
|
- }
|
|
|
- assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
|
|
- let mut counts = [0; 8];
|
|
|
- for _ in 0..100000 {
|
|
|
- let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
- shuffle.remove_index(5);
|
|
|
- shuffle.remove_index(3);
|
|
|
- shuffle.remove_index(1);
|
|
|
- let mut shuffle = shuffle.shuffle(&mut rng);
|
|
|
- counts[shuffle.next().unwrap()] += 1;
|
|
|
- let _ = shuffle.count(); // consume the rest.
|
|
|
+ test_weighted_shuffle_sanity_impl(
|
|
|
+ &mut rng,
|
|
|
+ &[95, 0, 90069, 0, 0, 908, 8928, 0],
|
|
|
+ &[97, 0, 90862, 0, 0, 0, 9041, 0],
|
|
|
+ );
|
|
|
+ let mut rng = ChaCha8Rng::from_seed(seed);
|
|
|
+ test_weighted_shuffle_sanity_impl(
|
|
|
+ &mut rng,
|
|
|
+ &[93, 0, 90185, 0, 0, 892, 8830, 0],
|
|
|
+ &[89, 0, 90741, 0, 0, 0, 9170, 0],
|
|
|
+ );
|
|
|
+ fn test_weighted_shuffle_sanity_impl<R: Rng>(
|
|
|
+ rng: &mut R,
|
|
|
+ counts1: &[i32],
|
|
|
+ counts2: &[i32],
|
|
|
+ ) {
|
|
|
+ let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
|
|
|
+ let mut counts = [0; 8];
|
|
|
+ for _ in 0..100000 {
|
|
|
+ let mut weighted_shuffle = WeightedShuffle::new("", weights);
|
|
|
+ let mut shuffle = weighted_shuffle.shuffle(rng);
|
|
|
+ counts[shuffle.next().unwrap()] += 1;
|
|
|
+ let _ = shuffle.count(); // consume the rest.
|
|
|
+ }
|
|
|
+ assert_eq!(counts, counts1);
|
|
|
+ let mut counts = [0; 8];
|
|
|
+ for _ in 0..100000 {
|
|
|
+ let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
+ shuffle.remove_index(5);
|
|
|
+ shuffle.remove_index(3);
|
|
|
+ shuffle.remove_index(1);
|
|
|
+ let mut shuffle = shuffle.shuffle(rng);
|
|
|
+ counts[shuffle.next().unwrap()] += 1;
|
|
|
+ let _ = shuffle.count(); // consume the rest.
|
|
|
+ }
|
|
|
+ assert_eq!(counts, counts2);
|
|
|
}
|
|
|
- assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]);
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
fn test_weighted_shuffle_negative_overflow() {
|
|
|
- const SEED: [u8; 32] = [48u8; 32];
|
|
|
- let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29];
|
|
|
- let mut rng = ChaChaRng::from_seed(SEED);
|
|
|
- let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
- assert_eq!(
|
|
|
- shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
|
|
|
- [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
|
|
|
- );
|
|
|
- // Negative weights and overflowing ones are treated as zero.
|
|
|
- let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29];
|
|
|
- let mut rng = ChaChaRng::from_seed(SEED);
|
|
|
- let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
- assert_eq!(
|
|
|
- shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
|
|
|
- [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
|
|
|
- );
|
|
|
+ test_weighted_shuffle_negative_overflow_impl::<ChaChaRng>(&[
|
|
|
+ 8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7,
|
|
|
+ ]);
|
|
|
+ test_weighted_shuffle_negative_overflow_impl::<ChaCha8Rng>(&[
|
|
|
+ 5, 11, 2, 0, 10, 1, 6, 8, 7, 3, 9, 4,
|
|
|
+ ]);
|
|
|
+
|
|
|
+ fn test_weighted_shuffle_negative_overflow_impl<
|
|
|
+ R: Rng + rand::SeedableRng<Seed = [u8; 32]>,
|
|
|
+ >(
|
|
|
+ counts: &[usize],
|
|
|
+ ) {
|
|
|
+ const SEED: [u8; 32] = [48u8; 32];
|
|
|
+ let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29];
|
|
|
+ let mut rng = R::from_seed(SEED);
|
|
|
+ let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
+ assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), counts);
|
|
|
+ // Negative weights and overflowing ones are treated as zero.
|
|
|
+ let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29];
|
|
|
+ let mut rng = R::from_seed(SEED);
|
|
|
+ let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
+ assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), counts);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
@@ -569,36 +608,47 @@ mod tests {
|
|
|
|
|
|
#[test]
|
|
|
fn test_weighted_shuffle_match_slow() {
|
|
|
- let mut rng = rand::thread_rng();
|
|
|
- let weights: Vec<u64> = repeat_with(|| rng.gen_range(0..1000)).take(997).collect();
|
|
|
- for _ in 0..10 {
|
|
|
- let mut seed = [0u8; 32];
|
|
|
- rng.fill(&mut seed[..]);
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
- let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
|
|
- assert_eq!(shuffle, shuffle_slow);
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- let shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
- assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
|
|
|
+ test_weighted_shuffle_match_slow_impl::<ChaChaRng>();
|
|
|
+ test_weighted_shuffle_match_slow_impl::<ChaCha8Rng>();
|
|
|
+
|
|
|
+ fn test_weighted_shuffle_match_slow_impl<R: Rng + rand::SeedableRng<Seed = [u8; 32]>>() {
|
|
|
+ let mut rng = rand::thread_rng();
|
|
|
+ let weights: Vec<u64> = repeat_with(|| rng.gen_range(0..1000)).take(997).collect();
|
|
|
+ for _ in 0..10 {
|
|
|
+ let mut seed = [0u8; 32];
|
|
|
+ rng.fill(&mut seed[..]);
|
|
|
+ let mut rng = R::from_seed(seed);
|
|
|
+ let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
+ let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
|
|
+ let mut rng = R::from_seed(seed);
|
|
|
+ let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
|
|
+ assert_eq!(shuffle, shuffle_slow);
|
|
|
+ let mut rng = R::from_seed(seed);
|
|
|
+ let shuffle = WeightedShuffle::<u64>::new("", &weights);
|
|
|
+ assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
fn test_weighted_shuffle_paranoid() {
|
|
|
let mut rng = rand::thread_rng();
|
|
|
- for size in 0..1351 {
|
|
|
- let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect();
|
|
|
- let seed = rng.gen::<[u8; 32]>();
|
|
|
- let mut rng = ChaChaRng::from_seed(seed);
|
|
|
- let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
|
|
|
- let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
- if size > 0 {
|
|
|
- assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
|
|
|
+ let seed = rng.gen::<[u8; 32]>();
|
|
|
+ let rng = ChaCha8Rng::from_seed(seed);
|
|
|
+ test_weighted_shuffle_paranoid_impl(rng);
|
|
|
+ let rng = ChaChaRng::from_seed(seed);
|
|
|
+ test_weighted_shuffle_paranoid_impl(rng);
|
|
|
+
|
|
|
+ fn test_weighted_shuffle_paranoid_impl<R: Rng + Clone>(mut rng: R) {
|
|
|
+ for size in 0..1351 {
|
|
|
+ let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect();
|
|
|
+ let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
|
|
|
+ let mut shuffle = WeightedShuffle::new("", weights);
|
|
|
+ if size > 0 {
|
|
|
+ assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
|
|
|
+ }
|
|
|
+ assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), shuffle_slow);
|
|
|
}
|
|
|
- assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), shuffle_slow);
|
|
|
}
|
|
|
}
|
|
|
}
|