Browse Source

log: Update precision logic (#252)

* Add precision cap

* Add tests

* Refactor precision logic

* Fix miri warning

* Avoid duplication

* Add missing syscall

* More tests

* Fix truncate logic

* Fix review comments
Fernando Otero 1 tháng trước cách đây
mục cha
commit
a271912984
2 tập tin đã thay đổi với 285 bổ sung86 xóa
  1. 139 1
      sdk/log/crate/src/lib.rs
  2. 146 85
      sdk/log/crate/src/logger.rs

+ 139 - 1
sdk/log/crate/src/lib.rs

@@ -199,9 +199,107 @@ mod tests {
 
         logger.clear();
 
-        // This should have no effect.
+        // This should have no effect since it is a string.
         logger.append_with_args("0123456789", &[Argument::Precision(2)]);
         assert!(&*logger == "0123456789".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args(2u8, &[Argument::Precision(8)]);
+        assert!(&*logger == "0.00000002".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args(2u8, &[Argument::Precision(u8::MAX)]);
+        assert!(&*logger == "0.0000000@".as_bytes());
+
+        let mut logger = Logger::<20>::default();
+
+        logger.append_with_args(2u8, &[Argument::Precision(u8::MAX)]);
+        assert!(&*logger == "0.00000000000000000@".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args(20_000u16, &[Argument::Precision(10)]);
+        assert!(&*logger == "0.0000020000".as_bytes());
+
+        let mut logger = Logger::<3>::default();
+
+        logger.append_with_args(2u64, &[Argument::Precision(u8::MAX)]);
+        assert!(&*logger == "0.@".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args(2u64, &[Argument::Precision(1)]);
+        assert!(&*logger == "0.2".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args(-2i64, &[Argument::Precision(1)]);
+        assert!(&*logger == "-0@".as_bytes());
+
+        let mut logger = Logger::<1>::default();
+
+        logger.append_with_args(-2i64, &[Argument::Precision(1)]);
+        assert!(&*logger == "@".as_bytes());
+
+        let mut logger = Logger::<2>::default();
+
+        logger.append_with_args(-2i64, &[Argument::Precision(1)]);
+        assert!(&*logger == "-@".as_bytes());
+
+        let mut logger = Logger::<20>::default();
+
+        logger.append_with_args(u64::MAX, &[Argument::Precision(u8::MAX)]);
+        assert!(&*logger == "0.00000000000000000@".as_bytes());
+
+        // 255 precision + leading 0 + decimal point
+        let mut logger = Logger::<257>::default();
+        logger.append_with_args(u64::MAX, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("18446744073709551615".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(u32::MAX, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("4294967295".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(u16::MAX, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("65535".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(u8::MAX, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("255".as_bytes()));
+
+        // 255 precision + sign + leading 0 + decimal point
+        let mut logger = Logger::<258>::default();
+        logger.append_with_args(i64::MIN, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("-0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("9223372036854775808".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(i32::MIN, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("-0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("2147483648".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(i16::MIN, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("-0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("32768".as_bytes()));
+
+        logger.clear();
+
+        logger.append_with_args(i8::MIN, &[Argument::Precision(u8::MAX)]);
+        assert!(logger.starts_with("-0.00000000000000".as_bytes()));
+        assert!(logger.ends_with("128".as_bytes()));
     }
 
     #[test]
@@ -235,6 +333,46 @@ mod tests {
 
         logger.append_with_args("0123456789", &[Argument::TruncateStart(9)]);
         assert!(&*logger == "..@".as_bytes());
+
+        let mut logger = Logger::<1>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateStart(0)]);
+        assert!(&*logger == "".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args("test", &[Argument::TruncateStart(1)]);
+        assert!(&*logger == "@".as_bytes());
+
+        let mut logger = Logger::<2>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateStart(2)]);
+        assert!(&*logger == ".@".as_bytes());
+
+        let mut logger = Logger::<3>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateStart(3)]);
+        assert!(&*logger == "..@".as_bytes());
+
+        let mut logger = Logger::<1>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateEnd(0)]);
+        assert!(&*logger == "".as_bytes());
+
+        logger.clear();
+
+        logger.append_with_args("test", &[Argument::TruncateEnd(1)]);
+        assert!(&*logger == "@".as_bytes());
+
+        let mut logger = Logger::<2>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateEnd(2)]);
+        assert!(&*logger == ".@".as_bytes());
+
+        let mut logger = Logger::<3>::default();
+
+        logger.append_with_args("test", &[Argument::TruncateEnd(3)]);
+        assert!(&*logger == "..@".as_bytes());
     }
 
     #[test]

+ 146 - 85
sdk/log/crate/src/logger.rs

@@ -1,4 +1,6 @@
-use core::{mem::MaybeUninit, ops::Deref, slice::from_raw_parts};
+use core::{
+    cmp::min, mem::MaybeUninit, ops::Deref, ptr::copy_nonoverlapping, slice::from_raw_parts,
+};
 
 #[cfg(all(target_os = "solana", not(target_feature = "static-syscalls")))]
 mod syscalls {
@@ -8,6 +10,8 @@ mod syscalls {
 
         pub fn sol_memcpy_(dst: *mut u8, src: *const u8, n: u64);
 
+        pub fn sol_memset_(s: *mut u8, c: u8, n: u64);
+
         pub fn sol_remaining_compute_units() -> u64;
     }
 }
@@ -26,6 +30,12 @@ mod syscalls {
         syscall(dest, src, n)
     }
 
+    pub(crate) fn sol_memset_(s: *mut u8, c: u8, n: u64) {
+        let syscall: extern "C" fn(*mut u8, u8, u64) =
+            unsafe { core::mem::transmute(930151202u64) }; // murmur32 hash of "sol_memset_"
+        syscall(s, c, n)
+    }
+
     pub(crate) fn sol_remaining_compute_units() -> u64 {
         let syscall: extern "C" fn() -> u64 = unsafe { core::mem::transmute(3991886574u64) }; // murmur32 hash of "sol_remaining_compute_units"
         syscall()
@@ -248,7 +258,7 @@ macro_rules! impl_log_for_unsigned_integer {
                             value /= 10;
                             offset -= 1;
                             // SAFETY: the offset is always within the bounds of the array since
-                            // the `offset` is initialized with the maximum number of digits that
+                            // `offset` is initialized with the maximum number of digits that
                             // the type can have and decremented on each iteration; `remainder`
                             // is always less than 10.
                             unsafe {
@@ -267,99 +277,132 @@ macro_rules! impl_log_for_unsigned_integer {
                             0
                         };
 
-                        // Number of digits written.
-                        let mut written = MAX_DIGITS - offset;
-
-                        if precision > 0 {
-                            while precision >= written {
-                                written += 1;
-                                offset -= 1;
-                                // SAFETY: the offset is always within the bounds of the array since
-                                // the `offset` is initialized with the maximum number of digits that
-                                // the type can have and decremented on each iteration.
-                                unsafe {
-                                    digits.get_unchecked_mut(offset).write(b'0');
-                                }
-                            }
-                            // Space for the decimal point.
-                            written += 1;
-                        }
-
-                        // Size of the buffer.
+                        let written = MAX_DIGITS - offset;
                         let length = buffer.len();
-                        // Determines if the value was truncated or not by calculating the
-                        // number of digits that can be written.
-                        let (overflow, written, fraction) = if written <= length {
-                            (false, written, precision)
-                        } else {
-                            (true, length, precision.saturating_sub(written - length))
+
+                        // Space required with the specified precision. We might need
+                        // to add leading zeros and a decimal point, but this is only
+                        // if the precision is greater than zero.
+                        let required = match precision {
+                            0 => written,
+                            // decimal point
+                            _precision if precision < written => written + 1,
+                            // decimal point + one leading zero
+                            _ => precision + 2,
                         };
+                        // Determines whether the value will be truncated or not.
+                        let is_truncated = required > length;
+                        // Cap the number of digits to write to the buffer length.
+                        let digits_to_write = min(MAX_DIGITS - offset, length);
+
                         // SAFETY: the length of both `digits` and `buffer` arrays are guaranteed
-                        // to be within bounds and the `written` value is always less than their
-                        // maximum length.
+                        // to be within bounds and the `digits_to_write` value is capped to the
+                        // length of the `buffer`.
                         unsafe {
                             let source = digits.as_ptr().add(offset);
                             let ptr = buffer.as_mut_ptr();
 
-                            #[cfg(target_os = "solana")]
-                            {
-                                if precision == 0 {
-                                    syscalls::sol_memcpy_(
-                                        ptr as *mut _,
-                                        source as *const _,
-                                        written as u64,
-                                    );
-                                } else {
-                                    // Integer part of the number.
-                                    let integer_part = written - (fraction + 1);
-                                    syscalls::sol_memcpy_(
-                                        ptr as *mut _,
-                                        source as *const _,
-                                        integer_part as u64,
+                            // Copy the number to the buffer if no precision is specified.
+                            if precision == 0 {
+                                #[cfg(target_os = "solana")]
+                                syscalls::sol_memcpy_(
+                                    ptr as *mut _,
+                                    source as *const _,
+                                    digits_to_write as u64,
+                                );
+                                #[cfg(not(target_os = "solana"))]
+                                copy_nonoverlapping(source, ptr, digits_to_write);
+                            }
+                            // If padding is needed to satisfy the precision, add leading zeros
+                            // and a decimal point.
+                            else if precision >= digits_to_write {
+                                // Prefix.
+                                (ptr as *mut u8).write(b'0');
+
+                                if length > 2 {
+                                    (ptr.add(1) as *mut u8).write(b'.');
+                                    let padding = min(length - 2, precision - digits_to_write);
+
+                                    // Precision padding.
+                                    #[cfg(target_os = "solana")]
+                                    syscalls::sol_memset_(
+                                        ptr.add(2) as *mut _,
+                                        b'0',
+                                        padding as u64,
                                     );
-
-                                    // Decimal point.
-                                    (ptr.add(integer_part) as *mut u8).write(b'.');
+                                    #[cfg(not(target_os = "solana"))]
+                                    (ptr.add(2) as *mut u8).write_bytes(b'0', padding);
+
+                                    let current = 2 + padding;
+
+                                    // If there is still space, copy (part of) the number.
+                                    if current < length {
+                                        let remaining = min(digits_to_write, length - current);
+
+                                        // Number part.
+                                        #[cfg(target_os = "solana")]
+                                        syscalls::sol_memcpy_(
+                                            ptr.add(current) as *mut _,
+                                            source as *const _,
+                                            remaining as u64,
+                                        );
+                                        #[cfg(not(target_os = "solana"))]
+                                        copy_nonoverlapping(source, ptr.add(current), remaining);
+                                    }
+                                }
+                            }
+                            // No padding is needed, calculate the integer and fractional
+                            // parts and add a decimal point.
+                            else {
+                                let integer_part = digits_to_write - precision;
+
+                                // Integer part of the number.
+                                #[cfg(target_os = "solana")]
+                                syscalls::sol_memcpy_(
+                                    ptr as *mut _,
+                                    source as *const _,
+                                    integer_part as u64,
+                                );
+                                #[cfg(not(target_os = "solana"))]
+                                copy_nonoverlapping(source, ptr, integer_part);
+
+                                // Decimal point.
+                                (ptr.add(integer_part) as *mut u8).write(b'.');
+                                let current = integer_part + 1;
+
+                                // If there is still space, copy (part of) the remaining.
+                                if current < length {
+                                    let remaining = min(precision, length - current);
 
                                     // Fractional part of the number.
+                                    #[cfg(target_os = "solana")]
                                     syscalls::sol_memcpy_(
-                                        ptr.add(integer_part + 1) as *mut _,
+                                        ptr.add(current) as *mut _,
                                         source.add(integer_part) as *const _,
-                                        fraction as u64,
+                                        remaining as u64,
                                     );
-                                }
-                            }
-
-                            #[cfg(not(target_os = "solana"))]
-                            {
-                                if precision == 0 {
-                                    core::ptr::copy_nonoverlapping(source, ptr, written);
-                                } else {
-                                    // Integer part of the number.
-                                    let integer_part = written - (fraction + 1);
-                                    core::ptr::copy_nonoverlapping(source, ptr, integer_part);
-
-                                    // Decimal point.
-                                    (ptr.add(integer_part) as *mut u8).write(b'.');
-
-                                    // Fractional part of the number.
-                                    core::ptr::copy_nonoverlapping(
+                                    #[cfg(not(target_os = "solana"))]
+                                    copy_nonoverlapping(
                                         source.add(integer_part),
-                                        ptr.add(integer_part + 1),
-                                        fraction,
+                                        ptr.add(current),
+                                        remaining,
                                     );
                                 }
                             }
                         }
 
-                        // There might not have been space for all the value.
-                        if overflow {
-                            // SAFETY: the buffer is checked to be within `written` bounds.
+                        let written = min(required, length);
+
+                        // There might not have been space.
+                        if is_truncated {
+                            // SAFETY: `written` is capped to the length of the buffer and
+                            // the required length (`required` is always greater than zero);
+                            // `buffer` is guaranteed  to have a length of at least 1.
                             unsafe {
-                                let last = buffer.get_unchecked_mut(written - 1);
-                                last.write(TRUNCATED);
+                                buffer.get_unchecked_mut(written - 1).write(TRUNCATED);
                             }
                         }
+
                         written
                     }
                 }
@@ -399,6 +442,15 @@ macro_rules! impl_log_for_signed {
                         let mut prefix = 0;
 
                         if *self < 0 {
+                            if buffer.len() == 1 {
+                                // SAFETY: the buffer is checked to be non-empty.
+                                unsafe {
+                                    buffer.get_unchecked_mut(0).write(TRUNCATED);
+                                }
+                                // There is no space for the number, so just return.
+                                return 1;
+                            }
+
                             // SAFETY: the buffer is checked to be non-empty.
                             unsafe {
                                 buffer.get_unchecked_mut(0).write(b'-');
@@ -511,7 +563,7 @@ unsafe impl Log for &str {
             // No truncate arguments were provided, so the entire `str` is copied to the buffer
             // if it fits; otherwise indicates that the `str` was truncated.
             if truncate_end.is_none() {
-                let length = core::cmp::min(size, self.len());
+                let length = min(size, self.len());
                 (
                     buffer.as_mut_ptr(),
                     self.as_ptr(),
@@ -520,7 +572,7 @@ unsafe impl Log for &str {
                     length != self.len(),
                 )
             } else {
-                let max_length = core::cmp::min(size, buffer.len());
+                let max_length = min(size, buffer.len());
                 let ptr = buffer.as_mut_ptr();
 
                 // The buffer is large enough to hold the entire `str`, so no need to use the
@@ -546,7 +598,7 @@ unsafe impl Log for &str {
                             )
                         };
                         // Copy the truncated slice to the buffer.
-                        core::ptr::copy_nonoverlapping(
+                        copy_nonoverlapping(
                             TRUNCATED_SLICE.as_ptr(),
                             ptr.add(offset) as *mut _,
                             TRUNCATED_SLICE.len(),
@@ -562,17 +614,26 @@ unsafe impl Log for &str {
                 }
             };
 
-        // SAFETY: the `destination` is always within `length_to_write` bounds.
-        unsafe {
-            core::ptr::copy_nonoverlapping(source, destination as *mut _, length_to_write);
-        }
-
-        // There might not have been space for all the value.
-        if truncated {
+        if length_to_write > 0 {
             // SAFETY: the `destination` is always within `length_to_write` bounds.
             unsafe {
-                let last = buffer.get_unchecked_mut(length_to_write - 1);
-                last.write(TRUNCATED);
+                #[cfg(target_os = "solana")]
+                syscalls::sol_memcpy_(
+                    destination as *mut _,
+                    source as *const _,
+                    length_to_write as u64,
+                );
+                #[cfg(not(target_os = "solana"))]
+                copy_nonoverlapping(source, destination as *mut _, length_to_write);
+            }
+
+            // There might not have been space for all the value.
+            if truncated {
+                // SAFETY: the `destination` is always within `length_to_write` bounds.
+                unsafe {
+                    let last = buffer.get_unchecked_mut(length_to_write - 1);
+                    last.write(TRUNCATED);
+                }
             }
         }