Browse Source

lang: Add non-8-byte discriminator support in `declare_program!` (#3103)

acheron 1 year ago
parent
commit
e5bed20736

+ 1 - 0
CHANGELOG.md

@@ -19,6 +19,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
 - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
 - lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)).
+- lang: Add non-8-byte discriminator support in `declare_program!` ([#3103](https://github.com/coral-xyz/anchor/pull/3103)).
 
 ### Fixes
 

+ 3 - 3
lang/attribute/program/src/declare_program/mods/accounts.rs

@@ -21,7 +21,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                         return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                     }
 
-                    let given_disc = &buf[..8];
+                    let given_disc = &buf[..#discriminator.len()];
                     if &#discriminator != given_disc {
                         return Err(
                             anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
@@ -51,7 +51,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                         #try_deserialize
 
                         fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                            let mut data: &[u8] = &buf[8..];
+                            let mut data: &[u8] = &buf[#discriminator.len()..];
                             AnchorDeserialize::deserialize(&mut data)
                                 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
                         }
@@ -75,7 +75,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                             #try_deserialize
 
                             fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                                let data: &[u8] = &buf[8..];
+                                let data: &[u8] = &buf[#discriminator.len()..];
                                 let account = anchor_lang::__private::bytemuck::from_bytes(data);
                                 Ok(*account)
                             }

+ 4 - 8
lang/attribute/program/src/declare_program/mods/internal.rs

@@ -46,15 +46,11 @@ fn gen_internal_args_mod(idl: &Idl) -> proc_macro2::TokenStream {
             }
         };
 
-        let impl_discriminator = if ix.discriminator.len() == 8 {
-            let discriminator = gen_discriminator(&ix.discriminator);
-            quote! {
-                impl anchor_lang::Discriminator for #ix_struct_name {
-                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
-                }
+        let discriminator = gen_discriminator(&ix.discriminator);
+        let impl_discriminator = quote! {
+            impl anchor_lang::Discriminator for #ix_struct_name {
+                const DISCRIMINATOR: &'static [u8] = &#discriminator;
             }
-        } else {
-            quote! {}
         };
 
         let impl_ix_data = quote! {

+ 24 - 32
lang/attribute/program/src/declare_program/mods/utils.rs

@@ -24,15 +24,17 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
         .iter()
         .map(|acc| format_ident!("{}", acc.name))
         .map(|name| quote! { #name(#name) });
-    let match_arms = idl.accounts.iter().map(|acc| {
-        let disc = gen_discriminator(&acc.discriminator);
+    let if_statements = idl.accounts.iter().map(|acc| {
         let name = format_ident!("{}", acc.name);
-        let account = quote! {
-            #name::try_from_slice(&value[8..])
-                .map(Self::#name)
-                .map_err(Into::into)
-        };
-        quote! { #disc => #account }
+        let disc = gen_discriminator(&acc.discriminator);
+        let disc_len = acc.discriminator.len();
+        quote! {
+            if value.starts_with(&#disc) {
+                return #name::try_from_slice(&value[#disc_len..])
+                    .map(Self::#name)
+                    .map_err(Into::into)
+            }
+        }
     });
 
     quote! {
@@ -57,14 +59,8 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
             type Error = anchor_lang::error::Error;
 
             fn try_from(value: &[u8]) -> Result<Self> {
-                if value.len() < 8 {
-                    return Err(ProgramError::InvalidArgument.into());
-                }
-
-                match &value[..8] {
-                    #(#match_arms,)*
-                    _ => Err(ProgramError::InvalidArgument.into()),
-                }
+                #(#if_statements)*
+                Err(ProgramError::InvalidArgument.into())
             }
         }
     }
@@ -76,15 +72,17 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
         .iter()
         .map(|ev| format_ident!("{}", ev.name))
         .map(|name| quote! { #name(#name) });
-    let match_arms = idl.events.iter().map(|ev| {
-        let disc = gen_discriminator(&ev.discriminator);
+    let if_statements = idl.events.iter().map(|ev| {
         let name = format_ident!("{}", ev.name);
-        let event = quote! {
-            #name::try_from_slice(&value[8..])
-                .map(Self::#name)
-                .map_err(Into::into)
-        };
-        quote! { #disc => #event }
+        let disc = gen_discriminator(&ev.discriminator);
+        let disc_len = ev.discriminator.len();
+        quote! {
+            if value.starts_with(&#disc) {
+                return #name::try_from_slice(&value[#disc_len..])
+                    .map(Self::#name)
+                    .map_err(Into::into)
+            }
+        }
     });
 
     quote! {
@@ -109,14 +107,8 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
             type Error = anchor_lang::error::Error;
 
             fn try_from(value: &[u8]) -> Result<Self> {
-                if value.len() < 8 {
-                    return Err(ProgramError::InvalidArgument.into());
-                }
-
-                match &value[..8] {
-                    #(#match_arms,)*
-                    _ => Err(ProgramError::InvalidArgument.into()),
-                }
+                #(#if_statements)*
+                Err(ProgramError::InvalidArgument.into())
             }
         }
     }