Browse Source

feat(schduling-utils): client & server handshake implementations (#8572)

feat(scheduling-utils): external scheduler handshake protocol
OliverNChalk 4 weeks ago
parent
commit
04f8e0a6e3

+ 9 - 4
Cargo.lock

@@ -235,8 +235,13 @@ dependencies = [
  "agave-scheduler-bindings",
  "agave-transaction-view",
  "ahash 0.8.11",
+ "libc",
+ "nix",
  "rts-alloc",
+ "shaq",
  "solana-pubkey",
+ "tempfile",
+ "thiserror 2.0.17",
 ]
 
 [[package]]
@@ -6037,9 +6042,9 @@ dependencies = [
 
 [[package]]
 name = "rts-alloc"
-version = "0.1.1"
+version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7ccc6c6ddf12ced79aeae78454bcb4d0f5e31e5bc0aaca490e2f1b012f56b9d"
+checksum = "9c55727ea58e2c9c131d8f003dab5aaa7056d99f8292bc6a5dfb299cefe55e60"
 dependencies = [
  "libc",
 ]
@@ -6620,9 +6625,9 @@ dependencies = [
 
 [[package]]
 name = "shaq"
-version = "0.1.0"
+version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c451e9289cd55fd2406917d16abe8af5f83c74b3ba3736398f6aadd2ab1fba2"
+checksum = "014fb38bb8370732f76c67752106d2a4b25cc1891ec489c7fc5ab23b27e90a75"
 dependencies = [
  "libc",
 ]

+ 2 - 2
Cargo.toml

@@ -356,7 +356,7 @@ reqwest = { version = "0.12.24", default-features = false }
 reqwest-middleware = "0.4.2"
 rolling-file = "0.2.0"
 rpassword = "7.4"
-rts-alloc = { version = "0.1.1" }
+rts-alloc = { version = "0.2.0" }
 rustls = { version = "0.23.34", features = ["std"], default-features = false }
 scopeguard = "1.2.0"
 semver = "1.0.27"
@@ -370,7 +370,7 @@ serde_yaml = "0.9.34"
 serial_test = "2.0.0"
 sha2 = "0.10.9"
 sha3 = "0.10.8"
-shaq = { version = "0.1.0" }
+shaq = { version = "0.2.0" }
 shuttle = "0.7.1"
 signal-hook = "0.3.18"
 siphasher = "1.0.1"

+ 2 - 15
core/src/banking_stage/progress_tracker.rs

@@ -7,7 +7,6 @@ use {
     solana_cost_model::cost_tracker::SharedBlockCost,
     solana_poh::poh_recorder::SharedLeaderState,
     std::{
-        path::{Path, PathBuf},
         sync::{
             atomic::{AtomicBool, Ordering},
             Arc,
@@ -19,16 +18,14 @@ use {
 /// Spawns a thread to track and send progress updates.
 pub fn spawn(
     exit: Arc<AtomicBool>,
-    queue_path: PathBuf,
+    mut producer: shaq::Producer<ProgressMessage>,
     shared_leader_state: SharedLeaderState,
     ticks_per_slot: u64,
 ) -> JoinHandle<()> {
     std::thread::Builder::new()
         .name("solProgTrker".to_string())
         .spawn(move || {
-            if let Some(mut producer) = setup(queue_path) {
-                ProgressTracker::new(exit, shared_leader_state, ticks_per_slot).run(&mut producer);
-            }
+            ProgressTracker::new(exit, shared_leader_state, ticks_per_slot).run(&mut producer);
         })
         .unwrap()
 }
@@ -141,16 +138,6 @@ impl ProgressTracker {
     }
 }
 
-fn setup(queue_path: impl AsRef<Path>) -> Option<shaq::Producer<ProgressMessage>> {
-    let producer = shaq::Producer::join(queue_path)
-        .map_err(|err| {
-            error!("Failed to join queue: {err:?}");
-        })
-        .ok()?;
-
-    Some(producer)
-}
-
 /// Calculate progress through a slot based on tick-height.
 fn progress(slot: Slot, tick_height: u64, ticks_per_slot: u64) -> u8 {
     debug_assert!(ticks_per_slot < u8::MAX as u64 && ticks_per_slot > 0);

+ 4 - 38
core/src/banking_stage/tpu_to_pack.rs

@@ -9,7 +9,6 @@ use {
     solana_perf::packet::PacketBatch,
     std::{
         net::IpAddr,
-        path::{Path, PathBuf},
         ptr::NonNull,
         sync::{
             atomic::{AtomicBool, Ordering},
@@ -26,26 +25,16 @@ pub struct BankingPacketReceivers {
 }
 
 /// Spawns a thread to receive packets from TPU and send them to the external scheduler.
-///
-/// # Safety:
-/// - `allocator_worker_id` must be unique among all processes using the same allocator path.
-pub unsafe fn spawn(
+pub fn spawn(
     exit: Arc<AtomicBool>,
     receivers: BankingPacketReceivers,
-    allocator_path: PathBuf,
-    allocator_worker_id: u32,
-    queue_path: PathBuf,
+    allocator: rts_alloc::Allocator,
+    producer: shaq::Producer<TpuToPackMessage>,
 ) -> JoinHandle<()> {
     std::thread::Builder::new()
         .name("solTpu2Pack".to_string())
         .spawn(move || {
-            // Setup allocator and queue
-            // SAFETY: The caller must ensure that no other process is using the same worker id.
-            if let Some((allocator, producer)) =
-                unsafe { setup(allocator_path, allocator_worker_id, queue_path) }
-            {
-                tpu_to_pack(exit, receivers, allocator, producer);
-            }
+            tpu_to_pack(exit, receivers, allocator, producer);
         })
         .unwrap()
 }
@@ -222,29 +211,6 @@ fn map_src_addr(addr: IpAddr) -> [u8; 16] {
     }
 }
 
-/// # Safety:
-/// - `allocator_worker_id` must be unique among all processes using the same allocator path.
-unsafe fn setup(
-    allocator_path: impl AsRef<Path>,
-    allocator_worker_id: u32,
-    queue_path: impl AsRef<Path>,
-) -> Option<(Allocator, shaq::Producer<TpuToPackMessage>)> {
-    // SAFETY: The caller must ensure that no other process is using the same worker id.
-    let allocator = unsafe { Allocator::join(allocator_path, allocator_worker_id) }
-        .map_err(|err| {
-            error!("Failed to join allocator: {err:?}");
-        })
-        .ok()?;
-
-    let producer = shaq::Producer::join(queue_path)
-        .map_err(|err| {
-            error!("Failed to join queue: {err:?}");
-        })
-        .ok()?;
-
-    Some((allocator, producer))
-}
-
 #[cfg(test)]
 mod tests {
     use {super::*, std::net::Ipv4Addr};

+ 8 - 4
programs/sbf/Cargo.lock

@@ -138,8 +138,12 @@ dependencies = [
  "agave-scheduler-bindings",
  "agave-transaction-view",
  "ahash 0.8.11",
+ "libc",
+ "nix",
  "rts-alloc",
+ "shaq",
  "solana-pubkey",
+ "thiserror 2.0.17",
 ]
 
 [[package]]
@@ -5089,9 +5093,9 @@ dependencies = [
 
 [[package]]
 name = "rts-alloc"
-version = "0.1.1"
+version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7ccc6c6ddf12ced79aeae78454bcb4d0f5e31e5bc0aaca490e2f1b012f56b9d"
+checksum = "9c55727ea58e2c9c131d8f003dab5aaa7056d99f8292bc6a5dfb299cefe55e60"
 dependencies = [
  "libc",
 ]
@@ -5543,9 +5547,9 @@ dependencies = [
 
 [[package]]
 name = "shaq"
-version = "0.1.0"
+version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c451e9289cd55fd2406917d16abe8af5f83c74b3ba3736398f6aadd2ab1fba2"
+checksum = "014fb38bb8370732f76c67752106d2a4b25cc1891ec489c7fc5ab23b27e90a75"
 dependencies = [
  "libc",
 ]

+ 1 - 0
scheduler-bindings/Cargo.toml

@@ -11,6 +11,7 @@ edition = { workspace = true }
 
 [features]
 agave-unstable-api = []
+dev-context-only-utils = []
 
 [dependencies]
 

+ 41 - 1
scheduler-bindings/src/lib.rs

@@ -57,6 +57,10 @@
 //!
 
 /// Reference to a transaction that can shared safely across processes.
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct SharableTransactionRegion {
     /// Offset within the shared memory allocator.
@@ -66,6 +70,10 @@ pub struct SharableTransactionRegion {
 }
 
 /// Reference to an array of Pubkeys that can be shared safely across processes.
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct SharablePubkeys {
     /// Offset within the shared memory allocator.
@@ -86,6 +94,10 @@ pub struct SharablePubkeys {
 /// 4. External pack process frees all transaction memory pointed to by the
 ///    [`SharableTransactionRegion`] in the batch, then frees the memory for
 ///    the array of [`SharableTransactionRegion`].
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct SharableTransactionBatchRegion {
     /// Number of transactions in the batch.
@@ -101,6 +113,10 @@ pub struct SharableTransactionBatchRegion {
 /// 2. agave sends a [`WorkerToPackMessage`] with `responses`.
 /// 3. External pack process processes the inner messages. Potentially freeing
 ///    any memory within each inner message (see [`worker_message_types`] for details).
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct TransactionResponseRegion {
     /// Tag indicating the type of message.
@@ -125,6 +141,10 @@ pub struct TransactionResponseRegion {
 /// TPU passes transactions to the external pack process.
 /// This is also a transfer of ownership of the transaction:
 ///   the external pack process is responsible for freeing the memory.
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct TpuToPackMessage {
     pub transaction: SharableTransactionRegion,
@@ -151,6 +171,10 @@ pub mod tpu_message_flags {
 
 /// Message: [Agave -> Pack]
 /// Agave passes leader status to the external pack process.
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct ProgressMessage {
     /// The current slot.
@@ -181,6 +205,10 @@ pub const MAX_TRANSACTIONS_PER_MESSAGE: usize = 64;
 ///
 /// These messages do not transfer ownership of the transactions.
 /// The external pack process is still responsible for freeing the memory.
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct PackToWorkerMessage {
     /// Flags on how to handle this message.
@@ -219,6 +247,10 @@ pub mod pack_message_flags {
 
 /// Message: [Worker -> Pack]
 /// Message from worker threads in response to a [`PackToWorkerMessage`].
+#[cfg_attr(
+    feature = "dev-context-only-utils",
+    derive(Debug, Clone, Copy, PartialEq, Eq)
+)]
 #[repr(C)]
 pub struct WorkerToPackMessage {
     /// Offset and number of transactions in the batch.
@@ -244,12 +276,16 @@ pub struct WorkerToPackMessage {
 pub mod worker_message_types {
     use crate::SharablePubkeys;
 
-    /// Tag indicating [`ExecutionResonse`] inner message.
+    /// Tag indicating [`ExecutionResponse`] inner message.
     pub const EXECUTION_RESPONSE: u8 = 0;
 
     /// Response to pack for a transaction that attempted execution.
     /// This response will only be sent if the original message flags
     /// requested execution i.e. not [`super::pack_message_flags::RESOLVE`].
+    #[cfg_attr(
+        feature = "dev-context-only-utils",
+        derive(Debug, Clone, Copy, PartialEq, Eq)
+    )]
     #[repr(C)]
     pub struct ExecutionResponse {
         /// Indicates if the transaction was included in the block or not.
@@ -292,6 +328,10 @@ pub mod worker_message_types {
     /// Tag indicating [`Resolved`] inner message.
     pub const RESOLVED: u8 = 1;
 
+    #[cfg_attr(
+        feature = "dev-context-only-utils",
+        derive(Debug, Clone, Copy, PartialEq, Eq)
+    )]
     #[repr(C)]
     pub struct Resolved {
         /// Indicates if resolution was successful.

+ 12 - 2
scheduling-utils/Cargo.toml

@@ -13,13 +13,23 @@ edition = { workspace = true }
 agave-unstable-api = []
 
 [dependencies]
-agave-scheduler-bindings = { workspace = true }
-agave-transaction-view = { workspace = true }
 ahash = { workspace = true }
 solana-pubkey = { workspace = true }
 
 [target."cfg(unix)".dependencies]
+agave-scheduler-bindings = { workspace = true }
+agave-transaction-view = { workspace = true }
+libc = { workspace = true }
+nix = { workspace = true, features = ["socket", "uio"] }
 rts-alloc = { workspace = true }
+shaq = { workspace = true }
+thiserror = { workspace = true }
+
+[dev-dependencies]
+tempfile = { workspace = true }
+
+[target."cfg(unix)".dev-dependencies]
+agave-scheduler-bindings = { workspace = true, features = ["dev-context-only-utils"] }
 
 [lints]
 workspace = true

+ 233 - 0
scheduling-utils/src/handshake/client.rs

@@ -0,0 +1,233 @@
+use {
+    crate::handshake::{
+        shared::{GLOBAL_ALLOCATORS, LOGON_FAILURE, MAX_WORKERS, VERSION},
+        ClientLogon,
+    },
+    agave_scheduler_bindings::{
+        PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
+    },
+    libc::CMSG_LEN,
+    nix::sys::socket::{self, ControlMessageOwned, MsgFlags, UnixAddr},
+    rts_alloc::Allocator,
+    std::{
+        fs::File,
+        io::{IoSliceMut, Write},
+        os::{
+            fd::{AsRawFd, FromRawFd},
+            unix::net::UnixStream,
+        },
+        path::Path,
+        time::Duration,
+    },
+    thiserror::Error,
+};
+
+type RtsError = rts_alloc::error::Error;
+type ShaqError = shaq::error::Error;
+
+/// Number of global shared memory objects (in addition to per worker objects).
+const GLOBAL_SHMEM: usize = 3;
+
+/// The maximum size in bytes of the control message containing the queues assuming [`MAX_WORKERS`]
+/// is respected.
+///
+/// Each FD is 4 bytes so we simply multiply the number of shmem objects by 4 to get the control
+/// message buffer size.
+const CMSG_MAX_SIZE: usize = (GLOBAL_SHMEM + MAX_WORKERS * 2) * 4;
+
+/// Connects to the scheduler server on the given IPC path.
+///
+/// # Timeout
+///
+/// Timeout is enforced at the syscall level. In the typical case, this function will do two
+/// syscalls, one to send the logon message and one to receive the response. However, if for
+/// whatever reason the OS does not accept 1024 bytes in a single syscall, then multiple writes
+/// could be needed. As such this timeout is meant to guard against a broken server but not
+/// necessarily ensure this function always returns before the timeout (this is somewhat in line
+/// with typical timeouts because you have no guarantee of being rescheduled).
+pub fn connect(
+    path: impl AsRef<Path>,
+    logon: ClientLogon,
+    timeout: Duration,
+) -> Result<ClientSession, ClientHandshakeError> {
+    connect_path(path.as_ref(), logon, timeout)
+}
+
+fn connect_path(
+    path: &Path,
+    logon: ClientLogon,
+    timeout: Duration,
+) -> Result<ClientSession, ClientHandshakeError> {
+    // NB: Technically this connect call can block indefinitely if the receiver's connection queue
+    // is full. In practice this should almost never happen. If it does work arounds are:
+    //
+    // - Users can spawn off a thread to handle the connect call and then just poll that thread
+    //   exiting.
+    // - This library could drop to raw unix sockets and use select/poll to enforce a timeout on the
+    //   IO operation.
+    let mut stream = UnixStream::connect(path)?;
+    stream.set_read_timeout(Some(timeout))?;
+    stream.set_write_timeout(Some(timeout))?;
+
+    // Send the logon message to the server.
+    send_logon(&mut stream, logon)?;
+
+    // Receive the server's response & on success the FDs for the newly allocated shared memory.
+    let fds = recv_response(&mut stream)?;
+
+    // Join the shared memory regions.
+    let session = setup_session(&logon, fds)?;
+
+    Ok(session)
+}
+
+fn send_logon(stream: &mut UnixStream, logon: ClientLogon) -> Result<(), ClientHandshakeError> {
+    // Send the logon message.
+    let mut buf = [0; 1024];
+    buf[..8].copy_from_slice(&VERSION.to_le_bytes());
+    const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
+    let ptr = buf[8..LOGON_END].as_mut_ptr().cast::<ClientLogon>();
+    // SAFETY:
+    // - `buf` is valid for writes.
+    // - `buf.len()` has enough space for logon's size in memory.
+    unsafe {
+        core::ptr::write_unaligned(ptr, logon);
+    }
+    stream.write_all(&buf)?;
+
+    Ok(())
+}
+
+fn recv_response(stream: &mut UnixStream) -> Result<Vec<i32>, ClientHandshakeError> {
+    // Receive the requested FDs.
+    let mut buf = [0; 1024];
+    let mut iov = [IoSliceMut::new(&mut buf)];
+    // SAFETY: CMSG_LEN is always safe (const expression).
+    let mut cmsgs = [0u8; unsafe { CMSG_LEN(CMSG_MAX_SIZE as u32) as usize }];
+    let msg = socket::recvmsg::<UnixAddr>(
+        stream.as_raw_fd(),
+        &mut iov,
+        Some(&mut cmsgs),
+        MsgFlags::empty(),
+    )?;
+
+    // Check for failure.
+    let buf = msg.iovs().next().unwrap();
+    if buf[0] == LOGON_FAILURE {
+        let reason_len = usize::from(buf[1]);
+        #[allow(clippy::arithmetic_side_effects)]
+        let reason = std::str::from_utf8(&buf[2..2 + reason_len]).unwrap();
+
+        return Err(ClientHandshakeError::Rejected(reason.to_string()));
+    }
+
+    // Extract FDs.
+    let mut cmsgs = msg.cmsgs().unwrap();
+    let fds = match cmsgs.next() {
+        Some(ControlMessageOwned::ScmRights(fds)) => fds,
+        Some(msg) => panic!("Unexpected; msg={msg:?}"),
+        None => panic!(),
+    };
+
+    Ok(fds)
+}
+
+fn setup_session(
+    logon: &ClientLogon,
+    fds: Vec<i32>,
+) -> Result<ClientSession, ClientHandshakeError> {
+    let [allocator_fd, tpu_to_pack_fd, progress_tracker_fd] = fds[..GLOBAL_SHMEM] else {
+        panic!();
+    };
+    // SAFETY: `allocator_fd` represents a valid file descriptor that was just returned to us via
+    // `ScmRights`.
+    let allocator_file = unsafe { File::from_raw_fd(allocator_fd) };
+    let worker_fds = &fds[GLOBAL_SHMEM..];
+
+    // Setup requested allocators.
+    let allocators = (0..logon.allocator_handles)
+        .map(|offset| {
+            // NB: Server validates all requested counts are within expected bands so this should
+            // never panic.
+            let id = GLOBAL_ALLOCATORS
+                .checked_add(logon.worker_count)
+                .unwrap()
+                .checked_add(offset)
+                .unwrap();
+
+            unsafe { Allocator::join(&allocator_file, u32::try_from(id).unwrap()) }
+        })
+        .collect::<Result<Vec<_>, _>>()?;
+
+    // Ensure worker_fds length matches expectations.
+    if worker_fds.is_empty()
+        || worker_fds.len() % 2 != 0
+        || worker_fds.len() / 2 != logon.worker_count
+    {
+        return Err(ClientHandshakeError::ProtocolViolation);
+    }
+
+    // NB: After creating & mapping the queues we are fine to drop the FDs as mmap will keep the
+    // underlying object alive until process exit or munmap.
+    let session = ClientSession {
+        allocators,
+        tpu_to_pack: unsafe { shaq::Consumer::join(&File::from_raw_fd(tpu_to_pack_fd))? },
+        progress_tracker: unsafe { shaq::Consumer::join(&File::from_raw_fd(progress_tracker_fd))? },
+        workers: worker_fds
+            .chunks(2)
+            .map(|window| {
+                let [pack_to_worker, worker_to_pack] = window else {
+                    panic!();
+                };
+
+                Ok(ClientWorkerSession {
+                    pack_to_worker: unsafe {
+                        shaq::Producer::join(&File::from_raw_fd(*pack_to_worker))?
+                    },
+                    worker_to_pack: unsafe {
+                        shaq::Consumer::join(&File::from_raw_fd(*worker_to_pack))?
+                    },
+                })
+            })
+            .collect::<Result<_, ClientHandshakeError>>()?,
+    };
+
+    Ok(session)
+}
+
+/// The complete initialized scheduling session.
+pub struct ClientSession {
+    pub allocators: Vec<Allocator>,
+    pub tpu_to_pack: shaq::Consumer<TpuToPackMessage>,
+    pub progress_tracker: shaq::Consumer<ProgressMessage>,
+    pub workers: Vec<ClientWorkerSession>,
+}
+
+/// An per worker scheduling session.
+pub struct ClientWorkerSession {
+    pub pack_to_worker: shaq::Producer<PackToWorkerMessage>,
+    pub worker_to_pack: shaq::Consumer<WorkerToPackMessage>,
+}
+
+/// Potential errors that can occur during the client's side of the handshake.
+#[derive(Debug, Error)]
+pub enum ClientHandshakeError {
+    #[error("Io; err={0}")]
+    Io(#[from] std::io::Error),
+    #[error("Timed out")]
+    TimedOut,
+    #[error("Protocol violation")]
+    ProtocolViolation,
+    #[error("Rejected; reason={0}")]
+    Rejected(String),
+    #[error("Rts alloc; err={0}")]
+    RtsAlloc(#[from] RtsError),
+    #[error("Shaq; err={0}")]
+    Shaq(#[from] ShaqError),
+}
+
+impl From<nix::Error> for ClientHandshakeError {
+    fn from(value: nix::Error) -> Self {
+        Self::Io(value.into())
+    }
+}

+ 7 - 0
scheduling-utils/src/handshake/mod.rs

@@ -0,0 +1,7 @@
+pub mod client;
+pub mod server;
+mod shared;
+#[cfg(test)]
+mod tests;
+
+pub use shared::ClientLogon;

+ 335 - 0
scheduling-utils/src/handshake/server.rs

@@ -0,0 +1,335 @@
+use {
+    crate::handshake::{
+        shared::{
+            GLOBAL_ALLOCATORS, LOGON_FAILURE, LOGON_SUCCESS, MAX_ALLOCATOR_HANDLES, MAX_WORKERS,
+            VERSION,
+        },
+        ClientLogon,
+    },
+    agave_scheduler_bindings::{
+        PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
+    },
+    nix::sys::socket::{self, ControlMessage, MsgFlags, UnixAddr},
+    rts_alloc::Allocator,
+    std::{
+        ffi::CStr,
+        fs::File,
+        io::{IoSlice, Read, Write},
+        os::{
+            fd::{AsRawFd, FromRawFd},
+            unix::net::{UnixListener, UnixStream},
+        },
+        path::Path,
+        time::{Duration, Instant},
+    },
+    thiserror::Error,
+};
+
+type ShaqError = shaq::error::Error;
+type RtsAllocError = rts_alloc::error::Error;
+
+const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(1);
+const SHMEM_NAME: &CStr = c"/agave-scheduler-bindings";
+
+/// Implements the Agave side of the scheduler bindings handshake protocol.
+pub struct Server {
+    listener: UnixListener,
+
+    buffer: [u8; 1024],
+}
+
+impl Server {
+    pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
+        let listener = UnixListener::bind(path)?;
+
+        Ok(Self {
+            listener,
+            buffer: [0; 1024],
+        })
+    }
+
+    pub fn accept(&mut self) -> Result<AgaveSession, AgaveHandshakeError> {
+        // Wait for next stream.
+        let (mut stream, _) = self.listener.accept()?;
+        stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT))?;
+
+        match self.handle_logon(&mut stream) {
+            Ok(session) => Ok(session),
+            Err(err) => {
+                let reason = err.to_string();
+                let reason_len = u8::try_from(reason.len()).unwrap_or(u8::MAX);
+
+                let buffer_len = 2usize.checked_add(usize::from(reason_len)).unwrap();
+                self.buffer[0] = LOGON_FAILURE;
+                self.buffer[1] = reason_len;
+                self.buffer[2..buffer_len]
+                    .copy_from_slice(&reason.as_bytes()[..usize::from(reason_len)]);
+
+                stream.set_nonblocking(true)?;
+                // NB: Caller will still error out even if our write fails so it's fine to ignore the
+                // result.
+                let _ = stream.write(&self.buffer[..buffer_len])?;
+
+                Err(err)
+            }
+        }
+    }
+
+    fn handle_logon(
+        &mut self,
+        stream: &mut UnixStream,
+    ) -> Result<AgaveSession, AgaveHandshakeError> {
+        // Receive & validate the logon message.
+        let logon = self.recv_logon(stream)?;
+
+        // Setup the requested shared memory regions.
+        let (session, files) = Self::setup_session(logon)?;
+
+        // Send the file descriptors to the client.
+        let fds_raw: Vec<_> = files.iter().map(|file| file.as_raw_fd()).collect();
+        let iov = [IoSlice::new(&[LOGON_SUCCESS])];
+        let cmsgs = [ControlMessage::ScmRights(&fds_raw)];
+        let sent =
+            socket::sendmsg::<UnixAddr>(stream.as_raw_fd(), &iov, &cmsgs, MsgFlags::empty(), None)
+                .map_err(std::io::Error::from)?;
+        debug_assert_eq!(sent, 1);
+
+        Ok(session)
+    }
+
+    fn recv_logon(&mut self, stream: &mut UnixStream) -> Result<ClientLogon, AgaveHandshakeError> {
+        // Read the logon message.
+        let handshake_start = Instant::now();
+        let mut buffer_len = 0;
+        while buffer_len < self.buffer.len() {
+            let read = stream.read(&mut self.buffer[buffer_len..])?;
+            if read == 0 {
+                return Err(AgaveHandshakeError::EofDuringHandshake);
+            }
+
+            // SAFETY: We cannot read a value greater than buffer.len() which itself is a usize.
+            buffer_len = buffer_len.checked_add(read).unwrap();
+
+            if handshake_start.elapsed() > HANDSHAKE_TIMEOUT {
+                return Err(AgaveHandshakeError::Timeout);
+            }
+        }
+
+        // Ensure exact version match, version will be bumped any time a backwards incompatible
+        // change is made to handshake/shared memory objects.
+        let version = u64::from_le_bytes(self.buffer[..8].try_into().unwrap());
+        if version != VERSION {
+            return Err(AgaveHandshakeError::Version {
+                server: VERSION,
+                client: version,
+            });
+        }
+
+        // Read the logon message, cannot panic as we ensure the correct buf size at compile time
+        // (hence the const just below).
+        const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
+        let logon = ClientLogon::try_from_bytes(&self.buffer[8..LOGON_END]).unwrap();
+
+        // Put a hard limit of 64 worker threads for now.
+        if !(1..=MAX_WORKERS).contains(&logon.worker_count) {
+            return Err(AgaveHandshakeError::WorkerCount(logon.worker_count));
+        }
+
+        // Hard limit allocator handles to 128.
+        if !(1..=MAX_ALLOCATOR_HANDLES).contains(&logon.allocator_handles) {
+            return Err(AgaveHandshakeError::AllocatorHandles(
+                logon.allocator_handles,
+            ));
+        }
+
+        Ok(logon)
+    }
+
+    fn setup_session(logon: ClientLogon) -> Result<(AgaveSession, Vec<File>), AgaveHandshakeError> {
+        // Setup the allocator in shared memory (`worker_count` & `allocator_handles` have been
+        // validated so this won't panic).
+        let allocator_count = GLOBAL_ALLOCATORS
+            .checked_add(logon.worker_count)
+            .unwrap()
+            .checked_add(logon.allocator_handles)
+            .unwrap();
+        let allocator_file = Self::create_shmem()?;
+        let tpu_to_pack_allocator = Allocator::create(
+            &allocator_file,
+            logon.allocator_size,
+            u32::try_from(allocator_count).unwrap(),
+            2 * 1024 * 1024,
+            0,
+        )?;
+
+        // Setup the global queues.
+        let (tpu_to_pack_file, tpu_to_pack_queue) = Self::create_producer(logon.tpu_to_pack_size)?;
+        let (progress_tracker_file, progress_tracker) =
+            Self::create_producer(logon.progress_tracker_size)?;
+
+        // Setup the worker sessions.
+        let (worker_files, workers) = (0..logon.worker_count).try_fold(
+            (Vec::default(), Vec::default()),
+            |(mut fds, mut workers), offset| {
+                // NB: Server validates all requested counts are within expected bands so this
+                // should never panic.
+                let worker_index = GLOBAL_ALLOCATORS.checked_add(offset).unwrap();
+                let worker_index = u32::try_from(worker_index).unwrap();
+                // SAFETY: Worker index is guaranteed to be unique.
+                let allocator = unsafe { Allocator::join(&allocator_file, worker_index) }?;
+
+                let (pack_to_worker_file, pack_to_worker) =
+                    Self::create_consumer(logon.pack_to_worker_size)?;
+                let (worker_to_pack_file, worker_to_pack) =
+                    Self::create_producer(logon.worker_to_pack_size)?;
+
+                fds.extend([pack_to_worker_file, worker_to_pack_file]);
+                workers.push(AgaveWorkerSession {
+                    allocator,
+                    pack_to_worker,
+                    worker_to_pack,
+                });
+
+                Ok::<_, AgaveHandshakeError>((fds, workers))
+            },
+        )?;
+
+        Ok((
+            AgaveSession {
+                tpu_to_pack: AgaveTpuToPackSession {
+                    allocator: tpu_to_pack_allocator,
+                    queue: tpu_to_pack_queue,
+                },
+                progress_tracker,
+                workers,
+            },
+            [allocator_file, tpu_to_pack_file, progress_tracker_file]
+                .into_iter()
+                .chain(worker_files)
+                .collect(),
+        ))
+    }
+
+    fn create_producer<T>(size: usize) -> Result<(File, shaq::Producer<T>), ShaqError> {
+        let file = Self::create_shmem()?;
+        let queue = shaq::Producer::create(&file, size)?;
+
+        Ok((file, queue))
+    }
+
+    fn create_consumer(
+        size: usize,
+    ) -> Result<(File, shaq::Consumer<PackToWorkerMessage>), ShaqError> {
+        let file = Self::create_shmem()?;
+        let queue = shaq::Consumer::create(&file, size)?;
+
+        Ok((file, queue))
+    }
+
+    #[cfg(any(
+        target_os = "linux",
+        target_os = "l4re",
+        target_os = "android",
+        target_os = "emscripten"
+    ))]
+    fn create_shmem() -> Result<File, std::io::Error> {
+        unsafe {
+            let ret = libc::memfd_create(SHMEM_NAME.as_ptr(), 0);
+            if ret == -1 {
+                return Err(std::io::Error::last_os_error());
+            }
+
+            Ok(File::from_raw_fd(ret))
+        }
+    }
+
+    #[cfg(not(any(
+        target_os = "linux",
+        target_os = "l4re",
+        target_os = "android",
+        target_os = "emscripten"
+    )))]
+    fn create_shmem() -> Result<File, std::io::Error> {
+        unsafe {
+            // Clean up the previous link if one exists.
+            let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
+            if ret == -1 {
+                let err = std::io::Error::last_os_error();
+                if err.kind() != std::io::ErrorKind::NotFound {
+                    return Err(err);
+                }
+            }
+
+            // Create a new shared memory object.
+            let ret = libc::shm_open(
+                SHMEM_NAME.as_ptr(),
+                libc::O_CREAT | libc::O_EXCL | libc::O_RDWR,
+                #[cfg(not(target_os = "macos"))]
+                {
+                    libc::S_IRUSR | libc::S_IWUSR
+                },
+                #[cfg(any(target_os = "macos", target_os = "ios"))]
+                {
+                    (libc::S_IRUSR | libc::S_IWUSR) as libc::c_uint
+                },
+            );
+            if ret == -1 {
+                return Err(std::io::Error::last_os_error());
+            }
+            let file = File::from_raw_fd(ret);
+
+            // Clean up after ourself.
+            let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
+            if ret == -1 {
+                return Err(std::io::Error::last_os_error());
+            }
+
+            Ok(file)
+        }
+    }
+}
+
+/// An initialized scheduling session.
+pub struct AgaveSession {
+    pub tpu_to_pack: AgaveTpuToPackSession,
+    pub progress_tracker: shaq::Producer<ProgressMessage>,
+    pub workers: Vec<AgaveWorkerSession>,
+}
+
+/// Shared memory objects for the tpu to pack worker.
+pub struct AgaveTpuToPackSession {
+    pub allocator: Allocator,
+    pub queue: shaq::Producer<TpuToPackMessage>,
+}
+
+/// Shared memory objects for a single banking worker.
+pub struct AgaveWorkerSession {
+    pub allocator: Allocator,
+    pub pack_to_worker: shaq::Consumer<PackToWorkerMessage>,
+    pub worker_to_pack: shaq::Producer<WorkerToPackMessage>,
+}
+
+/// Potential errors that can occur during the Agave side of the handshake.
+///
+/// # Note
+///
+/// These errors are stringified (up to 256 bytes then truncated) and sent to the client.
+#[derive(Debug, Error)]
+pub enum AgaveHandshakeError {
+    #[error("Io; err={0}")]
+    Io(#[from] std::io::Error),
+    #[error("Timeout")]
+    Timeout,
+    #[error("Close during handshake")]
+    EofDuringHandshake,
+    #[error("Version; server={server}; client={client}")]
+    Version { server: u64, client: u64 },
+    #[error("Worker count; count={0}")]
+    WorkerCount(usize),
+    #[error("Allocator handles; count={0}")]
+    AllocatorHandles(usize),
+    #[error("Rts alloc; err={0:?}")]
+    RtsAlloc(#[from] RtsAllocError),
+    #[error("Shaq; err={0:?}")]
+    Shaq(#[from] ShaqError),
+}

+ 45 - 0
scheduling-utils/src/handshake/shared.rs

@@ -0,0 +1,45 @@
+pub(crate) const VERSION: u64 = 1;
+pub(crate) const LOGON_SUCCESS: u8 = 0x01;
+pub(crate) const LOGON_FAILURE: u8 = 0x02;
+pub(crate) const MAX_WORKERS: usize = 64;
+pub(crate) const MAX_ALLOCATOR_HANDLES: usize = 128;
+pub(crate) const GLOBAL_ALLOCATORS: usize = 1;
+
+/// The logon message sent by the client to the server.
+#[derive(Debug, Default, Clone, Copy)]
+#[repr(C)]
+pub struct ClientLogon {
+    /// The number of Agave worker threads that will be spawned to handle packing requests.
+    pub worker_count: usize,
+    /// The allocator file size in bytes, this is shared by all allocator handles.
+    pub allocator_size: usize,
+    /// The number of [`rts_alloc::Allocator`] handles the external process is requesting.
+    pub allocator_handles: usize,
+    /// The size of the `tpu_to_pack` queue in bytes.
+    pub tpu_to_pack_size: usize,
+    /// The size of the `progress_tracker` queue in bytes.
+    pub progress_tracker_size: usize,
+    /// The size of the `pack_to_worker` queue in bytes.
+    pub pack_to_worker_size: usize,
+    /// The size of the `worker_to_pack` queue in bytes.
+    pub worker_to_pack_size: usize,
+    // NB: If adding more fields please ensure:
+    // - The fields are zeroable.
+    // - If possible the fields are backwards compatible:
+    //   - Added to the end of the struct.
+    //   - 0 bytes is valid default (older clients will not have the field and thus send zeroes).
+    // - If not backwards compatible, increment the version counter.
+}
+
+impl ClientLogon {
+    pub fn try_from_bytes(buffer: &[u8]) -> Option<Self> {
+        if buffer.len() != core::mem::size_of::<Self>() {
+            return None;
+        }
+
+        // SAFETY:
+        // - buffer is correctly sized, initialized and readable.
+        // - `Self` is valid for any byte pattern
+        Some(unsafe { core::ptr::read_unaligned(buffer.as_ptr().cast()) })
+    }
+}

+ 312 - 0
scheduling-utils/src/handshake/tests.rs

@@ -0,0 +1,312 @@
+use {
+    crate::handshake::{
+        client::{connect, ClientHandshakeError},
+        server::{AgaveHandshakeError, Server},
+        shared::MAX_WORKERS,
+        ClientLogon,
+    },
+    agave_scheduler_bindings::{
+        PackToWorkerMessage, ProgressMessage, SharableTransactionBatchRegion,
+        SharableTransactionRegion, TpuToPackMessage, TransactionResponseRegion,
+        WorkerToPackMessage,
+    },
+    std::time::Duration,
+    tempfile::NamedTempFile,
+};
+
+#[test]
+fn message_passing_on_all_queues() {
+    let ipc = NamedTempFile::new().unwrap();
+    std::fs::remove_file(ipc.path()).unwrap();
+    let mut server = Server::new(ipc.path()).unwrap();
+
+    // Test messages.
+    let tpu_to_pack = TpuToPackMessage {
+        transaction: SharableTransactionRegion {
+            offset: 10,
+            length: 5,
+        },
+        flags: 21,
+        src_addr: [4; 16],
+    };
+    let progress_tracker = ProgressMessage {
+        current_slot: 3,
+        next_leader_slot: 12,
+        remaining_cost_units: 12_000_000,
+        current_slot_progress: 32,
+    };
+    let pack_to_worker = PackToWorkerMessage {
+        flags: 123,
+        max_execution_slot: 100,
+        batch: SharableTransactionBatchRegion {
+            num_transactions: 5,
+            transactions_offset: 100,
+        },
+    };
+    let worker_to_pack = WorkerToPackMessage {
+        batch: SharableTransactionBatchRegion {
+            num_transactions: 5,
+            transactions_offset: 100,
+        },
+        processed: 0x01,
+        responses: TransactionResponseRegion {
+            tag: 3,
+            num_transaction_responses: 2,
+            transaction_responses_offset: 1,
+        },
+    };
+
+    let server_handle = std::thread::spawn(move || {
+        let mut session = server.accept().unwrap();
+
+        // Send a tpu_to_pack message.
+        let mut slot = session.tpu_to_pack.queue.reserve().unwrap();
+        unsafe { *slot.as_mut() = tpu_to_pack };
+        session.tpu_to_pack.queue.commit();
+
+        // Send a progress_tracker message.
+        let mut slot = session.progress_tracker.reserve().unwrap();
+        unsafe { *slot.as_mut() = progress_tracker };
+        session.progress_tracker.commit();
+
+        // Receive pack_to_worker messages.
+        for (i, worker) in session.workers.iter_mut().enumerate() {
+            let msg = loop {
+                worker.pack_to_worker.sync();
+                if let Some(slot) = worker.pack_to_worker.try_read() {
+                    break unsafe { *slot.as_ref() };
+                }
+            };
+            assert_eq!(
+                PackToWorkerMessage {
+                    max_execution_slot: pack_to_worker.max_execution_slot + i as u64,
+                    ..pack_to_worker
+                },
+                msg
+            );
+        }
+
+        // Send worker_to_pack messages.
+        for (i, worker) in session.workers.iter_mut().enumerate() {
+            let mut slot = worker.worker_to_pack.reserve().unwrap();
+            unsafe {
+                *slot.as_mut() = WorkerToPackMessage {
+                    batch: SharableTransactionBatchRegion {
+                        num_transactions: worker_to_pack.batch.num_transactions + i as u8,
+                        ..worker_to_pack.batch
+                    },
+                    ..worker_to_pack
+                }
+            };
+            worker.worker_to_pack.commit();
+        }
+    });
+    let client_handle = std::thread::spawn(move || {
+        let mut session = connect(
+            ipc,
+            ClientLogon {
+                worker_count: 4,
+                allocator_size: 1024 * 1024 * 1024,
+                allocator_handles: 3,
+                tpu_to_pack_size: 65536 * 1024,
+                progress_tracker_size: 16 * 1024,
+                pack_to_worker_size: 1024 * 1024,
+                worker_to_pack_size: 1024 * 1024,
+            },
+            Duration::from_secs(1),
+        )
+        .unwrap();
+
+        // Receive tpu_to_pack message.
+        let msg = loop {
+            session.tpu_to_pack.sync();
+            if let Some(msg) = session.tpu_to_pack.try_read() {
+                break unsafe { *msg.as_ref() };
+            };
+        };
+        assert_eq!(msg, tpu_to_pack);
+
+        // Receive progress_tracker message.
+        let msg = loop {
+            session.progress_tracker.sync();
+            if let Some(msg) = session.progress_tracker.try_read() {
+                break unsafe { *msg.as_ref() };
+            };
+        };
+        assert_eq!(msg, progress_tracker);
+
+        // Send pack_to_worker messages.
+        for (i, worker) in session.workers.iter_mut().enumerate() {
+            let mut slot = worker.pack_to_worker.reserve().unwrap();
+            unsafe {
+                *slot.as_mut() = PackToWorkerMessage {
+                    max_execution_slot: pack_to_worker.max_execution_slot + i as u64,
+                    ..pack_to_worker
+                }
+            };
+            worker.pack_to_worker.commit();
+        }
+
+        // Receive worker_to_pack messages.
+        for (i, worker) in session.workers.iter_mut().enumerate() {
+            let msg = loop {
+                worker.worker_to_pack.sync();
+                if let Some(slot) = worker.worker_to_pack.try_read() {
+                    break unsafe { *slot.as_ref() };
+                }
+            };
+            assert_eq!(
+                WorkerToPackMessage {
+                    batch: SharableTransactionBatchRegion {
+                        num_transactions: worker_to_pack.batch.num_transactions + i as u8,
+                        ..worker_to_pack.batch
+                    },
+                    ..worker_to_pack
+                },
+                msg
+            );
+        }
+    });
+
+    client_handle.join().unwrap();
+    server_handle.join().unwrap();
+}
+
+#[test]
+fn accept_worker_count_max() {
+    let ipc = NamedTempFile::new().unwrap();
+    std::fs::remove_file(ipc.path()).unwrap();
+    let mut server = Server::new(ipc.path()).unwrap();
+
+    let server_handle = std::thread::spawn(move || {
+        let res = server.accept();
+        assert!(res.is_ok());
+    });
+    let client_handle = std::thread::spawn(move || {
+        let res = connect(
+            ipc,
+            ClientLogon {
+                worker_count: MAX_WORKERS,
+                allocator_size: 1024 * 1024 * 1024,
+                allocator_handles: 3,
+                tpu_to_pack_size: 65536 * 1024,
+                progress_tracker_size: 16 * 1024,
+                pack_to_worker_size: 1024 * 1024,
+                worker_to_pack_size: 1024 * 1024,
+            },
+            Duration::from_secs(1),
+        );
+        assert!(res.is_ok());
+    });
+
+    client_handle.join().unwrap();
+    server_handle.join().unwrap();
+}
+
+#[test]
+fn reject_worker_count_low() {
+    let ipc = NamedTempFile::new().unwrap();
+    std::fs::remove_file(ipc.path()).unwrap();
+    let mut server = Server::new(ipc.path()).unwrap();
+
+    let server_handle = std::thread::spawn(move || {
+        let res = server.accept();
+        let Err(AgaveHandshakeError::WorkerCount(count)) = res else {
+            panic!();
+        };
+        assert_eq!(count, 0);
+    });
+    let client_handle = std::thread::spawn(move || {
+        let res = connect(
+            ipc,
+            ClientLogon {
+                worker_count: 0,
+                allocator_size: 1024 * 1024 * 1024,
+                allocator_handles: 3,
+                tpu_to_pack_size: 65536 * 1024,
+                progress_tracker_size: 16 * 1024,
+                pack_to_worker_size: 1024 * 1024,
+                worker_to_pack_size: 1024 * 1024,
+            },
+            Duration::from_secs(1),
+        );
+        let Err(ClientHandshakeError::Rejected(reason)) = res else {
+            panic!();
+        };
+        assert_eq!(reason, "Worker count; count=0");
+    });
+
+    client_handle.join().unwrap();
+    server_handle.join().unwrap();
+}
+
+#[test]
+fn reject_worker_count_high() {
+    let ipc = NamedTempFile::new().unwrap();
+    std::fs::remove_file(ipc.path()).unwrap();
+    let mut server = Server::new(ipc.path()).unwrap();
+
+    let server_handle = std::thread::spawn(move || {
+        let res = server.accept();
+        let Err(AgaveHandshakeError::WorkerCount(count)) = res else {
+            panic!();
+        };
+        assert_eq!(count, 100);
+    });
+    let client_handle = std::thread::spawn(move || {
+        let res = connect(
+            ipc,
+            ClientLogon {
+                worker_count: 100,
+                allocator_size: 1024 * 1024 * 1024,
+                allocator_handles: 3,
+                tpu_to_pack_size: 65536 * 1024,
+                progress_tracker_size: 16 * 1024,
+                pack_to_worker_size: 1024 * 1024,
+                worker_to_pack_size: 1024 * 1024,
+            },
+            Duration::from_secs(1),
+        );
+        let Err(ClientHandshakeError::Rejected(reason)) = res else {
+            panic!();
+        };
+        assert_eq!(reason, "Worker count; count=100");
+    });
+
+    client_handle.join().unwrap();
+    server_handle.join().unwrap();
+}
+
+#[test]
+fn reject_invalid_queue_size() {
+    let ipc = NamedTempFile::new().unwrap();
+    std::fs::remove_file(ipc.path()).unwrap();
+    let mut server = Server::new(ipc.path()).unwrap();
+
+    let server_handle = std::thread::spawn(move || {
+        let res = server.accept();
+        assert!(matches!(res, Err(AgaveHandshakeError::Shaq(_))));
+    });
+    let client_handle = std::thread::spawn(move || {
+        let res = connect(
+            ipc,
+            ClientLogon {
+                worker_count: 4,
+                allocator_size: 1024 * 1024 * 1024,
+                allocator_handles: 3,
+                tpu_to_pack_size: 0,
+                progress_tracker_size: 16 * 1024,
+                pack_to_worker_size: 1024 * 1024,
+                worker_to_pack_size: 1024 * 1024,
+            },
+            Duration::from_secs(1),
+        );
+        let Err(ClientHandshakeError::Rejected(reason)) = res else {
+            panic!();
+        };
+        assert_eq!(reason, "Shaq; err=InvalidBufferSize");
+    });
+
+    client_handle.join().unwrap();
+    server_handle.join().unwrap();
+}

+ 2 - 0
scheduling-utils/src/lib.rs

@@ -9,5 +9,7 @@
 )]
 pub mod thread_aware_account_locks;
 
+#[cfg(unix)]
+pub mod handshake;
 #[cfg(unix)]
 pub mod transaction_ptr;