pda.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. use crate::idl::*;
  2. use crate::parser;
  3. use crate::parser::context::CrateContext;
  4. use crate::ConstraintSeedsGroup;
  5. use crate::{AccountsStruct, Field};
  6. use std::collections::HashMap;
  7. use std::str::FromStr;
  8. use syn::{Expr, ExprLit, Lit};
  9. // Parses a seeds constraint, extracting the IdlSeed types.
  10. //
  11. // Note: This implementation makes assumptions about the types that can be used
  12. // (e.g., no program-defined function calls in seeds).
  13. //
  14. // This probably doesn't cover all cases. If you see a warning log, you
  15. // can add a new case here. In the worst case, we miss a seed and
  16. // the parser will treat the given seeds as empty and so clients will
  17. // simply fail to automatically populate the PDA accounts.
  18. //
  19. // Seed Assumptions: Seeds must be of one of the following forms:
  20. //
  21. // - instruction argument.
  22. // - account context field pubkey.
  23. // - account data, where the account is defined in the current program.
  24. // We make an exception for the SPL token program, since it is so common
  25. // and sometimes convenient to use fields as a seed (e.g. Auction house
  26. // program). In the case of nested structs/account data, all nested structs
  27. // must be defined in the current program as well.
  28. // - byte string literal (e.g. b"MY_SEED").
  29. // - byte string literal constant (e.g. `pub const MY_SEED: [u8; 2] = *b"hi";`).
  30. // - array constants.
  31. //
  32. pub fn parse(
  33. ctx: &CrateContext,
  34. accounts: &AccountsStruct,
  35. acc: &Field,
  36. seeds_feature: bool,
  37. ) -> Option<IdlPda> {
  38. if !seeds_feature {
  39. return None;
  40. }
  41. let pda_parser = PdaParser::new(ctx, accounts);
  42. acc.constraints
  43. .seeds
  44. .as_ref()
  45. .map(|s| pda_parser.parse(s))
  46. .unwrap_or(None)
  47. }
  48. struct PdaParser<'a> {
  49. ctx: &'a CrateContext,
  50. // Accounts context.
  51. accounts: &'a AccountsStruct,
  52. // Maps var name to var type. These are the instruction arguments in a
  53. // given accounts context.
  54. ix_args: HashMap<String, String>,
  55. // Constants available in the crate.
  56. const_names: Vec<String>,
  57. // Constants declared in impl blocks available in the crate
  58. impl_const_names: Vec<String>,
  59. // All field names of the accounts in the accounts context.
  60. account_field_names: Vec<String>,
  61. }
  62. impl<'a> PdaParser<'a> {
  63. fn new(ctx: &'a CrateContext, accounts: &'a AccountsStruct) -> Self {
  64. // All the available sources of seeds.
  65. let ix_args = accounts.instruction_args().unwrap_or_default();
  66. let const_names: Vec<String> = ctx.consts().map(|c| c.ident.to_string()).collect();
  67. let impl_const_names: Vec<String> = ctx
  68. .impl_consts()
  69. .map(|(ident, item)| format!("{} :: {}", ident, item.ident))
  70. .collect();
  71. let account_field_names = accounts.field_names();
  72. Self {
  73. ctx,
  74. accounts,
  75. ix_args,
  76. const_names,
  77. impl_const_names,
  78. account_field_names,
  79. }
  80. }
  81. fn parse(&self, seeds_grp: &ConstraintSeedsGroup) -> Option<IdlPda> {
  82. // Extract the idl seed types from the constraints.
  83. let seeds = seeds_grp
  84. .seeds
  85. .iter()
  86. .map(|s| self.parse_seed(s))
  87. .collect::<Option<Vec<_>>>()?;
  88. // Parse the program id from the constraints.
  89. let program_id = seeds_grp
  90. .program_seed
  91. .as_ref()
  92. .map(|pid| self.parse_seed(pid))
  93. .unwrap_or_default();
  94. // Done.
  95. Some(IdlPda { seeds, program_id })
  96. }
  97. fn parse_seed(&self, seed: &Expr) -> Option<IdlSeed> {
  98. match seed {
  99. Expr::MethodCall(_) => {
  100. let seed_path = parse_seed_path(seed)?;
  101. if self.is_instruction(&seed_path) {
  102. self.parse_instruction(&seed_path)
  103. } else if self.is_const(&seed_path) {
  104. self.parse_const(&seed_path)
  105. } else if self.is_impl_const(&seed_path) {
  106. self.parse_impl_const(&seed_path)
  107. } else if self.is_account(&seed_path) {
  108. self.parse_account(&seed_path)
  109. } else if self.is_str_literal(&seed_path) {
  110. self.parse_str_literal(&seed_path)
  111. } else {
  112. println!("WARNING: unexpected seed category for var: {:?}", seed_path);
  113. None
  114. }
  115. }
  116. Expr::Reference(expr_reference) => self.parse_seed(&expr_reference.expr),
  117. Expr::Index(_) => {
  118. println!("WARNING: auto pda derivation not currently supported for slice literals");
  119. None
  120. }
  121. Expr::Lit(ExprLit {
  122. lit: Lit::ByteStr(lit_byte_str),
  123. ..
  124. }) => {
  125. let seed_path: SeedPath = SeedPath(lit_byte_str.token().to_string(), Vec::new());
  126. self.parse_str_literal(&seed_path)
  127. }
  128. // Unknown type. Please file an issue.
  129. _ => {
  130. println!("WARNING: unexpected seed: {:?}", seed);
  131. None
  132. }
  133. }
  134. }
  135. fn parse_instruction(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
  136. let idl_ty = IdlType::from_str(self.ix_args.get(&seed_path.name()).unwrap()).ok()?;
  137. Some(IdlSeed::Arg(IdlSeedArg {
  138. ty: idl_ty,
  139. path: seed_path.path(),
  140. }))
  141. }
  142. fn parse_const(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
  143. // Pull in the constant value directly into the IDL.
  144. assert!(seed_path.components().is_empty());
  145. let const_item = self
  146. .ctx
  147. .consts()
  148. .find(|c| c.ident == seed_path.name())
  149. .unwrap();
  150. let idl_ty = IdlType::from_str(&parser::tts_to_string(&const_item.ty)).ok()?;
  151. let idl_ty_value = parser::tts_to_string(&const_item.expr);
  152. let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
  153. Some(IdlSeed::Const(IdlSeedConst {
  154. ty: idl_ty,
  155. value: serde_json::from_str(&idl_ty_value).unwrap(),
  156. }))
  157. }
  158. fn parse_impl_const(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
  159. // Pull in the constant value directly into the IDL.
  160. assert!(seed_path.components().is_empty());
  161. let static_item = self
  162. .ctx
  163. .impl_consts()
  164. .find(|(ident, item)| format!("{} :: {}", ident, item.ident) == seed_path.name())
  165. .unwrap()
  166. .1;
  167. let idl_ty = IdlType::from_str(&parser::tts_to_string(&static_item.ty)).ok()?;
  168. let idl_ty_value = parser::tts_to_string(&static_item.expr);
  169. let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
  170. Some(IdlSeed::Const(IdlSeedConst {
  171. ty: idl_ty,
  172. value: serde_json::from_str(&idl_ty_value).unwrap(),
  173. }))
  174. }
  175. fn parse_account(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
  176. // Get the anchor account field from the derive accounts struct.
  177. let account_field = self
  178. .accounts
  179. .fields
  180. .iter()
  181. .find(|field| *field.ident() == seed_path.name())
  182. .unwrap();
  183. // Follow the path to find the seed type.
  184. let ty = {
  185. let mut path = seed_path.components();
  186. match path.len() {
  187. 0 => IdlType::PublicKey,
  188. 1 => {
  189. // Name of the account struct.
  190. let account = account_field.ty_name()?;
  191. if account == "TokenAccount" {
  192. assert!(path.len() == 1);
  193. match path[0].as_str() {
  194. "mint" => IdlType::PublicKey,
  195. "amount" => IdlType::U64,
  196. "authority" => IdlType::PublicKey,
  197. "delegated_amount" => IdlType::U64,
  198. _ => {
  199. println!("WARNING: token field isn't supported: {}", &path[0]);
  200. return None;
  201. }
  202. }
  203. } else {
  204. // Get the rust representation of the field's struct.
  205. let strct = self.ctx.structs().find(|s| s.ident == account).unwrap();
  206. parse_field_path(self.ctx, strct, &mut path)
  207. }
  208. }
  209. _ => panic!("invariant violation"),
  210. }
  211. };
  212. Some(IdlSeed::Account(IdlSeedAccount {
  213. ty,
  214. account: account_field.ty_name(),
  215. path: seed_path.path(),
  216. }))
  217. }
  218. fn parse_str_literal(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
  219. let mut var_name = seed_path.name();
  220. // Remove the byte `b` prefix if the string is of the form `b"seed".
  221. if var_name.starts_with("b\"") {
  222. var_name.remove(0);
  223. }
  224. let value_string: String = var_name.chars().filter(|c| *c != '"').collect();
  225. Some(IdlSeed::Const(IdlSeedConst {
  226. value: serde_json::Value::String(value_string),
  227. ty: IdlType::String,
  228. }))
  229. }
  230. fn is_instruction(&self, seed_path: &SeedPath) -> bool {
  231. self.ix_args.contains_key(&seed_path.name())
  232. }
  233. fn is_const(&self, seed_path: &SeedPath) -> bool {
  234. self.const_names.contains(&seed_path.name())
  235. }
  236. fn is_impl_const(&self, seed_path: &SeedPath) -> bool {
  237. self.impl_const_names.contains(&seed_path.name())
  238. }
  239. fn is_account(&self, seed_path: &SeedPath) -> bool {
  240. self.account_field_names.contains(&seed_path.name())
  241. }
  242. fn is_str_literal(&self, seed_path: &SeedPath) -> bool {
  243. seed_path.components().is_empty() && seed_path.name().contains('"')
  244. }
  245. }
  246. // SeedPath represents the deconstructed syntax of a single pda seed,
  247. // consisting of a variable name and a vec of all the sub fields accessed
  248. // on that variable name. For example, if a seed is `my_field.my_data.as_ref()`,
  249. // then the field name is `my_field` and the vec of sub fields is `[my_data]`.
  250. #[derive(Debug)]
  251. struct SeedPath(String, Vec<String>);
  252. impl SeedPath {
  253. fn name(&self) -> String {
  254. self.0.clone()
  255. }
  256. // Full path to the data this seed represents.
  257. fn path(&self) -> String {
  258. match self.1.len() {
  259. 0 => self.0.clone(),
  260. _ => format!("{}.{}", self.name(), self.components().join(".")),
  261. }
  262. }
  263. // All path components for the subfields accessed on this seed.
  264. fn components(&self) -> &[String] {
  265. &self.1
  266. }
  267. }
  268. // Extracts the seed path from a single seed expression.
  269. fn parse_seed_path(seed: &Expr) -> Option<SeedPath> {
  270. // Convert the seed into the raw string representation.
  271. let seed_str = parser::tts_to_string(&seed);
  272. // Break up the seed into each sub field component.
  273. let mut components: Vec<&str> = seed_str.split(" . ").collect();
  274. if components.len() <= 1 {
  275. println!("WARNING: seeds are in an unexpected format: {:?}", seed);
  276. return None;
  277. }
  278. // The name of the variable (or field).
  279. let name = components.remove(0).to_string();
  280. // The path to the seed (only if the `name` type is a struct).
  281. let mut path = Vec::new();
  282. while !components.is_empty() {
  283. let c = components.remove(0);
  284. if c.contains("()") {
  285. break;
  286. }
  287. path.push(c.to_string());
  288. }
  289. if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
  290. path = Vec::new();
  291. }
  292. Some(SeedPath(name, path))
  293. }
  294. fn parse_field_path(ctx: &CrateContext, strct: &syn::ItemStruct, path: &mut &[String]) -> IdlType {
  295. let field_name = &path[0];
  296. *path = &path[1..];
  297. // Get the type name for the field.
  298. let next_field = strct
  299. .fields
  300. .iter()
  301. .find(|f| &f.ident.clone().unwrap().to_string() == field_name)
  302. .unwrap();
  303. let next_field_ty_str = parser::tts_to_string(&next_field.ty);
  304. // The path is empty so this must be a primitive type.
  305. if path.is_empty() {
  306. return next_field_ty_str.parse().unwrap();
  307. }
  308. // Get the rust representation of hte field's struct.
  309. let strct = ctx
  310. .structs()
  311. .find(|s| s.ident == next_field_ty_str)
  312. .unwrap();
  313. parse_field_path(ctx, strct, path)
  314. }
  315. fn str_lit_to_array(idl_ty: &IdlType, idl_ty_value: &String) -> String {
  316. if let IdlType::Array(_ty, _size) = &idl_ty {
  317. // Convert str literal to array.
  318. if idl_ty_value.contains("b\"") {
  319. let components: Vec<&str> = idl_ty_value.split('b').collect();
  320. assert_eq!(components.len(), 2);
  321. let mut str_lit = components[1].to_string();
  322. str_lit.retain(|c| c != '"');
  323. return format!("{:?}", str_lit.as_bytes());
  324. }
  325. }
  326. idl_ty_value.to_string()
  327. }