Selaa lähdekoodia

Cleanup entry.rs packing code (#3303)

carllin 6 vuotta sitten
vanhempi
sitoutus
36763d0802
2 muutettua tiedostoa jossa 110 lisäystä ja 92 poistoa
  1. 7 1
      core/src/banking_stage.rs
  2. 103 91
      core/src/entry.rs

+ 7 - 1
core/src/banking_stage.rs

@@ -3,6 +3,7 @@
 //! can do its processing in parallel with signature verification on the GPU.
 
 use crate::cluster_info::ClusterInfo;
+use crate::entry;
 use crate::entry::Entry;
 use crate::leader_confirmation_service::LeaderConfirmationService;
 use crate::leader_schedule_utils;
@@ -288,7 +289,12 @@ impl BankingStage {
     ) -> Result<(usize)> {
         let mut chunk_start = 0;
         while chunk_start != transactions.len() {
-            let chunk_end = chunk_start + Entry::num_will_fit(&transactions[chunk_start..]);
+            let chunk_end = chunk_start
+                + entry::num_will_fit(
+                    &transactions[chunk_start..],
+                    packet::BLOB_DATA_SIZE as u64,
+                    &Entry::serialized_size,
+                );
 
             let result = Self::process_and_record_transactions(
                 bank,

+ 103 - 91
core/src/entry.rs

@@ -122,42 +122,6 @@ impl Entry {
         (2 * size_of::<u64>() + size_of::<Hash>()) as u64 + txs_size
     }
 
-    pub fn num_will_fit(transactions: &[Transaction]) -> usize {
-        if transactions.is_empty() {
-            return 0;
-        }
-        let mut num = transactions.len();
-        let mut upper = transactions.len();
-        let mut lower = 1; // if one won't fit, we have a lot of TODOs
-        let mut next = transactions.len(); // optimistic
-        loop {
-            debug!(
-                "num {}, upper {} lower {} next {} transactions.len() {}",
-                num,
-                upper,
-                lower,
-                next,
-                transactions.len()
-            );
-            if Self::serialized_size(&transactions[..num]) <= BLOB_DATA_SIZE as u64 {
-                next = (upper + num) / 2;
-                lower = num;
-                debug!("num {} fits, maybe too well? trying {}", num, next);
-            } else {
-                next = (lower + num) / 2;
-                upper = num;
-                debug!("num {} doesn't fit! trying {}", num, next);
-            }
-            // same as last time
-            if next == num {
-                debug!("converged on num {}", num);
-                break;
-            }
-            num = next;
-        }
-        num
-    }
-
     /// Creates the next Tick Entry `num_hashes` after `start_hash`.
     pub fn new_mut(
         start_hash: &mut Hash,
@@ -303,6 +267,67 @@ pub fn next_entry_mut(start: &mut Hash, num_hashes: u64, transactions: Vec<Trans
     entry
 }
 
+pub fn num_will_fit<T, F>(serializables: &[T], max_size: u64, serialized_size: &F) -> usize
+where
+    F: Fn(&[T]) -> u64,
+{
+    if serializables.is_empty() {
+        return 0;
+    }
+    let mut num = serializables.len();
+    let mut upper = serializables.len();
+    let mut lower = 1; // if one won't fit, we have a lot of TODOs
+    let mut next = serializables.len(); // optimistic
+    loop {
+        debug!(
+            "num {}, upper {} lower {} next {} serializables.len() {}",
+            num,
+            upper,
+            lower,
+            next,
+            serializables.len()
+        );
+        if serialized_size(&serializables[..num]) <= max_size {
+            next = (upper + num) / 2;
+            lower = num;
+            debug!("num {} fits, maybe too well? trying {}", num, next);
+        } else {
+            next = (lower + num) / 2;
+            upper = num;
+            debug!("num {} doesn't fit! trying {}", num, next);
+        }
+        // same as last time
+        if next == num {
+            debug!("converged on num {}", num);
+            break;
+        }
+        num = next;
+    }
+    num
+}
+
+pub fn split_serializable_chunks<T, R, F1, F2>(
+    serializables: &[T],
+    max_size: u64,
+    serialized_size: &F1,
+    converter: &mut F2,
+) -> Vec<R>
+where
+    F1: Fn(&[T]) -> u64,
+    F2: FnMut(&[T]) -> R,
+{
+    let mut result = vec![];
+    let mut chunk_start = 0;
+    while chunk_start < serializables.len() {
+        let chunk_end =
+            chunk_start + num_will_fit(&serializables[chunk_start..], max_size, serialized_size);
+        result.push(converter(&serializables[chunk_start..chunk_end]));
+        chunk_start = chunk_end;
+    }
+
+    result
+}
+
 /// Creates the next entries for given transactions, outputs
 /// updates start_hash to hash of last Entry, sets num_hashes to 0
 pub fn next_entries_mut(
@@ -310,61 +335,12 @@ pub fn next_entries_mut(
     num_hashes: &mut u64,
     transactions: Vec<Transaction>,
 ) -> Vec<Entry> {
-    // TODO: ?? find a number that works better than |?
-    //                                               V
-    if transactions.is_empty() || transactions.len() == 1 {
-        vec![Entry::new_mut(start_hash, num_hashes, transactions)]
-    } else {
-        let mut chunk_start = 0;
-        let mut entries = Vec::new();
-
-        while chunk_start < transactions.len() {
-            let mut chunk_end = transactions.len();
-            let mut upper = chunk_end;
-            let mut lower = chunk_start;
-            let mut next = chunk_end; // be optimistic that all will fit
-
-            // binary search for how many transactions will fit in an Entry (i.e. a BLOB)
-            loop {
-                debug!(
-                    "chunk_end {}, upper {} lower {} next {} transactions.len() {}",
-                    chunk_end,
-                    upper,
-                    lower,
-                    next,
-                    transactions.len()
-                );
-                if Entry::serialized_size(&transactions[chunk_start..chunk_end])
-                    <= BLOB_DATA_SIZE as u64
-                {
-                    next = (upper + chunk_end) / 2;
-                    lower = chunk_end;
-                    debug!(
-                        "chunk_end {} fits, maybe too well? trying {}",
-                        chunk_end, next
-                    );
-                } else {
-                    next = (lower + chunk_end) / 2;
-                    upper = chunk_end;
-                    debug!("chunk_end {} doesn't fit! trying {}", chunk_end, next);
-                }
-                // same as last time
-                if next == chunk_end {
-                    debug!("converged on chunk_end {}", chunk_end);
-                    break;
-                }
-                chunk_end = next;
-            }
-            entries.push(Entry::new_mut(
-                start_hash,
-                num_hashes,
-                transactions[chunk_start..chunk_end].to_vec(),
-            ));
-            chunk_start = chunk_end;
-        }
-
-        entries
-    }
+    split_serializable_chunks(
+        &transactions[..],
+        BLOB_DATA_SIZE as u64,
+        &Entry::serialized_size,
+        &mut |txs: &[Transaction]| Entry::new_mut(start_hash, num_hashes, txs.to_vec()),
+    )
 }
 
 /// Creates the next Entries for given transactions
@@ -685,4 +661,40 @@ mod tests {
         assert!(entries0.verify(&hash));
     }
 
+    #[test]
+    fn test_num_will_fit_empty() {
+        let serializables: Vec<u32> = vec![];
+        let result = num_will_fit(&serializables[..], 8, &|_| 4);
+        assert_eq!(result, 0);
+    }
+
+    #[test]
+    fn test_num_fit() {
+        let serializables_vec: Vec<u8> = (0..10).map(|_| 1).collect();
+        let serializables = &serializables_vec[..];
+        let sum = |i: &[u8]| (0..i.len()).into_iter().sum::<usize>() as u64;
+        // sum[0] is = 0, but sum[0..1] > 0, so result contains 1 item
+        let result = num_will_fit(serializables, 0, &sum);
+        assert_eq!(result, 1);
+
+        // sum[0..3] is <= 8, but sum[0..4] > 8, so result contains 3 items
+        let result = num_will_fit(serializables, 8, &sum);
+        assert_eq!(result, 4);
+
+        // sum[0..1] is = 1, but sum[0..2] > 0, so result contains 2 items
+        let result = num_will_fit(serializables, 1, &sum);
+        assert_eq!(result, 2);
+
+        // sum[0..9] = 45, so contains all items
+        let result = num_will_fit(serializables, 45, &sum);
+        assert_eq!(result, 10);
+
+        // sum[0..8] <= 44, but sum[0..9] = 45, so contains all but last item
+        let result = num_will_fit(serializables, 44, &sum);
+        assert_eq!(result, 9);
+
+        // sum[0..9] <= 46, but contains all items
+        let result = num_will_fit(serializables, 46, &sum);
+        assert_eq!(result, 10);
+    }
 }