lib.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. use proc_macro::TokenStream;
  2. use proc_macro2::Ident;
  3. use quote::{quote, ToTokens};
  4. use syn::{
  5. parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
  6. ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
  7. };
  8. #[derive(Default)]
  9. struct SystemTransform {
  10. return_values: usize,
  11. }
  12. #[derive(Default)]
  13. struct Extractor {
  14. context_struct_name: Option<String>,
  15. field_count: Option<usize>,
  16. }
  17. /// This macro attribute is used to define a BOLT system.
  18. ///
  19. /// Bolt components are themselves programs. The macro adds parsing and serialization
  20. ///
  21. /// # Example
  22. /// ```ignore
  23. /// #[system]
  24. /// pub mod system_fly {
  25. /// pub fn execute(ctx: Context<Component>, _args: Vec<u8>) -> Result<Position> {
  26. /// let pos = Position {
  27. /// x: ctx.accounts.position.x,
  28. /// y: ctx.accounts.position.y,
  29. /// z: ctx.accounts.position.z + 1,
  30. /// };
  31. /// Ok(pos)
  32. /// }
  33. /// }
  34. /// ```
  35. #[proc_macro_attribute]
  36. pub fn system(attr: TokenStream, item: TokenStream) -> TokenStream {
  37. let mut ast = parse_macro_input!(item as ItemMod);
  38. let _attr = parse_macro_input!(attr as syn::AttributeArgs);
  39. // Extract the number of components from the module
  40. let mut extractor = Extractor::default();
  41. extractor.visit_item_mod_mut(&mut ast);
  42. if let Some(components_len) = extractor.field_count {
  43. let use_super = syn::parse_quote! { use super::*; };
  44. if let Some(ref mut content) = ast.content {
  45. content.1.insert(0, syn::Item::Use(use_super));
  46. }
  47. let mut transform = SystemTransform {
  48. return_values: components_len,
  49. };
  50. transform.visit_item_mod_mut(&mut ast);
  51. // Add `#[program]` macro and try_to_vec implementation
  52. let expanded = quote! {
  53. #[program]
  54. #ast
  55. };
  56. TokenStream::from(expanded)
  57. } else {
  58. panic!(
  59. "Could not find the component bundle: {} in the module",
  60. extractor.context_struct_name.unwrap()
  61. );
  62. }
  63. }
  64. /// Visits the AST and modifies the system function
  65. impl VisitMut for SystemTransform {
  66. // Modify the return instruction to return Result<Vec<u8>>
  67. fn visit_expr_mut(&mut self, expr: &mut Expr) {
  68. if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
  69. let new_return_expr: Expr = match inner_variable {
  70. Expr::Tuple(tuple_expr) => {
  71. let tuple_elements = tuple_expr.elems.iter().map(|elem| {
  72. quote! { (#elem).try_to_vec()? }
  73. });
  74. parse_quote! { Ok((#(#tuple_elements),*)) }
  75. }
  76. _ => {
  77. parse_quote! {
  78. #inner_variable.try_to_vec()
  79. }
  80. }
  81. };
  82. *expr = new_return_expr;
  83. }
  84. }
  85. // Modify the return type of the system function to Result<Vec<u8>,*>
  86. fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
  87. if item_fn.sig.ident == "execute" {
  88. // Modify the return type to Result<Vec<u8>> if necessary
  89. if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
  90. if let Type::Path(type_path) = &**type_box {
  91. if self.return_values > 1 {
  92. item_fn.sig.ident = Ident::new(
  93. format!("execute_{}", self.return_values).as_str(),
  94. item_fn.sig.ident.span(),
  95. );
  96. }
  97. if !Self::check_is_result_vec_u8(type_path) {
  98. Self::modify_fn_return_type(item_fn, self.return_values);
  99. // Modify the return statement inside the function body
  100. let block = &mut item_fn.block;
  101. for stmt in &mut block.stmts {
  102. if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
  103. self.visit_expr_mut(expr);
  104. }
  105. }
  106. }
  107. }
  108. }
  109. // If second argument is not Vec<u8>, modify it to be so and use parse_args
  110. Self::modify_args(item_fn);
  111. }
  112. }
  113. // Visit all the functions inside the system module and inject the init_extra_accounts function
  114. // if the module contains a struct with the `extra_accounts` attribute
  115. fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
  116. let content = match item_mod.content.as_mut() {
  117. Some(content) => &mut content.1,
  118. None => return,
  119. };
  120. let mut extra_accounts_struct_name = None;
  121. for item in content.iter_mut() {
  122. match item {
  123. syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
  124. syn::Item::Struct(item_struct)
  125. if item_struct
  126. .attrs
  127. .iter()
  128. .any(|attr| attr.path.is_ident("extra_accounts")) =>
  129. {
  130. extra_accounts_struct_name = Some(&item_struct.ident);
  131. break;
  132. }
  133. _ => {}
  134. }
  135. }
  136. if let Some(struct_name) = extra_accounts_struct_name {
  137. let initialize_extra_accounts = quote! {
  138. #[automatically_derived]
  139. pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
  140. Ok(())
  141. }
  142. };
  143. content.push(syn::parse2(initialize_extra_accounts).unwrap());
  144. }
  145. }
  146. }
  147. impl SystemTransform {
  148. // Helper function to check if a type is `Vec<u8>` or `(Vec<u8>, Vec<u8>, ...)`
  149. fn check_is_result_vec_u8(ty: &TypePath) -> bool {
  150. if let Some(segment) = ty.path.segments.last() {
  151. if segment.ident == "Result" {
  152. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  153. if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
  154. return tuple.elems.iter().all(|elem| {
  155. if let Type::Path(type_path) = elem {
  156. if let Some(segment) = type_path.path.segments.first() {
  157. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  158. }
  159. }
  160. false
  161. });
  162. } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
  163. args.args.first()
  164. {
  165. if let Some(segment) = type_path.path.segments.first() {
  166. return segment.ident == "Vec" && Self::is_u8_vec(segment);
  167. }
  168. }
  169. }
  170. }
  171. }
  172. false
  173. }
  174. // Helper function to check if a type is Vec<u8>
  175. fn is_u8_vec(segment: &syn::PathSegment) -> bool {
  176. if let PathArguments::AngleBracketed(args) = &segment.arguments {
  177. if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
  178. if let Some(segment) = path.path.segments.first() {
  179. return segment.ident == "u8";
  180. }
  181. }
  182. }
  183. false
  184. }
  185. // Helper function to modify the return type of a function to be Result<Vec<u8>> or Result<(Vec<u8>, Vec<u8>, ...)>
  186. fn modify_fn_return_type(item_fn: &mut ItemFn, ret_values: usize) {
  187. item_fn.sig.output = if ret_values == 1 {
  188. parse_quote! { -> Result<(Vec<u8>,)> }
  189. } else {
  190. let types = std::iter::repeat(quote! { Vec<u8> })
  191. .take(ret_values)
  192. .collect::<Vec<_>>();
  193. let tuple = quote! { (#(#types),*) };
  194. syn::parse2(quote! { -> Result<#tuple> }).unwrap()
  195. };
  196. }
  197. // Helper function to check if an expression is an `Ok(...)` or `return Ok(...);` variant
  198. fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
  199. match expr {
  200. Expr::Call(expr_call) => {
  201. // Direct `Ok(...)` call
  202. if let Expr::Path(expr_path) = &*expr_call.func {
  203. if let Some(last_segment) = expr_path.path.segments.last() {
  204. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  205. // Return the first argument of the Ok(...) call
  206. return expr_call.args.first();
  207. }
  208. }
  209. }
  210. }
  211. Expr::Return(expr_return) => {
  212. // `return Ok(...);`
  213. if let Some(expr_return_inner) = &expr_return.expr {
  214. if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
  215. if let Expr::Path(expr_path) = &*expr_call.func {
  216. if let Some(last_segment) = expr_path.path.segments.last() {
  217. if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
  218. // Return the first argument of the return Ok(...) call
  219. return expr_call.args.first();
  220. }
  221. }
  222. }
  223. }
  224. }
  225. }
  226. _ => {}
  227. }
  228. None
  229. }
  230. fn modify_args(item_fn: &mut ItemFn) {
  231. if item_fn.sig.inputs.len() >= 2 {
  232. let second_arg = &mut item_fn.sig.inputs[1];
  233. let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
  234. match &**ty {
  235. Type::Path(type_path) => {
  236. if let Some(segment) = type_path.path.segments.first() {
  237. segment.ident == "Vec" && Self::is_u8_vec(segment)
  238. } else {
  239. false
  240. }
  241. }
  242. _ => false,
  243. }
  244. } else {
  245. false
  246. };
  247. if !is_vec_u8 {
  248. if let FnArg::Typed(pat_type) = second_arg {
  249. let original_type = pat_type.ty.to_token_stream();
  250. let arg_original_name = pat_type.pat.to_token_stream();
  251. if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
  252. let new_ident_name = format!("_{}", pat_ident.ident);
  253. pat_ident.ident =
  254. Ident::new(&new_ident_name, proc_macro2::Span::call_site());
  255. }
  256. let arg_name = pat_type.pat.to_token_stream();
  257. pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
  258. let parse_stmt: Stmt = parse_quote! {
  259. let #arg_original_name = parse_args::<#original_type>(&#arg_name);
  260. };
  261. item_fn.block.stmts.insert(0, parse_stmt);
  262. }
  263. }
  264. }
  265. }
  266. }
  267. /// Visits the AST to extract the number of input components
  268. impl VisitMut for Extractor {
  269. fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
  270. for input in &i.sig.inputs {
  271. if let FnArg::Typed(pat_type) = input {
  272. if let Type::Path(type_path) = &*pat_type.ty {
  273. let last_segment = type_path.path.segments.last().unwrap();
  274. if last_segment.ident == "Context" {
  275. if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
  276. if let Some(syn::GenericArgument::Type(syn::Type::Path(type_path))) =
  277. args.args.first()
  278. {
  279. let ident = &type_path.path.segments.first().unwrap().ident;
  280. self.context_struct_name = Some(ident.to_string());
  281. }
  282. }
  283. }
  284. }
  285. }
  286. }
  287. }
  288. fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
  289. if let Some(name) = &self.context_struct_name {
  290. if i.ident == name {
  291. self.field_count = Some(i.fields.len());
  292. }
  293. }
  294. }
  295. }