Browse Source

lang: add the InitSpace macro (#2346)

Jean Marchand (Exotic Markets) 2 years ago
parent
commit
a0ef4ed7a4

+ 1 - 0
CHANGELOG.md

@@ -12,6 +12,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 
 ### Features
 
+- lang: Add the `InitSpace` derive macro to automatically calculate the space at the initialization of an account ([#2346](https://github.com/coral-xyz/anchor/pull/2346)).
 - cli: Add `env` option to verifiable builds ([#2325](https://github.com/coral-xyz/anchor/pull/2325)).
 - cli: Add `idl close` command to close a program's IDL account ([#2329](https://github.com/coral-xyz/anchor/pull/2329)).
 - cli: `idl init` now supports very large IDL files ([#2329](https://github.com/coral-xyz/anchor/pull/2329)).

+ 12 - 2
Cargo.lock

@@ -220,6 +220,15 @@ dependencies = [
  "syn 1.0.103",
 ]
 
+[[package]]
+name = "anchor-derive-space"
+version = "0.26.0"
+dependencies = [
+ "proc-macro2 1.0.47",
+ "quote 1.0.21",
+ "syn 1.0.103",
+]
+
 [[package]]
 name = "anchor-lang"
 version = "0.26.0"
@@ -231,6 +240,7 @@ dependencies = [
  "anchor-attribute-event",
  "anchor-attribute-program",
  "anchor-derive-accounts",
+ "anchor-derive-space",
  "arrayref",
  "base64 0.13.1",
  "bincode",
@@ -2235,9 +2245,9 @@ dependencies = [
 
 [[package]]
 name = "once_cell"
-version = "1.15.0"
+version = "1.16.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
+checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
 
 [[package]]
 name = "opaque-debug"

+ 1 - 0
lang/Cargo.toml

@@ -32,6 +32,7 @@ anchor-attribute-error = { path = "./attribute/error", version = "0.26.0" }
 anchor-attribute-program = { path = "./attribute/program", version = "0.26.0" }
 anchor-attribute-event = { path = "./attribute/event", version = "0.26.0" }
 anchor-derive-accounts = { path = "./derive/accounts", version = "0.26.0" }
+anchor-derive-space = { path = "./derive/space", version = "0.26.0" }
 arrayref = "0.3.6"
 base64 = "0.13.0"
 borsh = "0.9"

+ 17 - 0
lang/derive/space/Cargo.toml

@@ -0,0 +1,17 @@
+[package]
+name = "anchor-derive-space"
+version = "0.26.0"
+authors = ["Serum Foundation <foundation@projectserum.com>"]
+repository = "https://github.com/coral-xyz/anchor"
+license = "Apache-2.0"
+description = "Anchor Derive macro to automatically calculate the size of a structure or an enum"
+rust-version = "1.59"
+edition = "2021"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+proc-macro2 = "1.0"
+quote = "1.0"
+syn = "1.0"

+ 180 - 0
lang/derive/space/src/lib.rs

@@ -0,0 +1,180 @@
+use proc_macro::TokenStream;
+use proc_macro2::{Ident, TokenStream as TokenStream2};
+use quote::{quote, quote_spanned, ToTokens};
+use syn::{
+    parse_macro_input,
+    punctuated::{IntoIter, Punctuated},
+    Attribute, DeriveInput, Fields, GenericArgument, LitInt, PathArguments, Token, Type, TypeArray,
+};
+
+/// Implements a [`Space`](./trait.Space.html) trait on the given
+/// struct or enum.
+///
+/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
+/// For nested types, it is necessary to specify a size for each variable type (see example).
+///
+/// # Example
+/// ```ignore
+/// #[account]
+/// #[derive(InitSpace)]
+/// pub struct ExampleAccount {
+///     pub data: u64,
+///     #[max_len(50)]
+///     pub string_one: String,
+///     #[max_len(10, 5)]
+///     pub nested: Vec<Vec<u8>>,
+/// }
+///
+/// #[derive(Accounts)]
+/// pub struct Initialize<'info> {
+///    #[account(mut)]
+///    pub payer: Signer<'info>,
+///    pub system_program: Program<'info, System>,
+///    #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
+///    pub data: Account<'info, ExampleAccount>,
+/// }
+/// ```
+#[proc_macro_derive(InitSpace, attributes(max_len))]
+pub fn derive_anchor_deserialize(item: TokenStream) -> TokenStream {
+    let input = parse_macro_input!(item as DeriveInput);
+    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+    let name = input.ident;
+
+    let expanded: TokenStream2 = match input.data {
+        syn::Data::Struct(strct) => match strct.fields {
+            Fields::Named(named) => {
+                let recurse = named.named.into_iter().map(|f| {
+                    let mut max_len_args = get_max_len_args(&f.attrs);
+                    len_from_type(f.ty, &mut max_len_args)
+                });
+
+                quote! {
+                    #[automatically_derived]
+                    impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
+                        const INIT_SPACE: usize = 0 #(+ #recurse)*;
+                    }
+                }
+            }
+            _ => panic!("Please use named fields in account structure"),
+        },
+        syn::Data::Enum(enm) => {
+            let variants = enm.variants.into_iter().map(|v| {
+                let len = v.fields.into_iter().map(|f| {
+                    let mut max_len_args = get_max_len_args(&f.attrs);
+                    len_from_type(f.ty, &mut max_len_args)
+                });
+
+                quote! {
+                    0 #(+ #len)*
+                }
+            });
+
+            let max = gen_max(variants);
+
+            quote! {
+                #[automatically_derived]
+                impl anchor_lang::Space for #name {
+                    const INIT_SPACE: usize = 1 + #max;
+                }
+            }
+        }
+        _ => unimplemented!(),
+    };
+
+    TokenStream::from(expanded)
+}
+
+fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
+    if let Some(item) = iter.next() {
+        let next_item = gen_max(iter);
+        quote!(anchor_lang::__private::max(#item, #next_item))
+    } else {
+        quote!(0)
+    }
+}
+
+fn len_from_type(ty: Type, attrs: &mut Option<IntoIter<LitInt>>) -> TokenStream2 {
+    match ty {
+        Type::Array(TypeArray { elem, len, .. }) => {
+            let array_len = len.to_token_stream();
+            let type_len = len_from_type(*elem, attrs);
+            quote!((#array_len * #type_len))
+        }
+        Type::Path(ty_path) => {
+            let path_segment = ty_path.path.segments.last().unwrap();
+            let ident = &path_segment.ident;
+            let type_name = ident.to_string();
+            let first_ty = get_first_ty_arg(&path_segment.arguments);
+
+            match type_name.as_str() {
+                "i8" | "u8" | "bool" => quote!(1),
+                "i16" | "u16" => quote!(2),
+                "i32" | "u32" | "f32" => quote!(4),
+                "i64" | "u64" | "f64" => quote!(8),
+                "i128" | "u128" => quote!(16),
+                "String" => {
+                    let max_len = get_next_arg(ident, attrs);
+                    quote!((4 + #max_len))
+                }
+                "Pubkey" => quote!(32),
+                "Option" => {
+                    if let Some(ty) = first_ty {
+                        let type_len = len_from_type(ty, attrs);
+
+                        quote!((1 + #type_len))
+                    } else {
+                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
+                    }
+                }
+                "Vec" => {
+                    if let Some(ty) = first_ty {
+                        let max_len = get_next_arg(ident, attrs);
+                        let type_len = len_from_type(ty, attrs);
+
+                        quote!((4 + #type_len * #max_len))
+                    } else {
+                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
+                    }
+                }
+                _ => {
+                    let ty = &ty_path.path;
+                    quote!(<#ty as anchor_lang::Space>::INIT_SPACE)
+                }
+            }
+        }
+        _ => panic!("Type {:?} is not supported", ty),
+    }
+}
+
+fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
+    match args {
+        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
+            GenericArgument::Type(ty) => Some(ty.to_owned()),
+            _ => None,
+        }),
+        _ => None,
+    }
+}
+
+fn get_max_len_args(attributes: &[Attribute]) -> Option<IntoIter<LitInt>> {
+    attributes
+        .iter()
+        .find(|a| a.path.is_ident("max_len"))
+        .and_then(|a| {
+            a.parse_args_with(Punctuated::<LitInt, Token![,]>::parse_terminated)
+                .ok()
+        })
+        .map(|p| p.into_iter())
+}
+
+fn get_next_arg(ident: &Ident, args: &mut Option<IntoIter<LitInt>>) -> TokenStream2 {
+    if let Some(arg_list) = args {
+        if let Some(arg) = arg_list.next() {
+            quote!(#arg)
+        } else {
+            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
+        }
+    } else {
+        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
+    }
+}

+ 15 - 2
lang/src/lib.rs

@@ -51,6 +51,7 @@ pub use anchor_attribute_error::*;
 pub use anchor_attribute_event::{emit, event};
 pub use anchor_attribute_program::program;
 pub use anchor_derive_accounts::Accounts;
+pub use anchor_derive_space::InitSpace;
 /// Borsh is the default serialization format for instructions and accounts.
 pub use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize};
 pub use solana_program;
@@ -209,6 +210,11 @@ pub trait Discriminator {
     }
 }
 
+/// Defines the space of an account for initialization.
+pub trait Space {
+    const INIT_SPACE: usize;
+}
+
 /// Bump seed for program derived addresses.
 pub trait Bump {
     fn seed(&self) -> u8;
@@ -247,8 +253,8 @@ pub mod prelude {
         require, require_eq, require_gt, require_gte, require_keys_eq, require_keys_neq,
         require_neq, solana_program::bpf_loader_upgradeable::UpgradeableLoaderState, source,
         system_program::System, zero_copy, AccountDeserialize, AccountSerialize, Accounts,
-        AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Id, Key, Owner,
-        ProgramData, Result, ToAccountInfo, ToAccountInfos, ToAccountMetas,
+        AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Id, InitSpace, Key, Owner,
+        ProgramData, Result, Space, ToAccountInfo, ToAccountInfos, ToAccountMetas,
     };
     pub use anchor_attribute_error::*;
     pub use borsh;
@@ -288,6 +294,13 @@ pub mod __private {
 
     use solana_program::pubkey::Pubkey;
 
+    // Used to calculate the maximum between two expressions.
+    // It is necessary for the calculation of the enum space.
+    #[doc(hidden)]
+    pub const fn max(a: usize, b: usize) -> usize {
+        [a, b][(a < b) as usize]
+    }
+
     // Very experimental trait.
     #[doc(hidden)]
     pub trait ZeroCopyAccessor<Ty> {

+ 135 - 0
lang/tests/space.rs

@@ -0,0 +1,135 @@
+use anchor_lang::prelude::*;
+
+// Needed to declare accounts.
+declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");
+
+mod inside_mod {
+    use super::*;
+
+    #[derive(InitSpace)]
+    pub struct Data {
+        pub data: u64,
+    }
+}
+
+#[derive(InitSpace)]
+pub enum TestBasicEnum {
+    Basic1,
+    Basic2 {
+        test_u8: u8,
+    },
+    Basic3 {
+        test_u16: u16,
+    },
+    Basic4 {
+        #[max_len(10)]
+        test_vec: Vec<u8>,
+    },
+}
+
+#[account]
+#[derive(InitSpace)]
+pub struct TestEmptyAccount {}
+
+#[account]
+#[derive(InitSpace)]
+pub struct TestBasicVarAccount {
+    pub test_u8: u8,
+    pub test_u16: u16,
+    pub test_u32: u32,
+    pub test_u64: u64,
+    pub test_u128: u128,
+}
+
+#[account]
+#[derive(InitSpace)]
+pub struct TestComplexeVarAccount {
+    pub test_key: Pubkey,
+    #[max_len(10)]
+    pub test_vec: Vec<u8>,
+    #[max_len(10)]
+    pub test_string: String,
+    pub test_option: Option<u16>,
+}
+
+#[derive(InitSpace)]
+pub struct TestNonAccountStruct {
+    pub test_bool: bool,
+}
+
+#[account(zero_copy)]
+#[derive(InitSpace)]
+pub struct TestZeroCopyStruct {
+    pub test_array: [u8; 10],
+    pub test_u32: u32,
+}
+
+#[derive(InitSpace)]
+pub struct ChildStruct {
+    #[max_len(10)]
+    pub test_string: String,
+}
+
+#[derive(InitSpace)]
+pub struct TestNestedStruct {
+    pub test_struct: ChildStruct,
+    pub test_enum: TestBasicEnum,
+}
+
+#[derive(InitSpace)]
+pub struct TestMatrixStruct {
+    #[max_len(2, 4)]
+    pub test_matrix: Vec<Vec<u8>>,
+}
+
+#[derive(InitSpace)]
+pub struct TestFullPath {
+    pub test_option_path: Option<inside_mod::Data>,
+    pub test_path: inside_mod::Data,
+}
+
+#[test]
+fn test_empty_struct() {
+    assert_eq!(TestEmptyAccount::INIT_SPACE, 0);
+}
+
+#[test]
+fn test_basic_struct() {
+    assert_eq!(TestBasicVarAccount::INIT_SPACE, 1 + 2 + 4 + 8 + 16);
+}
+
+#[test]
+fn test_complexe_struct() {
+    assert_eq!(
+        TestComplexeVarAccount::INIT_SPACE,
+        32 + 4 + 10 + (4 + 10) + 3
+    )
+}
+
+#[test]
+fn test_zero_copy_struct() {
+    assert_eq!(TestZeroCopyStruct::INIT_SPACE, 10 + 4)
+}
+
+#[test]
+fn test_basic_enum() {
+    assert_eq!(TestBasicEnum::INIT_SPACE, 1 + 14);
+}
+
+#[test]
+fn test_nested_struct() {
+    assert_eq!(
+        TestNestedStruct::INIT_SPACE,
+        ChildStruct::INIT_SPACE + TestBasicEnum::INIT_SPACE
+    )
+}
+
+#[test]
+fn test_matrix_struct() {
+    assert_eq!(TestMatrixStruct::INIT_SPACE, 4 + (2 * (4 + 4)))
+}
+
+#[test]
+fn test_full_path() {
+    assert_eq!(TestFullPath::INIT_SPACE, 8 + 9)
+}

+ 3 - 1
tests/chat/programs/chat/src/lib.rs

@@ -46,7 +46,7 @@ pub struct CreateUser<'info> {
         seeds = [authority.key().as_ref()],
         bump,
         payer = authority,
-        space = 320,
+        space = 8 + User::INIT_SPACE,
     )]
     user: Account<'info, User>,
     #[account(mut)]
@@ -74,7 +74,9 @@ pub struct SendMessage<'info> {
 }
 
 #[account]
+#[derive(InitSpace)]
 pub struct User {
+    #[max_len(200)]
     name: String,
     authority: Pubkey,
     bump: u8,

+ 4 - 7
tests/ido-pool/programs/ido-pool/src/lib.rs

@@ -296,7 +296,7 @@ pub struct InitializePool<'info> {
         seeds = [ido_name.as_bytes()],
         bump,
         payer = ido_authority,
-        space = IdoAccount::LEN + 8
+        space = 8 + IdoAccount::INIT_SPACE
     )]
     pub ido_account: Box<Account<'info, IdoAccount>>,
     // TODO Confirm USDC mint address on mainnet or leave open as an option for other stables
@@ -545,6 +545,7 @@ pub struct WithdrawFromEscrow<'info> {
 }
 
 #[account]
+#[derive(InitSpace)]
 pub struct IdoAccount {
     pub ido_name: [u8; 10], // Setting an arbitrary max of ten characters in the ido name. // 10
     pub bumps: PoolBumps,   // 4
@@ -560,11 +561,7 @@ pub struct IdoAccount {
     pub ido_times: IdoTimes, // 32
 }
 
-impl IdoAccount {
-    pub const LEN: usize = 10 + 4 + 32 + 5 * 32 + 8 + 32;
-}
-
-#[derive(AnchorSerialize, AnchorDeserialize, Default, Clone, Copy)]
+#[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Default, Clone, Copy)]
 pub struct IdoTimes {
     pub start_ido: i64,    // 8
     pub end_deposits: i64, // 8
@@ -572,7 +569,7 @@ pub struct IdoTimes {
     pub end_escrow: i64,   // 8
 }
 
-#[derive(AnchorSerialize, AnchorDeserialize, Default, Clone)]
+#[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Default, Clone)]
 pub struct PoolBumps {
     pub ido_account: u8,     // 1
     pub redeemable_mint: u8, // 1