lib.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. use proc_macro::TokenStream;
  2. use proc_macro2::Ident;
  3. use quote::{quote, ToTokens};
  4. use syn::{
  5. parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
  6. ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
  7. };
  8. #[derive(Default)]
  9. struct SystemTransform;
  10. #[derive(Default)]
  11. struct Extractor {
  12. context_struct_name: Option<String>,
  13. field_count: Option<usize>,
  14. }
  15. /// This macro attribute is used to define a BOLT system.
  16. ///
  17. /// Bolt components are themselves programs. The macro adds parsing and serialization
  18. ///
  19. /// # Example
  20. /// ```ignore
  21. /// #[system]
  22. /// pub mod system_fly {
  23. /// pub fn execute(ctx: Context<Component>, _args: Vec<u8>) -> Result<Position> {
  24. /// let pos = Position {
  25. /// x: ctx.accounts.position.x,
  26. /// y: ctx.accounts.position.y,
  27. /// z: ctx.accounts.position.z + 1,
  28. /// };
  29. /// Ok(pos)
  30. /// }
  31. /// }
  32. /// ```
  33. #[proc_macro_attribute]
  34. pub fn system(attr: TokenStream, item: TokenStream) -> TokenStream {
  35. let mut ast = parse_macro_input!(item as ItemMod);
  36. let _attr = parse_macro_input!(attr as syn::AttributeArgs);
  37. // Extract the number of components from the module
  38. let mut extractor = Extractor::default();
  39. extractor.visit_item_mod_mut(&mut ast);
  40. if extractor.field_count.is_some() {
  41. let use_super = syn::parse_quote! { use super::*; };
  42. if let Some((_, ref mut items)) = ast.content {
  43. items.insert(0, syn::Item::Use(use_super));
  44. SystemTransform::add_variadic_execute_function(items);
  45. }
  46. let mut transform = SystemTransform;
  47. transform.visit_item_mod_mut(&mut ast);
  48. // Add `#[program]` macro and try_to_vec implementation
  49. let expanded = quote! {
  50. #[program]
  51. #ast
  52. };
  53. TokenStream::from(expanded)
  54. } else {
  55. panic!(
  56. "Could not find the component bundle: {} in the module",
  57. extractor.context_struct_name.unwrap()
  58. );
  59. }
  60. }
  61. impl SystemTransform {
  62. fn visit_stmts_mut(&mut self, stmts: &mut Vec<Stmt>) {
  63. for stmt in stmts {
  64. if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
  65. self.visit_expr_mut(expr);
  66. }
  67. }
  68. }
  69. }
  70. /// Visits the AST and modifies the system function
  71. impl VisitMut for SystemTransform {
  72. // Modify the return instruction to return Result<Vec<u8>>
  73. fn visit_expr_mut(&mut self, expr: &mut Expr) {
  74. match expr {
  75. Expr::ForLoop(for_loop_expr) => {
  76. self.visit_stmts_mut(&mut for_loop_expr.body.stmts);
  77. }
  78. Expr::Loop(loop_expr) => {
  79. self.visit_stmts_mut(&mut loop_expr.body.stmts);
  80. }
  81. Expr::If(if_expr) => {
  82. self.visit_stmts_mut(&mut if_expr.then_branch.stmts);
  83. if let Some((_, else_expr)) = &mut if_expr.else_branch {
  84. self.visit_expr_mut(else_expr);
  85. }
  86. }
  87. Expr::Block(block_expr) => {
  88. self.visit_stmts_mut(&mut block_expr.block.stmts);
  89. }
  90. _ => (),
  91. }
  92. if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
  93. let new_return_expr: Expr = match inner_variable {
  94. Expr::Tuple(tuple_expr) => {
  95. let tuple_elements = tuple_expr.elems.iter().map(|elem| {
  96. quote! { (#elem).try_to_vec()? }
  97. });
  98. parse_quote! { Ok((#(#tuple_elements),*)) }
  99. }
  100. _ => {
  101. parse_quote! {
  102. #inner_variable.try_to_vec()
  103. }
  104. }
  105. };
  106. if let Expr::Return(return_expr) = expr {
  107. return_expr.expr = Some(Box::new(new_return_expr));
  108. } else {
  109. *expr = new_return_expr;
  110. }
  111. }
  112. }
  113. // Modify the return type of the system function to Result<Vec<u8>,*>
  114. fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
  115. if item_fn.sig.ident == "execute" {
  116. // Modify the return type to Result<Vec<u8>> if necessary
  117. if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
  118. if let Type::Path(type_path) = &**type_box {
  119. if !Self::check_is_result_vec_u8(type_path) {
  120. item_fn.sig.output = parse_quote! { -> Result<Vec<Vec<u8>>> };
  121. // Modify the return statement inside the function body
  122. let block = &mut item_fn.block;
  123. self.visit_stmts_mut(&mut block.stmts);
  124. }
  125. }
  126. }
  127. // If second argument is not Vec<u8>, modify it to be so and use parse_args
  128. Self::modify_args(item_fn);
  129. }
  130. }
  131. // Visit all the functions inside the system module and inject the init_extra_accounts function
  132. // if the module contains a struct with the `extra_accounts` attribute
  133. fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
  134. let content = match item_mod.content.as_mut() {
  135. Some(content) => &mut content.1,
  136. None => return,
  137. };
  138. let mut extra_accounts_struct_name = None;
  139. for item in content.iter_mut() {
  140. match item {
  141. syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
  142. syn::Item::Struct(item_struct)
  143. if item_struct
  144. .attrs
  145. .iter()
  146. .any(|attr| attr.path.is_ident("extra_accounts")) =>
  147. {
  148. extra_accounts_struct_name = Some(&item_struct.ident);
  149. break;
  150. }
  151. _ => {}
  152. }
  153. }
  154. if let Some(struct_name) = extra_accounts_struct_name {
  155. let initialize_extra_accounts = quote! {
  156. #[automatically_derived]
  157. pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
  158. Ok(())
  159. }
  160. };
  161. content.push(syn::parse2(initialize_extra_accounts).unwrap());
  162. }
  163. }
  164. }
  165. impl SystemTransform {
  166. fn add_variadic_execute_function(content: &mut Vec<syn::Item>) {
  167. content.push(syn::parse2(quote! {
  168. pub fn bolt_execute<'info>(ctx: Context<'_, '_, 'info, 'info, VariadicBoltComponents<'info>>, args: Vec<u8>) -> Result<Vec<Vec<u8>>> {
  169. let mut components = Components::try_from(&ctx)?;
  170. let bumps = ComponentsBumps {};
  171. let context = Context::new(ctx.program_id, &mut components, ctx.remaining_accounts, bumps);
  172. execute(context, args)
  173. }
  174. }).unwrap());
  175. }
  176. // Helper function to check if a type is `Vec<u8>` or `(Vec<u8>, Vec<u8>, ...)`
  177. fn check_is_result_vec_u8(ty: &TypePath) -> bool {
  178. if let Some(segment) = ty.path.segments.last() {
  179. if segment.ident == "Result" {
  180. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  181. if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
  182. return tuple.elems.iter().all(|elem| {
  183. if let Type::Path(type_path) = elem {
  184. if let Some(segment) = type_path.path.segments.first() {
  185. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  186. }
  187. }
  188. false
  189. });
  190. } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
  191. args.args.first()
  192. {
  193. if let Some(segment) = type_path.path.segments.first() {
  194. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  195. }
  196. }
  197. }
  198. }
  199. }
  200. false
  201. }
  202. // Helper function to check if a type is Vec<u8>
  203. fn is_u8_vec(segment: &syn::PathSegment) -> bool {
  204. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  205. if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
  206. if let Some(segment) = path.path.segments.first() {
  207. return segment.ident == "u8";
  208. }
  209. }
  210. }
  211. false
  212. }
  213. // Helper function to check if an expression is an `Ok(...)` or `return Ok(...);` variant
  214. fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
  215. match expr {
  216. Expr::Call(expr_call) => {
  217. // Direct `Ok(...)` call
  218. if let Expr::Path(expr_path) = &*expr_call.func {
  219. if let Some(last_segment) = expr_path.path.segments.last() {
  220. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  221. // Return the first argument of the Ok(...) call
  222. return expr_call.args.first();
  223. }
  224. }
  225. }
  226. }
  227. Expr::Return(expr_return) => {
  228. // `return Ok(...);`
  229. if let Some(expr_return_inner) = &expr_return.expr {
  230. if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
  231. if let Expr::Path(expr_path) = &*expr_call.func {
  232. if let Some(last_segment) = expr_path.path.segments.last() {
  233. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  234. // Return the first argument of the return Ok(...) call
  235. return expr_call.args.first();
  236. }
  237. }
  238. }
  239. }
  240. }
  241. }
  242. _ => {}
  243. }
  244. None
  245. }
  246. fn modify_args(item_fn: &mut ItemFn) {
  247. if item_fn.sig.inputs.len() >= 2 {
  248. let second_arg = &mut item_fn.sig.inputs[1];
  249. let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
  250. match &**ty {
  251. Type::Path(type_path) => {
  252. if let Some(segment) = type_path.path.segments.first() {
  253. segment.ident == "Vec" && Self::is_u8_vec(segment)
  254. } else {
  255. false
  256. }
  257. }
  258. _ => false,
  259. }
  260. } else {
  261. false
  262. };
  263. if !is_vec_u8 {
  264. if let FnArg::Typed(pat_type) = second_arg {
  265. let original_type = pat_type.ty.to_token_stream();
  266. let arg_original_name = pat_type.pat.to_token_stream();
  267. if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
  268. let new_ident_name = format!("_{}", pat_ident.ident);
  269. pat_ident.ident =
  270. Ident::new(&new_ident_name, proc_macro2::Span::call_site());
  271. }
  272. let arg_name = pat_type.pat.to_token_stream();
  273. pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
  274. let parse_stmt: Stmt = parse_quote! {
  275. let #arg_original_name = parse_args::<#original_type>(&#arg_name);
  276. };
  277. item_fn.block.stmts.insert(0, parse_stmt);
  278. }
  279. }
  280. }
  281. }
  282. }
  283. /// Visits the AST to extract the number of input components
  284. impl VisitMut for Extractor {
  285. fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
  286. for input in &i.sig.inputs {
  287. if let FnArg::Typed(pat_type) = input {
  288. if let Type::Path(type_path) = &*pat_type.ty {
  289. let last_segment = type_path.path.segments.last().unwrap();
  290. if last_segment.ident == "Context" {
  291. if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
  292. if let Some(syn::GenericArgument::Type(syn::Type::Path(type_path))) =
  293. args.args.first()
  294. {
  295. let ident = &type_path.path.segments.first().unwrap().ident;
  296. self.context_struct_name = Some(ident.to_string());
  297. }
  298. }
  299. }
  300. }
  301. }
  302. }
  303. }
  304. fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
  305. if let Some(name) = &self.context_struct_name {
  306. if i.ident == name {
  307. self.field_count = Some(i.fields.len());
  308. }
  309. }
  310. }
  311. }