123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600 |
- use crate::parser::docs;
- use crate::*;
- use syn::parse::{Error as ParseError, Result as ParseResult};
- use syn::punctuated::Punctuated;
- use syn::spanned::Spanned;
- use syn::token::Comma;
- use syn::Expr;
- use syn::Path;
- pub mod constraints;
- pub fn parse(strct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
- let instruction_api: Option<Punctuated<Expr, Comma>> = strct
- .attrs
- .iter()
- .find(|a| {
- a.path
- .get_ident()
- .map_or(false, |ident| ident == "instruction")
- })
- .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
- .transpose()?;
- let fields = match &strct.fields {
- syn::Fields::Named(fields) => fields
- .named
- .iter()
- .map(parse_account_field)
- .collect::<ParseResult<Vec<AccountField>>>()?,
- _ => {
- return Err(ParseError::new_spanned(
- &strct.fields,
- "fields must be named",
- ))
- }
- };
- constraints_cross_checks(&fields)?;
- Ok(AccountsStruct::new(strct.clone(), fields, instruction_api))
- }
- fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
- // COMMON ERROR MESSAGE
- let message = |constraint: &str, field: &str, required: bool| {
- if required {
- format! {
- "The {} constraint requires \
- a {} field to exist in the account \
- validation struct. Use the Program type to add \
- the {} field to your validation struct.", constraint, field, field
- }
- } else {
- format! {
- "An optional {} constraint requires \
- an optional or required {} field to exist \
- in the account validation struct. Use the Program type \
- to add the {} field to your validation struct.", constraint, field, field
- }
- }
- };
- // INIT
- let mut required_init = false;
- let init_fields: Vec<&Field> = fields
- .iter()
- .filter_map(|f| match f {
- AccountField::Field(field) if field.constraints.init.is_some() => {
- if !field.is_optional {
- required_init = true
- }
- Some(field)
- }
- _ => None,
- })
- .collect();
- if !init_fields.is_empty() {
- // init needs system program.
- if !fields
- .iter()
- // ensures that a non optional `system_program` is present with non optional `init`
- .any(|f| f.ident() == "system_program" && !(required_init && f.is_optional()))
- {
- return Err(ParseError::new(
- init_fields[0].ident.span(),
- message("init", "system_program", required_init),
- ));
- }
- let kind = &init_fields[0].constraints.init.as_ref().unwrap().kind;
- // init token/a_token/mint needs token program.
- match kind {
- InitKind::Program { .. } => (),
- InitKind::Token { .. } | InitKind::AssociatedToken { .. } | InitKind::Mint { .. } => {
- if !fields
- .iter()
- .any(|f| f.ident() == "token_program" && !(required_init && f.is_optional()))
- {
- return Err(ParseError::new(
- init_fields[0].ident.span(),
- message("init", "token_program", required_init),
- ));
- }
- }
- }
- // a_token needs associated token program.
- if let InitKind::AssociatedToken { .. } = kind {
- if !fields.iter().any(|f| {
- f.ident() == "associated_token_program" && !(required_init && f.is_optional())
- }) {
- return Err(ParseError::new(
- init_fields[0].ident.span(),
- message("init", "associated_token_program", required_init),
- ));
- }
- }
- for field in init_fields {
- // Get payer for init-ed account
- let associated_payer_name = match field.constraints.init.clone().unwrap().payer {
- // composite payer, check not supported
- Expr::Field(_) => continue,
- // method call, check not supported
- Expr::MethodCall(_) => continue,
- field_name => field_name.to_token_stream().to_string(),
- };
- // Check payer is mutable
- let associated_payer_field = fields.iter().find_map(|f| match f {
- AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
- _ => None,
- });
- match associated_payer_field {
- Some(associated_payer_field) => {
- if !associated_payer_field.constraints.is_mutable() {
- return Err(ParseError::new(
- field.ident.span(),
- "the payer specified for an init constraint must be mutable.",
- ));
- } else if associated_payer_field.is_optional && required_init {
- return Err(ParseError::new(
- field.ident.span(),
- "the payer specified for a required init constraint must be required.",
- ));
- }
- }
- _ => {
- return Err(ParseError::new(
- field.ident.span(),
- "the payer specified does not exist.",
- ));
- }
- }
- match kind {
- // This doesn't catch cases like account.key() or account.key.
- // My guess is that doesn't happen often and we can revisit
- // this if I'm wrong.
- InitKind::Token { mint, .. } | InitKind::AssociatedToken { mint, .. } => {
- if !fields.iter().any(|f| {
- f.ident()
- .to_string()
- .starts_with(&mint.to_token_stream().to_string())
- }) {
- return Err(ParseError::new(
- field.ident.span(),
- "the mint constraint has to be an account field for token initializations (not a public key)",
- ));
- }
- }
- _ => (),
- }
- }
- }
- // REALLOC
- let mut required_realloc = false;
- let realloc_fields: Vec<&Field> = fields
- .iter()
- .filter_map(|f| match f {
- AccountField::Field(field) if field.constraints.realloc.is_some() => {
- if !field.is_optional {
- required_realloc = true
- }
- Some(field)
- }
- _ => None,
- })
- .collect();
- if !realloc_fields.is_empty() {
- // realloc needs system program.
- if !fields
- .iter()
- .any(|f| f.ident() == "system_program" && !(required_realloc && f.is_optional()))
- {
- return Err(ParseError::new(
- realloc_fields[0].ident.span(),
- message("realloc", "system_program", required_realloc),
- ));
- }
- for field in realloc_fields {
- // Get allocator for realloc-ed account
- let associated_payer_name = match field.constraints.realloc.clone().unwrap().payer {
- // composite allocator, check not supported
- Expr::Field(_) => continue,
- // method call, check not supported
- Expr::MethodCall(_) => continue,
- field_name => field_name.to_token_stream().to_string(),
- };
- // Check allocator is mutable
- let associated_payer_field = fields.iter().find_map(|f| match f {
- AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
- _ => None,
- });
- match associated_payer_field {
- Some(associated_payer_field) => {
- if !associated_payer_field.constraints.is_mutable() {
- return Err(ParseError::new(
- field.ident.span(),
- "the realloc::payer specified for an realloc constraint must be mutable.",
- ));
- } else if associated_payer_field.is_optional && required_realloc {
- return Err(ParseError::new(
- field.ident.span(),
- "the realloc::payer specified for a required realloc constraint must be required.",
- ));
- }
- }
- _ => {
- return Err(ParseError::new(
- field.ident.span(),
- "the realloc::payer specified does not exist.",
- ));
- }
- }
- }
- }
- Ok(())
- }
- pub fn parse_account_field(f: &syn::Field) -> ParseResult<AccountField> {
- let ident = f.ident.clone().unwrap();
- let docs = docs::parse(&f.attrs);
- let account_field = match is_field_primitive(f)? {
- true => {
- let (ty, is_optional) = parse_ty(f)?;
- let account_constraints = constraints::parse(f, Some(&ty))?;
- AccountField::Field(Field {
- ident,
- ty,
- is_optional,
- constraints: account_constraints,
- docs,
- })
- }
- false => {
- let (_, optional, _) = ident_string(f)?;
- if optional {
- return Err(ParseError::new(
- f.ty.span(),
- "Cannot have Optional composite accounts",
- ));
- }
- let account_constraints = constraints::parse(f, None)?;
- AccountField::CompositeField(CompositeField {
- ident,
- constraints: account_constraints,
- symbol: ident_string(f)?.0,
- raw_field: f.clone(),
- docs,
- })
- }
- };
- Ok(account_field)
- }
- fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
- let r = matches!(
- ident_string(f)?.0.as_str(),
- "ProgramState"
- | "ProgramAccount"
- | "CpiAccount"
- | "Sysvar"
- | "AccountInfo"
- | "UncheckedAccount"
- | "CpiState"
- | "Loader"
- | "AccountLoader"
- | "Account"
- | "Program"
- | "Signer"
- | "SystemAccount"
- | "ProgramData"
- );
- Ok(r)
- }
- // TODO call `account_parse` a single time at the start of this function and then init each of the types
- fn parse_ty(f: &syn::Field) -> ParseResult<(Ty, bool)> {
- let (ident, optional, path) = ident_string(f)?;
- let ty = match ident.as_str() {
- "ProgramState" => Ty::ProgramState(parse_program_state(&path)?),
- "CpiState" => Ty::CpiState(parse_cpi_state(&path)?),
- "ProgramAccount" => Ty::ProgramAccount(parse_program_account(&path)?),
- "CpiAccount" => Ty::CpiAccount(parse_cpi_account(&path)?),
- "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
- "AccountInfo" => Ty::AccountInfo,
- "UncheckedAccount" => Ty::UncheckedAccount,
- "Loader" => Ty::Loader(parse_program_account_zero_copy(&path)?),
- "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
- "Account" => Ty::Account(parse_account_ty(&path)?),
- "Program" => Ty::Program(parse_program_ty(&path)?),
- "Signer" => Ty::Signer,
- "SystemAccount" => Ty::SystemAccount,
- "ProgramData" => Ty::ProgramData,
- _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
- };
- Ok((ty, optional))
- }
- fn option_to_inner_path(path: &Path) -> ParseResult<Path> {
- let segment_0 = path.segments[0].clone();
- match segment_0.arguments {
- syn::PathArguments::AngleBracketed(args) => {
- if args.args.len() != 1 {
- return Err(ParseError::new(
- args.args.span(),
- "can only have one argument in option",
- ));
- }
- match &args.args[0] {
- syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.path.clone()),
- _ => Err(ParseError::new(
- args.args[1].span(),
- "first bracket argument must be a lifetime",
- )),
- }
- }
- _ => Err(ParseError::new(
- segment_0.arguments.span(),
- "expected angle brackets with a lifetime and type",
- )),
- }
- }
- fn ident_string(f: &syn::Field) -> ParseResult<(String, bool, Path)> {
- // TODO support parsing references to account infos
- let mut path = match &f.ty {
- syn::Type::Path(ty_path) => ty_path.path.clone(),
- _ => {
- return Err(ParseError::new_spanned(
- f,
- format!(
- "Field {} has a non-path type",
- f.ident.as_ref().expect("named fields only")
- ),
- ))
- }
- };
- // TODO replace string matching with helper functions using syn type match statments
- let mut optional = false;
- if parser::tts_to_string(&path)
- .replace(' ', "")
- .starts_with("Option<")
- {
- path = option_to_inner_path(&path)?;
- optional = true;
- }
- if parser::tts_to_string(&path)
- .replace(' ', "")
- .starts_with("Box<Account<")
- {
- return Ok(("Account".to_string(), optional, path));
- }
- // TODO: allow segmented paths.
- if path.segments.len() != 1 {
- return Err(ParseError::new(
- f.ty.span(),
- "segmented paths are not currently allowed",
- ));
- }
- let segments = &path.segments[0];
- Ok((segments.ident.to_string(), optional, path))
- }
- fn parse_program_state(path: &syn::Path) -> ParseResult<ProgramStateTy> {
- let account_ident = parse_account(path)?;
- Ok(ProgramStateTy {
- account_type_path: account_ident,
- })
- }
- fn parse_cpi_state(path: &syn::Path) -> ParseResult<CpiStateTy> {
- let account_ident = parse_account(path)?;
- Ok(CpiStateTy {
- account_type_path: account_ident,
- })
- }
- fn parse_cpi_account(path: &syn::Path) -> ParseResult<CpiAccountTy> {
- let account_ident = parse_account(path)?;
- Ok(CpiAccountTy {
- account_type_path: account_ident,
- })
- }
- fn parse_program_account(path: &syn::Path) -> ParseResult<ProgramAccountTy> {
- let account_ident = parse_account(path)?;
- Ok(ProgramAccountTy {
- account_type_path: account_ident,
- })
- }
- fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult<LoaderTy> {
- let account_ident = parse_account(path)?;
- Ok(LoaderTy {
- account_type_path: account_ident,
- })
- }
- fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
- let account_ident = parse_account(path)?;
- Ok(AccountLoaderTy {
- account_type_path: account_ident,
- })
- }
- fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
- let account_type_path = parse_account(path)?;
- let boxed = parser::tts_to_string(path)
- .replace(' ', "")
- .starts_with("Box<Account<");
- Ok(AccountTy {
- account_type_path,
- boxed,
- })
- }
- fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
- let account_type_path = parse_account(path)?;
- Ok(ProgramTy { account_type_path })
- }
- // Extract the type path T from Box<[AccountType]<'info, T>> or [AccountType]<'info, T>
- fn parse_account(path: &syn::Path) -> ParseResult<syn::TypePath> {
- let segment = &path.segments[0];
- match segment.ident.to_string().as_str() {
- "Box" => match &segment.arguments {
- syn::PathArguments::AngleBracketed(args) => {
- if args.args.len() != 1 {
- return Err(ParseError::new(
- args.args.span(),
- "Expected a single argument: [AccountType]<'info, T>",
- ));
- }
- match &args.args[0] {
- syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
- parse_account(&ty_path.path)
- }
- _ => Err(ParseError::new(
- args.args[1].span(),
- "Expected a path containing: [AccountType]<'info, T>",
- )),
- }
- }
- _ => Err(ParseError::new(
- segment.arguments.span(),
- "Expected angle brackets with a type: Box<[AccountType]<'info, T>",
- )),
- },
- _ => match &segment.arguments {
- syn::PathArguments::AngleBracketed(args) => {
- if args.args.len() != 2 {
- return Err(ParseError::new(
- args.args.span(),
- "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>",
- ));
- }
- match (&args.args[0], &args.args[1]) {
- (
- syn::GenericArgument::Lifetime(_),
- syn::GenericArgument::Type(syn::Type::Path(ty_path)),
- ) => {
- Ok(ty_path.clone())
- }
- _ => {
- Err(ParseError::new(
- args.args.span(),
- "Expected the two arguments to be a lifetime and a type: [AccountType]<'info, T>",
- ))
- }
- }
- }
- _ => Err(ParseError::new(
- segment.arguments.span(),
- "Expected angle brackets with a type: [AccountType]<'info, T>",
- )),
- },
- }
- }
- fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
- let segments = &path.segments[0];
- let account_ident = match &segments.arguments {
- syn::PathArguments::AngleBracketed(args) => {
- // Expected: <'info, MyType>.
- if args.args.len() != 2 {
- return Err(ParseError::new(
- args.args.span(),
- "bracket arguments must be the lifetime and type",
- ));
- }
- match &args.args[1] {
- syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
- // TODO: allow segmented paths.
- if ty_path.path.segments.len() != 1 {
- return Err(ParseError::new(
- ty_path.path.span(),
- "segmented paths are not currently allowed",
- ));
- }
- let path_segment = &ty_path.path.segments[0];
- path_segment.ident.clone()
- }
- _ => {
- return Err(ParseError::new(
- args.args[1].span(),
- "first bracket argument must be a lifetime",
- ))
- }
- }
- }
- _ => {
- return Err(ParseError::new(
- segments.arguments.span(),
- "expected angle brackets with a lifetime and type",
- ))
- }
- };
- let ty = match account_ident.to_string().as_str() {
- "Clock" => SysvarTy::Clock,
- "Rent" => SysvarTy::Rent,
- "EpochSchedule" => SysvarTy::EpochSchedule,
- "Fees" => SysvarTy::Fees,
- "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
- "SlotHashes" => SysvarTy::SlotHashes,
- "SlotHistory" => SysvarTy::SlotHistory,
- "StakeHistory" => SysvarTy::StakeHistory,
- "Instructions" => SysvarTy::Instructions,
- "Rewards" => SysvarTy::Rewards,
- _ => {
- return Err(ParseError::new(
- account_ident.span(),
- "invalid sysvar provided",
- ))
- }
- };
- Ok(ty)
- }
- #[test]
- fn test_parse_account() {
- let expected_ty_path: syn::TypePath = syn::parse_quote!(u32);
- let path = syn::parse_quote! { Box<Account<'info, u32>> };
- let ty = parse_account(&path).unwrap();
- assert_eq!(ty, expected_ty_path);
- let path = syn::parse_quote! { Account<'info, u32> };
- let ty = parse_account(&path).unwrap();
- assert_eq!(ty, expected_ty_path);
- let path = syn::parse_quote! { Box<Account<'info, u32, u64>> };
- let err = parse_account(&path).unwrap_err();
- assert_eq!(
- err.to_string(),
- "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>"
- );
- let path = syn::parse_quote! { Box<Account<'info>> };
- let err = parse_account(&path).unwrap_err();
- assert_eq!(
- err.to_string(),
- "Expected only two arguments, a lifetime and a type: [AccountType]<'info, T>"
- );
- let path = syn::parse_quote! { Box<Account> };
- let err = parse_account(&path).unwrap_err();
- assert_eq!(
- err.to_string(),
- "Expected angle brackets with a type: [AccountType]<'info, T>"
- );
- }
|