Explorar el Código

lang: Add `#[instruction]` attribute proc-macro (#3137)

acheron hace 1 año
padre
commit
3f945f682c

+ 2 - 0
.github/workflows/reusable-tests.yaml

@@ -439,6 +439,8 @@ jobs:
             path: tests/safety-checks
           - cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit
             path: tests/custom-coder
+          - cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit
+            path: tests/custom-discriminator
           - cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit
             path: tests/validator-clone
           - cmd: cd tests/cpi-returns && anchor test --skip-lint && npx tsc --noEmit

+ 1 - 0
CHANGELOG.md

@@ -29,6 +29,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - ts: Add optional `wallet` property to the `Provider` interface ([#3130](https://github.com/coral-xyz/anchor/pull/3130)).
 - cli: Warn if `anchor-spl/idl-build` is missing ([#3133](https://github.com/coral-xyz/anchor/pull/3133)).
 - client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
+- lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
 
 ### Fixes
 

+ 50 - 0
lang/attribute/program/src/lib.rs

@@ -103,3 +103,53 @@ pub fn interface(
     // discriminator.
     input
 }
+
+/// This attribute is used to override the Anchor defaults of program instructions.
+///
+/// # Args
+///
+/// - `discriminator`: Override the default 8-byte discriminator
+///
+///     **Usage:** `discriminator = <CONST_EXPR>`
+///
+///     All constant expressions are supported.
+///
+///     **Examples:**
+///
+///     - `discriminator = 0` (shortcut for `[0]`)
+///     - `discriminator = [1, 2, 3, 4]`
+///     - `discriminator = b"hi"`
+///     - `discriminator = MY_DISC`
+///     - `discriminator = get_disc(...)`
+///
+/// # Example
+///
+/// ```ignore
+/// use anchor_lang::prelude::*;
+///
+/// declare_id!("CustomDiscriminator111111111111111111111111");
+///
+/// #[program]
+/// pub mod custom_discriminator {
+///     use super::*;
+///
+///     #[instruction(discriminator = [1, 2, 3, 4])]
+///     pub fn my_ix(_ctx: Context<MyIx>) -> Result<()> {
+///         Ok(())
+///     }
+/// }
+///
+/// #[derive(Accounts)]
+/// pub struct MyIx<'info> {
+///     pub signer: Signer<'info>,
+/// }
+/// ```
+#[proc_macro_attribute]
+pub fn instruction(
+    _args: proc_macro::TokenStream,
+    input: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+    // This macro itself is a no-op, but the `#[program]` macro will detect this attribute and use
+    // the arguments to transform the instruction.
+    input
+}

+ 3 - 3
lang/src/lib.rs

@@ -52,7 +52,7 @@ pub use anchor_attribute_account::{account, declare_id, pubkey, zero_copy};
 pub use anchor_attribute_constant::constant;
 pub use anchor_attribute_error::*;
 pub use anchor_attribute_event::{emit, event};
-pub use anchor_attribute_program::{declare_program, program};
+pub use anchor_attribute_program::{declare_program, instruction, program};
 pub use anchor_derive_accounts::Accounts;
 pub use anchor_derive_serde::{AnchorDeserialize, AnchorSerialize};
 pub use anchor_derive_space::InitSpace;
@@ -392,8 +392,8 @@ pub mod prelude {
         accounts::signer::Signer, accounts::system_account::SystemAccount,
         accounts::sysvar::Sysvar, accounts::unchecked_account::UncheckedAccount, constant,
         context::Context, context::CpiContext, declare_id, declare_program, emit, err, error,
-        event, program, pubkey, require, require_eq, require_gt, require_gte, require_keys_eq,
-        require_keys_neq, require_neq,
+        event, instruction, program, pubkey, require, require_eq, require_gt, require_gte,
+        require_keys_eq, require_keys_neq, require_neq,
         solana_program::bpf_loader_upgradeable::UpgradeableLoaderState, source,
         system_program::System, zero_copy, AccountDeserialize, AccountSerialize, Accounts,
         AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Discriminator, Id,

+ 19 - 9
lang/syn/src/codegen/program/instruction.rs

@@ -21,15 +21,25 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                         .unwrap()
                 })
                 .collect();
-            let ix_data_trait = {
-                let discriminator = ix
-                    .interface_discriminator
-                    .unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
-                let discriminator: proc_macro2::TokenStream =
-                    format!("{discriminator:?}").parse().unwrap();
+            let impls = {
+                let discriminator = match ix.ix_attr.as_ref() {
+                    Some(ix_attr) if ix_attr.discriminator.is_some() => {
+                        ix_attr.discriminator.as_ref().unwrap().to_owned()
+                    }
+                    _ => {
+                        // TODO: Remove `interface_discriminator`
+                        let discriminator = ix
+                            .interface_discriminator
+                            .unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
+                        let discriminator: proc_macro2::TokenStream =
+                            format!("{discriminator:?}").parse().unwrap();
+                        quote! { &#discriminator }
+                    }
+                };
+
                 quote! {
                     impl anchor_lang::Discriminator for #ix_name_camel {
-                        const DISCRIMINATOR: &'static [u8] = &#discriminator;
+                        const DISCRIMINATOR: &'static [u8] = #discriminator;
                     }
                     impl anchor_lang::InstructionData for #ix_name_camel {}
                     impl anchor_lang::Owner for #ix_name_camel {
@@ -46,7 +56,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                     #[derive(AnchorSerialize, AnchorDeserialize)]
                     pub struct #ix_name_camel;
 
-                    #ix_data_trait
+                    #impls
                 }
             } else {
                 quote! {
@@ -56,7 +66,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                         #(#raw_args),*
                     }
 
-                    #ix_data_trait
+                    #impls
                 }
             }
         })

+ 52 - 0
lang/syn/src/lib.rs

@@ -25,6 +25,7 @@ use syn::parse::{Error as ParseError, Parse, ParseStream, Result as ParseResult}
 use syn::punctuated::Punctuated;
 use syn::spanned::Spanned;
 use syn::token::Comma;
+use syn::Lit;
 use syn::{
     Expr, Generics, Ident, ItemEnum, ItemFn, ItemMod, ItemStruct, LitInt, PatType, Token, Type,
     TypePath,
@@ -68,7 +69,58 @@ pub struct Ix {
     // The ident for the struct deriving Accounts.
     pub anchor_ident: Ident,
     // The discriminator based on the `#[interface]` attribute.
+    // TODO: Remove and use `ix_attr`
     pub interface_discriminator: Option<[u8; 8]>,
+    /// `#[instruction]` attribute
+    pub ix_attr: Option<IxAttr>,
+}
+
+/// `#[instruction]` attribute proc-macro
+#[derive(Debug, Default)]
+pub struct IxAttr {
+    /// Discriminator override
+    pub discriminator: Option<TokenStream>,
+}
+
+impl Parse for IxAttr {
+    fn parse(input: ParseStream) -> ParseResult<Self> {
+        let mut attr = Self::default();
+        let args = input.parse_terminated::<_, Comma>(AttrArg::parse)?;
+        for arg in args {
+            match arg.name.to_string().as_str() {
+                "discriminator" => {
+                    let value = match &arg.value {
+                        // Allow `discriminator = 42`
+                        Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
+                        // Allow `discriminator = [0, 1, 2, 3]`
+                        Expr::Array(arr) => quote! { &#arr },
+                        expr => expr.to_token_stream(),
+                    };
+                    attr.discriminator.replace(value)
+                }
+                _ => return Err(ParseError::new(arg.name.span(), "Invalid argument")),
+            };
+        }
+
+        Ok(attr)
+    }
+}
+
+struct AttrArg {
+    name: Ident,
+    #[allow(dead_code)]
+    eq_token: Token!(=),
+    value: Expr,
+}
+
+impl Parse for AttrArg {
+    fn parse(input: ParseStream) -> ParseResult<Self> {
+        Ok(Self {
+            name: input.parse()?,
+            eq_token: input.parse()?,
+            value: input.parse()?,
+        })
+    }
 }
 
 #[derive(Debug)]

+ 15 - 1
lang/syn/src/parser/program/instructions.rs

@@ -1,7 +1,7 @@
 use crate::parser::docs;
 use crate::parser::program::ctx_accounts_ident;
 use crate::parser::spl_interface;
-use crate::{FallbackFn, Ix, IxArg, IxReturn};
+use crate::{FallbackFn, Ix, IxArg, IxAttr, IxReturn};
 use syn::parse::{Error as ParseError, Result as ParseResult};
 use syn::spanned::Spanned;
 
@@ -25,6 +25,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
         })
         .map(|method: &syn::ItemFn| {
             let (ctx, args) = parse_args(method)?;
+            let ix_attr = parse_ix_attr(&method.attrs)?;
             let interface_discriminator = spl_interface::parse(&method.attrs);
             let docs = docs::parse(&method.attrs);
             let returns = parse_return(method)?;
@@ -37,6 +38,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
                 anchor_ident,
                 returns,
                 interface_discriminator,
+                ix_attr,
             })
         })
         .collect::<ParseResult<Vec<Ix>>>()?;
@@ -71,6 +73,18 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
     Ok((ixs, fallback_fn))
 }
 
+/// Parse `#[instruction]` attribute proc-macro.
+fn parse_ix_attr(attrs: &[syn::Attribute]) -> ParseResult<Option<IxAttr>> {
+    attrs
+        .iter()
+        .find(|attr| match attr.path.segments.last() {
+            Some(seg) => seg.ident == "instruction",
+            _ => false,
+        })
+        .map(|attr| attr.parse_args())
+        .transpose()
+}
+
 pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {
     let mut args: Vec<IxArg> = method
         .sig

+ 9 - 0
tests/custom-discriminator/Anchor.toml

@@ -0,0 +1,9 @@
+[programs.localnet]
+custom_discriminator = "CustomDiscriminator111111111111111111111111"
+
+[provider]
+cluster = "localnet"
+wallet = "~/.config/solana/id.json"
+
+[scripts]
+test = "yarn run ts-mocha -p ./tsconfig.json -t 1000000 tests/**/*.ts"

+ 14 - 0
tests/custom-discriminator/Cargo.toml

@@ -0,0 +1,14 @@
+[workspace]
+members = [
+    "programs/*"
+]
+resolver = "2"
+
+[profile.release]
+overflow-checks = true
+lto = "fat"
+codegen-units = 1
+[profile.release.build-override]
+opt-level = 3
+incremental = false
+codegen-units = 1

+ 16 - 0
tests/custom-discriminator/package.json

@@ -0,0 +1,16 @@
+{
+  "name": "custom-discriminator",
+  "version": "0.30.1",
+  "license": "(MIT OR Apache-2.0)",
+  "homepage": "https://github.com/coral-xyz/anchor#readme",
+  "bugs": {
+    "url": "https://github.com/coral-xyz/anchor/issues"
+  },
+  "repository": {
+    "type": "git",
+    "url": "https://github.com/coral-xyz/anchor.git"
+  },
+  "engines": {
+    "node": ">=17"
+  }
+}

+ 19 - 0
tests/custom-discriminator/programs/custom-discriminator/Cargo.toml

@@ -0,0 +1,19 @@
+[package]
+name = "custom-discriminator"
+version = "0.1.0"
+description = "Created with Anchor"
+edition = "2021"
+
+[lib]
+crate-type = ["cdylib", "lib"]
+name = "custom_discriminator"
+
+[features]
+no-entrypoint = []
+no-idl = []
+cpi = ["no-entrypoint"]
+default = []
+idl-build = ["anchor-lang/idl-build"]
+
+[dependencies]
+anchor-lang = { path = "../../../../lang" }

+ 2 - 0
tests/custom-discriminator/programs/custom-discriminator/Xargo.toml

@@ -0,0 +1,2 @@
+[target.bpfel-unknown-unknown.dependencies.std]
+features = []

+ 47 - 0
tests/custom-discriminator/programs/custom-discriminator/src/lib.rs

@@ -0,0 +1,47 @@
+use anchor_lang::prelude::*;
+
+declare_id!("CustomDiscriminator111111111111111111111111");
+
+const CONST_DISC: &'static [u8] = &[55, 66, 77, 88];
+
+const fn get_disc(input: &str) -> &'static [u8] {
+    match input.as_bytes() {
+        b"wow" => &[4 + 5, 55 / 5],
+        _ => unimplemented!(),
+    }
+}
+
+#[program]
+pub mod custom_discriminator {
+    use super::*;
+
+    #[instruction(discriminator = 0)]
+    pub fn int(_ctx: Context<DefaultIx>) -> Result<()> {
+        Ok(())
+    }
+
+    #[instruction(discriminator = [1, 2, 3, 4])]
+    pub fn array(_ctx: Context<DefaultIx>) -> Result<()> {
+        Ok(())
+    }
+
+    #[instruction(discriminator = b"hi")]
+    pub fn byte_str(_ctx: Context<DefaultIx>) -> Result<()> {
+        Ok(())
+    }
+
+    #[instruction(discriminator = CONST_DISC)]
+    pub fn constant(_ctx: Context<DefaultIx>) -> Result<()> {
+        Ok(())
+    }
+
+    #[instruction(discriminator = get_disc("wow"))]
+    pub fn const_fn(_ctx: Context<DefaultIx>) -> Result<()> {
+        Ok(())
+    }
+}
+
+#[derive(Accounts)]
+pub struct DefaultIx<'info> {
+    pub signer: Signer<'info>,
+}

+ 31 - 0
tests/custom-discriminator/tests/custom-discriminator.ts

@@ -0,0 +1,31 @@
+import * as anchor from "@coral-xyz/anchor";
+import assert from "assert";
+
+import type { CustomDiscriminator } from "../target/types/custom_discriminator";
+
+describe("custom-discriminator", () => {
+  anchor.setProvider(anchor.AnchorProvider.env());
+  const program: anchor.Program<CustomDiscriminator> =
+    anchor.workspace.customDiscriminator;
+
+  describe("Can use custom instruction discriminators", () => {
+    const testCommon = async (ixName: keyof typeof program["methods"]) => {
+      const tx = await program.methods[ixName]().transaction();
+
+      // Verify discriminator
+      const ix = program.idl.instructions.find((ix) => ix.name === ixName)!;
+      assert(ix.discriminator.length < 8);
+      const data = tx.instructions[0].data;
+      assert(data.equals(Buffer.from(ix.discriminator)));
+
+      // Verify tx runs
+      await program.provider.sendAndConfirm!(tx);
+    };
+
+    it("Integer", () => testCommon("int"));
+    it("Array", () => testCommon("array"));
+    it("Byte string", () => testCommon("byteStr"));
+    it("Constant", () => testCommon("constant"));
+    it("Const Fn", () => testCommon("constFn"));
+  });
+});

+ 11 - 0
tests/custom-discriminator/tsconfig.json

@@ -0,0 +1,11 @@
+{
+  "compilerOptions": {
+    "types": ["mocha", "chai"],
+    "lib": ["es2015"],
+    "module": "commonjs",
+    "target": "es6",
+    "esModuleInterop": true,
+    "strict": true,
+    "skipLibCheck": true
+  }
+}

+ 1 - 0
tests/package.json

@@ -15,6 +15,7 @@
     "chat",
     "composite",
     "custom-coder",
+    "custom-discriminator",
     "declare-id",
     "declare-program",
     "errors",