lib.rs 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. use proc_macro::TokenStream;
  2. use quote::quote;
  3. use syn::{parse_macro_input, Fields, ItemStruct, Lit, Meta, NestedMeta};
  4. /// This macro attribute is used to define a BOLT system input.
  5. ///
  6. /// The input can be defined as a struct and will be transformed into an Anchor context.
  7. ///
  8. ///
  9. /// # Example
  10. /// ```ignore
  11. ///#[system_input]
  12. ///pub struct Components {
  13. /// pub position: Position,
  14. ///}
  15. ///
  16. /// ```
  17. #[proc_macro_attribute]
  18. pub fn system_input(_attr: TokenStream, item: TokenStream) -> TokenStream {
  19. // Parse the input TokenStream (the struct) into a Rust data structure
  20. let input = parse_macro_input!(item as ItemStruct);
  21. // Ensure the struct has named fields
  22. let fields = match &input.fields {
  23. Fields::Named(fields) => &fields.named,
  24. _ => panic!("system_input macro only supports structs with named fields"),
  25. };
  26. let name = &input.ident;
  27. // Collect imports for components
  28. let components_imports: Vec<_> = fields
  29. .iter()
  30. .filter_map(|field| {
  31. field.attrs.iter().find_map(|attr| {
  32. if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
  33. if meta_list.path.is_ident("component_id") {
  34. meta_list.nested.first().and_then(|nested_meta| {
  35. if let NestedMeta::Lit(Lit::Str(lit_str)) = nested_meta {
  36. let component_type =
  37. format!("bolt_types::Component{}", lit_str.value());
  38. if let Ok(parsed_component_type) =
  39. syn::parse_str::<syn::Type>(&component_type)
  40. {
  41. let field_type = &field.ty;
  42. let component_import = quote! {
  43. use #parsed_component_type as #field_type;
  44. };
  45. return Some(component_import);
  46. }
  47. }
  48. None
  49. })
  50. } else {
  51. None
  52. }
  53. } else {
  54. None
  55. }
  56. })
  57. })
  58. .collect();
  59. let bolt_accounts = fields.iter().map(|f| {
  60. let field_type = &f.ty;
  61. quote! {
  62. pub type #field_type = bolt_lang::account::BoltAccount<super::#field_type, { bolt_lang::account::pubkey_p0(crate::ID) }, { bolt_lang::account::pubkey_p1(crate::ID) }>;
  63. }
  64. });
  65. // Transform fields for the struct definition
  66. let transformed_fields = fields.iter().map(|f| {
  67. let field_name = &f.ident;
  68. let field_type = &f.ty;
  69. quote! {
  70. #[account(mut)]
  71. pub #field_name: Account<'info, bolt_accounts::#field_type>,
  72. }
  73. });
  74. // Generate the new struct with the Accounts derive and transformed fields
  75. let output_struct = quote! {
  76. #[derive(Accounts)]
  77. pub struct #name<'info> {
  78. #(#transformed_fields)*
  79. /// CHECK: Authority check
  80. #[account()]
  81. pub authority: AccountInfo<'info>,
  82. }
  83. };
  84. let try_from_fields = fields.iter().enumerate().map(|(i, f)| {
  85. let field_name = &f.ident;
  86. quote! {
  87. #field_name: {
  88. Account::try_from(context.remaining_accounts.as_ref().get(#i).ok_or_else(|| ErrorCode::ConstraintAccountIsNone)?)?
  89. },
  90. }
  91. });
  92. let number_of_components = fields.len();
  93. let output_trait = quote! {
  94. pub trait NumberOfComponents<'a, 'b, 'c, 'info, T> {
  95. const NUMBER_OF_COMPONENTS: usize;
  96. }
  97. };
  98. let output_trait_implementation = quote! {
  99. impl<'a, 'b, 'c, 'info, T: bolt_lang::Bumps> NumberOfComponents<'a, 'b, 'c, 'info, T> for Context<'a, 'b, 'c, 'info, T> {
  100. const NUMBER_OF_COMPONENTS: usize = #number_of_components;
  101. }
  102. };
  103. // Generate the implementation of try_from for the struct
  104. let output_impl = quote! {
  105. impl<'info> #name<'info> {
  106. fn try_from<'a, 'b>(context: &Context<'a, 'b, 'info, 'info, VariadicBoltComponents<'info>>) -> Result<Self> {
  107. Ok(Self {
  108. authority: context.accounts.authority.clone(),
  109. #(#try_from_fields)*
  110. })
  111. }
  112. }
  113. };
  114. // Combine the struct definition and its implementation into the final TokenStream
  115. let output = quote! {
  116. mod bolt_accounts {
  117. #(#bolt_accounts)*
  118. }
  119. #output_struct
  120. #output_impl
  121. #output_trait
  122. #output_trait_implementation
  123. #(#components_imports)*
  124. #[derive(Accounts)]
  125. pub struct VariadicBoltComponents<'info> {
  126. /// CHECK: Authority check
  127. #[account()]
  128. pub authority: AccountInfo<'info>,
  129. }
  130. };
  131. TokenStream::from(output)
  132. }