lib.rs 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. use proc_macro::TokenStream;
  2. use proc_macro2::Ident;
  3. use quote::quote;
  4. use syn::{
  5. parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, GenericArgument, ItemFn, ItemMod,
  6. PathArguments, ReturnType, Stmt, Type, TypePath,
  7. };
  8. struct SystemTransform;
  9. /// This macro attribute is used to define a BOLT system.
  10. ///
  11. /// Bolt components are themselves programs. The macro adds parsing and serialization
  12. ///
  13. /// # Example
  14. /// ```ignore
  15. /// #[system]
  16. /// pub mod system_fly {
  17. /// use super::*;
  18. ///
  19. /// pub fn execute(ctx: Context<Component>, _args: Vec<u8>) -> Result<Position> {
  20. /// let pos = Position {
  21. /// x: ctx.accounts.position.x,
  22. /// y: ctx.accounts.position.y,
  23. /// z: ctx.accounts.position.z + 1,
  24. /// };
  25. /// Ok(pos)
  26. /// }
  27. /// }
  28. /// ```
  29. #[proc_macro_attribute]
  30. pub fn system(attr: TokenStream, item: TokenStream) -> TokenStream {
  31. let mut input = parse_macro_input!(item as ItemMod);
  32. let _attr = parse_macro_input!(attr as syn::AttributeArgs);
  33. let use_super = syn::parse_quote! { use super::*; };
  34. if let Some(ref mut content) = input.content {
  35. content.1.insert(0, syn::Item::Use(use_super));
  36. }
  37. let mut transform = SystemTransform;
  38. transform.visit_item_mod_mut(&mut input);
  39. // Add `#[program]` macro
  40. let expanded = quote! {
  41. #[program]
  42. #input
  43. };
  44. TokenStream::from(expanded)
  45. }
  46. /// Visits the AST and modifies the system function
  47. impl VisitMut for SystemTransform {
  48. // Modify the return instruction to return Result<Vec<u8>>
  49. fn visit_expr_mut(&mut self, expr: &mut Expr) {
  50. if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
  51. let new_return_expr: Expr = match inner_variable {
  52. Expr::Tuple(tuple_expr) => {
  53. let tuple_elements = tuple_expr.elems.iter().map(|elem| {
  54. quote! { (#elem).try_to_vec()? }
  55. });
  56. parse_quote! { Ok((#(#tuple_elements),*)) }
  57. }
  58. _ => {
  59. parse_quote! { Ok((#inner_variable).try_to_vec()?) }
  60. }
  61. };
  62. *expr = new_return_expr;
  63. }
  64. }
  65. // Modify the return type of the system function to Result<Vec<u8>>
  66. fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
  67. if item_fn.sig.ident == "execute" {
  68. // Modify the return type to Result<Vec<u8>> if necessary
  69. if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
  70. if let Type::Path(type_path) = &**type_box {
  71. let ret_values = Self::extract_return_value(type_path);
  72. if ret_values > 1 {
  73. item_fn.sig.ident = Ident::new(
  74. format!("execute_{}", ret_values).as_str(),
  75. item_fn.sig.ident.span(),
  76. );
  77. }
  78. if !Self::check_is_vec_u8(type_path) {
  79. Self::modify_fn_return_type(item_fn, ret_values);
  80. // Modify the return statement inside the function body
  81. let block = &mut item_fn.block;
  82. for stmt in &mut block.stmts {
  83. if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
  84. self.visit_expr_mut(expr);
  85. }
  86. }
  87. }
  88. }
  89. }
  90. }
  91. }
  92. // Visit all the functions inside the system module
  93. fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
  94. for item in &mut item_mod.content.as_mut().unwrap().1 {
  95. if let syn::Item::Fn(item_fn) = item {
  96. self.visit_item_fn_mut(item_fn)
  97. }
  98. }
  99. }
  100. }
  101. impl SystemTransform {
  102. // Helper function to check if a type is `Vec<u8>` or `(Vec<u8>, Vec<u8>, ...)`
  103. fn check_is_vec_u8(ty: &TypePath) -> bool {
  104. if let Some(segment) = ty.path.segments.last() {
  105. if segment.ident == "Result" {
  106. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  107. if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
  108. return tuple.elems.iter().all(|elem| {
  109. if let Type::Path(type_path) = elem {
  110. if let Some(segment) = type_path.path.segments.first() {
  111. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  112. }
  113. }
  114. false
  115. });
  116. } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
  117. args.args.first()
  118. {
  119. if let Some(segment) = type_path.path.segments.first() {
  120. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  121. }
  122. }
  123. }
  124. }
  125. }
  126. false
  127. }
  128. // Helper function to check if a type is Vec<u8>
  129. fn is_u8_vec(segment: &syn::PathSegment) -> bool {
  130. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  131. if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
  132. if let Some(segment) = path.path.segments.first() {
  133. return segment.ident == "u8";
  134. }
  135. }
  136. }
  137. false
  138. }
  139. // Helper function to extract the number of return values from a type
  140. fn extract_return_value(ty: &TypePath) -> usize {
  141. if let Some(segment) = ty.path.segments.last() {
  142. if segment.ident == "Result" {
  143. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  144. return if let Some(GenericArgument::Type(Type::Tuple(tuple))) =
  145. args.args.first()
  146. {
  147. tuple.elems.len()
  148. } else {
  149. 1
  150. };
  151. }
  152. }
  153. }
  154. 0
  155. }
  156. // Helper function to modify the return type of a function to be Result<Vec<u8>> or Result<(Vec<u8>, Vec<u8>, ...)>
  157. fn modify_fn_return_type(item_fn: &mut syn::ItemFn, ret_values: usize) {
  158. item_fn.sig.output = if ret_values == 1 {
  159. parse_quote! { -> Result<Vec<u8>> }
  160. } else {
  161. let types = std::iter::repeat(quote! { Vec<u8> })
  162. .take(ret_values)
  163. .collect::<Vec<_>>();
  164. let tuple = quote! { (#(#types),*) };
  165. syn::parse2(quote! { -> Result<#tuple> }).unwrap()
  166. };
  167. }
  168. // Helper function to check if an expression is an `Ok(...)` or `return Ok(...);` variant
  169. fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
  170. match expr {
  171. Expr::Call(expr_call) => {
  172. // Direct `Ok(...)` call
  173. if let Expr::Path(expr_path) = &*expr_call.func {
  174. if let Some(last_segment) = expr_path.path.segments.last() {
  175. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  176. // Return the first argument of the Ok(...) call
  177. return expr_call.args.first();
  178. }
  179. }
  180. }
  181. }
  182. Expr::Return(expr_return) => {
  183. // `return Ok(...);`
  184. if let Some(expr_return_inner) = &expr_return.expr {
  185. if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
  186. if let Expr::Path(expr_path) = &*expr_call.func {
  187. if let Some(last_segment) = expr_path.path.segments.last() {
  188. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  189. // Return the first argument of the return Ok(...) call
  190. return expr_call.args.first();
  191. }
  192. }
  193. }
  194. }
  195. }
  196. }
  197. _ => {}
  198. }
  199. None
  200. }
  201. }