mod.rs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. use crate::parser::docs;
  2. use crate::*;
  3. use syn::parse::{Error as ParseError, Result as ParseResult};
  4. use syn::punctuated::Punctuated;
  5. use syn::spanned::Spanned;
  6. use syn::token::Comma;
  7. use syn::Expr;
  8. use syn::Path;
  9. pub mod constraints;
  10. pub fn parse(strct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
  11. let instruction_api: Option<Punctuated<Expr, Comma>> = strct
  12. .attrs
  13. .iter()
  14. .find(|a| {
  15. a.path
  16. .get_ident()
  17. .map_or(false, |ident| ident == "instruction")
  18. })
  19. .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
  20. .transpose()?;
  21. let fields = match &strct.fields {
  22. syn::Fields::Named(fields) => fields
  23. .named
  24. .iter()
  25. .map(parse_account_field)
  26. .collect::<ParseResult<Vec<AccountField>>>()?,
  27. _ => {
  28. return Err(ParseError::new_spanned(
  29. &strct.fields,
  30. "fields must be named",
  31. ))
  32. }
  33. };
  34. constraints_cross_checks(&fields)?;
  35. Ok(AccountsStruct::new(strct.clone(), fields, instruction_api))
  36. }
  37. fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
  38. // COMMON ERROR MESSAGE
  39. let message = |constraint: &str, field: &str, required: bool| {
  40. if required {
  41. format! {
  42. "The {} constraint requires \
  43. a {} field to exist in the account \
  44. validation struct. Use the Program type to add \
  45. the {} field to your validation struct.", constraint, field, field
  46. }
  47. } else {
  48. format! {
  49. "An optional {} constraint requires \
  50. an optional or required {} field to exist \
  51. in the account validation struct. Use the Program type \
  52. to add the {} field to your validation struct.", constraint, field, field
  53. }
  54. }
  55. };
  56. // INIT
  57. let mut required_init = false;
  58. let init_fields: Vec<&Field> = fields
  59. .iter()
  60. .filter_map(|f| match f {
  61. AccountField::Field(field) if field.constraints.init.is_some() => {
  62. if !field.is_optional {
  63. required_init = true
  64. }
  65. Some(field)
  66. }
  67. _ => None,
  68. })
  69. .collect();
  70. if !init_fields.is_empty() {
  71. // init needs system program.
  72. if !fields
  73. .iter()
  74. // ensures that a non optional `system_program` is present with non optional `init`
  75. .any(|f| f.ident() == "system_program" && !(required_init && f.is_optional()))
  76. {
  77. return Err(ParseError::new(
  78. init_fields[0].ident.span(),
  79. message("init", "system_program", required_init),
  80. ));
  81. }
  82. let kind = &init_fields[0].constraints.init.as_ref().unwrap().kind;
  83. // init token/a_token/mint needs token program.
  84. match kind {
  85. InitKind::Program { .. } => (),
  86. InitKind::Token { .. } | InitKind::AssociatedToken { .. } | InitKind::Mint { .. } => {
  87. if !fields
  88. .iter()
  89. .any(|f| f.ident() == "token_program" && !(required_init && f.is_optional()))
  90. {
  91. return Err(ParseError::new(
  92. init_fields[0].ident.span(),
  93. message("init", "token_program", required_init),
  94. ));
  95. }
  96. }
  97. }
  98. // a_token needs associated token program.
  99. if let InitKind::AssociatedToken { .. } = kind {
  100. if !fields.iter().any(|f| {
  101. f.ident() == "associated_token_program" && !(required_init && f.is_optional())
  102. }) {
  103. return Err(ParseError::new(
  104. init_fields[0].ident.span(),
  105. message("init", "associated_token_program", required_init),
  106. ));
  107. }
  108. }
  109. for field in init_fields {
  110. // Get payer for init-ed account
  111. let associated_payer_name = match field.constraints.init.clone().unwrap().payer {
  112. // composite payer, check not supported
  113. Expr::Field(_) => continue,
  114. // method call, check not supported
  115. Expr::MethodCall(_) => continue,
  116. field_name => field_name.to_token_stream().to_string(),
  117. };
  118. // Check payer is mutable
  119. let associated_payer_field = fields.iter().find_map(|f| match f {
  120. AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
  121. _ => None,
  122. });
  123. match associated_payer_field {
  124. Some(associated_payer_field) => {
  125. if !associated_payer_field.constraints.is_mutable() {
  126. return Err(ParseError::new(
  127. field.ident.span(),
  128. "the payer specified for an init constraint must be mutable.",
  129. ));
  130. } else if associated_payer_field.is_optional && required_init {
  131. return Err(ParseError::new(
  132. field.ident.span(),
  133. "the payer specified for a required init constraint must be required.",
  134. ));
  135. }
  136. }
  137. _ => {
  138. return Err(ParseError::new(
  139. field.ident.span(),
  140. "the payer specified does not exist.",
  141. ));
  142. }
  143. }
  144. match kind {
  145. // This doesn't catch cases like account.key() or account.key.
  146. // My guess is that doesn't happen often and we can revisit
  147. // this if I'm wrong.
  148. InitKind::Token { mint, .. } | InitKind::AssociatedToken { mint, .. } => {
  149. if !fields.iter().any(|f| {
  150. f.ident()
  151. .to_string()
  152. .starts_with(&mint.to_token_stream().to_string())
  153. }) {
  154. return Err(ParseError::new(
  155. field.ident.span(),
  156. "the mint constraint has to be an account field for token initializations (not a public key)",
  157. ));
  158. }
  159. }
  160. _ => (),
  161. }
  162. }
  163. }
  164. // REALLOC
  165. let mut required_realloc = false;
  166. let realloc_fields: Vec<&Field> = fields
  167. .iter()
  168. .filter_map(|f| match f {
  169. AccountField::Field(field) if field.constraints.realloc.is_some() => {
  170. if !field.is_optional {
  171. required_realloc = true
  172. }
  173. Some(field)
  174. }
  175. _ => None,
  176. })
  177. .collect();
  178. if !realloc_fields.is_empty() {
  179. // realloc needs system program.
  180. if !fields
  181. .iter()
  182. .any(|f| f.ident() == "system_program" && !(required_realloc && f.is_optional()))
  183. {
  184. return Err(ParseError::new(
  185. realloc_fields[0].ident.span(),
  186. message("realloc", "system_program", required_realloc),
  187. ));
  188. }
  189. for field in realloc_fields {
  190. // Get allocator for realloc-ed account
  191. let associated_payer_name = match field.constraints.realloc.clone().unwrap().payer {
  192. // composite allocator, check not supported
  193. Expr::Field(_) => continue,
  194. // method call, check not supported
  195. Expr::MethodCall(_) => continue,
  196. field_name => field_name.to_token_stream().to_string(),
  197. };
  198. // Check allocator is mutable
  199. let associated_payer_field = fields.iter().find_map(|f| match f {
  200. AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
  201. _ => None,
  202. });
  203. match associated_payer_field {
  204. Some(associated_payer_field) => {
  205. if !associated_payer_field.constraints.is_mutable() {
  206. return Err(ParseError::new(
  207. field.ident.span(),
  208. "the realloc::payer specified for an realloc constraint must be mutable.",
  209. ));
  210. } else if associated_payer_field.is_optional && required_realloc {
  211. return Err(ParseError::new(
  212. field.ident.span(),
  213. "the realloc::payer specified for a required realloc constraint must be required.",
  214. ));
  215. }
  216. }
  217. _ => {
  218. return Err(ParseError::new(
  219. field.ident.span(),
  220. "the realloc::payer specified does not exist.",
  221. ));
  222. }
  223. }
  224. }
  225. }
  226. Ok(())
  227. }
  228. pub fn parse_account_field(f: &syn::Field) -> ParseResult<AccountField> {
  229. let ident = f.ident.clone().unwrap();
  230. let docs = docs::parse(&f.attrs);
  231. let account_field = match is_field_primitive(f)? {
  232. true => {
  233. let (ty, is_optional) = parse_ty(f)?;
  234. let account_constraints = constraints::parse(f, Some(&ty))?;
  235. AccountField::Field(Field {
  236. ident,
  237. ty,
  238. is_optional,
  239. constraints: account_constraints,
  240. docs,
  241. })
  242. }
  243. false => {
  244. let (_, optional, _) = ident_string(f)?;
  245. if optional {
  246. return Err(ParseError::new(
  247. f.ty.span(),
  248. "Cannot have Optional composite accounts",
  249. ));
  250. }
  251. let account_constraints = constraints::parse(f, None)?;
  252. AccountField::CompositeField(CompositeField {
  253. ident,
  254. constraints: account_constraints,
  255. symbol: ident_string(f)?.0,
  256. raw_field: f.clone(),
  257. docs,
  258. })
  259. }
  260. };
  261. Ok(account_field)
  262. }
  263. fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
  264. let r = matches!(
  265. ident_string(f)?.0.as_str(),
  266. "ProgramState"
  267. | "ProgramAccount"
  268. | "CpiAccount"
  269. | "Sysvar"
  270. | "AccountInfo"
  271. | "UncheckedAccount"
  272. | "CpiState"
  273. | "Loader"
  274. | "AccountLoader"
  275. | "Account"
  276. | "Program"
  277. | "Signer"
  278. | "SystemAccount"
  279. | "ProgramData"
  280. );
  281. Ok(r)
  282. }
  283. // TODO call `account_parse` a single time at the start of this function and then init each of the types
  284. fn parse_ty(f: &syn::Field) -> ParseResult<(Ty, bool)> {
  285. let (ident, optional, path) = ident_string(f)?;
  286. let ty = match ident.as_str() {
  287. "ProgramState" => Ty::ProgramState(parse_program_state(&path)?),
  288. "CpiState" => Ty::CpiState(parse_cpi_state(&path)?),
  289. "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)?),
  290. "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)?),
  291. "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
  292. "AccountInfo" => Ty::AccountInfo,
  293. "UncheckedAccount" => Ty::UncheckedAccount,
  294. "Loader" => Ty::Loader(parse_program_account_zero_copy(&path)?),
  295. "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
  296. "Account" => Ty::Account(parse_account_ty(&path)?),
  297. "Program" => Ty::Program(parse_program_ty(&path)?),
  298. "Signer" => Ty::Signer,
  299. "SystemAccount" => Ty::SystemAccount,
  300. "ProgramData" => Ty::ProgramData,
  301. _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
  302. };
  303. Ok((ty, optional))
  304. }
  305. fn option_to_inner_path(path: &Path) -> ParseResult<Path> {
  306. let segment_0 = path.segments[0].clone();
  307. match segment_0.arguments {
  308. syn::PathArguments::AngleBracketed(args) => {
  309. if args.args.len() != 1 {
  310. return Err(ParseError::new(
  311. args.args.span(),
  312. "can only have one argument in option",
  313. ));
  314. }
  315. match &args.args[0] {
  316. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.path.clone()),
  317. _ => Err(ParseError::new(
  318. args.args[1].span(),
  319. "first bracket argument must be a lifetime",
  320. )),
  321. }
  322. }
  323. _ => Err(ParseError::new(
  324. segment_0.arguments.span(),
  325. "expected angle brackets with a lifetime and type",
  326. )),
  327. }
  328. }
  329. fn ident_string(f: &syn::Field) -> ParseResult<(String, bool, Path)> {
  330. // TODO support parsing references to account infos
  331. let mut path = match &f.ty {
  332. syn::Type::Path(ty_path) => ty_path.path.clone(),
  333. _ => {
  334. return Err(ParseError::new_spanned(
  335. f,
  336. format!(
  337. "Field {} has a non-path type",
  338. f.ident.as_ref().expect("named fields only")
  339. ),
  340. ))
  341. }
  342. };
  343. // TODO replace string matching with helper functions using syn type match statments
  344. let mut optional = false;
  345. if parser::tts_to_string(&path)
  346. .replace(' ', "")
  347. .starts_with("Option<")
  348. {
  349. path = option_to_inner_path(&path)?;
  350. optional = true;
  351. }
  352. if parser::tts_to_string(&path)
  353. .replace(' ', "")
  354. .starts_with("Box<Account<")
  355. {
  356. return Ok(("Account".to_string(), optional, path));
  357. }
  358. // TODO: allow segmented paths.
  359. if path.segments.len() != 1 {
  360. return Err(ParseError::new(
  361. f.ty.span(),
  362. "segmented paths are not currently allowed",
  363. ));
  364. }
  365. let segments = &path.segments[0];
  366. Ok((segments.ident.to_string(), optional, path))
  367. }
  368. fn parse_program_state(path: &syn::Path) -> ParseResult<ProgramStateTy> {
  369. let account_ident = parse_account(path)?;
  370. Ok(ProgramStateTy {
  371. account_type_path: account_ident,
  372. })
  373. }
  374. fn parse_cpi_state(path: &syn::Path) -> ParseResult<CpiStateTy> {
  375. let account_ident = parse_account(path)?;
  376. Ok(CpiStateTy {
  377. account_type_path: account_ident,
  378. })
  379. }
  380. fn parse_cpi_account(path: &syn::Path) -> ParseResult<CpiAccountTy> {
  381. let account_ident = parse_account(path)?;
  382. Ok(CpiAccountTy {
  383. account_type_path: account_ident,
  384. })
  385. }
  386. fn parse_program_account(path: &syn::Path) -> ParseResult<ProgramAccountTy> {
  387. let account_ident = parse_account(path)?;
  388. Ok(ProgramAccountTy {
  389. account_type_path: account_ident,
  390. })
  391. }
  392. fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult<LoaderTy> {
  393. let account_ident = parse_account(path)?;
  394. Ok(LoaderTy {
  395. account_type_path: account_ident,
  396. })
  397. }
  398. fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
  399. let account_ident = parse_account(path)?;
  400. Ok(AccountLoaderTy {
  401. account_type_path: account_ident,
  402. })
  403. }
  404. fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
  405. let account_type_path = parse_account(path)?;
  406. let boxed = parser::tts_to_string(path)
  407. .replace(' ', "")
  408. .starts_with("Box<Account<");
  409. Ok(AccountTy {
  410. account_type_path,
  411. boxed,
  412. })
  413. }
  414. fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
  415. let account_type_path = parse_account(path)?;
  416. Ok(ProgramTy { account_type_path })
  417. }
  418. // Extract the type path T from Box<[AccountType]<'info, T>> or [AccountType]<'info, T>
  419. fn parse_account(path: &syn::Path) -> ParseResult<syn::TypePath> {
  420. let segment = &path.segments[0];
  421. match segment.ident.to_string().as_str() {
  422. "Box" => match &segment.arguments {
  423. syn::PathArguments::AngleBracketed(args) => {
  424. if args.args.len() != 1 {
  425. return Err(ParseError::new(
  426. args.args.span(),
  427. "Expected a single argument: [AccountType]<'info, T>",
  428. ));
  429. }
  430. match &args.args[0] {
  431. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  432. parse_account(&ty_path.path)
  433. }
  434. _ => Err(ParseError::new(
  435. args.args[1].span(),
  436. "Expected a path containing: [AccountType]<'info, T>",
  437. )),
  438. }
  439. }
  440. _ => Err(ParseError::new(
  441. segment.arguments.span(),
  442. "Expected angle brackets with a type: Box<[AccountType]<'info, T>",
  443. )),
  444. },
  445. _ => match &segment.arguments {
  446. syn::PathArguments::AngleBracketed(args) => {
  447. if args.args.len() != 2 {
  448. return Err(ParseError::new(
  449. args.args.span(),
  450. "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>",
  451. ));
  452. }
  453. match (&args.args[0], &args.args[1]) {
  454. (
  455. syn::GenericArgument::Lifetime(_),
  456. syn::GenericArgument::Type(syn::Type::Path(ty_path)),
  457. ) => {
  458. Ok(ty_path.clone())
  459. }
  460. _ => {
  461. Err(ParseError::new(
  462. args.args.span(),
  463. "Expected the two arguments to be a lifetime and a type: [AccountType]<'info, T>",
  464. ))
  465. }
  466. }
  467. }
  468. _ => Err(ParseError::new(
  469. segment.arguments.span(),
  470. "Expected angle brackets with a type: [AccountType]<'info, T>",
  471. )),
  472. },
  473. }
  474. }
  475. fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
  476. let segments = &path.segments[0];
  477. let account_ident = match &segments.arguments {
  478. syn::PathArguments::AngleBracketed(args) => {
  479. // Expected: <'info, MyType>.
  480. if args.args.len() != 2 {
  481. return Err(ParseError::new(
  482. args.args.span(),
  483. "bracket arguments must be the lifetime and type",
  484. ));
  485. }
  486. match &args.args[1] {
  487. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  488. // TODO: allow segmented paths.
  489. if ty_path.path.segments.len() != 1 {
  490. return Err(ParseError::new(
  491. ty_path.path.span(),
  492. "segmented paths are not currently allowed",
  493. ));
  494. }
  495. let path_segment = &ty_path.path.segments[0];
  496. path_segment.ident.clone()
  497. }
  498. _ => {
  499. return Err(ParseError::new(
  500. args.args[1].span(),
  501. "first bracket argument must be a lifetime",
  502. ))
  503. }
  504. }
  505. }
  506. _ => {
  507. return Err(ParseError::new(
  508. segments.arguments.span(),
  509. "expected angle brackets with a lifetime and type",
  510. ))
  511. }
  512. };
  513. let ty = match account_ident.to_string().as_str() {
  514. "Clock" => SysvarTy::Clock,
  515. "Rent" => SysvarTy::Rent,
  516. "EpochSchedule" => SysvarTy::EpochSchedule,
  517. "Fees" => SysvarTy::Fees,
  518. "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
  519. "SlotHashes" => SysvarTy::SlotHashes,
  520. "SlotHistory" => SysvarTy::SlotHistory,
  521. "StakeHistory" => SysvarTy::StakeHistory,
  522. "Instructions" => SysvarTy::Instructions,
  523. "Rewards" => SysvarTy::Rewards,
  524. _ => {
  525. return Err(ParseError::new(
  526. account_ident.span(),
  527. "invalid sysvar provided",
  528. ))
  529. }
  530. };
  531. Ok(ty)
  532. }
  533. #[test]
  534. fn test_parse_account() {
  535. let expected_ty_path: syn::TypePath = syn::parse_quote!(u32);
  536. let path = syn::parse_quote! { Box<Account<'info, u32>> };
  537. let ty = parse_account(&path).unwrap();
  538. assert_eq!(ty, expected_ty_path);
  539. let path = syn::parse_quote! { Account<'info, u32> };
  540. let ty = parse_account(&path).unwrap();
  541. assert_eq!(ty, expected_ty_path);
  542. let path = syn::parse_quote! { Box<Account<'info, u32, u64>> };
  543. let err = parse_account(&path).unwrap_err();
  544. assert_eq!(
  545. err.to_string(),
  546. "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>"
  547. );
  548. let path = syn::parse_quote! { Box<Account<'info>> };
  549. let err = parse_account(&path).unwrap_err();
  550. assert_eq!(
  551. err.to_string(),
  552. "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>"
  553. );
  554. let path = syn::parse_quote! { Box<Account> };
  555. let err = parse_account(&path).unwrap_err();
  556. assert_eq!(
  557. err.to_string(),
  558. "Expected angle brackets with a type: [AccountType]<'info, T>"
  559. );
  560. }