|
@@ -6,7 +6,7 @@ use syn::{
|
|
|
parse::{Parse, ParseStream},
|
|
|
parse_macro_input,
|
|
|
token::{Comma, Paren},
|
|
|
- Ident, LitStr,
|
|
|
+ Expr, Ident, Lit, LitStr, Token,
|
|
|
};
|
|
|
|
|
|
mod id;
|
|
@@ -31,6 +31,22 @@ mod id;
|
|
|
/// check this discriminator. If it doesn't match, an invalid account was given,
|
|
|
/// and the account deserialization will exit with an error.
|
|
|
///
|
|
|
+/// # Args
|
|
|
+///
|
|
|
+/// - `discriminator`: Override the default 8-byte discriminator
|
|
|
+///
|
|
|
+/// **Usage:** `discriminator = <CONST_EXPR>`
|
|
|
+///
|
|
|
+/// All constant expressions are supported.
|
|
|
+///
|
|
|
+/// **Examples:**
|
|
|
+///
|
|
|
+/// - `discriminator = 0` (shortcut for `[0]`)
|
|
|
+/// - `discriminator = [1, 2, 3, 4]`
|
|
|
+/// - `discriminator = b"hi"`
|
|
|
+/// - `discriminator = MY_DISC`
|
|
|
+/// - `discriminator = get_disc(...)`
|
|
|
+///
|
|
|
/// # Zero Copy Deserialization
|
|
|
///
|
|
|
/// **WARNING**: Zero copy deserialization is an experimental feature. It's
|
|
@@ -83,23 +99,21 @@ pub fn account(
|
|
|
let account_name_str = account_name.to_string();
|
|
|
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
|
|
|
|
|
|
- let discriminator: proc_macro2::TokenStream = {
|
|
|
+ let discriminator = args.discriminator.unwrap_or_else(|| {
|
|
|
// Namespace the discriminator to prevent collisions.
|
|
|
- let discriminator_preimage = {
|
|
|
- // For now, zero copy accounts can't be namespaced.
|
|
|
- if namespace.is_empty() {
|
|
|
- format!("account:{account_name}")
|
|
|
- } else {
|
|
|
- format!("{namespace}:{account_name}")
|
|
|
- }
|
|
|
+ let discriminator_preimage = if namespace.is_empty() {
|
|
|
+ format!("account:{account_name}")
|
|
|
+ } else {
|
|
|
+ format!("{namespace}:{account_name}")
|
|
|
};
|
|
|
|
|
|
let mut discriminator = [0u8; 8];
|
|
|
discriminator.copy_from_slice(
|
|
|
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
|
|
|
);
|
|
|
- format!("{discriminator:?}").parse().unwrap()
|
|
|
- };
|
|
|
+ let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
|
|
|
+ quote! { &#discriminator }
|
|
|
+ });
|
|
|
let disc = if account_strct.generics.lt_token.is_some() {
|
|
|
quote! { #account_name::#type_gen::DISCRIMINATOR }
|
|
|
} else {
|
|
@@ -159,7 +173,7 @@ pub fn account(
|
|
|
|
|
|
#[automatically_derived]
|
|
|
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
|
|
|
- const DISCRIMINATOR: &'static [u8] = &#discriminator;
|
|
|
+ const DISCRIMINATOR: &'static [u8] = #discriminator;
|
|
|
}
|
|
|
|
|
|
// This trait is useful for clients deserializing accounts.
|
|
@@ -229,7 +243,7 @@ pub fn account(
|
|
|
|
|
|
#[automatically_derived]
|
|
|
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
|
|
|
- const DISCRIMINATOR: &'static [u8] = &#discriminator;
|
|
|
+ const DISCRIMINATOR: &'static [u8] = #discriminator;
|
|
|
}
|
|
|
|
|
|
#owner_impl
|
|
@@ -242,7 +256,10 @@ pub fn account(
|
|
|
struct AccountArgs {
|
|
|
/// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
|
|
|
zero_copy: Option<bool>,
|
|
|
+ /// Account namespace override, `account` if not specified
|
|
|
namespace: Option<String>,
|
|
|
+ /// Discriminator override
|
|
|
+ discriminator: Option<proc_macro2::TokenStream>,
|
|
|
}
|
|
|
|
|
|
impl Parse for AccountArgs {
|
|
@@ -257,6 +274,9 @@ impl Parse for AccountArgs {
|
|
|
AccountArg::Namespace(ns) => {
|
|
|
parsed.namespace.replace(ns);
|
|
|
}
|
|
|
+ AccountArg::Discriminator(disc) => {
|
|
|
+ parsed.discriminator.replace(disc);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -267,6 +287,7 @@ impl Parse for AccountArgs {
|
|
|
enum AccountArg {
|
|
|
ZeroCopy { is_unsafe: bool },
|
|
|
Namespace(String),
|
|
|
+ Discriminator(proc_macro2::TokenStream),
|
|
|
}
|
|
|
|
|
|
impl Parse for AccountArg {
|
|
@@ -300,7 +321,24 @@ impl Parse for AccountArg {
|
|
|
return Ok(Self::ZeroCopy { is_unsafe });
|
|
|
};
|
|
|
|
|
|
- Err(syn::Error::new(ident.span(), "Unexpected argument"))
|
|
|
+ // Named arguments
|
|
|
+ // TODO: Share the common arguments with `#[instruction]`
|
|
|
+ input.parse::<Token![=]>()?;
|
|
|
+ let value = input.parse::<Expr>()?;
|
|
|
+ match ident.to_string().as_str() {
|
|
|
+ "discriminator" => {
|
|
|
+ let value = match value {
|
|
|
+ // Allow `discriminator = 42`
|
|
|
+ Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
|
|
|
+ // Allow `discriminator = [0, 1, 2, 3]`
|
|
|
+ Expr::Array(arr) => quote! { &#arr },
|
|
|
+ expr => expr.to_token_stream(),
|
|
|
+ };
|
|
|
+
|
|
|
+ Ok(Self::Discriminator(value))
|
|
|
+ }
|
|
|
+ _ => Err(syn::Error::new(ident.span(), "Invalid argument")),
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|