lib.rs 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. use std::collections::VecDeque;
  2. use proc_macro::TokenStream;
  3. use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
  4. use quote::{quote, quote_spanned, ToTokens};
  5. use syn::{
  6. parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
  7. DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
  8. };
  9. /// Implements a [`Space`](./trait.Space.html) trait on the given
  10. /// struct or enum.
  11. ///
  12. /// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
  13. /// For nested types, it is necessary to specify a size for each variable type (see example).
  14. ///
  15. /// # Example
  16. /// ```ignore
  17. /// #[account]
  18. /// #[derive(InitSpace)]
  19. /// pub struct ExampleAccount {
  20. /// pub data: u64,
  21. /// #[max_len(50)]
  22. /// pub string_one: String,
  23. /// #[max_len(10, 5)]
  24. /// pub nested: Vec<Vec<u8>>,
  25. /// }
  26. ///
  27. /// #[derive(Accounts)]
  28. /// pub struct Initialize<'info> {
  29. /// #[account(mut)]
  30. /// pub payer: Signer<'info>,
  31. /// pub system_program: Program<'info, System>,
  32. /// #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
  33. /// pub data: Account<'info, ExampleAccount>,
  34. /// }
  35. /// ```
  36. #[proc_macro_derive(InitSpace, attributes(max_len))]
  37. pub fn derive_init_space(item: TokenStream) -> TokenStream {
  38. let input = parse_macro_input!(item as DeriveInput);
  39. let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
  40. let name = input.ident;
  41. let process_struct_fields = |fields: Punctuated<Field, Comma>| {
  42. let recurse = fields.into_iter().map(|f| {
  43. let mut max_len_args = get_max_len_args(&f.attrs);
  44. len_from_type(f.ty, &mut max_len_args)
  45. });
  46. quote! {
  47. #[automatically_derived]
  48. impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
  49. const INIT_SPACE: usize = 0 #(+ #recurse)*;
  50. }
  51. }
  52. };
  53. let expanded: TokenStream2 = match input.data {
  54. syn::Data::Struct(strct) => match strct.fields {
  55. Fields::Named(named) => process_struct_fields(named.named),
  56. Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
  57. Fields::Unit => quote! {
  58. #[automatically_derived]
  59. impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
  60. const INIT_SPACE: usize = 0;
  61. }
  62. },
  63. },
  64. syn::Data::Enum(enm) => {
  65. let variants = enm.variants.into_iter().map(|v| {
  66. let len = v.fields.into_iter().map(|f| {
  67. let mut max_len_args = get_max_len_args(&f.attrs);
  68. len_from_type(f.ty, &mut max_len_args)
  69. });
  70. quote! {
  71. 0 #(+ #len)*
  72. }
  73. });
  74. let max = gen_max(variants);
  75. quote! {
  76. #[automatically_derived]
  77. impl anchor_lang::Space for #name {
  78. const INIT_SPACE: usize = 1 + #max;
  79. }
  80. }
  81. }
  82. _ => unimplemented!(),
  83. };
  84. TokenStream::from(expanded)
  85. }
  86. fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
  87. if let Some(item) = iter.next() {
  88. let next_item = gen_max(iter);
  89. quote!(anchor_lang::__private::max(#item, #next_item))
  90. } else {
  91. quote!(0)
  92. }
  93. }
  94. fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
  95. match ty {
  96. Type::Array(TypeArray { elem, len, .. }) => {
  97. let array_len = len.to_token_stream();
  98. let type_len = len_from_type(*elem, attrs);
  99. quote!((#array_len * #type_len))
  100. }
  101. Type::Path(ty_path) => {
  102. let path_segment = ty_path.path.segments.last().unwrap();
  103. let ident = &path_segment.ident;
  104. let type_name = ident.to_string();
  105. let first_ty = get_first_ty_arg(&path_segment.arguments);
  106. match type_name.as_str() {
  107. "i8" | "u8" | "bool" => quote!(1),
  108. "i16" | "u16" => quote!(2),
  109. "i32" | "u32" | "f32" => quote!(4),
  110. "i64" | "u64" | "f64" => quote!(8),
  111. "i128" | "u128" => quote!(16),
  112. "String" => {
  113. let max_len = get_next_arg(ident, attrs);
  114. quote!((4 + #max_len))
  115. }
  116. "Pubkey" => quote!(32),
  117. "Option" => {
  118. if let Some(ty) = first_ty {
  119. let type_len = len_from_type(ty, attrs);
  120. quote!((1 + #type_len))
  121. } else {
  122. quote_spanned!(ident.span() => compile_error!("Invalid argument in Option"))
  123. }
  124. }
  125. "Vec" => {
  126. if let Some(ty) = first_ty {
  127. let max_len = get_next_arg(ident, attrs);
  128. let type_len = len_from_type(ty, attrs);
  129. quote!((4 + #type_len * #max_len))
  130. } else {
  131. quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
  132. }
  133. }
  134. _ => {
  135. let ty = &ty_path.path;
  136. quote!(<#ty as anchor_lang::Space>::INIT_SPACE)
  137. }
  138. }
  139. }
  140. _ => panic!("Type {ty:?} is not supported"),
  141. }
  142. }
  143. fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
  144. match args {
  145. PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
  146. GenericArgument::Type(ty) => Some(ty.to_owned()),
  147. _ => None,
  148. }),
  149. _ => None,
  150. }
  151. }
  152. fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
  153. let mut result = VecDeque::new();
  154. while let Some(token_tree) = item.parse()? {
  155. match token_tree {
  156. TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
  157. TokenTree::Literal(lit) => {
  158. if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
  159. result.push_front(quote!(#lit_int))
  160. }
  161. }
  162. _ => (),
  163. }
  164. }
  165. Ok(result)
  166. }
  167. fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
  168. attributes
  169. .iter()
  170. .find(|a| a.path.is_ident("max_len"))
  171. .and_then(|a| a.parse_args_with(parse_len_arg).ok())
  172. }
  173. fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
  174. if let Some(arg_list) = args {
  175. if let Some(arg) = arg_list.pop_back() {
  176. quote!(#arg)
  177. } else {
  178. quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
  179. }
  180. } else {
  181. quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
  182. }
  183. }