file.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. use crate::idl::*;
  2. use crate::parser::{self, accounts, error, program};
  3. use crate::{AccountsStruct, StateIx};
  4. use anyhow::Result;
  5. use heck::MixedCase;
  6. use quote::ToTokens;
  7. use std::collections::{HashMap, HashSet};
  8. use std::fs::File;
  9. use std::io::Read;
  10. use std::iter::FromIterator;
  11. use std::path::Path;
  12. const DERIVE_NAME: &str = "Accounts";
  13. // Parse an entire interface file.
  14. pub fn parse(filename: impl AsRef<Path>) -> Result<Idl> {
  15. let mut file = File::open(&filename)?;
  16. let mut src = String::new();
  17. file.read_to_string(&mut src).expect("Unable to read file");
  18. let f = syn::parse_file(&src).expect("Unable to parse file");
  19. let p = program::parse(parse_program_mod(&f));
  20. let accs = parse_account_derives(&f);
  21. let state = match p.state {
  22. None => None,
  23. Some(state) => match state.ctor_and_anchor {
  24. None => None, // State struct defined but no implementation
  25. Some((ctor, anchor_ident)) => {
  26. let mut methods = state
  27. .impl_block_and_methods
  28. .map(|(_impl_block, methods)| {
  29. methods
  30. .iter()
  31. .map(|method: &StateIx| {
  32. let name = method.ident.to_string().to_mixed_case();
  33. let args = method
  34. .args
  35. .iter()
  36. .map(|arg| {
  37. let mut tts = proc_macro2::TokenStream::new();
  38. arg.raw_arg.ty.to_tokens(&mut tts);
  39. let ty = tts.to_string().parse().unwrap();
  40. IdlField {
  41. name: arg.name.to_string().to_mixed_case(),
  42. ty,
  43. }
  44. })
  45. .collect::<Vec<_>>();
  46. let accounts_strct =
  47. accs.get(&method.anchor_ident.to_string()).unwrap();
  48. let accounts = accounts_strct.idl_accounts(&accs);
  49. IdlStateMethod {
  50. name,
  51. args,
  52. accounts,
  53. }
  54. })
  55. .collect::<Vec<_>>()
  56. })
  57. .unwrap_or_default();
  58. let ctor = {
  59. let name = "new".to_string();
  60. let args = ctor
  61. .sig
  62. .inputs
  63. .iter()
  64. .filter(|arg| match arg {
  65. syn::FnArg::Typed(pat_ty) => {
  66. // TODO: this filtering should be donein the parser.
  67. let mut arg_str = parser::tts_to_string(&pat_ty.ty);
  68. arg_str.retain(|c| !c.is_whitespace());
  69. !arg_str.starts_with("Context<")
  70. }
  71. _ => false,
  72. })
  73. .map(|arg: &syn::FnArg| match arg {
  74. syn::FnArg::Typed(arg_typed) => {
  75. let mut tts = proc_macro2::TokenStream::new();
  76. arg_typed.ty.to_tokens(&mut tts);
  77. let ty = tts.to_string().parse().unwrap();
  78. IdlField {
  79. name: parser::tts_to_string(&arg_typed.pat).to_mixed_case(),
  80. ty,
  81. }
  82. }
  83. _ => panic!("Invalid syntax"),
  84. })
  85. .collect();
  86. let accounts_strct = accs.get(&anchor_ident.to_string()).unwrap();
  87. let accounts = accounts_strct.idl_accounts(&accs);
  88. IdlStateMethod {
  89. name,
  90. args,
  91. accounts,
  92. }
  93. };
  94. methods.insert(0, ctor);
  95. let strct = {
  96. let fields = match state.strct.fields {
  97. syn::Fields::Named(f_named) => f_named
  98. .named
  99. .iter()
  100. .map(|f: &syn::Field| {
  101. let mut tts = proc_macro2::TokenStream::new();
  102. f.ty.to_tokens(&mut tts);
  103. let ty = tts.to_string().parse().unwrap();
  104. IdlField {
  105. name: f.ident.as_ref().unwrap().to_string().to_mixed_case(),
  106. ty,
  107. }
  108. })
  109. .collect::<Vec<IdlField>>(),
  110. _ => panic!("State must be a struct"),
  111. };
  112. IdlTypeDef {
  113. name: state.name,
  114. ty: IdlTypeDefTy::Struct { fields },
  115. }
  116. };
  117. Some(IdlState { strct, methods })
  118. }
  119. },
  120. };
  121. let error = parse_error_enum(&f).map(|mut e| error::parse(&mut e));
  122. let error_codes = error.as_ref().map(|e| {
  123. e.codes
  124. .iter()
  125. .map(|code| IdlErrorCode {
  126. code: 100 + code.id,
  127. name: code.ident.to_string(),
  128. msg: code.msg.clone(),
  129. })
  130. .collect::<Vec<IdlErrorCode>>()
  131. });
  132. let instructions = p
  133. .ixs
  134. .iter()
  135. .map(|ix| {
  136. let args = ix
  137. .args
  138. .iter()
  139. .map(|arg| {
  140. let mut tts = proc_macro2::TokenStream::new();
  141. arg.raw_arg.ty.to_tokens(&mut tts);
  142. let ty = tts.to_string().parse().unwrap();
  143. IdlField {
  144. name: arg.name.to_string().to_mixed_case(),
  145. ty,
  146. }
  147. })
  148. .collect::<Vec<_>>();
  149. // todo: don't unwrap
  150. let accounts_strct = accs.get(&ix.anchor_ident.to_string()).unwrap();
  151. let accounts = accounts_strct.idl_accounts(&accs);
  152. IdlIx {
  153. name: ix.ident.to_string().to_mixed_case(),
  154. accounts,
  155. args,
  156. }
  157. })
  158. .collect::<Vec<_>>();
  159. let events = parse_events(&f)
  160. .iter()
  161. .map(|e: &&syn::ItemStruct| {
  162. let fields = match &e.fields {
  163. syn::Fields::Named(n) => n,
  164. _ => panic!("Event fields must be named"),
  165. };
  166. let fields = fields
  167. .named
  168. .iter()
  169. .map(|f: &syn::Field| {
  170. let index = match f.attrs.iter().next() {
  171. None => false,
  172. Some(i) => parser::tts_to_string(&i.path) == "index",
  173. };
  174. IdlEventField {
  175. name: f.ident.clone().unwrap().to_string(),
  176. ty: parser::tts_to_string(&f.ty).to_string().parse().unwrap(),
  177. index,
  178. }
  179. })
  180. .collect::<Vec<IdlEventField>>();
  181. IdlEvent {
  182. name: e.ident.to_string(),
  183. fields,
  184. }
  185. })
  186. .collect::<Vec<IdlEvent>>();
  187. // All user defined types.
  188. let mut accounts = vec![];
  189. let mut types = vec![];
  190. let ty_defs = parse_ty_defs(&f)?;
  191. let account_structs = parse_accounts(&f);
  192. let account_names: HashSet<String> =
  193. HashSet::from_iter(account_structs.iter().map(|a| a.ident.to_string()));
  194. let error_name = error.map(|e| e.name).unwrap_or_else(|| "".to_string());
  195. // All types that aren't in the accounts section, are in the types section.
  196. for ty_def in ty_defs {
  197. // Don't add the error type to the types or accounts sections.
  198. if ty_def.name != error_name {
  199. if account_names.contains(&ty_def.name) {
  200. accounts.push(ty_def);
  201. } else if events.iter().position(|e| e.name == ty_def.name).is_none() {
  202. types.push(ty_def);
  203. }
  204. }
  205. }
  206. Ok(Idl {
  207. version: "0.0.0".to_string(),
  208. name: p.name.to_string(),
  209. state,
  210. instructions,
  211. types,
  212. accounts,
  213. events: if events.is_empty() {
  214. None
  215. } else {
  216. Some(events)
  217. },
  218. errors: error_codes,
  219. metadata: None,
  220. })
  221. }
  222. // Parse the main program mod.
  223. fn parse_program_mod(f: &syn::File) -> syn::ItemMod {
  224. let mods = f
  225. .items
  226. .iter()
  227. .filter_map(|i| match i {
  228. syn::Item::Mod(item_mod) => {
  229. let mod_count = item_mod
  230. .attrs
  231. .iter()
  232. .filter(|attr| attr.path.segments.last().unwrap().ident == "program")
  233. .count();
  234. if mod_count != 1 {
  235. return None;
  236. }
  237. Some(item_mod)
  238. }
  239. _ => None,
  240. })
  241. .collect::<Vec<_>>();
  242. if mods.len() != 1 {
  243. panic!("Did not find program attribute");
  244. }
  245. mods[0].clone()
  246. }
  247. fn parse_error_enum(f: &syn::File) -> Option<syn::ItemEnum> {
  248. f.items
  249. .iter()
  250. .filter_map(|i| match i {
  251. syn::Item::Enum(item_enum) => {
  252. let attrs_count = item_enum
  253. .attrs
  254. .iter()
  255. .filter(|attr| {
  256. let segment = attr.path.segments.last().unwrap();
  257. segment.ident == "error"
  258. })
  259. .count();
  260. match attrs_count {
  261. 0 => None,
  262. 1 => Some(item_enum),
  263. _ => panic!("Invalid syntax: one error attribute allowed"),
  264. }
  265. }
  266. _ => None,
  267. })
  268. .next()
  269. .cloned()
  270. }
  271. fn parse_events(f: &syn::File) -> Vec<&syn::ItemStruct> {
  272. f.items
  273. .iter()
  274. .filter_map(|i| match i {
  275. syn::Item::Struct(item_strct) => {
  276. let attrs_count = item_strct
  277. .attrs
  278. .iter()
  279. .filter(|attr| {
  280. let segment = attr.path.segments.last().unwrap();
  281. segment.ident == "event"
  282. })
  283. .count();
  284. match attrs_count {
  285. 0 => None,
  286. 1 => Some(item_strct),
  287. _ => panic!("Invalid syntax: one event attribute allowed"),
  288. }
  289. }
  290. _ => None,
  291. })
  292. .collect()
  293. }
  294. fn parse_accounts(f: &syn::File) -> Vec<&syn::ItemStruct> {
  295. f.items
  296. .iter()
  297. .filter_map(|i| match i {
  298. syn::Item::Struct(item_strct) => {
  299. let attrs_count = item_strct
  300. .attrs
  301. .iter()
  302. .filter(|attr| {
  303. let segment = attr.path.segments.last().unwrap();
  304. segment.ident == "account" || segment.ident == "associated"
  305. })
  306. .count();
  307. match attrs_count {
  308. 0 => None,
  309. 1 => Some(item_strct),
  310. _ => panic!("Invalid syntax: one event attribute allowed"),
  311. }
  312. }
  313. _ => None,
  314. })
  315. .collect()
  316. }
  317. // Parse all structs implementing the `Accounts` trait.
  318. fn parse_account_derives(f: &syn::File) -> HashMap<String, AccountsStruct> {
  319. f.items
  320. .iter()
  321. .filter_map(|i| match i {
  322. syn::Item::Struct(i_strct) => {
  323. for attr in &i_strct.attrs {
  324. if attr.tokens.to_string().contains(DERIVE_NAME) {
  325. let strct = accounts::parse(i_strct);
  326. return Some((strct.ident.to_string(), strct));
  327. }
  328. }
  329. None
  330. }
  331. // TODO: parse manual implementations. Currently we only look
  332. // for derives.
  333. _ => None,
  334. })
  335. .collect()
  336. }
  337. // Parse all user defined types in the file.
  338. fn parse_ty_defs(f: &syn::File) -> Result<Vec<IdlTypeDef>> {
  339. f.items
  340. .iter()
  341. .filter_map(|i| match i {
  342. syn::Item::Struct(item_strct) => {
  343. for attr in &item_strct.attrs {
  344. if attr.tokens.to_string().contains(DERIVE_NAME) {
  345. return None;
  346. }
  347. }
  348. if let syn::Visibility::Public(_) = &item_strct.vis {
  349. let name = item_strct.ident.to_string();
  350. let fields = match &item_strct.fields {
  351. syn::Fields::Named(fields) => fields
  352. .named
  353. .iter()
  354. .map(|f: &syn::Field| {
  355. let mut tts = proc_macro2::TokenStream::new();
  356. f.ty.to_tokens(&mut tts);
  357. Ok(IdlField {
  358. name: f.ident.as_ref().unwrap().to_string().to_mixed_case(),
  359. ty: tts.to_string().parse()?,
  360. })
  361. })
  362. .collect::<Result<Vec<IdlField>>>(),
  363. _ => panic!("Only named structs are allowed."),
  364. };
  365. return Some(fields.map(|fields| IdlTypeDef {
  366. name,
  367. ty: IdlTypeDefTy::Struct { fields },
  368. }));
  369. }
  370. None
  371. }
  372. syn::Item::Enum(enm) => {
  373. let name = enm.ident.to_string();
  374. let variants = enm
  375. .variants
  376. .iter()
  377. .map(|variant: &syn::Variant| {
  378. let name = variant.ident.to_string();
  379. let fields = match &variant.fields {
  380. syn::Fields::Unit => None,
  381. syn::Fields::Unnamed(fields) => {
  382. let fields: Vec<IdlType> =
  383. fields.unnamed.iter().map(to_idl_type).collect();
  384. Some(EnumFields::Tuple(fields))
  385. }
  386. syn::Fields::Named(fields) => {
  387. let fields: Vec<IdlField> = fields
  388. .named
  389. .iter()
  390. .map(|f: &syn::Field| {
  391. let name = f.ident.as_ref().unwrap().to_string();
  392. let ty = to_idl_type(f);
  393. IdlField { name, ty }
  394. })
  395. .collect();
  396. Some(EnumFields::Named(fields))
  397. }
  398. };
  399. EnumVariant { name, fields }
  400. })
  401. .collect::<Vec<EnumVariant>>();
  402. Some(Ok(IdlTypeDef {
  403. name,
  404. ty: IdlTypeDefTy::Enum { variants },
  405. }))
  406. }
  407. _ => None,
  408. })
  409. .collect()
  410. }
  411. fn to_idl_type(f: &syn::Field) -> IdlType {
  412. let mut tts = proc_macro2::TokenStream::new();
  413. f.ty.to_tokens(&mut tts);
  414. tts.to_string().parse().unwrap()
  415. }