Browse Source

lang: Refactor discriminator generation (#3182)

acheron 1 year ago
parent
commit
546945b69c

+ 5 - 11
lang/attribute/account/src/lib.rs

@@ -1,6 +1,6 @@
 extern crate proc_macro;
 extern crate proc_macro;
 
 
-use anchor_syn::Overrides;
+use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
 use quote::{quote, ToTokens};
 use quote::{quote, ToTokens};
 use syn::{
 use syn::{
     parenthesized,
     parenthesized,
@@ -105,19 +105,13 @@ pub fn account(
         .and_then(|ov| ov.discriminator)
         .and_then(|ov| ov.discriminator)
         .unwrap_or_else(|| {
         .unwrap_or_else(|| {
             // Namespace the discriminator to prevent collisions.
             // Namespace the discriminator to prevent collisions.
-            let discriminator_preimage = if namespace.is_empty() {
-                format!("account:{account_name}")
+            let namespace = if namespace.is_empty() {
+                "account"
             } else {
             } else {
-                format!("{namespace}:{account_name}")
+                &namespace
             };
             };
 
 
-            let mut discriminator = [0u8; 8];
-            discriminator.copy_from_slice(
-                &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
-            );
-            let discriminator: proc_macro2::TokenStream =
-                format!("{discriminator:?}").parse().unwrap();
-            quote! { &#discriminator }
+            gen_discriminator(namespace, account_name)
         });
         });
     let disc = if account_strct.generics.lt_token.is_some() {
     let disc = if account_strct.generics.lt_token.is_some() {
         quote! { #account_name::#type_gen::DISCRIMINATOR }
         quote! { #account_name::#type_gen::DISCRIMINATOR }

+ 4 - 8
lang/attribute/event/src/lib.rs

@@ -2,7 +2,7 @@ extern crate proc_macro;
 
 
 #[cfg(feature = "event-cpi")]
 #[cfg(feature = "event-cpi")]
 use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
 use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
-use anchor_syn::Overrides;
+use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
 use quote::quote;
 use quote::quote;
 use syn::parse_macro_input;
 use syn::parse_macro_input;
 
 
@@ -37,13 +37,9 @@ pub fn event(
     let event_strct = parse_macro_input!(input as syn::ItemStruct);
     let event_strct = parse_macro_input!(input as syn::ItemStruct);
     let event_name = &event_strct.ident;
     let event_name = &event_strct.ident;
 
 
-    let discriminator = args.discriminator.unwrap_or_else(|| {
-        let discriminator_preimage = format!("event:{event_name}").into_bytes();
-        let discriminator = anchor_syn::hash::hash(&discriminator_preimage);
-        let discriminator: proc_macro2::TokenStream =
-            format!("{:?}", &discriminator.0[..8]).parse().unwrap();
-        quote! { &#discriminator }
-    });
+    let discriminator = args
+        .discriminator
+        .unwrap_or_else(|| gen_discriminator("event", event_name));
 
 
     let ret = quote! {
     let ret = quote! {
         #[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)]
         #[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)]

+ 5 - 0
lang/syn/src/codegen/program/common.rs

@@ -18,6 +18,11 @@ pub fn sighash(namespace: &str, name: &str) -> [u8; 8] {
     sighash
     sighash
 }
 }
 
 
+pub fn gen_discriminator(namespace: &str, name: impl ToString) -> proc_macro2::TokenStream {
+    let discriminator = sighash(namespace, name.to_string().as_str());
+    format!("&{:?}", discriminator).parse().unwrap()
+}
+
 pub fn generate_ix_variant(name: String, args: &[IxArg]) -> proc_macro2::TokenStream {
 pub fn generate_ix_variant(name: String, args: &[IxArg]) -> proc_macro2::TokenStream {
     let ix_arg_names: Vec<&syn::Ident> = args.iter().map(|arg| &arg.name).collect();
     let ix_arg_names: Vec<&syn::Ident> = args.iter().map(|arg| &arg.name).collect();
     let ix_name_camel: proc_macro2::TokenStream = {
     let ix_name_camel: proc_macro2::TokenStream = {

+ 7 - 7
lang/syn/src/codegen/program/cpi.rs

@@ -1,4 +1,6 @@
-use crate::codegen::program::common::{generate_ix_variant, sighash, SIGHASH_GLOBAL_NAMESPACE};
+use crate::codegen::program::common::{
+    gen_discriminator, generate_ix_variant, SIGHASH_GLOBAL_NAMESPACE,
+};
 use crate::Program;
 use crate::Program;
 use heck::SnakeCase;
 use heck::SnakeCase;
 use quote::{quote, ToTokens};
 use quote::{quote, ToTokens};
@@ -11,13 +13,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
         .map(|ix| {
         .map(|ix| {
             let accounts_ident: proc_macro2::TokenStream = format!("crate::cpi::accounts::{}", &ix.anchor_ident.to_string()).parse().unwrap();
             let accounts_ident: proc_macro2::TokenStream = format!("crate::cpi::accounts::{}", &ix.anchor_ident.to_string()).parse().unwrap();
             let cpi_method = {
             let cpi_method = {
-                let ix_variant = generate_ix_variant(ix.raw_method.sig.ident.to_string(), &ix.args);
+                let name = &ix.raw_method.sig.ident;
+                let ix_variant = generate_ix_variant(name.to_string(), &ix.args);
                 let method_name = &ix.ident;
                 let method_name = &ix.ident;
                 let args: Vec<&syn::PatType> = ix.args.iter().map(|arg| &arg.raw_arg).collect();
                 let args: Vec<&syn::PatType> = ix.args.iter().map(|arg| &arg.raw_arg).collect();
-                let name = &ix.raw_method.sig.ident.to_string();
-                let sighash_arr = sighash(SIGHASH_GLOBAL_NAMESPACE, name);
-                let sighash_tts: proc_macro2::TokenStream =
-                    format!("{sighash_arr:?}").parse().unwrap();
+                let discriminator = gen_discriminator(SIGHASH_GLOBAL_NAMESPACE, name);
                 let ret_type = &ix.returns.ty.to_token_stream();
                 let ret_type = &ix.returns.ty.to_token_stream();
                 let (method_ret, maybe_return) = match ret_type.to_string().as_str() {
                 let (method_ret, maybe_return) = match ret_type.to_string().as_str() {
                     "()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
                     "()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
@@ -35,7 +35,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                         let ix = {
                         let ix = {
                             let ix = instruction::#ix_variant;
                             let ix = instruction::#ix_variant;
                             let mut data = Vec::with_capacity(256);
                             let mut data = Vec::with_capacity(256);
-                            data.extend_from_slice(&#sighash_tts);
+                            data.extend_from_slice(#discriminator);
                             AnchorSerialize::serialize(&ix, &mut data)
                             AnchorSerialize::serialize(&ix, &mut data)
                                 .map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotSerialize)?;
                                 .map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotSerialize)?;
                             let accounts = ctx.to_account_metas(None);
                             let accounts = ctx.to_account_metas(None);

+ 5 - 9
lang/syn/src/codegen/program/instruction.rs

@@ -26,15 +26,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                     Some(overrides) if overrides.discriminator.is_some() => {
                     Some(overrides) if overrides.discriminator.is_some() => {
                         overrides.discriminator.as_ref().unwrap().to_owned()
                         overrides.discriminator.as_ref().unwrap().to_owned()
                     }
                     }
-                    _ => {
-                        // TODO: Remove `interface_discriminator`
-                        let discriminator = ix
-                            .interface_discriminator
-                            .unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
-                        let discriminator: proc_macro2::TokenStream =
-                            format!("{discriminator:?}").parse().unwrap();
-                        quote! { &#discriminator }
-                    }
+                    // TODO: Remove `interface_discriminator`
+                    _ => match &ix.interface_discriminator {
+                        Some(disc) => format!("&{disc:?}").parse().unwrap(),
+                        _ => gen_discriminator(SIGHASH_GLOBAL_NAMESPACE, name),
+                    },
                 };
                 };
 
 
                 quote! {
                 quote! {