accounts.rs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. use anchor_idl::types::{Idl, IdlSerialization};
  2. use quote::{format_ident, quote};
  3. use super::common::{convert_idl_type_def_to_ts, gen_discriminator, get_canonical_program_id};
  4. pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
  5. let accounts = idl.accounts.iter().map(|acc| {
  6. let name = format_ident!("{}", acc.name);
  7. let discriminator = gen_discriminator(&acc.discriminator);
  8. let ty_def = idl
  9. .types
  10. .iter()
  11. .find(|ty| ty.name == acc.name)
  12. .expect("Type must exist");
  13. let impls = {
  14. let try_deserialize = quote! {
  15. fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  16. if buf.len() < #discriminator.len() {
  17. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  18. }
  19. let given_disc = &buf[..8];
  20. if &#discriminator != given_disc {
  21. return Err(
  22. anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
  23. .with_account_name(stringify!(#name))
  24. );
  25. }
  26. Self::try_deserialize_unchecked(buf)
  27. }
  28. };
  29. match ty_def.serialization {
  30. IdlSerialization::Borsh => quote! {
  31. impl anchor_lang::AccountSerialize for #name {
  32. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
  33. if writer.write_all(&#discriminator).is_err() {
  34. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  35. }
  36. if AnchorSerialize::serialize(self, writer).is_err() {
  37. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  38. }
  39. Ok(())
  40. }
  41. }
  42. impl anchor_lang::AccountDeserialize for #name {
  43. #try_deserialize
  44. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  45. let mut data: &[u8] = &buf[8..];
  46. AnchorDeserialize::deserialize(&mut data)
  47. .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
  48. }
  49. }
  50. },
  51. _ => {
  52. let unsafe_bytemuck_impl =
  53. matches!(ty_def.serialization, IdlSerialization::BytemuckUnsafe)
  54. .then(|| {
  55. quote! {
  56. unsafe impl anchor_lang::__private::Pod for #name {}
  57. unsafe impl anchor_lang::__private::Zeroable for #name {}
  58. }
  59. })
  60. .unwrap_or_default();
  61. quote! {
  62. impl anchor_lang::ZeroCopy for #name {}
  63. impl anchor_lang::AccountDeserialize for #name {
  64. #try_deserialize
  65. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  66. let data: &[u8] = &buf[8..];
  67. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  68. Ok(*account)
  69. }
  70. }
  71. #unsafe_bytemuck_impl
  72. }
  73. }
  74. }
  75. };
  76. let type_def_ts = convert_idl_type_def_to_ts(ty_def, &idl.types);
  77. let program_id = get_canonical_program_id();
  78. quote! {
  79. #type_def_ts
  80. #impls
  81. impl anchor_lang::Discriminator for #name {
  82. const DISCRIMINATOR: [u8; 8] = #discriminator;
  83. }
  84. impl anchor_lang::Owner for #name {
  85. fn owner() -> Pubkey {
  86. #program_id
  87. }
  88. }
  89. }
  90. });
  91. quote! {
  92. /// Program account type definitions.
  93. pub mod accounts {
  94. use super::{*, types::*};
  95. #(#accounts)*
  96. }
  97. }
  98. }