lib.rs 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. extern crate proc_macro;
  2. use quote::quote;
  3. use syn::parse_macro_input;
  4. /// A data structure representing a Solana account, implementing various traits:
  5. ///
  6. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  7. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  8. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  9. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  10. ///
  11. /// When implementing account serialization traits the first 8 bytes are
  12. /// reserved for a unique account discriminator, self described by the first 8
  13. /// bytes of the SHA256 of the account's Rust ident.
  14. ///
  15. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  16. /// check this discriminator. If it doesn't match, an invalid account was given,
  17. /// and the account deserialization will exit with an error.
  18. #[proc_macro_attribute]
  19. pub fn account(
  20. args: proc_macro::TokenStream,
  21. input: proc_macro::TokenStream,
  22. ) -> proc_macro::TokenStream {
  23. let namespace = args.to_string().replace("\"", "");
  24. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  25. let account_name = &account_strct.ident;
  26. let discriminator: proc_macro2::TokenStream = {
  27. // Namespace the discriminator to prevent collisions.
  28. let discriminator_preimage = {
  29. if namespace.is_empty() {
  30. format!("account:{}", account_name.to_string())
  31. } else {
  32. format!("{}:{}", namespace, account_name.to_string())
  33. }
  34. };
  35. let mut discriminator = [0u8; 8];
  36. discriminator.copy_from_slice(
  37. &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
  38. );
  39. format!("{:?}", discriminator).parse().unwrap()
  40. };
  41. let coder = quote! {
  42. impl anchor_lang::AccountSerialize for #account_name {
  43. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
  44. writer.write_all(&#discriminator).map_err(|_| ProgramError::InvalidAccountData)?;
  45. AnchorSerialize::serialize(
  46. self,
  47. writer
  48. )
  49. .map_err(|_| ProgramError::InvalidAccountData)?;
  50. Ok(())
  51. }
  52. }
  53. impl anchor_lang::AccountDeserialize for #account_name {
  54. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  55. if buf.len() < #discriminator.len() {
  56. return Err(ProgramError::AccountDataTooSmall);
  57. }
  58. let given_disc = &buf[..8];
  59. if &#discriminator != given_disc {
  60. return Err(ProgramError::InvalidInstructionData);
  61. }
  62. Self::try_deserialize_unchecked(buf)
  63. }
  64. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  65. let mut data: &[u8] = &buf[8..];
  66. AnchorDeserialize::deserialize(&mut data)
  67. .map_err(|_| ProgramError::InvalidAccountData)
  68. }
  69. }
  70. };
  71. proc_macro::TokenStream::from(quote! {
  72. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  73. #account_strct
  74. #coder
  75. })
  76. }