Browse Source

lang: Get discriminator length dynamically (#3101)

acheron 1 year ago
parent
commit
ba33d5e974

+ 1 - 0
CHANGELOG.md

@@ -18,6 +18,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - lang: Export `Discriminator` trait from `prelude` ([#3075](https://github.com/coral-xyz/anchor/pull/3075)).
 - lang: Export `Discriminator` trait from `prelude` ([#3075](https://github.com/coral-xyz/anchor/pull/3075)).
 - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
 - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
 - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
 - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
+- lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)).
 
 
 ### Fixes
 ### Fixes
 
 

+ 4 - 4
lang/attribute/account/src/lib.rs

@@ -180,7 +180,7 @@ pub fn account(
                         if buf.len() < #discriminator.len() {
                         if buf.len() < #discriminator.len() {
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
                         }
-                        let given_disc = &buf[..8];
+                        let given_disc = &buf[..#discriminator.len()];
                         if &#discriminator != given_disc {
                         if &#discriminator != given_disc {
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                         }
                         }
@@ -188,7 +188,7 @@ pub fn account(
                     }
                     }
 
 
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        let data: &[u8] = &buf[8..];
+                        let data: &[u8] = &buf[#discriminator.len()..];
                         // Re-interpret raw bytes into the POD data structure.
                         // Re-interpret raw bytes into the POD data structure.
                         let account = anchor_lang::__private::bytemuck::from_bytes(data);
                         let account = anchor_lang::__private::bytemuck::from_bytes(data);
                         // Copy out the bytes into a new, owned data structure.
                         // Copy out the bytes into a new, owned data structure.
@@ -223,7 +223,7 @@ pub fn account(
                         if buf.len() < #discriminator.len() {
                         if buf.len() < #discriminator.len() {
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
                         }
-                        let given_disc = &buf[..8];
+                        let given_disc = &buf[..#discriminator.len()];
                         if &#discriminator != given_disc {
                         if &#discriminator != given_disc {
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                         }
                         }
@@ -231,7 +231,7 @@ pub fn account(
                     }
                     }
 
 
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        let mut data: &[u8] = &buf[8..];
+                        let mut data: &[u8] = &buf[#discriminator.len()..];
                         AnchorDeserialize::deserialize(&mut data)
                         AnchorDeserialize::deserialize(&mut data)
                             .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
                             .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
                     }
                     }

+ 14 - 10
lang/src/accounts/account_loader.rs

@@ -129,7 +129,7 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
         }
 
 
-        let given_disc = &data[..8];
+        let given_disc = &data[..disc.len()];
         if given_disc != disc {
         if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
         }
@@ -158,13 +158,13 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
         }
 
 
-        let given_disc = &data[..8];
+        let given_disc = &data[..disc.len()];
         if given_disc != disc {
         if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
         }
 
 
         Ok(Ref::map(data, |data| {
         Ok(Ref::map(data, |data| {
-            bytemuck::from_bytes(&data[8..mem::size_of::<T>() + 8])
+            bytemuck::from_bytes(&data[disc.len()..mem::size_of::<T>() + disc.len()])
         }))
         }))
     }
     }
 
 
@@ -182,13 +182,15 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
         }
 
 
-        let given_disc = &data[..8];
+        let given_disc = &data[..disc.len()];
         if given_disc != disc {
         if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
         }
 
 
         Ok(RefMut::map(data, |data| {
         Ok(RefMut::map(data, |data| {
-            bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::<T>() + 8])
+            bytemuck::from_bytes_mut(
+                &mut data.deref_mut()[disc.len()..mem::size_of::<T>() + disc.len()],
+            )
         }))
         }))
     }
     }
 
 
@@ -204,15 +206,17 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
         let data = self.acc_info.try_borrow_mut_data()?;
         let data = self.acc_info.try_borrow_mut_data()?;
 
 
         // The discriminator should be zero, since we're initializing.
         // The discriminator should be zero, since we're initializing.
-        let mut disc_bytes = [0u8; 8];
-        disc_bytes.copy_from_slice(&data[..8]);
-        let discriminator = u64::from_le_bytes(disc_bytes);
-        if discriminator != 0 {
+        let disc = T::DISCRIMINATOR;
+        let given_disc = &data[..disc.len()];
+        let has_disc = given_disc.iter().any(|b| *b != 0);
+        if has_disc {
             return Err(ErrorCode::AccountDiscriminatorAlreadySet.into());
             return Err(ErrorCode::AccountDiscriminatorAlreadySet.into());
         }
         }
 
 
         Ok(RefMut::map(data, |data| {
         Ok(RefMut::map(data, |data| {
-            bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::<T>() + 8])
+            bytemuck::from_bytes_mut(
+                &mut data.deref_mut()[disc.len()..mem::size_of::<T>() + disc.len()],
+            )
         }))
         }))
     }
     }
 }
 }

+ 4 - 1
lang/syn/src/codegen/program/idl.rs

@@ -147,7 +147,10 @@ pub fn idl_accounts_and_functions() -> proc_macro2::TokenStream {
             let owner = accounts.program.key;
             let owner = accounts.program.key;
             let to = Pubkey::create_with_seed(&base, seed, owner).unwrap();
             let to = Pubkey::create_with_seed(&base, seed, owner).unwrap();
             // Space: account discriminator || authority pubkey || vec len || vec data
             // Space: account discriminator || authority pubkey || vec len || vec data
-            let space = std::cmp::min(8 + 32 + 4 + data_len as usize, 10_000);
+            let space = std::cmp::min(
+                IdlAccount::DISCRIMINATOR.len() + 32 + 4 + data_len as usize,
+                10_000
+            );
             let rent = Rent::get()?;
             let rent = Rent::get()?;
             let lamports = rent.minimum_balance(space);
             let lamports = rent.minimum_balance(space);
             let seeds = &[&[nonce][..]];
             let seeds = &[&[nonce][..]];