accounts.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. use crate::{
  2. AccountField, AccountsStruct, CompositeField, Constraint, ConstraintBelongsTo,
  3. ConstraintExecutable, ConstraintLiteral, ConstraintOwner, ConstraintRentExempt,
  4. ConstraintSeeds, ConstraintSigner, CpiAccountTy, Field, ProgramAccountTy, ProgramStateTy,
  5. SysvarTy, Ty,
  6. };
  7. pub fn parse(strct: &syn::ItemStruct) -> AccountsStruct {
  8. let fields = match &strct.fields {
  9. syn::Fields::Named(fields) => fields.named.iter().map(parse_account_field).collect(),
  10. _ => panic!("invalid input"),
  11. };
  12. AccountsStruct::new(strct.clone(), fields)
  13. }
  14. fn parse_account_field(f: &syn::Field) -> AccountField {
  15. let anchor_attr = parse_account_attr(f);
  16. parse_field(f, anchor_attr)
  17. }
  18. fn parse_account_attr(f: &syn::Field) -> Option<&syn::Attribute> {
  19. let anchor_attrs: Vec<&syn::Attribute> = f
  20. .attrs
  21. .iter()
  22. .filter(|attr| {
  23. if attr.path.segments.len() != 1 {
  24. return false;
  25. }
  26. if attr.path.segments[0].ident != "account" {
  27. return false;
  28. }
  29. true
  30. })
  31. .collect();
  32. match anchor_attrs.len() {
  33. 0 => None,
  34. 1 => Some(anchor_attrs[0]),
  35. _ => panic!("Invalid syntax: please specify one account attribute."),
  36. }
  37. }
  38. fn parse_field(f: &syn::Field, anchor: Option<&syn::Attribute>) -> AccountField {
  39. let ident = f.ident.clone().unwrap();
  40. let (constraints, is_mut, is_signer, is_init) = match anchor {
  41. None => (vec![], false, false, false),
  42. Some(anchor) => parse_constraints(anchor),
  43. };
  44. match is_field_primitive(f) {
  45. true => {
  46. let ty = parse_ty(f);
  47. AccountField::Field(Field {
  48. ident,
  49. ty,
  50. constraints,
  51. is_mut,
  52. is_signer,
  53. is_init,
  54. })
  55. }
  56. false => AccountField::AccountsStruct(CompositeField {
  57. ident,
  58. symbol: ident_string(f),
  59. constraints,
  60. raw_field: f.clone(),
  61. }),
  62. }
  63. }
  64. fn is_field_primitive(f: &syn::Field) -> bool {
  65. match ident_string(f).as_str() {
  66. "ProgramState" | "ProgramAccount" | "CpiAccount" | "Sysvar" | "AccountInfo" => true,
  67. _ => false,
  68. }
  69. }
  70. fn parse_ty(f: &syn::Field) -> Ty {
  71. let path = match &f.ty {
  72. syn::Type::Path(ty_path) => ty_path.path.clone(),
  73. _ => panic!("invalid account syntax"),
  74. };
  75. match ident_string(f).as_str() {
  76. "ProgramState" => Ty::ProgramState(parse_program_state(&path)),
  77. "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)),
  78. "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)),
  79. "Sysvar" => Ty::Sysvar(parse_sysvar(&path)),
  80. "AccountInfo" => Ty::AccountInfo,
  81. _ => panic!("invalid account type"),
  82. }
  83. }
  84. fn ident_string(f: &syn::Field) -> String {
  85. let path = match &f.ty {
  86. syn::Type::Path(ty_path) => ty_path.path.clone(),
  87. _ => panic!("invalid account syntax"),
  88. };
  89. // TODO: allow segmented paths.
  90. assert!(path.segments.len() == 1);
  91. let segments = &path.segments[0];
  92. segments.ident.to_string()
  93. }
  94. fn parse_program_state(path: &syn::Path) -> ProgramStateTy {
  95. let account_ident = parse_account(&path);
  96. ProgramStateTy { account_ident }
  97. }
  98. fn parse_cpi_account(path: &syn::Path) -> CpiAccountTy {
  99. let account_ident = parse_account(path);
  100. CpiAccountTy { account_ident }
  101. }
  102. fn parse_program_account(path: &syn::Path) -> ProgramAccountTy {
  103. let account_ident = parse_account(path);
  104. ProgramAccountTy { account_ident }
  105. }
  106. fn parse_account(path: &syn::Path) -> syn::Ident {
  107. let segments = &path.segments[0];
  108. match &segments.arguments {
  109. syn::PathArguments::AngleBracketed(args) => {
  110. // Expected: <'info, MyType>.
  111. assert!(args.args.len() == 2);
  112. match &args.args[1] {
  113. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  114. // TODO: allow segmented paths.
  115. assert!(ty_path.path.segments.len() == 1);
  116. let path_segment = &ty_path.path.segments[0];
  117. path_segment.ident.clone()
  118. }
  119. _ => panic!("Invalid ProgramAccount"),
  120. }
  121. }
  122. _ => panic!("Invalid ProgramAccount"),
  123. }
  124. }
  125. fn parse_sysvar(path: &syn::Path) -> SysvarTy {
  126. let segments = &path.segments[0];
  127. let account_ident = match &segments.arguments {
  128. syn::PathArguments::AngleBracketed(args) => {
  129. // Expected: <'info, MyType>.
  130. assert!(args.args.len() == 2);
  131. match &args.args[1] {
  132. syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
  133. // TODO: allow segmented paths.
  134. assert!(ty_path.path.segments.len() == 1);
  135. let path_segment = &ty_path.path.segments[0];
  136. path_segment.ident.clone()
  137. }
  138. _ => panic!("Invalid Sysvar"),
  139. }
  140. }
  141. _ => panic!("Invalid Sysvar"),
  142. };
  143. match account_ident.to_string().as_str() {
  144. "Clock" => SysvarTy::Clock,
  145. "Rent" => SysvarTy::Rent,
  146. "EpochSchedule" => SysvarTy::EpochSchedule,
  147. "Fees" => SysvarTy::Fees,
  148. "RecentBlockhashes" => SysvarTy::RecentBlockHashes,
  149. "SlotHashes" => SysvarTy::SlotHashes,
  150. "SlotHistory" => SysvarTy::SlotHistory,
  151. "StakeHistory" => SysvarTy::StakeHistory,
  152. "Instructions" => SysvarTy::Instructions,
  153. "Rewards" => SysvarTy::Rewards,
  154. _ => panic!("Invalid Sysvar"),
  155. }
  156. }
  157. fn parse_constraints(anchor: &syn::Attribute) -> (Vec<Constraint>, bool, bool, bool) {
  158. let mut tts = anchor.tokens.clone().into_iter();
  159. let g_stream = match tts.next().expect("Must have a token group") {
  160. proc_macro2::TokenTree::Group(g) => g.stream(),
  161. _ => panic!("Invalid syntax"),
  162. };
  163. let mut is_init = false;
  164. let mut is_mut = false;
  165. let mut is_signer = false;
  166. let mut constraints = vec![];
  167. let mut is_rent_exempt = None;
  168. let mut inner_tts = g_stream.into_iter();
  169. while let Some(token) = inner_tts.next() {
  170. match token {
  171. proc_macro2::TokenTree::Ident(ident) => match ident.to_string().as_str() {
  172. "init" => {
  173. is_init = true;
  174. is_mut = true;
  175. // If it's not specified, all program owned accounts default
  176. // to being rent exempt.
  177. if is_rent_exempt.is_none() {
  178. is_rent_exempt = Some(true);
  179. }
  180. }
  181. "mut" => {
  182. is_mut = true;
  183. }
  184. "signer" => {
  185. is_signer = true;
  186. constraints.push(Constraint::Signer(ConstraintSigner {}));
  187. }
  188. "seeds" => {
  189. match inner_tts.next().unwrap() {
  190. proc_macro2::TokenTree::Punct(punct) => {
  191. assert!(punct.as_char() == '=');
  192. punct
  193. }
  194. _ => panic!("invalid syntax"),
  195. };
  196. let seeds = match inner_tts.next().unwrap() {
  197. proc_macro2::TokenTree::Group(g) => g,
  198. _ => panic!("invalid syntax"),
  199. };
  200. constraints.push(Constraint::Seeds(ConstraintSeeds { seeds }))
  201. }
  202. "belongs_to" | "has_one" => {
  203. match inner_tts.next().unwrap() {
  204. proc_macro2::TokenTree::Punct(punct) => {
  205. assert!(punct.as_char() == '=');
  206. punct
  207. }
  208. _ => panic!("invalid syntax"),
  209. };
  210. let join_target = match inner_tts.next().unwrap() {
  211. proc_macro2::TokenTree::Ident(ident) => ident,
  212. _ => panic!("invalid syntax"),
  213. };
  214. constraints.push(Constraint::BelongsTo(ConstraintBelongsTo { join_target }))
  215. }
  216. "owner" => {
  217. match inner_tts.next().unwrap() {
  218. proc_macro2::TokenTree::Punct(punct) => {
  219. assert!(punct.as_char() == '=');
  220. punct
  221. }
  222. _ => panic!("invalid syntax"),
  223. };
  224. let owner = match inner_tts.next().unwrap() {
  225. proc_macro2::TokenTree::Ident(ident) => ident,
  226. _ => panic!("invalid syntax"),
  227. };
  228. let constraint = match owner.to_string().as_str() {
  229. "program" => ConstraintOwner::Program,
  230. "skip" => ConstraintOwner::Skip,
  231. _ => panic!("invalid syntax"),
  232. };
  233. constraints.push(Constraint::Owner(constraint));
  234. }
  235. "rent_exempt" => {
  236. match inner_tts.next() {
  237. None => is_rent_exempt = Some(true),
  238. Some(tkn) => {
  239. match tkn {
  240. proc_macro2::TokenTree::Punct(punct) => {
  241. assert!(punct.as_char() == '=');
  242. punct
  243. }
  244. _ => panic!("invalid syntax"),
  245. };
  246. let should_skip = match inner_tts.next().unwrap() {
  247. proc_macro2::TokenTree::Ident(ident) => ident,
  248. _ => panic!("invalid syntax"),
  249. };
  250. match should_skip.to_string().as_str() {
  251. "skip" => {
  252. is_rent_exempt = Some(false);
  253. },
  254. _ => panic!("invalid syntax: omit the rent_exempt attribute to enforce rent exemption"),
  255. };
  256. }
  257. };
  258. }
  259. "executable" => {
  260. constraints.push(Constraint::Executable(ConstraintExecutable {}));
  261. }
  262. _ => {
  263. panic!("invalid syntax");
  264. }
  265. },
  266. proc_macro2::TokenTree::Punct(punct) => {
  267. if punct.as_char() != ',' {
  268. panic!("invalid syntax");
  269. }
  270. }
  271. proc_macro2::TokenTree::Literal(literal) => {
  272. let tokens: proc_macro2::TokenStream =
  273. literal.to_string().replace("\"", "").parse().unwrap();
  274. constraints.push(Constraint::Literal(ConstraintLiteral { tokens }));
  275. }
  276. _ => {
  277. panic!("invalid syntax");
  278. }
  279. }
  280. }
  281. if let Some(is_re) = is_rent_exempt {
  282. match is_re {
  283. false => constraints.push(Constraint::RentExempt(ConstraintRentExempt::Skip)),
  284. true => constraints.push(Constraint::RentExempt(ConstraintRentExempt::Enforce)),
  285. }
  286. }
  287. (constraints, is_mut, is_signer, is_init)
  288. }