mod.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. use crate::parser::docs;
  2. use crate::*;
  3. use syn::parse::{Error as ParseError, Result as ParseResult};
  4. use syn::punctuated::Punctuated;
  5. use syn::spanned::Spanned;
  6. use syn::token::Comma;
  7. use syn::Expr;
  8. pub mod constraints;
  9. pub fn parse(strct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
  10. let instruction_api: Option<Punctuated<Expr, Comma>> = strct
  11. .attrs
  12. .iter()
  13. .find(|a| {
  14. a.path
  15. .get_ident()
  16. .map_or(false, |ident| ident == "instruction")
  17. })
  18. .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
  19. .transpose()?;
  20. let fields = match &strct.fields {
  21. syn::Fields::Named(fields) => fields
  22. .named
  23. .iter()
  24. .map(parse_account_field)
  25. .collect::<ParseResult<Vec<AccountField>>>()?,
  26. _ => {
  27. return Err(ParseError::new_spanned(
  28. &strct.fields,
  29. "fields must be named",
  30. ))
  31. }
  32. };
  33. let _ = constraints_cross_checks(&fields)?;
  34. Ok(AccountsStruct::new(strct.clone(), fields, instruction_api))
  35. }
  36. fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
  37. // INIT
  38. let init_fields: Vec<&Field> = fields
  39. .iter()
  40. .filter_map(|f| match f {
  41. AccountField::Field(field) if field.constraints.init.is_some() => Some(field),
  42. _ => None,
  43. })
  44. .collect();
  45. if !init_fields.is_empty() {
  46. // init needs system program.
  47. if fields.iter().all(|f| f.ident() != "system_program") {
  48. return Err(ParseError::new(
  49. init_fields[0].ident.span(),
  50. "the init constraint requires \
  51. the system_program field to exist in the account \
  52. validation struct. Use the program type to add \
  53. the system_program field to your validation struct.",
  54. ));
  55. }
  56. let kind = &init_fields[0].constraints.init.as_ref().unwrap().kind;
  57. // init token/a_token/mint needs token program.
  58. match kind {
  59. InitKind::Program { .. } => (),
  60. InitKind::Token { .. } | InitKind::AssociatedToken { .. } | InitKind::Mint { .. } => {
  61. if fields.iter().all(|f| f.ident() != "token_program") {
  62. return Err(ParseError::new(
  63. init_fields[0].ident.span(),
  64. "the init constraint requires \
  65. the token_program field to exist in the account \
  66. validation struct. Use the program type to add \
  67. the token_program field to your validation struct.",
  68. ));
  69. }
  70. }
  71. }
  72. // a_token needs associated token program.
  73. if let InitKind::AssociatedToken { .. } = kind {
  74. if fields
  75. .iter()
  76. .all(|f| f.ident() != "associated_token_program")
  77. {
  78. return Err(ParseError::new(
  79. init_fields[0].ident.span(),
  80. "the init constraint requires \
  81. the associated_token_program field to exist in the account \
  82. validation struct. Use the program type to add \
  83. the associated_token_program field to your validation struct.",
  84. ));
  85. }
  86. }
  87. for field in init_fields {
  88. // Get payer for init-ed account
  89. let associated_payer_name = match field.constraints.init.clone().unwrap().payer {
  90. // composite payer, check not supported
  91. Expr::Field(_) => continue,
  92. field_name => field_name.to_token_stream().to_string(),
  93. };
  94. // Check payer is mutable
  95. let associated_payer_field = fields.iter().find_map(|f| match f {
  96. AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
  97. _ => None,
  98. });
  99. match associated_payer_field {
  100. Some(associated_payer_field) => {
  101. if !associated_payer_field.constraints.is_mutable() {
  102. return Err(ParseError::new(
  103. field.ident.span(),
  104. "the payer specified for an init constraint must be mutable.",
  105. ));
  106. }
  107. }
  108. _ => {
  109. return Err(ParseError::new(
  110. field.ident.span(),
  111. "the payer specified does not exist.",
  112. ));
  113. }
  114. }
  115. match kind {
  116. // This doesn't catch cases like account.key() or account.key.
  117. // My guess is that doesn't happen often and we can revisit
  118. // this if I'm wrong.
  119. InitKind::Token { mint, .. } | InitKind::AssociatedToken { mint, .. } => {
  120. if !fields.iter().any(|f| {
  121. f.ident()
  122. .to_string()
  123. .starts_with(&mint.to_token_stream().to_string())
  124. }) {
  125. return Err(ParseError::new(
  126. field.ident.span(),
  127. "the mint constraint has to be an account field for token initializations (not a public key)",
  128. ));
  129. }
  130. }
  131. _ => (),
  132. }
  133. }
  134. }
  135. Ok(())
  136. }
  137. pub fn parse_account_field(f: &syn::Field) -> ParseResult<AccountField> {
  138. let ident = f.ident.clone().unwrap();
  139. let docs = docs::parse(&f.attrs);
  140. let account_field = match is_field_primitive(f)? {
  141. true => {
  142. let ty = parse_ty(f)?;
  143. let account_constraints = constraints::parse(f, Some(&ty))?;
  144. AccountField::Field(Field {
  145. ident,
  146. ty,
  147. constraints: account_constraints,
  148. docs,
  149. })
  150. }
  151. false => {
  152. let account_constraints = constraints::parse(f, None)?;
  153. AccountField::CompositeField(CompositeField {
  154. ident,
  155. constraints: account_constraints,
  156. symbol: ident_string(f)?,
  157. raw_field: f.clone(),
  158. docs,
  159. })
  160. }
  161. };
  162. Ok(account_field)
  163. }
  164. fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
  165. let r = matches!(
  166. ident_string(f)?.as_str(),
  167. "ProgramState"
  168. | "ProgramAccount"
  169. | "CpiAccount"
  170. | "Sysvar"
  171. | "AccountInfo"
  172. | "UncheckedAccount"
  173. | "CpiState"
  174. | "Loader"
  175. | "AccountLoader"
  176. | "Account"
  177. | "Program"
  178. | "Signer"
  179. | "SystemAccount"
  180. | "ProgramData"
  181. );
  182. Ok(r)
  183. }
  184. fn parse_ty(f: &syn::Field) -> ParseResult<Ty> {
  185. let path = match &f.ty {
  186. syn::Type::Path(ty_path) => ty_path.path.clone(),
  187. _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
  188. };
  189. let ty = match ident_string(f)?.as_str() {
  190. "ProgramState" => Ty::ProgramState(parse_program_state(&path)?),
  191. "CpiState" => Ty::CpiState(parse_cpi_state(&path)?),
  192. "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)?),
  193. "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)?),
  194. "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
  195. "AccountInfo" => Ty::AccountInfo,
  196. "UncheckedAccount" => Ty::UncheckedAccount,
  197. "Loader" => Ty::Loader(parse_program_account_zero_copy(&path)?),
  198. "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
  199. "Account" => Ty::Account(parse_account_ty(&path)?),
  200. "Program" => Ty::Program(parse_program_ty(&path)?),
  201. "Signer" => Ty::Signer,
  202. "SystemAccount" => Ty::SystemAccount,
  203. "ProgramData" => Ty::ProgramData,
  204. _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
  205. };
  206. Ok(ty)
  207. }
  208. fn ident_string(f: &syn::Field) -> ParseResult<String> {
  209. let path = match &f.ty {
  210. syn::Type::Path(ty_path) => ty_path.path.clone(),
  211. _ => return Err(ParseError::new(f.ty.span(), "invalid type")),
  212. };
  213. if parser::tts_to_string(&path)
  214. .replace(' ', "")
  215. .starts_with("Box<Account<")
  216. {
  217. return Ok("Account".to_string());
  218. }
  219. // TODO: allow segmented paths.
  220. if path.segments.len() != 1 {
  221. return Err(ParseError::new(
  222. f.ty.span(),
  223. "segmented paths are not currently allowed",
  224. ));
  225. }
  226. let segments = &path.segments[0];
  227. Ok(segments.ident.to_string())
  228. }
  229. fn parse_program_state(path: &syn::Path) -> ParseResult<ProgramStateTy> {
  230. let account_ident = parse_account(path)?;
  231. Ok(ProgramStateTy {
  232. account_type_path: account_ident,
  233. })
  234. }
  235. fn parse_cpi_state(path: &syn::Path) -> ParseResult<CpiStateTy> {
  236. let account_ident = parse_account(path)?;
  237. Ok(CpiStateTy {
  238. account_type_path: account_ident,
  239. })
  240. }
  241. fn parse_cpi_account(path: &syn::Path) -> ParseResult<CpiAccountTy> {
  242. let account_ident = parse_account(path)?;
  243. Ok(CpiAccountTy {
  244. account_type_path: account_ident,
  245. })
  246. }
  247. fn parse_program_account(path: &syn::Path) -> ParseResult<ProgramAccountTy> {
  248. let account_ident = parse_account(path)?;
  249. Ok(ProgramAccountTy {
  250. account_type_path: account_ident,
  251. })
  252. }
  253. fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult<LoaderTy> {
  254. let account_ident = parse_account(path)?;
  255. Ok(LoaderTy {
  256. account_type_path: account_ident,
  257. })
  258. }
  259. fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
  260. let account_ident = parse_account(path)?;
  261. Ok(AccountLoaderTy {
  262. account_type_path: account_ident,
  263. })
  264. }
  265. fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
  266. let account_type_path = parse_account(path)?;
  267. let boxed = parser::tts_to_string(&path)
  268. .replace(' ', "")
  269. .starts_with("Box<Account<");
  270. Ok(AccountTy {
  271. account_type_path,
  272. boxed,
  273. })
  274. }
  275. fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
  276. let account_type_path = parse_account(path)?;
  277. Ok(ProgramTy { account_type_path })
  278. }
  279. // TODO: this whole method is a hack. Do something more idiomatic.
  280. fn parse_account(mut path: &syn::Path) -> ParseResult<syn::TypePath> {
  281. if parser::tts_to_string(path)
  282. .replace(' ', "")
  283. .starts_with("Box<Account<")
  284. {
  285. let segments = &path.segments[0];
  286. match &segments.arguments {
  287. syn::PathArguments::AngleBracketed(args) => {
  288. // Expected: <'info, MyType>.
  289. if args.args.len() != 1 {
  290. return Err(ParseError::new(
  291. args.args.span(),
  292. "bracket arguments must be the lifetime and type",
  293. ));
  294. }
  295. match &args.args[0] {
  296. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  297. path = &ty_path.path;
  298. }
  299. _ => {
  300. return Err(ParseError::new(
  301. args.args[1].span(),
  302. "first bracket argument must be a lifetime",
  303. ))
  304. }
  305. }
  306. }
  307. _ => {
  308. return Err(ParseError::new(
  309. segments.arguments.span(),
  310. "expected angle brackets with a lifetime and type",
  311. ))
  312. }
  313. }
  314. }
  315. let segments = &path.segments[0];
  316. match &segments.arguments {
  317. syn::PathArguments::AngleBracketed(args) => {
  318. // Expected: <'info, MyType>.
  319. if args.args.len() != 2 {
  320. return Err(ParseError::new(
  321. args.args.span(),
  322. "bracket arguments must be the lifetime and type",
  323. ));
  324. }
  325. match &args.args[1] {
  326. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()),
  327. _ => Err(ParseError::new(
  328. args.args[1].span(),
  329. "first bracket argument must be a lifetime",
  330. )),
  331. }
  332. }
  333. _ => Err(ParseError::new(
  334. segments.arguments.span(),
  335. "expected angle brackets with a lifetime and type",
  336. )),
  337. }
  338. }
  339. fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
  340. let segments = &path.segments[0];
  341. let account_ident = match &segments.arguments {
  342. syn::PathArguments::AngleBracketed(args) => {
  343. // Expected: <'info, MyType>.
  344. if args.args.len() != 2 {
  345. return Err(ParseError::new(
  346. args.args.span(),
  347. "bracket arguments must be the lifetime and type",
  348. ));
  349. }
  350. match &args.args[1] {
  351. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  352. // TODO: allow segmented paths.
  353. if ty_path.path.segments.len() != 1 {
  354. return Err(ParseError::new(
  355. ty_path.path.span(),
  356. "segmented paths are not currently allowed",
  357. ));
  358. }
  359. let path_segment = &ty_path.path.segments[0];
  360. path_segment.ident.clone()
  361. }
  362. _ => {
  363. return Err(ParseError::new(
  364. args.args[1].span(),
  365. "first bracket argument must be a lifetime",
  366. ))
  367. }
  368. }
  369. }
  370. _ => {
  371. return Err(ParseError::new(
  372. segments.arguments.span(),
  373. "expected angle brackets with a lifetime and type",
  374. ))
  375. }
  376. };
  377. let ty = match account_ident.to_string().as_str() {
  378. "Clock" => SysvarTy::Clock,
  379. "Rent" => SysvarTy::Rent,
  380. "EpochSchedule" => SysvarTy::EpochSchedule,
  381. "Fees" => SysvarTy::Fees,
  382. "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
  383. "SlotHashes" => SysvarTy::SlotHashes,
  384. "SlotHistory" => SysvarTy::SlotHistory,
  385. "StakeHistory" => SysvarTy::StakeHistory,
  386. "Instructions" => SysvarTy::Instructions,
  387. "Rewards" => SysvarTy::Rewards,
  388. _ => {
  389. return Err(ParseError::new(
  390. account_ident.span(),
  391. "invalid sysvar provided",
  392. ))
  393. }
  394. };
  395. Ok(ty)
  396. }