Selaa lähdekoodia

lang: Make discriminator type unsized (#3098)

acheron 1 vuosi sitten
vanhempi
sitoutus
14cec14617

+ 1 - 0
CHANGELOG.md

@@ -37,6 +37,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - client: Add `tokio` support to `RequestBuilder` with `async` feature ([#3057](https://github.com/coral-xyz/anchor/pull/3057)).
 - lang: Remove `EventData` trait ([#3083](https://github.com/coral-xyz/anchor/pull/3083)).
 - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
+- lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
 
 ## [0.30.1] - 2024-06-20
 

+ 1 - 1
client/src/lib.rs

@@ -257,7 +257,7 @@ impl<C: Deref<Target = impl Signer> + Clone> Program<C> {
         filters: Vec<RpcFilterType>,
     ) -> Result<ProgramAccountsIterator<T>, ClientError> {
         let account_type_filter =
-            RpcFilterType::Memcmp(Memcmp::new_base58_encoded(0, &T::discriminator()));
+            RpcFilterType::Memcmp(Memcmp::new_base58_encoded(0, T::DISCRIMINATOR));
         let config = RpcProgramAccountsConfig {
             filters: Some([vec![account_type_filter], filters].concat()),
             account_config: RpcAccountInfoConfig {

+ 2 - 2
lang/attribute/account/src/lib.rs

@@ -169,7 +169,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    const DISCRIMINATOR: [u8; 8] = #discriminator;
+                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
                 }
 
                 // This trait is useful for clients deserializing accounts.
@@ -239,7 +239,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    const DISCRIMINATOR: [u8; 8] = #discriminator;
+                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
                 }
 
                 #owner_impl

+ 6 - 2
lang/attribute/event/src/lib.rs

@@ -40,7 +40,7 @@ pub fn event(
         }
 
         impl anchor_lang::Discriminator for #event_name {
-            const DISCRIMINATOR: [u8; 8] = #discriminator;
+            const DISCRIMINATOR: &'static [u8] = &#discriminator;
         }
     };
 
@@ -161,7 +161,11 @@ pub fn emit_cpi(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
 
             let disc = anchor_lang::event::EVENT_IX_TAG_LE;
             let inner_data = anchor_lang::Event::data(&#event_struct);
-            let ix_data: Vec<u8> = disc.into_iter().chain(inner_data.into_iter()).collect();
+            let ix_data: Vec<u8> = disc
+                .into_iter()
+                .map(|b| *b)
+                .chain(inner_data.into_iter())
+                .collect();
 
             let ix = anchor_lang::solana_program::instruction::Instruction::new_with_bytes(
                 crate::ID,

+ 1 - 1
lang/attribute/program/src/declare_program/mods/accounts.rs

@@ -96,7 +96,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
             #impls
 
             impl anchor_lang::Discriminator for #name {
-                const DISCRIMINATOR: [u8; 8] = #discriminator;
+                const DISCRIMINATOR: &'static [u8] = &#discriminator;
             }
 
             impl anchor_lang::Owner for #name {

+ 1 - 1
lang/attribute/program/src/declare_program/mods/events.rs

@@ -29,7 +29,7 @@ pub fn gen_events_mod(idl: &Idl) -> proc_macro2::TokenStream {
             }
 
             impl anchor_lang::Discriminator for #name {
-                const DISCRIMINATOR: [u8; 8] = #discriminator;
+                const DISCRIMINATOR: &'static [u8] = &#discriminator;
             }
         }
     });

+ 1 - 1
lang/attribute/program/src/declare_program/mods/internal.rs

@@ -50,7 +50,7 @@ fn gen_internal_args_mod(idl: &Idl) -> proc_macro2::TokenStream {
             let discriminator = gen_discriminator(&ix.discriminator);
             quote! {
                 impl anchor_lang::Discriminator for #ix_struct_name {
-                    const DISCRIMINATOR: [u8; 8] = #discriminator;
+                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
                 }
             }
         } else {

+ 16 - 13
lang/src/accounts/account_loader.rs

@@ -6,7 +6,6 @@ use crate::{
     Accounts, AccountsClose, AccountsExit, Key, Owner, Result, ToAccountInfo, ToAccountInfos,
     ToAccountMetas, ZeroCopy,
 };
-use arrayref::array_ref;
 use solana_program::account_info::AccountInfo;
 use solana_program::instruction::AccountMeta;
 use solana_program::pubkey::Pubkey;
@@ -123,13 +122,15 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
             return Err(Error::from(ErrorCode::AccountOwnedByWrongProgram)
                 .with_pubkeys((*acc_info.owner, T::owner())));
         }
-        let data: &[u8] = &acc_info.try_borrow_data()?;
-        if data.len() < T::discriminator().len() {
+
+        let data = &acc_info.try_borrow_data()?;
+        let disc = T::DISCRIMINATOR;
+        if data.len() < disc.len() {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
-        // Discriminator must match.
-        let disc_bytes = array_ref![data, 0, 8];
-        if disc_bytes != &T::discriminator() {
+
+        let given_disc = &data[..8];
+        if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
 
@@ -152,12 +153,13 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
     /// Returns a Ref to the account data structure for reading.
     pub fn load(&self) -> Result<Ref<T>> {
         let data = self.acc_info.try_borrow_data()?;
-        if data.len() < T::discriminator().len() {
+        let disc = T::DISCRIMINATOR;
+        if data.len() < disc.len() {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
 
-        let disc_bytes = array_ref![data, 0, 8];
-        if disc_bytes != &T::discriminator() {
+        let given_disc = &data[..8];
+        if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
 
@@ -175,12 +177,13 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
         }
 
         let data = self.acc_info.try_borrow_mut_data()?;
-        if data.len() < T::discriminator().len() {
+        let disc = T::DISCRIMINATOR;
+        if data.len() < disc.len() {
             return Err(ErrorCode::AccountDiscriminatorNotFound.into());
         }
 
-        let disc_bytes = array_ref![data, 0, 8];
-        if disc_bytes != &T::discriminator() {
+        let given_disc = &data[..8];
+        if given_disc != disc {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
 
@@ -241,7 +244,7 @@ impl<'info, T: ZeroCopy + Owner> AccountsExit<'info> for AccountLoader<'info, T>
             let mut data = self.acc_info.try_borrow_mut_data()?;
             let dst: &mut [u8] = &mut data;
             let mut writer = BpfWriter::new(dst);
-            writer.write_all(&T::discriminator()).unwrap();
+            writer.write_all(T::DISCRIMINATOR).unwrap();
         }
         Ok(())
     }

+ 1 - 1
lang/src/bpf_upgradeable_state.rs

@@ -79,6 +79,6 @@ mod idl_build {
 
     impl crate::IdlBuild for ProgramData {}
     impl crate::Discriminator for ProgramData {
-        const DISCRIMINATOR: [u8; 8] = [u8::MAX; 8];
+        const DISCRIMINATOR: &'static [u8] = &[];
     }
 }

+ 1 - 1
lang/src/event.rs

@@ -1,3 +1,3 @@
 // Sha256(anchor:event)[..8]
 pub const EVENT_IX_TAG: u64 = 0x1d9acb512ea545e4;
-pub const EVENT_IX_TAG_LE: [u8; 8] = EVENT_IX_TAG.to_le_bytes();
+pub const EVENT_IX_TAG_LE: &[u8] = EVENT_IX_TAG.to_le_bytes().as_slice();

+ 1 - 1
lang/src/idl.rs

@@ -25,7 +25,7 @@ use crate::prelude::*;
 //
 // Sha256(anchor:idl)[..8];
 pub const IDL_IX_TAG: u64 = 0x0a69e9a778bcf440;
-pub const IDL_IX_TAG_LE: [u8; 8] = IDL_IX_TAG.to_le_bytes();
+pub const IDL_IX_TAG_LE: &[u8] = IDL_IX_TAG.to_le_bytes().as_slice();
 
 // The Pubkey that is stored as the 'authority' on the IdlAccount when the authority
 // is "erased".

+ 4 - 4
lang/src/lib.rs

@@ -279,7 +279,7 @@ pub trait ZeroCopy: Discriminator + Copy + Clone + Zeroable + Pod {}
 pub trait InstructionData: Discriminator + AnchorSerialize {
     fn data(&self) -> Vec<u8> {
         let mut data = Vec::with_capacity(256);
-        data.extend_from_slice(&Self::discriminator());
+        data.extend_from_slice(Self::DISCRIMINATOR);
         self.serialize(&mut data).unwrap();
         data
     }
@@ -290,7 +290,7 @@ pub trait InstructionData: Discriminator + AnchorSerialize {
     /// necessary), and because the data field in `Instruction` expects a `Vec<u8>`.
     fn write_to(&self, mut data: &mut Vec<u8>) {
         data.clear();
-        data.extend_from_slice(&Self::DISCRIMINATOR);
+        data.extend_from_slice(Self::DISCRIMINATOR);
         self.serialize(&mut data).unwrap()
     }
 }
@@ -302,8 +302,8 @@ pub trait Event: AnchorSerialize + AnchorDeserialize + Discriminator {
 
 /// 8 byte unique identifier for a type.
 pub trait Discriminator {
-    const DISCRIMINATOR: [u8; 8];
-    fn discriminator() -> [u8; 8] {
+    const DISCRIMINATOR: &'static [u8];
+    fn discriminator() -> &'static [u8] {
         Self::DISCRIMINATOR
     }
 }

+ 1 - 2
lang/syn/src/codegen/program/dispatch.rs

@@ -11,7 +11,6 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
             let ix_method_name = &ix.raw_method.sig.ident;
             let ix_name_camel: proc_macro2::TokenStream = ix_method_name
                 .to_string()
-                .as_str()
                 .to_camel_case()
                 .parse()
                 .expect("Failed to parse ix method name in camel as `TokenStream`");
@@ -65,7 +64,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                 sighash
             };
 
-            match sighash {
+            match sighash.as_slice() {
                 #(#global_dispatch_arms)*
                 anchor_lang::idl::IDL_IX_TAG_LE => {
                     // If the method identifier is the IDL tag, then execute an IDL

+ 5 - 6
lang/syn/src/codegen/program/instruction.rs

@@ -22,14 +22,14 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                 })
                 .collect();
             let ix_data_trait = {
-                let sighash_arr = ix
+                let discriminator = ix
                     .interface_discriminator
-                    .unwrap_or(sighash(SIGHASH_GLOBAL_NAMESPACE, name));
-                let sighash_tts: proc_macro2::TokenStream =
-                    format!("{sighash_arr:?}").parse().unwrap();
+                    .unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
+                let discriminator: proc_macro2::TokenStream =
+                    format!("{discriminator:?}").parse().unwrap();
                 quote! {
                     impl anchor_lang::Discriminator for #ix_name_camel {
-                        const DISCRIMINATOR: [u8; 8] = #sighash_tts;
+                        const DISCRIMINATOR: &'static [u8] = &#discriminator;
                     }
                     impl anchor_lang::InstructionData for #ix_name_camel {}
                     impl anchor_lang::Owner for #ix_name_camel {
@@ -72,7 +72,6 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
         pub mod instruction {
             use super::*;
 
-
             #(#variants)*
         }
     }

+ 2 - 2
lang/tests/serialization.rs

@@ -9,7 +9,7 @@ fn test_instruction_data() {
         bar: String,
     }
     impl Discriminator for MyType {
-        const DISCRIMINATOR: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
+        const DISCRIMINATOR: &'static [u8] = &[1, 2, 3, 4, 5, 6, 7, 8];
     }
     impl InstructionData for MyType {}
 
@@ -25,7 +25,7 @@ fn test_instruction_data() {
     instance.write_to(&mut write);
 
     // Check that one is correct and that they are equal (implies other is correct)
-    let correct_disc = data[0..8] == MyType::DISCRIMINATOR;
+    let correct_disc = &data[0..8] == MyType::DISCRIMINATOR;
     let correct_data = MyType::deserialize(&mut &data[8..]).is_ok_and(|result| result == instance);
     let correct_serialization = correct_disc & correct_data;
     assert!(correct_serialization, "serialization was not correct");

+ 1 - 1
spl/src/governance.rs

@@ -60,7 +60,7 @@ macro_rules! vote_weight_record {
 
         #[cfg(feature = "idl-build")]
         impl anchor_lang::Discriminator for VoterWeightRecord {
-            const DISCRIMINATOR: [u8; 8] = [0; 8];
+            const DISCRIMINATOR: &[u8] = &[];
         }
     };
 }

+ 1 - 1
spl/src/idl_build.rs

@@ -10,7 +10,7 @@ macro_rules! impl_idl_build {
         //
         // TODO: Find a better way to handle discriminators of wrapped external accounts.
         impl anchor_lang::Discriminator for $ty {
-            const DISCRIMINATOR: [u8; 8] = [0; 8];
+            const DISCRIMINATOR: &[u8] = &[];
         }
     };
 }