|
@@ -1,12 +1,13 @@
|
|
|
extern crate proc_macro;
|
|
|
|
|
|
+use anchor_syn::Overrides;
|
|
|
use quote::{quote, ToTokens};
|
|
|
use syn::{
|
|
|
parenthesized,
|
|
|
parse::{Parse, ParseStream},
|
|
|
parse_macro_input,
|
|
|
token::{Comma, Paren},
|
|
|
- Expr, Ident, Lit, LitStr, Token,
|
|
|
+ Ident, LitStr,
|
|
|
};
|
|
|
|
|
|
mod id;
|
|
@@ -99,21 +100,25 @@ pub fn account(
|
|
|
let account_name_str = account_name.to_string();
|
|
|
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
|
|
|
|
|
|
- let discriminator = args.discriminator.unwrap_or_else(|| {
|
|
|
- // Namespace the discriminator to prevent collisions.
|
|
|
- let discriminator_preimage = if namespace.is_empty() {
|
|
|
- format!("account:{account_name}")
|
|
|
- } else {
|
|
|
- format!("{namespace}:{account_name}")
|
|
|
- };
|
|
|
+ let discriminator = args
|
|
|
+ .overrides
|
|
|
+ .and_then(|ov| ov.discriminator)
|
|
|
+ .unwrap_or_else(|| {
|
|
|
+ // Namespace the discriminator to prevent collisions.
|
|
|
+ let discriminator_preimage = if namespace.is_empty() {
|
|
|
+ format!("account:{account_name}")
|
|
|
+ } else {
|
|
|
+ format!("{namespace}:{account_name}")
|
|
|
+ };
|
|
|
|
|
|
- let mut discriminator = [0u8; 8];
|
|
|
- discriminator.copy_from_slice(
|
|
|
- &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
|
|
|
- );
|
|
|
- let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
|
|
|
- quote! { &#discriminator }
|
|
|
- });
|
|
|
+ let mut discriminator = [0u8; 8];
|
|
|
+ discriminator.copy_from_slice(
|
|
|
+ &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
|
|
|
+ );
|
|
|
+ let discriminator: proc_macro2::TokenStream =
|
|
|
+ format!("{discriminator:?}").parse().unwrap();
|
|
|
+ quote! { &#discriminator }
|
|
|
+ });
|
|
|
let disc = if account_strct.generics.lt_token.is_some() {
|
|
|
quote! { #account_name::#type_gen::DISCRIMINATOR }
|
|
|
} else {
|
|
@@ -258,8 +263,8 @@ struct AccountArgs {
|
|
|
zero_copy: Option<bool>,
|
|
|
/// Account namespace override, `account` if not specified
|
|
|
namespace: Option<String>,
|
|
|
- /// Discriminator override
|
|
|
- discriminator: Option<proc_macro2::TokenStream>,
|
|
|
+ /// Named overrides
|
|
|
+ overrides: Option<Overrides>,
|
|
|
}
|
|
|
|
|
|
impl Parse for AccountArgs {
|
|
@@ -274,8 +279,8 @@ impl Parse for AccountArgs {
|
|
|
AccountArg::Namespace(ns) => {
|
|
|
parsed.namespace.replace(ns);
|
|
|
}
|
|
|
- AccountArg::Discriminator(disc) => {
|
|
|
- parsed.discriminator.replace(disc);
|
|
|
+ AccountArg::Overrides(ov) => {
|
|
|
+ parsed.overrides.replace(ov);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -287,7 +292,7 @@ impl Parse for AccountArgs {
|
|
|
enum AccountArg {
|
|
|
ZeroCopy { is_unsafe: bool },
|
|
|
Namespace(String),
|
|
|
- Discriminator(proc_macro2::TokenStream),
|
|
|
+ Overrides(Overrides),
|
|
|
}
|
|
|
|
|
|
impl Parse for AccountArg {
|
|
@@ -300,8 +305,8 @@ impl Parse for AccountArg {
|
|
|
}
|
|
|
|
|
|
// Zero copy
|
|
|
- let ident = input.parse::<Ident>()?;
|
|
|
- if ident == "zero_copy" {
|
|
|
+ if input.fork().parse::<Ident>()? == "zero_copy" {
|
|
|
+ input.parse::<Ident>()?;
|
|
|
let is_unsafe = if input.peek(Paren) {
|
|
|
let content;
|
|
|
parenthesized!(content in input);
|
|
@@ -321,24 +326,8 @@ impl Parse for AccountArg {
|
|
|
return Ok(Self::ZeroCopy { is_unsafe });
|
|
|
};
|
|
|
|
|
|
- // Named arguments
|
|
|
- // TODO: Share the common arguments with `#[instruction]`
|
|
|
- input.parse::<Token![=]>()?;
|
|
|
- let value = input.parse::<Expr>()?;
|
|
|
- match ident.to_string().as_str() {
|
|
|
- "discriminator" => {
|
|
|
- let value = match value {
|
|
|
- // Allow `discriminator = 42`
|
|
|
- Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
|
|
|
- // Allow `discriminator = [0, 1, 2, 3]`
|
|
|
- Expr::Array(arr) => quote! { &#arr },
|
|
|
- expr => expr.to_token_stream(),
|
|
|
- };
|
|
|
-
|
|
|
- Ok(Self::Discriminator(value))
|
|
|
- }
|
|
|
- _ => Err(syn::Error::new(ident.span(), "Invalid argument")),
|
|
|
- }
|
|
|
+ // Overrides
|
|
|
+ input.parse::<Overrides>().map(Self::Overrides)
|
|
|
}
|
|
|
}
|
|
|
|