Browse Source

Fix amount conversion

febo 10 months ago
parent
commit
df8b6aa38a

+ 11 - 3
p-token/src/processor/amount_to_ui_amount.rs

@@ -7,6 +7,12 @@ use token_interface::state::mint::Mint;
 
 use super::{check_account_owner, MAX_DIGITS_U64};
 
+/// Maximum length of the UI amount string.
+///
+/// The length includes the maximum number of digits in a `u64`` (20)
+/// and the maximum number of punctuation characters (2).
+const MAX_UI_AMOUNT_LENGTH: usize = MAX_DIGITS_U64 + 2;
+
 #[inline(always)]
 pub fn process_amount_to_ui_amount(
     accounts: &[AccountInfo],
@@ -20,12 +26,14 @@ pub fn process_amount_to_ui_amount(
 
     let mint_info = accounts.first().ok_or(ProgramError::NotEnoughAccountKeys)?;
     check_account_owner(mint_info)?;
-
+    // SAFETY: there is a single borrow to the `Mint` account.
     let mint = unsafe { Mint::from_bytes(mint_info.borrow_data_unchecked()) };
 
-    let mut logger = Logger::<MAX_DIGITS_U64>::default();
+    let mut logger = Logger::<MAX_UI_AMOUNT_LENGTH>::default();
     logger.append_with_args(amount, &[Argument::Precision(mint.decimals)]);
-
+    // "Extract" the formatted string from the logger.
+    //
+    // SAFETY: the logger is guaranteed to be a valid UTF-8 string.
     let mut s = unsafe { from_utf8_unchecked(&logger) };
 
     if mint.decimals > 0 {

+ 37 - 24
p-token/src/processor/mod.rs

@@ -1,6 +1,11 @@
-use core::{mem::MaybeUninit, slice::from_raw_parts, str::from_utf8_unchecked};
+use core::{
+    cmp::max,
+    mem::MaybeUninit,
+    slice::{from_raw_parts, from_raw_parts_mut},
+    str::from_utf8_unchecked,
+};
 use pinocchio::{
-    account_info::AccountInfo, program_error::ProgramError, pubkey::Pubkey, syscalls::sol_memcpy_,
+    account_info::AccountInfo, memory::sol_memcpy, program_error::ProgramError, pubkey::Pubkey,
     ProgramResult,
 };
 use token_interface::{
@@ -134,46 +139,53 @@ fn try_ui_amount_into_amount(ui_amount: &str, decimals: u8) -> Result<u64, Progr
     let decimals = decimals as usize;
     let mut parts = ui_amount.split('.');
 
-    // splitting a string, even an empty one, will always yield an iterator of at
-    // least length == 1
+    // Splitting a string, even an empty one, will always yield an iterator of at
+    // least length == 1.
     let amount_str = parts.next().unwrap();
-    let mut length = amount_str.len();
-
-    let mut digits = [UNINIT_BYTE; MAX_DIGITS_U64];
-    let mut ptr = digits.as_mut_ptr();
-
-    unsafe {
-        sol_memcpy_(
-            ptr as *mut _,
-            amount_str.as_ptr() as *const _,
-            length as u64,
-        );
-    }
-
     let after_decimal = parts.next().unwrap_or("");
+    // Clean up trailing zeros.
     let after_decimal = after_decimal.trim_end_matches('0');
 
+    // Validates the input.
+
+    let mut length = amount_str.len();
+    let expected_after_decimal_length = max(after_decimal.len(), decimals);
+
     if (amount_str.is_empty() && after_decimal.is_empty())
         || parts.next().is_some()
         || after_decimal.len() > decimals
+        || (length + expected_after_decimal_length) > MAX_DIGITS_U64
     {
         return Err(ProgramError::InvalidArgument);
     }
 
+    let mut digits = [UNINIT_BYTE; MAX_DIGITS_U64];
+    // SAFETY: `digits` is an array of `MaybeUninit<u8>`, which has the same
+    // memory layout as `u8`.
+    let slice: &mut [u8] =
+        unsafe { from_raw_parts_mut(digits.as_mut_ptr() as *mut _, MAX_DIGITS_U64) };
+
+    // SAFETY: the total length of `amount_str` and `after_decimal` is less than
+    // `MAX_DIGITS_U64`.
     unsafe {
-        sol_memcpy_(
-            ptr.add(length) as *mut _,
-            after_decimal.as_ptr() as *const _,
-            after_decimal.len() as u64,
-        );
+        sol_memcpy(slice, amount_str.as_bytes(), length);
 
-        length += after_decimal.len();
-        ptr = ptr.add(length);
+        sol_memcpy(
+            &mut slice[length..],
+            after_decimal.as_bytes(),
+            after_decimal.len(),
+        );
     }
 
+    length += after_decimal.len();
     let remaining = decimals.saturating_sub(after_decimal.len());
 
+    // SAFETY: `digits` is an array of `MaybeUninit<u8>`, which has the same memory
+    // layout as `u8`.
+    let ptr = unsafe { digits.as_mut_ptr().add(length) };
+
     for offset in 0..remaining {
+        // SAFETY: `ptr` is within the bounds of `digits`.
         unsafe {
             (ptr.add(offset) as *mut u8).write(b'0');
         }
@@ -181,6 +193,7 @@ fn try_ui_amount_into_amount(ui_amount: &str, decimals: u8) -> Result<u64, Progr
 
     length += remaining;
 
+    // SAFETY: `digits` only contains valid UTF-8 bytes.
     unsafe {
         from_utf8_unchecked(from_raw_parts(digits.as_ptr() as _, length))
             .parse::<u64>()

+ 1 - 1
p-token/src/processor/ui_amount_to_amount.rs

@@ -16,7 +16,7 @@ pub fn process_ui_amount_to_amount(
 
     let mint_info = accounts.first().ok_or(ProgramError::NotEnoughAccountKeys)?;
     check_account_owner(mint_info)?;
-
+    // SAFETY: there is a single borrow to the `Mint` account.
     let mint = unsafe { Mint::from_bytes(mint_info.borrow_data_unchecked()) };
 
     let amount = try_ui_amount_into_amount(ui_amount, mint.decimals)?;