فهرست منبع

lang: Add `discriminator` argument to `#[event]` attribute (#3152)

acheron 1 سال پیش
والد
کامیت
9117bbc001

+ 1 - 0
CHANGELOG.md

@@ -32,6 +32,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
 - lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).
 - lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)).
+- lang: Add `discriminator` argument to `#[event]` attribute ([#3152](https://github.com/coral-xyz/anchor/pull/3152)).
 
 ### Fixes
 

+ 80 - 9
lang/attribute/event/src/lib.rs

@@ -2,29 +2,52 @@ extern crate proc_macro;
 
 #[cfg(feature = "event-cpi")]
 use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
-use quote::quote;
-use syn::parse_macro_input;
+use quote::{quote, ToTokens};
+use syn::{
+    parse::{Parse, ParseStream},
+    parse_macro_input,
+    token::Comma,
+    Expr, Ident, Lit, Token,
+};
 
 /// The event attribute allows a struct to be used with
 /// [emit!](./macro.emit.html) so that programs can log significant events in
 /// their programs that clients can subscribe to. Currently, this macro is for
 /// structs only.
 ///
+/// # Args
+///
+/// - `discriminator`: Override the default 8-byte discriminator
+///
+///     **Usage:** `discriminator = <CONST_EXPR>`
+///
+///     All constant expressions are supported.
+///
+///     **Examples:**
+///
+///     - `discriminator = 0` (shortcut for `[0]`)
+///     - `discriminator = [1, 2, 3, 4]`
+///     - `discriminator = b"hi"`
+///     - `discriminator = MY_DISC`
+///     - `discriminator = get_disc(...)`
+///
 /// See the [`emit!` macro](emit!) for an example.
 #[proc_macro_attribute]
 pub fn event(
-    _args: proc_macro::TokenStream,
+    args: proc_macro::TokenStream,
     input: proc_macro::TokenStream,
 ) -> proc_macro::TokenStream {
+    let args = parse_macro_input!(args as EventArgs);
     let event_strct = parse_macro_input!(input as syn::ItemStruct);
-
     let event_name = &event_strct.ident;
 
-    let discriminator: proc_macro2::TokenStream = {
+    let discriminator = args.discriminator.unwrap_or_else(|| {
         let discriminator_preimage = format!("event:{event_name}").into_bytes();
         let discriminator = anchor_syn::hash::hash(&discriminator_preimage);
-        format!("{:?}", &discriminator.0[..8]).parse().unwrap()
-    };
+        let discriminator: proc_macro2::TokenStream =
+            format!("{:?}", &discriminator.0[..8]).parse().unwrap();
+        quote! { &#discriminator }
+    });
 
     let ret = quote! {
         #[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)]
@@ -33,14 +56,14 @@ pub fn event(
         impl anchor_lang::Event for #event_name {
             fn data(&self) -> Vec<u8> {
                 let mut data = Vec::with_capacity(256);
-                data.extend_from_slice(&#discriminator);
+                data.extend_from_slice(#event_name::DISCRIMINATOR);
                 self.serialize(&mut data).unwrap();
                 data
             }
         }
 
         impl anchor_lang::Discriminator for #event_name {
-            const DISCRIMINATOR: &'static [u8] = &#discriminator;
+            const DISCRIMINATOR: &'static [u8] = #discriminator;
         }
     };
 
@@ -57,6 +80,54 @@ pub fn event(
     proc_macro::TokenStream::from(ret)
 }
 
+#[derive(Debug, Default)]
+struct EventArgs {
+    /// Discriminator override
+    discriminator: Option<proc_macro2::TokenStream>,
+}
+
+impl Parse for EventArgs {
+    fn parse(input: ParseStream) -> syn::Result<Self> {
+        // TODO: Share impl with `#[instruction]`
+        let mut parsed = Self::default();
+        let args = input.parse_terminated::<_, Comma>(EventArg::parse)?;
+        for arg in args {
+            match arg.name.to_string().as_str() {
+                "discriminator" => {
+                    let value = match &arg.value {
+                        // Allow `discriminator = 42`
+                        Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
+                        // Allow `discriminator = [0, 1, 2, 3]`
+                        Expr::Array(arr) => quote! { &#arr },
+                        expr => expr.to_token_stream(),
+                    };
+                    parsed.discriminator.replace(value);
+                }
+                _ => return Err(syn::Error::new(arg.name.span(), "Invalid argument")),
+            }
+        }
+
+        Ok(parsed)
+    }
+}
+
+struct EventArg {
+    name: Ident,
+    #[allow(dead_code)]
+    eq_token: Token![=],
+    value: Expr,
+}
+
+impl Parse for EventArg {
+    fn parse(input: ParseStream) -> syn::Result<Self> {
+        Ok(Self {
+            name: input.parse()?,
+            eq_token: input.parse()?,
+            value: input.parse()?,
+        })
+    }
+}
+
 // EventIndex is a marker macro. It functionally does nothing other than
 // allow one to mark fields with the `#[index]` inert attribute, which is
 // used to add metadata to IDLs.

+ 10 - 0
tests/custom-discriminator/programs/custom-discriminator/src/lib.rs

@@ -44,6 +44,11 @@ pub mod custom_discriminator {
         ctx.accounts.my_account.field = field;
         Ok(())
     }
+
+    pub fn event(_ctx: Context<DefaultIx>, field: u8) -> Result<()> {
+        emit!(MyEvent { field });
+        Ok(())
+    }
 }
 
 #[derive(Accounts)]
@@ -70,3 +75,8 @@ pub struct CustomAccountIx<'info> {
 pub struct MyAccount {
     pub field: u8,
 }
+
+#[event(discriminator = 1)]
+pub struct MyEvent {
+    field: u8,
+}

+ 20 - 1
tests/custom-discriminator/tests/custom-discriminator.ts

@@ -47,7 +47,26 @@ describe("custom-discriminator", () => {
       const myAccount = await program.account.myAccount.fetch(
         pubkeys.myAccount
       );
-      assert.strictEqual(field, myAccount.field);
+      assert.strictEqual(myAccount.field, field);
+    });
+  });
+
+  describe("Events", () => {
+    it("Works", async () => {
+      // Verify discriminator
+      const event = program.idl.events.find((acc) => acc.name === "myEvent")!;
+      assert(event.discriminator.length < 8);
+
+      // Verify regular event works
+      await new Promise<void>((res) => {
+        const field = 5;
+        const id = program.addEventListener("myEvent", (ev) => {
+          assert.strictEqual(ev.field, field);
+          program.removeEventListener(id);
+          res();
+        });
+        program.methods.event(field).rpc();
+      });
     });
   });
 });