Переглянути джерело

shaq update to v1.0.0 (#8891)

* shaq update to v1.0.0

* unsafe allocate_and_reserve_message
Andrew Fitzgerald 2 тижнів тому
батько
коміт
178da9add8

+ 2 - 2
Cargo.lock

@@ -6880,9 +6880,9 @@ dependencies = [
 
 [[package]]
 name = "shaq"
-version = "0.2.0"
+version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "014fb38bb8370732f76c67752106d2a4b25cc1891ec489c7fc5ab23b27e90a75"
+checksum = "3722c79f507ee6b701d95e52d9106744d27666187e89e45a52a53745a10f71bd"
 dependencies = [
  "libc",
 ]

+ 1 - 1
Cargo.toml

@@ -375,7 +375,7 @@ serde_yaml = "0.9.34"
 serial_test = "2.0.0"
 sha2 = "0.10.9"
 sha3 = "0.10.8"
-shaq = { version = "0.2.0" }
+shaq = { version = "1.0.0" }
 shuttle = "0.7.1"
 signal-hook = "0.3.18"
 siphasher = "1.0.1"

+ 13 - 29
core/src/banking_stage/consume_worker.rs

@@ -280,7 +280,7 @@ pub(crate) mod external {
                     should_drain_executes = false;
                 }
 
-                match self.receiver.try_read() {
+                match self.receiver.try_read_ptr() {
                     Some(message) => {
                         did_work = true;
                         self.sender.sync();
@@ -527,13 +527,9 @@ pub(crate) mod external {
                 responses,
             };
 
-            let send_ptr = self
-                .sender
-                .reserve()
-                .ok_or(ExternalConsumeWorkerError::SenderDisconnected)?;
-
-            // `reserve` returns valid aligned pointer
-            unsafe { send_ptr.write(response) };
+            self.sender
+                .try_write(response)
+                .map_err(|_| ExternalConsumeWorkerError::SenderDisconnected)?;
 
             Ok(())
         }
@@ -551,12 +547,9 @@ pub(crate) mod external {
                 responses,
             };
 
-            // `reserve` returns valid aligned pointer
-            let send_ptr = self
-                .sender
-                .reserve()
-                .ok_or(ExternalConsumeWorkerError::SenderDisconnected)?;
-            unsafe { send_ptr.write(response) };
+            self.sender
+                .try_write(response)
+                .map_err(|_| ExternalConsumeWorkerError::SenderDisconnected)?;
 
             Ok(())
         }
@@ -630,14 +623,10 @@ pub(crate) mod external {
 
             // Should de-allocate the memory, but this is a non-recoverable
             // error and so it's not needed.
-            let send_message = self
-                .sender
-                .reserve()
-                .ok_or(ExternalConsumeWorkerError::SenderDisconnected)?;
+            self.sender
+                .try_write(response_message)
+                .map_err(|_| ExternalConsumeWorkerError::SenderDisconnected)?;
 
-            unsafe {
-                send_message.write(response_message);
-            }
             Ok(())
         }
 
@@ -745,14 +734,9 @@ pub(crate) mod external {
                 },
             };
 
-            let send_ptr = self
-                .sender
-                .reserve()
-                .ok_or(ExternalConsumeWorkerError::SenderDisconnected)?;
-
-            // SAFETY: `reserve` guarantees a properly aligned space
-            //         for a `WorkerToPackMessage`
-            unsafe { send_ptr.write(response) };
+            self.sender
+                .try_write(response)
+                .map_err(|_| ExternalConsumeWorkerError::SenderDisconnected)?;
 
             Ok(())
         }

+ 1 - 2
core/src/banking_stage/progress_tracker.rs

@@ -89,8 +89,7 @@ impl ProgressTracker {
         message: ProgressMessage,
     ) -> bool {
         producer.sync();
-        if let Some(reserved_ptr) = producer.reserve() {
-            unsafe { reserved_ptr.write(message) };
+        if producer.try_write(message).is_ok() {
             producer.commit();
             true
         } else {

+ 7 - 3
core/src/banking_stage/tpu_to_pack.rs

@@ -96,8 +96,9 @@ fn handle_packet_batches(
             };
             let packet_size = packet_bytes.len();
 
+            // SAFETY: message written by `copy_pack_and_populate_message` below.
             let Some((allocated_ptr, tpu_to_pack_message)) =
-                allocate_and_reserve_message(allocator, producer, packet_size)
+                (unsafe { allocate_and_reserve_message(allocator, producer, packet_size) })
             else {
                 warn!("Failed to allocate/reserve message. Dropping the rest of the batch.");
                 break 'batch_loop;
@@ -127,7 +128,9 @@ fn handle_packet_batches(
     producer.commit();
 }
 
-fn allocate_and_reserve_message(
+/// # Safety
+/// - returned `TpuToPackMessage` pointer must be populated with a valid message.
+unsafe fn allocate_and_reserve_message(
     allocator: &Allocator,
     producer: &mut shaq::Producer<TpuToPackMessage>,
     packet_size: usize,
@@ -136,7 +139,8 @@ fn allocate_and_reserve_message(
     let allocated_ptr = allocator.allocate(packet_size as u32)?;
 
     // Reserve space in the producer queue for the packet message.
-    let Some(tpu_to_pack_message) = producer.reserve() else {
+    // SAFETY: unsafe condition of the function is that the message is populated
+    let Some(tpu_to_pack_message) = (unsafe { producer.reserve() }) else {
         // Free the allocated packet if we can't reserve space in the queue.
         // SAFETY: `allocated_ptr` was allocated from `allocator`.
         unsafe {

+ 2 - 2
dev-bins/Cargo.lock

@@ -5872,9 +5872,9 @@ dependencies = [
 
 [[package]]
 name = "shaq"
-version = "0.2.0"
+version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "014fb38bb8370732f76c67752106d2a4b25cc1891ec489c7fc5ab23b27e90a75"
+checksum = "3722c79f507ee6b701d95e52d9106744d27666187e89e45a52a53745a10f71bd"
 dependencies = [
  "libc",
 ]

+ 2 - 2
programs/sbf/Cargo.lock

@@ -5820,9 +5820,9 @@ dependencies = [
 
 [[package]]
 name = "shaq"
-version = "0.2.0"
+version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "014fb38bb8370732f76c67752106d2a4b25cc1891ec489c7fc5ab23b27e90a75"
+checksum = "3722c79f507ee6b701d95e52d9106744d27666187e89e45a52a53745a10f71bd"
 dependencies = [
  "libc",
 ]

+ 4 - 2
scheduling-utils/src/handshake/server.rs

@@ -234,7 +234,8 @@ impl Server {
             let minimum_file_size = shaq::minimum_file_size::<T>(capacity);
             let file_size = Self::align_file_size(minimum_file_size, huge);
 
-            shaq::Producer::create(&file, file_size).map(|producer| (file, producer))
+            // SAFETY: uniqely creating as producer
+            unsafe { shaq::Producer::create(&file, file_size) }.map(|producer| (file, producer))
         };
 
         // Try to create with huge pages, fallback to regular pages.
@@ -252,7 +253,8 @@ impl Server {
             let minimum_file_size = shaq::minimum_file_size::<PackToWorkerMessage>(capacity);
             let file_size = Self::align_file_size(minimum_file_size, huge);
 
-            shaq::Consumer::create(&file, file_size).map(|producer| (file, producer))
+            // SAFETY: uniquely creating as consumer.
+            unsafe { shaq::Consumer::create(&file, file_size) }.map(|producer| (file, producer))
         };
 
         // Try to create with huge pages, fallback to regular pages.

+ 21 - 20
scheduling-utils/src/handshake/tests.rs

@@ -62,21 +62,22 @@ fn message_passing_on_all_queues() {
         let mut session = server.accept().unwrap();
 
         // Send a tpu_to_pack message.
-        let mut slot = session.tpu_to_pack.producer.reserve().unwrap();
-        unsafe { *slot.as_mut() = tpu_to_pack };
+        session.tpu_to_pack.producer.try_write(tpu_to_pack).unwrap();
         session.tpu_to_pack.producer.commit();
 
         // Send a progress_tracker message.
-        let mut slot = session.progress_tracker.reserve().unwrap();
-        unsafe { *slot.as_mut() = progress_tracker };
+        session
+            .progress_tracker
+            .try_write(progress_tracker)
+            .unwrap();
         session.progress_tracker.commit();
 
         // Receive pack_to_worker messages.
         for (i, worker) in session.workers.iter_mut().enumerate() {
             let msg = loop {
                 worker.pack_to_worker.sync();
-                if let Some(slot) = worker.pack_to_worker.try_read() {
-                    break unsafe { *slot.as_ref() };
+                if let Some(msg) = worker.pack_to_worker.try_read() {
+                    break *msg;
                 }
             };
             assert_eq!(
@@ -90,16 +91,16 @@ fn message_passing_on_all_queues() {
 
         // Send worker_to_pack messages.
         for (i, worker) in session.workers.iter_mut().enumerate() {
-            let mut slot = worker.worker_to_pack.reserve().unwrap();
-            unsafe {
-                *slot.as_mut() = WorkerToPackMessage {
+            worker
+                .worker_to_pack
+                .try_write(WorkerToPackMessage {
                     batch: SharableTransactionBatchRegion {
                         num_transactions: worker_to_pack.batch.num_transactions + i as u8,
                         ..worker_to_pack.batch
                     },
                     ..worker_to_pack
-                }
-            };
+                })
+                .unwrap();
             worker.worker_to_pack.commit();
         }
     });
@@ -124,7 +125,7 @@ fn message_passing_on_all_queues() {
         let msg = loop {
             session.tpu_to_pack.sync();
             if let Some(msg) = session.tpu_to_pack.try_read() {
-                break unsafe { *msg.as_ref() };
+                break *msg;
             };
         };
         assert_eq!(msg, tpu_to_pack);
@@ -133,20 +134,20 @@ fn message_passing_on_all_queues() {
         let msg = loop {
             session.progress_tracker.sync();
             if let Some(msg) = session.progress_tracker.try_read() {
-                break unsafe { *msg.as_ref() };
+                break *msg;
             };
         };
         assert_eq!(msg, progress_tracker);
 
         // Send pack_to_worker messages.
         for (i, worker) in session.workers.iter_mut().enumerate() {
-            let mut slot = worker.pack_to_worker.reserve().unwrap();
-            unsafe {
-                *slot.as_mut() = PackToWorkerMessage {
+            worker
+                .pack_to_worker
+                .try_write(PackToWorkerMessage {
                     max_working_slot: pack_to_worker.max_working_slot + i as u64,
                     ..pack_to_worker
-                }
-            };
+                })
+                .unwrap();
             worker.pack_to_worker.commit();
         }
 
@@ -154,8 +155,8 @@ fn message_passing_on_all_queues() {
         for (i, worker) in session.workers.iter_mut().enumerate() {
             let msg = loop {
                 worker.worker_to_pack.sync();
-                if let Some(slot) = worker.worker_to_pack.try_read() {
-                    break unsafe { *slot.as_ref() };
+                if let Some(msg) = worker.worker_to_pack.try_read() {
+                    break *msg;
                 }
             };
             assert_eq!(