constraints.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. use crate::{
  2. ConstraintAssociated, ConstraintAssociatedGroup, ConstraintAssociatedPayer,
  3. ConstraintAssociatedSpace, ConstraintAssociatedWith, ConstraintBelongsTo, ConstraintClose,
  4. ConstraintExecutable, ConstraintGroup, ConstraintInit, ConstraintLiteral, ConstraintMut,
  5. ConstraintOwner, ConstraintRaw, ConstraintRentExempt, ConstraintSeeds, ConstraintSigner,
  6. ConstraintState, ConstraintToken, Context, Ty,
  7. };
  8. use syn::ext::IdentExt;
  9. use syn::parse::{Error as ParseError, Parse, ParseStream, Result as ParseResult};
  10. use syn::punctuated::Punctuated;
  11. use syn::spanned::Spanned;
  12. use syn::token::Comma;
  13. use syn::{bracketed, Expr, Ident, LitStr, Token};
  14. pub fn parse(f: &syn::Field, f_ty: Option<&Ty>) -> ParseResult<ConstraintGroup> {
  15. let mut constraints = ConstraintGroupBuilder::new(f_ty);
  16. for attr in f.attrs.iter().filter(is_account) {
  17. for c in attr.parse_args_with(Punctuated::<ConstraintToken, Comma>::parse_terminated)? {
  18. constraints.add(c)?;
  19. }
  20. }
  21. constraints.build()
  22. }
  23. pub fn is_account(attr: &&syn::Attribute) -> bool {
  24. attr.path
  25. .get_ident()
  26. .map_or(false, |ident| ident == "account")
  27. }
  28. // Parses a single constraint from a parse stream for `#[account(<STREAM>)]`.
  29. pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
  30. let is_lit = stream.peek(LitStr);
  31. if is_lit {
  32. let lit: LitStr = stream.parse()?;
  33. let c = ConstraintToken::Literal(Context::new(lit.span(), ConstraintLiteral { lit }));
  34. return Ok(c);
  35. }
  36. let ident = stream.call(Ident::parse_any)?;
  37. let kw = ident.to_string();
  38. let c = match kw.as_str() {
  39. "init" => ConstraintToken::Init(Context::new(ident.span(), ConstraintInit {})),
  40. "mut" => ConstraintToken::Mut(Context::new(ident.span(), ConstraintMut {})),
  41. "signer" => ConstraintToken::Signer(Context::new(ident.span(), ConstraintSigner {})),
  42. "executable" => {
  43. ConstraintToken::Executable(Context::new(ident.span(), ConstraintExecutable {}))
  44. }
  45. _ => {
  46. stream.parse::<Token![=]>()?;
  47. let span = ident.span().join(stream.span()).unwrap_or(ident.span());
  48. match kw.as_str() {
  49. "belongs_to" | "has_one" => ConstraintToken::BelongsTo(Context::new(
  50. span,
  51. ConstraintBelongsTo {
  52. join_target: stream.parse()?,
  53. },
  54. )),
  55. "owner" => ConstraintToken::Owner(Context::new(
  56. span,
  57. ConstraintOwner {
  58. owner_target: stream.parse()?,
  59. },
  60. )),
  61. "rent_exempt" => ConstraintToken::RentExempt(Context::new(
  62. span,
  63. match stream.parse::<Ident>()?.to_string().as_str() {
  64. "skip" => ConstraintRentExempt::Skip,
  65. "enforce" => ConstraintRentExempt::Enforce,
  66. _ => {
  67. return Err(ParseError::new(
  68. span,
  69. "rent_exempt must be either skip or enforce",
  70. ))
  71. }
  72. },
  73. )),
  74. "state" => ConstraintToken::State(Context::new(
  75. span,
  76. ConstraintState {
  77. program_target: stream.parse()?,
  78. },
  79. )),
  80. "associated" => ConstraintToken::Associated(Context::new(
  81. span,
  82. ConstraintAssociated {
  83. target: stream.parse()?,
  84. },
  85. )),
  86. "payer" => ConstraintToken::AssociatedPayer(Context::new(
  87. span,
  88. ConstraintAssociatedPayer {
  89. target: stream.parse()?,
  90. },
  91. )),
  92. "with" => ConstraintToken::AssociatedWith(Context::new(
  93. span,
  94. ConstraintAssociatedWith {
  95. target: stream.parse()?,
  96. },
  97. )),
  98. "space" => ConstraintToken::AssociatedSpace(Context::new(
  99. span,
  100. ConstraintAssociatedSpace {
  101. space: stream.parse()?,
  102. },
  103. )),
  104. "seeds" => {
  105. let seeds;
  106. let bracket = bracketed!(seeds in stream);
  107. ConstraintToken::Seeds(Context::new(
  108. span.join(bracket.span).unwrap_or(span),
  109. ConstraintSeeds {
  110. seeds: seeds.parse_terminated(Expr::parse)?,
  111. },
  112. ))
  113. }
  114. "constraint" => ConstraintToken::Raw(Context::new(
  115. span,
  116. ConstraintRaw {
  117. raw: stream.parse()?,
  118. },
  119. )),
  120. "close" => ConstraintToken::Close(Context::new(
  121. span,
  122. ConstraintClose {
  123. sol_dest: stream.parse()?,
  124. },
  125. )),
  126. _ => Err(ParseError::new(ident.span(), "Invalid attribute"))?,
  127. }
  128. }
  129. };
  130. Ok(c)
  131. }
  132. #[derive(Default)]
  133. pub struct ConstraintGroupBuilder<'ty> {
  134. pub f_ty: Option<&'ty Ty>,
  135. pub init: Option<Context<ConstraintInit>>,
  136. pub mutable: Option<Context<ConstraintMut>>,
  137. pub signer: Option<Context<ConstraintSigner>>,
  138. pub belongs_to: Vec<Context<ConstraintBelongsTo>>,
  139. pub literal: Vec<Context<ConstraintLiteral>>,
  140. pub raw: Vec<Context<ConstraintRaw>>,
  141. pub owner: Option<Context<ConstraintOwner>>,
  142. pub rent_exempt: Option<Context<ConstraintRentExempt>>,
  143. pub seeds: Option<Context<ConstraintSeeds>>,
  144. pub executable: Option<Context<ConstraintExecutable>>,
  145. pub state: Option<Context<ConstraintState>>,
  146. pub associated: Option<Context<ConstraintAssociated>>,
  147. pub associated_payer: Option<Context<ConstraintAssociatedPayer>>,
  148. pub associated_space: Option<Context<ConstraintAssociatedSpace>>,
  149. pub associated_with: Vec<Context<ConstraintAssociatedWith>>,
  150. pub close: Option<Context<ConstraintClose>>,
  151. }
  152. impl<'ty> ConstraintGroupBuilder<'ty> {
  153. pub fn new(f_ty: Option<&'ty Ty>) -> Self {
  154. Self {
  155. f_ty,
  156. init: None,
  157. mutable: None,
  158. signer: None,
  159. belongs_to: Vec::new(),
  160. literal: Vec::new(),
  161. raw: Vec::new(),
  162. owner: None,
  163. rent_exempt: None,
  164. seeds: None,
  165. executable: None,
  166. state: None,
  167. associated: None,
  168. associated_payer: None,
  169. associated_space: None,
  170. associated_with: Vec::new(),
  171. close: None,
  172. }
  173. }
  174. pub fn build(mut self) -> ParseResult<ConstraintGroup> {
  175. // Init implies mutable and rent exempt.
  176. if let Some(i) = &self.init {
  177. match self.mutable {
  178. Some(m) => {
  179. return Err(ParseError::new(
  180. m.span(),
  181. "mut cannot be provided with init",
  182. ))
  183. }
  184. None => self
  185. .mutable
  186. .replace(Context::new(i.span(), ConstraintMut {})),
  187. };
  188. if self.rent_exempt.is_none() {
  189. self.rent_exempt
  190. .replace(Context::new(i.span(), ConstraintRentExempt::Enforce));
  191. }
  192. }
  193. let ConstraintGroupBuilder {
  194. f_ty: _,
  195. init,
  196. mutable,
  197. signer,
  198. belongs_to,
  199. literal,
  200. raw,
  201. owner,
  202. rent_exempt,
  203. seeds,
  204. executable,
  205. state,
  206. associated,
  207. associated_payer,
  208. associated_space,
  209. associated_with,
  210. close,
  211. } = self;
  212. // Converts Option<Context<T>> -> Option<T>.
  213. macro_rules! into_inner {
  214. ($opt:ident) => {
  215. $opt.map(|c| c.into_inner())
  216. };
  217. }
  218. // Converts Vec<Context<T>> - Vec<T>.
  219. macro_rules! into_inner_vec {
  220. ($opt:ident) => {
  221. $opt.into_iter().map(|c| c.into_inner()).collect()
  222. };
  223. }
  224. let is_init = init.is_some();
  225. Ok(ConstraintGroup {
  226. init: into_inner!(init),
  227. mutable: into_inner!(mutable),
  228. signer: into_inner!(signer),
  229. belongs_to: into_inner_vec!(belongs_to),
  230. literal: into_inner_vec!(literal),
  231. raw: into_inner_vec!(raw),
  232. owner: into_inner!(owner),
  233. rent_exempt: into_inner!(rent_exempt),
  234. seeds: into_inner!(seeds),
  235. executable: into_inner!(executable),
  236. state: into_inner!(state),
  237. associated: associated.map(|associated| ConstraintAssociatedGroup {
  238. is_init,
  239. associated_target: associated.target.clone(),
  240. associated_seeds: associated_with.iter().map(|s| s.target.clone()).collect(),
  241. payer: associated_payer.map(|p| p.target.clone()),
  242. space: associated_space.map(|s| s.space.clone()),
  243. }),
  244. close: into_inner!(close),
  245. })
  246. }
  247. pub fn add(&mut self, c: ConstraintToken) -> ParseResult<()> {
  248. match c {
  249. ConstraintToken::Init(c) => self.add_init(c),
  250. ConstraintToken::Mut(c) => self.add_mut(c),
  251. ConstraintToken::Signer(c) => self.add_signer(c),
  252. ConstraintToken::BelongsTo(c) => self.add_belongs_to(c),
  253. ConstraintToken::Literal(c) => self.add_literal(c),
  254. ConstraintToken::Raw(c) => self.add_raw(c),
  255. ConstraintToken::Owner(c) => self.add_owner(c),
  256. ConstraintToken::RentExempt(c) => self.add_rent_exempt(c),
  257. ConstraintToken::Seeds(c) => self.add_seeds(c),
  258. ConstraintToken::Executable(c) => self.add_executable(c),
  259. ConstraintToken::State(c) => self.add_state(c),
  260. ConstraintToken::Associated(c) => self.add_associated(c),
  261. ConstraintToken::AssociatedPayer(c) => self.add_associated_payer(c),
  262. ConstraintToken::AssociatedSpace(c) => self.add_associated_space(c),
  263. ConstraintToken::AssociatedWith(c) => self.add_associated_with(c),
  264. ConstraintToken::Close(c) => self.add_close(c),
  265. }
  266. }
  267. fn add_init(&mut self, c: Context<ConstraintInit>) -> ParseResult<()> {
  268. if self.init.is_some() {
  269. return Err(ParseError::new(c.span(), "init already provided"));
  270. }
  271. self.init.replace(c);
  272. Ok(())
  273. }
  274. fn add_close(&mut self, c: Context<ConstraintClose>) -> ParseResult<()> {
  275. if !matches!(self.f_ty, Some(Ty::ProgramAccount(_)))
  276. && !matches!(self.f_ty, Some(Ty::Loader(_)))
  277. {
  278. return Err(ParseError::new(
  279. c.span(),
  280. "close must be on a ProgramAccount",
  281. ));
  282. }
  283. if self.mutable.is_none() {
  284. return Err(ParseError::new(
  285. c.span(),
  286. "mut must be provided before close",
  287. ));
  288. }
  289. if self.close.is_some() {
  290. return Err(ParseError::new(c.span(), "close already provided"));
  291. }
  292. self.close.replace(c);
  293. Ok(())
  294. }
  295. fn add_mut(&mut self, c: Context<ConstraintMut>) -> ParseResult<()> {
  296. if self.mutable.is_some() {
  297. return Err(ParseError::new(c.span(), "mut already provided"));
  298. }
  299. self.mutable.replace(c);
  300. Ok(())
  301. }
  302. fn add_signer(&mut self, c: Context<ConstraintSigner>) -> ParseResult<()> {
  303. if self.signer.is_some() {
  304. return Err(ParseError::new(c.span(), "signer already provided"));
  305. }
  306. self.signer.replace(c);
  307. Ok(())
  308. }
  309. fn add_belongs_to(&mut self, c: Context<ConstraintBelongsTo>) -> ParseResult<()> {
  310. if self
  311. .belongs_to
  312. .iter()
  313. .filter(|item| item.join_target == c.join_target)
  314. .count()
  315. > 0
  316. {
  317. return Err(ParseError::new(
  318. c.span(),
  319. "belongs_to target already provided",
  320. ));
  321. }
  322. self.belongs_to.push(c);
  323. Ok(())
  324. }
  325. fn add_literal(&mut self, c: Context<ConstraintLiteral>) -> ParseResult<()> {
  326. self.literal.push(c);
  327. Ok(())
  328. }
  329. fn add_raw(&mut self, c: Context<ConstraintRaw>) -> ParseResult<()> {
  330. self.raw.push(c);
  331. Ok(())
  332. }
  333. fn add_owner(&mut self, c: Context<ConstraintOwner>) -> ParseResult<()> {
  334. if self.owner.is_some() {
  335. return Err(ParseError::new(c.span(), "owner already provided"));
  336. }
  337. self.owner.replace(c);
  338. Ok(())
  339. }
  340. fn add_rent_exempt(&mut self, c: Context<ConstraintRentExempt>) -> ParseResult<()> {
  341. if self.rent_exempt.is_some() {
  342. return Err(ParseError::new(c.span(), "rent already provided"));
  343. }
  344. self.rent_exempt.replace(c);
  345. Ok(())
  346. }
  347. fn add_seeds(&mut self, c: Context<ConstraintSeeds>) -> ParseResult<()> {
  348. if self.seeds.is_some() {
  349. return Err(ParseError::new(c.span(), "seeds already provided"));
  350. }
  351. self.seeds.replace(c);
  352. Ok(())
  353. }
  354. fn add_executable(&mut self, c: Context<ConstraintExecutable>) -> ParseResult<()> {
  355. if self.executable.is_some() {
  356. return Err(ParseError::new(c.span(), "executable already provided"));
  357. }
  358. self.executable.replace(c);
  359. Ok(())
  360. }
  361. fn add_state(&mut self, c: Context<ConstraintState>) -> ParseResult<()> {
  362. if self.state.is_some() {
  363. return Err(ParseError::new(c.span(), "state already provided"));
  364. }
  365. self.state.replace(c);
  366. Ok(())
  367. }
  368. fn add_associated(&mut self, c: Context<ConstraintAssociated>) -> ParseResult<()> {
  369. if self.associated.is_some() {
  370. return Err(ParseError::new(c.span(), "associated already provided"));
  371. }
  372. self.associated.replace(c);
  373. Ok(())
  374. }
  375. fn add_associated_payer(&mut self, c: Context<ConstraintAssociatedPayer>) -> ParseResult<()> {
  376. if self.associated.is_none() {
  377. return Err(ParseError::new(
  378. c.span(),
  379. "associated must be provided before payer",
  380. ));
  381. }
  382. if self.associated_payer.is_some() {
  383. return Err(ParseError::new(c.span(), "payer already provided"));
  384. }
  385. self.associated_payer.replace(c);
  386. Ok(())
  387. }
  388. fn add_associated_space(&mut self, c: Context<ConstraintAssociatedSpace>) -> ParseResult<()> {
  389. if self.associated.is_none() {
  390. return Err(ParseError::new(
  391. c.span(),
  392. "associated must be provided before space",
  393. ));
  394. }
  395. if self.associated_space.is_some() {
  396. return Err(ParseError::new(c.span(), "space already provided"));
  397. }
  398. self.associated_space.replace(c);
  399. Ok(())
  400. }
  401. fn add_associated_with(&mut self, c: Context<ConstraintAssociatedWith>) -> ParseResult<()> {
  402. if self.associated.is_none() {
  403. return Err(ParseError::new(
  404. c.span(),
  405. "associated must be provided before with",
  406. ));
  407. }
  408. self.associated_with.push(c);
  409. Ok(())
  410. }
  411. }