mod.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. use crate::*;
  2. use syn::parse::{Error as ParseError, Result as ParseResult};
  3. use syn::punctuated::Punctuated;
  4. use syn::spanned::Spanned;
  5. use syn::token::Comma;
  6. use syn::Expr;
  7. pub mod constraints;
  8. pub fn parse(strct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
  9. let instruction_api: Option<Punctuated<Expr, Comma>> = strct
  10. .attrs
  11. .iter()
  12. .find(|a| {
  13. a.path
  14. .get_ident()
  15. .map_or(false, |ident| ident == "instruction")
  16. })
  17. .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
  18. .transpose()?;
  19. let fields = match &strct.fields {
  20. syn::Fields::Named(fields) => fields
  21. .named
  22. .iter()
  23. .map(|f| parse_account_field(f, instruction_api.is_some()))
  24. .collect::<ParseResult<Vec<AccountField>>>()?,
  25. _ => {
  26. return Err(ParseError::new_spanned(
  27. &strct.fields,
  28. "fields must be named",
  29. ))
  30. }
  31. };
  32. let _ = constraints_cross_checks(&fields)?;
  33. Ok(AccountsStruct::new(strct.clone(), fields, instruction_api))
  34. }
  35. fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
  36. // INIT
  37. let init_field = fields.iter().find(|f| {
  38. if let AccountField::Field(field) = f {
  39. field.constraints.init.is_some()
  40. } else {
  41. false
  42. }
  43. });
  44. if let Some(init_field) = init_field {
  45. // init needs system program.
  46. if fields.iter().all(|f| f.ident() != "system_program") {
  47. return Err(ParseError::new(
  48. init_field.ident().span(),
  49. "the init constraint requires \
  50. the system_program field to exist in the account \
  51. validation struct. Use the program type to add \
  52. the system_program field to your validation struct.",
  53. ));
  54. }
  55. if let AccountField::Field(field) = init_field {
  56. let kind = &field.constraints.init.as_ref().unwrap().kind;
  57. // init token/a_token/mint needs token program.
  58. match kind {
  59. InitKind::Program { .. } => (),
  60. InitKind::Token { .. }
  61. | InitKind::AssociatedToken { .. }
  62. | InitKind::Mint { .. } => {
  63. if fields.iter().all(|f| f.ident() != "token_program") {
  64. return Err(ParseError::new(
  65. init_field.ident().span(),
  66. "the init constraint requires \
  67. the token_program field to exist in the account \
  68. validation struct. Use the program type to add \
  69. the token_program field to your validation struct.",
  70. ));
  71. }
  72. }
  73. }
  74. // a_token needs associated token program.
  75. if let InitKind::AssociatedToken { .. } = kind {
  76. if fields
  77. .iter()
  78. .all(|f| f.ident() != "associated_token_program")
  79. {
  80. return Err(ParseError::new(
  81. init_field.ident().span(),
  82. "the init constraint requires \
  83. the associated_token_program field to exist in the account \
  84. validation struct. Use the program type to add \
  85. the associated_token_program field to your validation struct.",
  86. ));
  87. }
  88. }
  89. }
  90. }
  91. Ok(())
  92. }
  93. pub fn parse_account_field(f: &syn::Field, has_instruction_api: bool) -> ParseResult<AccountField> {
  94. let ident = f.ident.clone().unwrap();
  95. let account_field = match is_field_primitive(f)? {
  96. true => {
  97. let ty = parse_ty(f)?;
  98. let (account_constraints, instruction_constraints) =
  99. constraints::parse(f, Some(&ty), has_instruction_api)?;
  100. AccountField::Field(Field {
  101. ident,
  102. ty,
  103. constraints: account_constraints,
  104. instruction_constraints,
  105. })
  106. }
  107. false => {
  108. let (account_constraints, instruction_constraints) =
  109. constraints::parse(f, None, has_instruction_api)?;
  110. AccountField::CompositeField(CompositeField {
  111. ident,
  112. constraints: account_constraints,
  113. instruction_constraints,
  114. symbol: ident_string(f)?,
  115. raw_field: f.clone(),
  116. })
  117. }
  118. };
  119. Ok(account_field)
  120. }
  121. fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
  122. let r = matches!(
  123. ident_string(f)?.as_str(),
  124. "ProgramState"
  125. | "ProgramAccount"
  126. | "CpiAccount"
  127. | "Sysvar"
  128. | "AccountInfo"
  129. | "UncheckedAccount"
  130. | "CpiState"
  131. | "Loader"
  132. | "AccountLoader"
  133. | "Account"
  134. | "Program"
  135. | "Signer"
  136. | "SystemAccount"
  137. | "ProgramData"
  138. );
  139. Ok(r)
  140. }
  141. fn parse_ty(f: &syn::Field) -> ParseResult<Ty> {
  142. let path = match &f.ty {
  143. syn::Type::Path(ty_path) => ty_path.path.clone(),
  144. _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
  145. };
  146. let ty = match ident_string(f)?.as_str() {
  147. "ProgramState" => Ty::ProgramState(parse_program_state(&path)?),
  148. "CpiState" => Ty::CpiState(parse_cpi_state(&path)?),
  149. "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)?),
  150. "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)?),
  151. "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
  152. "AccountInfo" => Ty::AccountInfo,
  153. "UncheckedAccount" => Ty::UncheckedAccount,
  154. "Loader" => Ty::Loader(parse_program_account_zero_copy(&path)?),
  155. "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
  156. "Account" => Ty::Account(parse_account_ty(&path)?),
  157. "Program" => Ty::Program(parse_program_ty(&path)?),
  158. "Signer" => Ty::Signer,
  159. "SystemAccount" => Ty::SystemAccount,
  160. "ProgramData" => Ty::ProgramData,
  161. _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
  162. };
  163. Ok(ty)
  164. }
  165. fn ident_string(f: &syn::Field) -> ParseResult<String> {
  166. let path = match &f.ty {
  167. syn::Type::Path(ty_path) => ty_path.path.clone(),
  168. _ => return Err(ParseError::new(f.ty.span(), "invalid type")),
  169. };
  170. if parser::tts_to_string(&path)
  171. .replace(' ', "")
  172. .starts_with("Box<Account<")
  173. {
  174. return Ok("Account".to_string());
  175. }
  176. // TODO: allow segmented paths.
  177. if path.segments.len() != 1 {
  178. return Err(ParseError::new(
  179. f.ty.span(),
  180. "segmented paths are not currently allowed",
  181. ));
  182. }
  183. let segments = &path.segments[0];
  184. Ok(segments.ident.to_string())
  185. }
  186. fn parse_program_state(path: &syn::Path) -> ParseResult<ProgramStateTy> {
  187. let account_ident = parse_account(path)?;
  188. Ok(ProgramStateTy {
  189. account_type_path: account_ident,
  190. })
  191. }
  192. fn parse_cpi_state(path: &syn::Path) -> ParseResult<CpiStateTy> {
  193. let account_ident = parse_account(path)?;
  194. Ok(CpiStateTy {
  195. account_type_path: account_ident,
  196. })
  197. }
  198. fn parse_cpi_account(path: &syn::Path) -> ParseResult<CpiAccountTy> {
  199. let account_ident = parse_account(path)?;
  200. Ok(CpiAccountTy {
  201. account_type_path: account_ident,
  202. })
  203. }
  204. fn parse_program_account(path: &syn::Path) -> ParseResult<ProgramAccountTy> {
  205. let account_ident = parse_account(path)?;
  206. Ok(ProgramAccountTy {
  207. account_type_path: account_ident,
  208. })
  209. }
  210. fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult<LoaderTy> {
  211. let account_ident = parse_account(path)?;
  212. Ok(LoaderTy {
  213. account_type_path: account_ident,
  214. })
  215. }
  216. fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
  217. let account_ident = parse_account(path)?;
  218. Ok(AccountLoaderTy {
  219. account_type_path: account_ident,
  220. })
  221. }
  222. fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
  223. let account_type_path = parse_account(path)?;
  224. let boxed = parser::tts_to_string(&path)
  225. .replace(' ', "")
  226. .starts_with("Box<Account<");
  227. Ok(AccountTy {
  228. account_type_path,
  229. boxed,
  230. })
  231. }
  232. fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
  233. let account_type_path = parse_account(path)?;
  234. Ok(ProgramTy { account_type_path })
  235. }
  236. // TODO: this whole method is a hack. Do something more idiomatic.
  237. fn parse_account(mut path: &syn::Path) -> ParseResult<syn::TypePath> {
  238. if parser::tts_to_string(path)
  239. .replace(' ', "")
  240. .starts_with("Box<Account<")
  241. {
  242. let segments = &path.segments[0];
  243. match &segments.arguments {
  244. syn::PathArguments::AngleBracketed(args) => {
  245. // Expected: <'info, MyType>.
  246. if args.args.len() != 1 {
  247. return Err(ParseError::new(
  248. args.args.span(),
  249. "bracket arguments must be the lifetime and type",
  250. ));
  251. }
  252. match &args.args[0] {
  253. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  254. path = &ty_path.path;
  255. }
  256. _ => {
  257. return Err(ParseError::new(
  258. args.args[1].span(),
  259. "first bracket argument must be a lifetime",
  260. ))
  261. }
  262. }
  263. }
  264. _ => {
  265. return Err(ParseError::new(
  266. segments.arguments.span(),
  267. "expected angle brackets with a lifetime and type",
  268. ))
  269. }
  270. }
  271. }
  272. let segments = &path.segments[0];
  273. match &segments.arguments {
  274. syn::PathArguments::AngleBracketed(args) => {
  275. // Expected: <'info, MyType>.
  276. if args.args.len() != 2 {
  277. return Err(ParseError::new(
  278. args.args.span(),
  279. "bracket arguments must be the lifetime and type",
  280. ));
  281. }
  282. match &args.args[1] {
  283. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()),
  284. _ => Err(ParseError::new(
  285. args.args[1].span(),
  286. "first bracket argument must be a lifetime",
  287. )),
  288. }
  289. }
  290. _ => Err(ParseError::new(
  291. segments.arguments.span(),
  292. "expected angle brackets with a lifetime and type",
  293. )),
  294. }
  295. }
  296. fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
  297. let segments = &path.segments[0];
  298. let account_ident = match &segments.arguments {
  299. syn::PathArguments::AngleBracketed(args) => {
  300. // Expected: <'info, MyType>.
  301. if args.args.len() != 2 {
  302. return Err(ParseError::new(
  303. args.args.span(),
  304. "bracket arguments must be the lifetime and type",
  305. ));
  306. }
  307. match &args.args[1] {
  308. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  309. // TODO: allow segmented paths.
  310. if ty_path.path.segments.len() != 1 {
  311. return Err(ParseError::new(
  312. ty_path.path.span(),
  313. "segmented paths are not currently allowed",
  314. ));
  315. }
  316. let path_segment = &ty_path.path.segments[0];
  317. path_segment.ident.clone()
  318. }
  319. _ => {
  320. return Err(ParseError::new(
  321. args.args[1].span(),
  322. "first bracket argument must be a lifetime",
  323. ))
  324. }
  325. }
  326. }
  327. _ => {
  328. return Err(ParseError::new(
  329. segments.arguments.span(),
  330. "expected angle brackets with a lifetime and type",
  331. ))
  332. }
  333. };
  334. let ty = match account_ident.to_string().as_str() {
  335. "Clock" => SysvarTy::Clock,
  336. "Rent" => SysvarTy::Rent,
  337. "EpochSchedule" => SysvarTy::EpochSchedule,
  338. "Fees" => SysvarTy::Fees,
  339. "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
  340. "SlotHashes" => SysvarTy::SlotHashes,
  341. "SlotHistory" => SysvarTy::SlotHistory,
  342. "StakeHistory" => SysvarTy::StakeHistory,
  343. "Instructions" => SysvarTy::Instructions,
  344. "Rewards" => SysvarTy::Rewards,
  345. _ => {
  346. return Err(ParseError::new(
  347. account_ident.span(),
  348. "invalid sysvar provided",
  349. ))
  350. }
  351. };
  352. Ok(ty)
  353. }