accounts.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. use crate::{
  2. AccountField, AccountsStruct, CompositeField, Constraint, ConstraintBelongsTo,
  3. ConstraintLiteral, ConstraintOwner, ConstraintRentExempt, ConstraintSeeds, ConstraintSigner,
  4. CpiAccountTy, Field, ProgramAccountTy, ProgramStateTy, SysvarTy, Ty,
  5. };
  6. pub fn parse(strct: &syn::ItemStruct) -> AccountsStruct {
  7. let fields = match &strct.fields {
  8. syn::Fields::Named(fields) => fields,
  9. _ => panic!("invalid input"),
  10. };
  11. let fields: Vec<AccountField> = fields
  12. .named
  13. .iter()
  14. .map(|f: &syn::Field| {
  15. let anchor_attr = {
  16. let anchor_attrs: Vec<&syn::Attribute> = f
  17. .attrs
  18. .iter()
  19. .filter_map(|attr: &syn::Attribute| {
  20. if attr.path.segments.len() != 1 {
  21. return None;
  22. }
  23. if attr.path.segments[0].ident.to_string() != "account" {
  24. return None;
  25. }
  26. Some(attr)
  27. })
  28. .collect();
  29. match anchor_attrs.len() {
  30. 0 => None,
  31. 1 => Some(anchor_attrs[0]),
  32. _ => panic!("invalid syntax: only one account attribute is allowed"),
  33. }
  34. };
  35. parse_field(f, anchor_attr)
  36. })
  37. .collect();
  38. AccountsStruct::new(strct.clone(), fields)
  39. }
  40. fn parse_field(f: &syn::Field, anchor: Option<&syn::Attribute>) -> AccountField {
  41. let ident = f.ident.clone().unwrap();
  42. let (constraints, is_mut, is_signer, is_init) = match anchor {
  43. None => (vec![], false, false, false),
  44. Some(anchor) => parse_constraints(anchor),
  45. };
  46. match is_field_primitive(f) {
  47. true => {
  48. let ty = parse_ty(f);
  49. AccountField::Field(Field {
  50. ident,
  51. ty,
  52. constraints,
  53. is_mut,
  54. is_signer,
  55. is_init,
  56. })
  57. }
  58. false => AccountField::AccountsStruct(CompositeField {
  59. ident,
  60. symbol: ident_string(f),
  61. constraints,
  62. raw_field: f.clone(),
  63. }),
  64. }
  65. }
  66. fn is_field_primitive(f: &syn::Field) -> bool {
  67. match ident_string(f).as_str() {
  68. "ProgramState" | "ProgramAccount" | "CpiAccount" | "Sysvar" | "AccountInfo" => true,
  69. _ => false,
  70. }
  71. }
  72. fn parse_ty(f: &syn::Field) -> Ty {
  73. let path = match &f.ty {
  74. syn::Type::Path(ty_path) => ty_path.path.clone(),
  75. _ => panic!("invalid account syntax"),
  76. };
  77. match ident_string(f).as_str() {
  78. "ProgramState" => Ty::ProgramState(parse_program_state(&path)),
  79. "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)),
  80. "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)),
  81. "Sysvar" => Ty::Sysvar(parse_sysvar(&path)),
  82. "AccountInfo" => Ty::AccountInfo,
  83. _ => panic!("invalid account type"),
  84. }
  85. }
  86. fn ident_string(f: &syn::Field) -> String {
  87. let path = match &f.ty {
  88. syn::Type::Path(ty_path) => ty_path.path.clone(),
  89. _ => panic!("invalid account syntax"),
  90. };
  91. // TODO: allow segmented paths.
  92. assert!(path.segments.len() == 1);
  93. let segments = &path.segments[0];
  94. segments.ident.to_string()
  95. }
  96. fn parse_program_state(path: &syn::Path) -> ProgramStateTy {
  97. let account_ident = parse_account(&path);
  98. ProgramStateTy { account_ident }
  99. }
  100. fn parse_cpi_account(path: &syn::Path) -> CpiAccountTy {
  101. let account_ident = parse_account(path);
  102. CpiAccountTy { account_ident }
  103. }
  104. fn parse_program_account(path: &syn::Path) -> ProgramAccountTy {
  105. let account_ident = parse_account(path);
  106. ProgramAccountTy { account_ident }
  107. }
  108. fn parse_account(path: &syn::Path) -> syn::Ident {
  109. let segments = &path.segments[0];
  110. let account_ident = match &segments.arguments {
  111. syn::PathArguments::AngleBracketed(args) => {
  112. // Expected: <'info, MyType>.
  113. assert!(args.args.len() == 2);
  114. match &args.args[1] {
  115. syn::GenericArgument::Type(ty) => match ty {
  116. syn::Type::Path(ty_path) => {
  117. // TODO: allow segmented paths.
  118. assert!(ty_path.path.segments.len() == 1);
  119. let path_segment = &ty_path.path.segments[0];
  120. path_segment.ident.clone()
  121. }
  122. _ => panic!("Invalid ProgramAccount"),
  123. },
  124. _ => panic!("Invalid ProgramAccount"),
  125. }
  126. }
  127. _ => panic!("Invalid ProgramAccount"),
  128. };
  129. account_ident
  130. }
  131. fn parse_sysvar(path: &syn::Path) -> SysvarTy {
  132. let segments = &path.segments[0];
  133. let account_ident = match &segments.arguments {
  134. syn::PathArguments::AngleBracketed(args) => {
  135. // Expected: <'info, MyType>.
  136. assert!(args.args.len() == 2);
  137. match &args.args[1] {
  138. syn::GenericArgument::Type(ty) => match ty {
  139. syn::Type::Path(ty_path) => {
  140. // TODO: allow segmented paths.
  141. assert!(ty_path.path.segments.len() == 1);
  142. let path_segment = &ty_path.path.segments[0];
  143. path_segment.ident.clone()
  144. }
  145. _ => panic!("Invalid Sysvar"),
  146. },
  147. _ => panic!("Invalid Sysvar"),
  148. }
  149. }
  150. _ => panic!("Invalid Sysvar"),
  151. };
  152. match account_ident.to_string().as_str() {
  153. "Clock" => SysvarTy::Clock,
  154. "Rent" => SysvarTy::Rent,
  155. "EpochSchedule" => SysvarTy::EpochSchedule,
  156. "Fees" => SysvarTy::Fees,
  157. "RecentBlockhashes" => SysvarTy::RecentBlockHashes,
  158. "SlotHashes" => SysvarTy::SlotHashes,
  159. "SlotHistory" => SysvarTy::SlotHistory,
  160. "StakeHistory" => SysvarTy::StakeHistory,
  161. "Instructions" => SysvarTy::Instructions,
  162. "Rewards" => SysvarTy::Rewards,
  163. _ => panic!("Invalid Sysvar"),
  164. }
  165. }
  166. fn parse_constraints(anchor: &syn::Attribute) -> (Vec<Constraint>, bool, bool, bool) {
  167. let mut tts = anchor.tokens.clone().into_iter();
  168. let g_stream = match tts.next().expect("Must have a token group") {
  169. proc_macro2::TokenTree::Group(g) => g.stream(),
  170. _ => panic!("Invalid syntax"),
  171. };
  172. let mut is_init = false;
  173. let mut is_mut = false;
  174. let mut is_signer = false;
  175. let mut constraints = vec![];
  176. let mut is_rent_exempt = None;
  177. let mut inner_tts = g_stream.into_iter();
  178. while let Some(token) = inner_tts.next() {
  179. match token {
  180. proc_macro2::TokenTree::Ident(ident) => match ident.to_string().as_str() {
  181. "init" => {
  182. is_init = true;
  183. is_mut = true;
  184. // If it's not specified, all program owned accounts default
  185. // to being rent exempt.
  186. if is_rent_exempt.is_none() {
  187. is_rent_exempt = Some(true);
  188. }
  189. }
  190. "mut" => {
  191. is_mut = true;
  192. }
  193. "signer" => {
  194. is_signer = true;
  195. constraints.push(Constraint::Signer(ConstraintSigner {}));
  196. }
  197. "seeds" => {
  198. match inner_tts.next().unwrap() {
  199. proc_macro2::TokenTree::Punct(punct) => {
  200. assert!(punct.as_char() == '=');
  201. punct
  202. }
  203. _ => panic!("invalid syntax"),
  204. };
  205. let seeds = match inner_tts.next().unwrap() {
  206. proc_macro2::TokenTree::Group(g) => g,
  207. _ => panic!("invalid syntax"),
  208. };
  209. constraints.push(Constraint::Seeds(ConstraintSeeds { seeds }))
  210. }
  211. "belongs_to" | "has_one" => {
  212. match inner_tts.next().unwrap() {
  213. proc_macro2::TokenTree::Punct(punct) => {
  214. assert!(punct.as_char() == '=');
  215. punct
  216. }
  217. _ => panic!("invalid syntax"),
  218. };
  219. let join_target = match inner_tts.next().unwrap() {
  220. proc_macro2::TokenTree::Ident(ident) => ident,
  221. _ => panic!("invalid syntax"),
  222. };
  223. constraints.push(Constraint::BelongsTo(ConstraintBelongsTo { join_target }))
  224. }
  225. "owner" => {
  226. match inner_tts.next().unwrap() {
  227. proc_macro2::TokenTree::Punct(punct) => {
  228. assert!(punct.as_char() == '=');
  229. punct
  230. }
  231. _ => panic!("invalid syntax"),
  232. };
  233. let owner = match inner_tts.next().unwrap() {
  234. proc_macro2::TokenTree::Ident(ident) => ident,
  235. _ => panic!("invalid syntax"),
  236. };
  237. let constraint = match owner.to_string().as_str() {
  238. "program" => ConstraintOwner::Program,
  239. "skip" => ConstraintOwner::Skip,
  240. _ => panic!("invalid syntax"),
  241. };
  242. constraints.push(Constraint::Owner(constraint));
  243. }
  244. "rent_exempt" => {
  245. match inner_tts.next() {
  246. None => is_rent_exempt = Some(true),
  247. Some(tkn) => {
  248. match tkn {
  249. proc_macro2::TokenTree::Punct(punct) => {
  250. assert!(punct.as_char() == '=');
  251. punct
  252. }
  253. _ => panic!("invalid syntax"),
  254. };
  255. let should_skip = match inner_tts.next().unwrap() {
  256. proc_macro2::TokenTree::Ident(ident) => ident,
  257. _ => panic!("invalid syntax"),
  258. };
  259. match should_skip.to_string().as_str() {
  260. "skip" => {
  261. is_rent_exempt = Some(false);
  262. },
  263. _ => panic!("invalid syntax: omit the rent_exempt attribute to enforce rent exemption"),
  264. };
  265. }
  266. };
  267. }
  268. _ => {
  269. panic!("invalid syntax");
  270. }
  271. },
  272. proc_macro2::TokenTree::Punct(punct) => {
  273. if punct.as_char() != ',' {
  274. panic!("invalid syntax");
  275. }
  276. }
  277. proc_macro2::TokenTree::Literal(literal) => {
  278. let tokens: proc_macro2::TokenStream =
  279. literal.to_string().replace("\"", "").parse().unwrap();
  280. constraints.push(Constraint::Literal(ConstraintLiteral { tokens }));
  281. }
  282. _ => {
  283. panic!("invalid syntax");
  284. }
  285. }
  286. }
  287. if let Some(is_re) = is_rent_exempt {
  288. match is_re {
  289. false => constraints.push(Constraint::RentExempt(ConstraintRentExempt::Skip)),
  290. true => constraints.push(Constraint::RentExempt(ConstraintRentExempt::Enforce)),
  291. }
  292. }
  293. (constraints, is_mut, is_signer, is_init)
  294. }