Browse Source

lang: Use associated discriminator constants instead of hardcoding in `#[account]` (#3144)

acheron 1 year ago
parent
commit
dc6ac2d631

+ 1 - 0
CHANGELOG.md

@@ -30,6 +30,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - cli: Warn if `anchor-spl/idl-build` is missing ([#3133](https://github.com/coral-xyz/anchor/pull/3133)).
 - client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
 - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
+- lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).
 
 ### Fixes
 

+ 10 - 9
lang/attribute/account/src/lib.rs

@@ -100,6 +100,7 @@ pub fn account(
         );
         format!("{discriminator:?}").parse().unwrap()
     };
+    let disc = quote! { #account_name::DISCRIMINATOR };
 
     let owner_impl = {
         if namespace.is_empty() {
@@ -162,18 +163,18 @@ pub fn account(
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
                     fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        if buf.len() < #discriminator.len() {
+                        if buf.len() < #disc.len() {
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
-                        let given_disc = &buf[..#discriminator.len()];
-                        if &#discriminator != given_disc {
+                        let given_disc = &buf[..#disc.len()];
+                        if #disc != given_disc {
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                         }
                         Self::try_deserialize_unchecked(buf)
                     }
 
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        let data: &[u8] = &buf[#discriminator.len()..];
+                        let data: &[u8] = &buf[#disc.len()..];
                         // Re-interpret raw bytes into the POD data structure.
                         let account = anchor_lang::__private::bytemuck::from_bytes(data);
                         // Copy out the bytes into a new, owned data structure.
@@ -191,7 +192,7 @@ pub fn account(
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
                     fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
-                        if writer.write_all(&#discriminator).is_err() {
+                        if writer.write_all(#disc).is_err() {
                             return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
                         }
 
@@ -205,18 +206,18 @@ pub fn account(
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
                     fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        if buf.len() < #discriminator.len() {
+                        if buf.len() < #disc.len() {
                             return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
-                        let given_disc = &buf[..#discriminator.len()];
-                        if &#discriminator != given_disc {
+                        let given_disc = &buf[..#disc.len()];
+                        if #disc != given_disc {
                             return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
                         }
                         Self::try_deserialize_unchecked(buf)
                     }
 
                     fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                        let mut data: &[u8] = &buf[#discriminator.len()..];
+                        let mut data: &[u8] = &buf[#disc.len()..];
                         AnchorDeserialize::deserialize(&mut data)
                             .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
                     }

+ 7 - 6
lang/attribute/program/src/declare_program/mods/accounts.rs

@@ -7,6 +7,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
     let accounts = idl.accounts.iter().map(|acc| {
         let name = format_ident!("{}", acc.name);
         let discriminator = gen_discriminator(&acc.discriminator);
+        let disc = quote! { #name::DISCRIMINATOR };
 
         let ty_def = idl
             .types
@@ -17,12 +18,12 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
         let impls = {
             let try_deserialize = quote! {
                 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                    if buf.len() < #discriminator.len() {
+                    if buf.len() < #disc.len() {
                         return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
                     }
 
-                    let given_disc = &buf[..#discriminator.len()];
-                    if &#discriminator != given_disc {
+                    let given_disc = &buf[..#disc.len()];
+                    if #disc != given_disc {
                         return Err(
                             anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
                             .with_account_name(stringify!(#name))
@@ -36,7 +37,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                 IdlSerialization::Borsh => quote! {
                     impl anchor_lang::AccountSerialize for #name {
                         fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
-                            if writer.write_all(&#discriminator).is_err() {
+                            if writer.write_all(#disc).is_err() {
                                 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
                             }
                             if AnchorSerialize::serialize(self, writer).is_err() {
@@ -51,7 +52,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                         #try_deserialize
 
                         fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                            let mut data: &[u8] = &buf[#discriminator.len()..];
+                            let mut data: &[u8] = &buf[#disc.len()..];
                             AnchorDeserialize::deserialize(&mut data)
                                 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
                         }
@@ -75,7 +76,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
                             #try_deserialize
 
                             fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
-                                let data: &[u8] = &buf[#discriminator.len()..];
+                                let data: &[u8] = &buf[#disc.len()..];
                                 let account = anchor_lang::__private::bytemuck::from_bytes(data);
                                 Ok(*account)
                             }