Browse Source

lang: Reuse common override arguments (#3154)

acheron 1 year ago
parent
commit
dfb2de5338

+ 29 - 40
lang/attribute/account/src/lib.rs

@@ -1,12 +1,13 @@
 extern crate proc_macro;
 
+use anchor_syn::Overrides;
 use quote::{quote, ToTokens};
 use syn::{
     parenthesized,
     parse::{Parse, ParseStream},
     parse_macro_input,
     token::{Comma, Paren},
-    Expr, Ident, Lit, LitStr, Token,
+    Ident, LitStr,
 };
 
 mod id;
@@ -99,21 +100,25 @@ pub fn account(
     let account_name_str = account_name.to_string();
     let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
 
-    let discriminator = args.discriminator.unwrap_or_else(|| {
-        // Namespace the discriminator to prevent collisions.
-        let discriminator_preimage = if namespace.is_empty() {
-            format!("account:{account_name}")
-        } else {
-            format!("{namespace}:{account_name}")
-        };
+    let discriminator = args
+        .overrides
+        .and_then(|ov| ov.discriminator)
+        .unwrap_or_else(|| {
+            // Namespace the discriminator to prevent collisions.
+            let discriminator_preimage = if namespace.is_empty() {
+                format!("account:{account_name}")
+            } else {
+                format!("{namespace}:{account_name}")
+            };
 
-        let mut discriminator = [0u8; 8];
-        discriminator.copy_from_slice(
-            &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
-        );
-        let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
-        quote! { &#discriminator }
-    });
+            let mut discriminator = [0u8; 8];
+            discriminator.copy_from_slice(
+                &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
+            );
+            let discriminator: proc_macro2::TokenStream =
+                format!("{discriminator:?}").parse().unwrap();
+            quote! { &#discriminator }
+        });
     let disc = if account_strct.generics.lt_token.is_some() {
         quote! { #account_name::#type_gen::DISCRIMINATOR }
     } else {
@@ -258,8 +263,8 @@ struct AccountArgs {
     zero_copy: Option<bool>,
     /// Account namespace override, `account` if not specified
     namespace: Option<String>,
-    /// Discriminator override
-    discriminator: Option<proc_macro2::TokenStream>,
+    /// Named overrides
+    overrides: Option<Overrides>,
 }
 
 impl Parse for AccountArgs {
@@ -274,8 +279,8 @@ impl Parse for AccountArgs {
                 AccountArg::Namespace(ns) => {
                     parsed.namespace.replace(ns);
                 }
-                AccountArg::Discriminator(disc) => {
-                    parsed.discriminator.replace(disc);
+                AccountArg::Overrides(ov) => {
+                    parsed.overrides.replace(ov);
                 }
             }
         }
@@ -287,7 +292,7 @@ impl Parse for AccountArgs {
 enum AccountArg {
     ZeroCopy { is_unsafe: bool },
     Namespace(String),
-    Discriminator(proc_macro2::TokenStream),
+    Overrides(Overrides),
 }
 
 impl Parse for AccountArg {
@@ -300,8 +305,8 @@ impl Parse for AccountArg {
         }
 
         // Zero copy
-        let ident = input.parse::<Ident>()?;
-        if ident == "zero_copy" {
+        if input.fork().parse::<Ident>()? == "zero_copy" {
+            input.parse::<Ident>()?;
             let is_unsafe = if input.peek(Paren) {
                 let content;
                 parenthesized!(content in input);
@@ -321,24 +326,8 @@ impl Parse for AccountArg {
             return Ok(Self::ZeroCopy { is_unsafe });
         };
 
-        // Named arguments
-        // TODO: Share the common arguments with `#[instruction]`
-        input.parse::<Token![=]>()?;
-        let value = input.parse::<Expr>()?;
-        match ident.to_string().as_str() {
-            "discriminator" => {
-                let value = match value {
-                    // Allow `discriminator = 42`
-                    Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
-                    // Allow `discriminator = [0, 1, 2, 3]`
-                    Expr::Array(arr) => quote! { &#arr },
-                    expr => expr.to_token_stream(),
-                };
-
-                Ok(Self::Discriminator(value))
-            }
-            _ => Err(syn::Error::new(ident.span(), "Invalid argument")),
-        }
+        // Overrides
+        input.parse::<Overrides>().map(Self::Overrides)
     }
 }
 

+ 4 - 56
lang/attribute/event/src/lib.rs

@@ -2,13 +2,9 @@ extern crate proc_macro;
 
 #[cfg(feature = "event-cpi")]
 use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
-use quote::{quote, ToTokens};
-use syn::{
-    parse::{Parse, ParseStream},
-    parse_macro_input,
-    token::Comma,
-    Expr, Ident, Lit, Token,
-};
+use anchor_syn::Overrides;
+use quote::quote;
+use syn::parse_macro_input;
 
 /// The event attribute allows a struct to be used with
 /// [emit!](./macro.emit.html) so that programs can log significant events in
@@ -37,7 +33,7 @@ pub fn event(
     args: proc_macro::TokenStream,
     input: proc_macro::TokenStream,
 ) -> proc_macro::TokenStream {
-    let args = parse_macro_input!(args as EventArgs);
+    let args = parse_macro_input!(args as Overrides);
     let event_strct = parse_macro_input!(input as syn::ItemStruct);
     let event_name = &event_strct.ident;
 
@@ -80,54 +76,6 @@ pub fn event(
     proc_macro::TokenStream::from(ret)
 }
 
-#[derive(Debug, Default)]
-struct EventArgs {
-    /// Discriminator override
-    discriminator: Option<proc_macro2::TokenStream>,
-}
-
-impl Parse for EventArgs {
-    fn parse(input: ParseStream) -> syn::Result<Self> {
-        // TODO: Share impl with `#[instruction]`
-        let mut parsed = Self::default();
-        let args = input.parse_terminated::<_, Comma>(EventArg::parse)?;
-        for arg in args {
-            match arg.name.to_string().as_str() {
-                "discriminator" => {
-                    let value = match &arg.value {
-                        // Allow `discriminator = 42`
-                        Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
-                        // Allow `discriminator = [0, 1, 2, 3]`
-                        Expr::Array(arr) => quote! { &#arr },
-                        expr => expr.to_token_stream(),
-                    };
-                    parsed.discriminator.replace(value);
-                }
-                _ => return Err(syn::Error::new(arg.name.span(), "Invalid argument")),
-            }
-        }
-
-        Ok(parsed)
-    }
-}
-
-struct EventArg {
-    name: Ident,
-    #[allow(dead_code)]
-    eq_token: Token![=],
-    value: Expr,
-}
-
-impl Parse for EventArg {
-    fn parse(input: ParseStream) -> syn::Result<Self> {
-        Ok(Self {
-            name: input.parse()?,
-            eq_token: input.parse()?,
-            value: input.parse()?,
-        })
-    }
-}
-
 // EventIndex is a marker macro. It functionally does nothing other than
 // allow one to mark fields with the `#[index]` inert attribute, which is
 // used to add metadata to IDLs.

+ 3 - 3
lang/syn/src/codegen/program/instruction.rs

@@ -22,9 +22,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                 })
                 .collect();
             let impls = {
-                let discriminator = match ix.ix_attr.as_ref() {
-                    Some(ix_attr) if ix_attr.discriminator.is_some() => {
-                        ix_attr.discriminator.as_ref().unwrap().to_owned()
+                let discriminator = match ix.overrides.as_ref() {
+                    Some(overrides) if overrides.discriminator.is_some() => {
+                        overrides.discriminator.as_ref().unwrap().to_owned()
                     }
                     _ => {
                         // TODO: Remove `interface_discriminator`

+ 11 - 11
lang/syn/src/lib.rs

@@ -69,23 +69,23 @@ pub struct Ix {
     // The ident for the struct deriving Accounts.
     pub anchor_ident: Ident,
     // The discriminator based on the `#[interface]` attribute.
-    // TODO: Remove and use `ix_attr`
+    // TODO: Remove and use `overrides`
     pub interface_discriminator: Option<[u8; 8]>,
-    /// `#[instruction]` attribute
-    pub ix_attr: Option<IxAttr>,
+    /// Overrides coming from the `#[instruction]` attribute
+    pub overrides: Option<Overrides>,
 }
 
-/// `#[instruction]` attribute proc-macro
+/// Common overrides for the `#[instruction]`, `#[account]` and `#[event]` attributes
 #[derive(Debug, Default)]
-pub struct IxAttr {
-    /// Discriminator override
+pub struct Overrides {
+    /// Override the default 8-byte discriminator
     pub discriminator: Option<TokenStream>,
 }
 
-impl Parse for IxAttr {
+impl Parse for Overrides {
     fn parse(input: ParseStream) -> ParseResult<Self> {
         let mut attr = Self::default();
-        let args = input.parse_terminated::<_, Comma>(AttrArg::parse)?;
+        let args = input.parse_terminated::<_, Comma>(NamedArg::parse)?;
         for arg in args {
             match arg.name.to_string().as_str() {
                 "discriminator" => {
@@ -106,14 +106,14 @@ impl Parse for IxAttr {
     }
 }
 
-struct AttrArg {
+struct NamedArg {
     name: Ident,
     #[allow(dead_code)]
-    eq_token: Token!(=),
+    eq_token: Token![=],
     value: Expr,
 }
 
-impl Parse for AttrArg {
+impl Parse for NamedArg {
     fn parse(input: ParseStream) -> ParseResult<Self> {
         Ok(Self {
             name: input.parse()?,

+ 5 - 5
lang/syn/src/parser/program/instructions.rs

@@ -1,7 +1,7 @@
 use crate::parser::docs;
 use crate::parser::program::ctx_accounts_ident;
 use crate::parser::spl_interface;
-use crate::{FallbackFn, Ix, IxArg, IxAttr, IxReturn};
+use crate::{FallbackFn, Ix, IxArg, IxReturn, Overrides};
 use syn::parse::{Error as ParseError, Result as ParseResult};
 use syn::spanned::Spanned;
 
@@ -25,7 +25,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
         })
         .map(|method: &syn::ItemFn| {
             let (ctx, args) = parse_args(method)?;
-            let ix_attr = parse_ix_attr(&method.attrs)?;
+            let overrides = parse_overrides(&method.attrs)?;
             let interface_discriminator = spl_interface::parse(&method.attrs);
             let docs = docs::parse(&method.attrs);
             let returns = parse_return(method)?;
@@ -38,7 +38,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
                 anchor_ident,
                 returns,
                 interface_discriminator,
-                ix_attr,
+                overrides,
             })
         })
         .collect::<ParseResult<Vec<Ix>>>()?;
@@ -73,8 +73,8 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
     Ok((ixs, fallback_fn))
 }
 
-/// Parse `#[instruction]` attribute proc-macro.
-fn parse_ix_attr(attrs: &[syn::Attribute]) -> ParseResult<Option<IxAttr>> {
+/// Parse overrides from the `#[instruction]` attribute proc-macro.
+fn parse_overrides(attrs: &[syn::Attribute]) -> ParseResult<Option<Overrides>> {
     attrs
         .iter()
         .find(|attr| match attr.path.segments.last() {