Browse Source

lang: Add `discriminator` argument to `#[account]` attribute (#3149)

acheron 1 year ago
parent
commit
d73983d3db

+ 1 - 1
.github/workflows/reusable-tests.yaml

@@ -439,7 +439,7 @@ jobs:
             path: tests/safety-checks
           - cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit
             path: tests/custom-coder
-          - cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit
+          - cmd: cd tests/custom-discriminator && anchor test
             path: tests/custom-discriminator
           - cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit
             path: tests/validator-clone

+ 1 - 0
CHANGELOG.md

@@ -31,6 +31,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
 - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
 - lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).
+- lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)).
 
 ### Fixes
 

+ 52 - 14
lang/attribute/account/src/lib.rs

@@ -6,7 +6,7 @@ use syn::{
     parse::{Parse, ParseStream},
     parse_macro_input,
     token::{Comma, Paren},
-    Ident, LitStr,
+    Expr, Ident, Lit, LitStr, Token,
 };
 
 mod id;
@@ -31,6 +31,22 @@ mod id;
 /// check this discriminator. If it doesn't match, an invalid account was given,
 /// and the account deserialization will exit with an error.
 ///
+/// # Args
+///
+/// - `discriminator`: Override the default 8-byte discriminator
+///
+///     **Usage:** `discriminator = <CONST_EXPR>`
+///
+///     All constant expressions are supported.
+///
+///     **Examples:**
+///
+///     - `discriminator = 0` (shortcut for `[0]`)
+///     - `discriminator = [1, 2, 3, 4]`
+///     - `discriminator = b"hi"`
+///     - `discriminator = MY_DISC`
+///     - `discriminator = get_disc(...)`
+///
 /// # Zero Copy Deserialization
 ///
 /// **WARNING**: Zero copy deserialization is an experimental feature. It's
@@ -83,23 +99,21 @@ pub fn account(
     let account_name_str = account_name.to_string();
     let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
 
-    let discriminator: proc_macro2::TokenStream = {
+    let discriminator = args.discriminator.unwrap_or_else(|| {
         // Namespace the discriminator to prevent collisions.
-        let discriminator_preimage = {
-            // For now, zero copy accounts can't be namespaced.
-            if namespace.is_empty() {
-                format!("account:{account_name}")
-            } else {
-                format!("{namespace}:{account_name}")
-            }
+        let discriminator_preimage = if namespace.is_empty() {
+            format!("account:{account_name}")
+        } else {
+            format!("{namespace}:{account_name}")
         };
 
         let mut discriminator = [0u8; 8];
         discriminator.copy_from_slice(
             &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
         );
-        format!("{discriminator:?}").parse().unwrap()
-    };
+        let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
+        quote! { &#discriminator }
+    });
     let disc = if account_strct.generics.lt_token.is_some() {
         quote! { #account_name::#type_gen::DISCRIMINATOR }
     } else {
@@ -159,7 +173,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
+                    const DISCRIMINATOR: &'static [u8] = #discriminator;
                 }
 
                 // This trait is useful for clients deserializing accounts.
@@ -229,7 +243,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    const DISCRIMINATOR: &'static [u8] = &#discriminator;
+                    const DISCRIMINATOR: &'static [u8] = #discriminator;
                 }
 
                 #owner_impl
@@ -242,7 +256,10 @@ pub fn account(
 struct AccountArgs {
     /// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
     zero_copy: Option<bool>,
+    /// Account namespace override, `account` if not specified
     namespace: Option<String>,
+    /// Discriminator override
+    discriminator: Option<proc_macro2::TokenStream>,
 }
 
 impl Parse for AccountArgs {
@@ -257,6 +274,9 @@ impl Parse for AccountArgs {
                 AccountArg::Namespace(ns) => {
                     parsed.namespace.replace(ns);
                 }
+                AccountArg::Discriminator(disc) => {
+                    parsed.discriminator.replace(disc);
+                }
             }
         }
 
@@ -267,6 +287,7 @@ impl Parse for AccountArgs {
 enum AccountArg {
     ZeroCopy { is_unsafe: bool },
     Namespace(String),
+    Discriminator(proc_macro2::TokenStream),
 }
 
 impl Parse for AccountArg {
@@ -300,7 +321,24 @@ impl Parse for AccountArg {
             return Ok(Self::ZeroCopy { is_unsafe });
         };
 
-        Err(syn::Error::new(ident.span(), "Unexpected argument"))
+        // Named arguments
+        // TODO: Share the common arguments with `#[instruction]`
+        input.parse::<Token![=]>()?;
+        let value = input.parse::<Expr>()?;
+        match ident.to_string().as_str() {
+            "discriminator" => {
+                let value = match value {
+                    // Allow `discriminator = 42`
+                    Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
+                    // Allow `discriminator = [0, 1, 2, 3]`
+                    Expr::Array(arr) => quote! { &#arr },
+                    expr => expr.to_token_stream(),
+                };
+
+                Ok(Self::Discriminator(value))
+            }
+            _ => Err(syn::Error::new(ident.span(), "Invalid argument")),
+        }
     }
 }
 

+ 25 - 0
tests/custom-discriminator/programs/custom-discriminator/src/lib.rs

@@ -39,9 +39,34 @@ pub mod custom_discriminator {
     pub fn const_fn(_ctx: Context<DefaultIx>) -> Result<()> {
         Ok(())
     }
+
+    pub fn account(ctx: Context<CustomAccountIx>, field: u8) -> Result<()> {
+        ctx.accounts.my_account.field = field;
+        Ok(())
+    }
 }
 
 #[derive(Accounts)]
 pub struct DefaultIx<'info> {
     pub signer: Signer<'info>,
 }
+
+#[derive(Accounts)]
+pub struct CustomAccountIx<'info> {
+    #[account(mut)]
+    pub signer: Signer<'info>,
+    #[account(
+        init,
+        payer = signer,
+        space = MyAccount::DISCRIMINATOR.len() + core::mem::size_of::<MyAccount>(),
+        seeds = [b"my_account"],
+        bump
+    )]
+    pub my_account: Account<'info, MyAccount>,
+    pub system_program: Program<'info, System>,
+}
+
+#[account(discriminator = 1)]
+pub struct MyAccount {
+    pub field: u8,
+}

+ 23 - 1
tests/custom-discriminator/tests/custom-discriminator.ts

@@ -8,7 +8,7 @@ describe("custom-discriminator", () => {
   const program: anchor.Program<CustomDiscriminator> =
     anchor.workspace.customDiscriminator;
 
-  describe("Can use custom instruction discriminators", () => {
+  describe("Instructions", () => {
     const testCommon = async (ixName: keyof typeof program["methods"]) => {
       const tx = await program.methods[ixName]().transaction();
 
@@ -28,4 +28,26 @@ describe("custom-discriminator", () => {
     it("Constant", () => testCommon("constant"));
     it("Const Fn", () => testCommon("constFn"));
   });
+
+  describe("Accounts", () => {
+    it("Works", async () => {
+      // Verify discriminator
+      const acc = program.idl.accounts.find((acc) => acc.name === "myAccount")!;
+      assert(acc.discriminator.length < 8);
+
+      // Verify regular `init` ix works
+      const field = 5;
+      const { pubkeys, signature } = await program.methods
+        .account(field)
+        .rpcAndKeys();
+      await program.provider.connection.confirmTransaction(
+        signature,
+        "confirmed"
+      );
+      const myAccount = await program.account.myAccount.fetch(
+        pubkeys.myAccount
+      );
+      assert.strictEqual(field, myAccount.field);
+    });
+  });
 });