lib.rs 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. use proc_macro::TokenStream;
  2. use quote::quote;
  3. use syn::{parse_macro_input, parse_quote, Attribute, DeriveInput, Lit, Meta, NestedMeta};
  4. /// This BoltAccount attribute is used to automatically generate the seed and size functions
  5. ///
  6. /// The component_id define the seed used to generate the PDA which stores the component data.
  7. /// The macro also adds the InitSpace and Default derives to the struct.
  8. ///
  9. /// #[account]
  10. /// #[bolt_account]
  11. /// pub struct Position {
  12. /// pub x: i64,
  13. /// pub y: i64,
  14. /// pub z: i64,
  15. /// }
  16. /// ```
  17. #[proc_macro_attribute]
  18. pub fn bolt_account(attr: TokenStream, item: TokenStream) -> TokenStream {
  19. let mut input = parse_macro_input!(item as DeriveInput);
  20. let mut component_id_value = None;
  21. if !attr.is_empty() {
  22. let attr_meta = parse_macro_input!(attr as Meta);
  23. component_id_value = match attr_meta {
  24. Meta::Path(_) => None,
  25. Meta::NameValue(meta_name_value) if meta_name_value.path.is_ident("component_id") => {
  26. if let Lit::Str(lit) = meta_name_value.lit {
  27. Some(lit.value())
  28. } else {
  29. None
  30. }
  31. }
  32. Meta::List(meta) => meta.nested.into_iter().find_map(|nested_meta| {
  33. if let NestedMeta::Meta(Meta::NameValue(meta_name_value)) = nested_meta {
  34. if meta_name_value.path.is_ident("component_id") {
  35. if let Lit::Str(lit) = meta_name_value.lit {
  36. Some(lit.value())
  37. } else {
  38. None
  39. }
  40. } else {
  41. None
  42. }
  43. } else {
  44. None
  45. }
  46. }),
  47. _ => None,
  48. };
  49. }
  50. let component_id_value = component_id_value.unwrap_or_else(|| "".to_string());
  51. let additional_derives: Attribute = parse_quote! { #[derive(InitSpace, Default)] };
  52. input.attrs.push(additional_derives);
  53. let name = &input.ident;
  54. let expanded = quote! {
  55. #input
  56. #[automatically_derived]
  57. impl ComponentTraits for #name {
  58. fn seed() -> &'static [u8] {
  59. #component_id_value.as_bytes()
  60. }
  61. fn size() -> usize {
  62. 8 + <#name>::INIT_SPACE
  63. }
  64. }
  65. };
  66. expanded.into()
  67. }