|
@@ -1,6 +1,11 @@
|
|
|
-use core::{mem::MaybeUninit, slice::from_raw_parts, str::from_utf8_unchecked};
|
|
|
+use core::{
|
|
|
+ cmp::max,
|
|
|
+ mem::MaybeUninit,
|
|
|
+ slice::{from_raw_parts, from_raw_parts_mut},
|
|
|
+ str::from_utf8_unchecked,
|
|
|
+};
|
|
|
use pinocchio::{
|
|
|
- account_info::AccountInfo, program_error::ProgramError, pubkey::Pubkey, syscalls::sol_memcpy_,
|
|
|
+ account_info::AccountInfo, memory::sol_memcpy, program_error::ProgramError, pubkey::Pubkey,
|
|
|
ProgramResult,
|
|
|
};
|
|
|
use token_interface::{
|
|
@@ -134,46 +139,53 @@ fn try_ui_amount_into_amount(ui_amount: &str, decimals: u8) -> Result<u64, Progr
|
|
|
let decimals = decimals as usize;
|
|
|
let mut parts = ui_amount.split('.');
|
|
|
|
|
|
- // splitting a string, even an empty one, will always yield an iterator of at
|
|
|
- // least length == 1
|
|
|
+ // Splitting a string, even an empty one, will always yield an iterator of at
|
|
|
+ // least length == 1.
|
|
|
let amount_str = parts.next().unwrap();
|
|
|
- let mut length = amount_str.len();
|
|
|
-
|
|
|
- let mut digits = [UNINIT_BYTE; MAX_DIGITS_U64];
|
|
|
- let mut ptr = digits.as_mut_ptr();
|
|
|
-
|
|
|
- unsafe {
|
|
|
- sol_memcpy_(
|
|
|
- ptr as *mut _,
|
|
|
- amount_str.as_ptr() as *const _,
|
|
|
- length as u64,
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
let after_decimal = parts.next().unwrap_or("");
|
|
|
+ // Clean up trailing zeros.
|
|
|
let after_decimal = after_decimal.trim_end_matches('0');
|
|
|
|
|
|
+ // Validates the input.
|
|
|
+
|
|
|
+ let mut length = amount_str.len();
|
|
|
+ let expected_after_decimal_length = max(after_decimal.len(), decimals);
|
|
|
+
|
|
|
if (amount_str.is_empty() && after_decimal.is_empty())
|
|
|
|| parts.next().is_some()
|
|
|
|| after_decimal.len() > decimals
|
|
|
+ || (length + expected_after_decimal_length) > MAX_DIGITS_U64
|
|
|
{
|
|
|
return Err(ProgramError::InvalidArgument);
|
|
|
}
|
|
|
|
|
|
+ let mut digits = [UNINIT_BYTE; MAX_DIGITS_U64];
|
|
|
+ // SAFETY: `digits` is an array of `MaybeUninit<u8>`, which has the same
|
|
|
+ // memory layout as `u8`.
|
|
|
+ let slice: &mut [u8] =
|
|
|
+ unsafe { from_raw_parts_mut(digits.as_mut_ptr() as *mut _, MAX_DIGITS_U64) };
|
|
|
+
|
|
|
+ // SAFETY: the total length of `amount_str` and `after_decimal` is less than
|
|
|
+ // `MAX_DIGITS_U64`.
|
|
|
unsafe {
|
|
|
- sol_memcpy_(
|
|
|
- ptr.add(length) as *mut _,
|
|
|
- after_decimal.as_ptr() as *const _,
|
|
|
- after_decimal.len() as u64,
|
|
|
- );
|
|
|
+ sol_memcpy(slice, amount_str.as_bytes(), length);
|
|
|
|
|
|
- length += after_decimal.len();
|
|
|
- ptr = ptr.add(length);
|
|
|
+ sol_memcpy(
|
|
|
+ &mut slice[length..],
|
|
|
+ after_decimal.as_bytes(),
|
|
|
+ after_decimal.len(),
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
+ length += after_decimal.len();
|
|
|
let remaining = decimals.saturating_sub(after_decimal.len());
|
|
|
|
|
|
+ // SAFETY: `digits` is an array of `MaybeUninit<u8>`, which has the same memory
|
|
|
+ // layout as `u8`.
|
|
|
+ let ptr = unsafe { digits.as_mut_ptr().add(length) };
|
|
|
+
|
|
|
for offset in 0..remaining {
|
|
|
+ // SAFETY: `ptr` is within the bounds of `digits`.
|
|
|
unsafe {
|
|
|
(ptr.add(offset) as *mut u8).write(b'0');
|
|
|
}
|
|
@@ -181,6 +193,7 @@ fn try_ui_amount_into_amount(ui_amount: &str, decimals: u8) -> Result<u64, Progr
|
|
|
|
|
|
length += remaining;
|
|
|
|
|
|
+ // SAFETY: `digits` only contains valid UTF-8 bytes.
|
|
|
unsafe {
|
|
|
from_utf8_unchecked(from_raw_parts(digits.as_ptr() as _, length))
|
|
|
.parse::<u64>()
|