lib.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. extern crate proc_macro;
  2. use quote::quote;
  3. use syn::parse_macro_input;
  4. mod id;
  5. /// An attribute for a data structure representing a Solana account.
  6. ///
  7. /// `#[account]` generates trait implementations for the following traits:
  8. ///
  9. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  10. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  11. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  12. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  13. /// - [`Owner`](./trait.Owner.html)
  14. /// - [`Discriminator`](./trait.Discriminator.html)
  15. ///
  16. /// When implementing account serialization traits the first 8 bytes are
  17. /// reserved for a unique account discriminator, self described by the first 8
  18. /// bytes of the SHA256 of the account's Rust ident.
  19. ///
  20. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  21. /// check this discriminator. If it doesn't match, an invalid account was given,
  22. /// and the account deserialization will exit with an error.
  23. ///
  24. /// # Zero Copy Deserialization
  25. ///
  26. /// **WARNING**: Zero copy deserialization is an experimental feature. It's
  27. /// recommended to use it only when necessary, i.e., when you have extremely
  28. /// large accounts that cannot be Borsh deserialized without hitting stack or
  29. /// heap limits.
  30. ///
  31. /// ## Usage
  32. ///
  33. /// To enable zero-copy-deserialization, one can pass in the `zero_copy`
  34. /// argument to the macro as follows:
  35. ///
  36. /// ```ignore
  37. /// #[account(zero_copy)]
  38. /// ```
  39. ///
  40. /// This can be used to conveniently implement
  41. /// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
  42. /// with [`Loader`](./struct.Loader.html).
  43. ///
  44. /// Other than being more efficient, the most salient benefit this provides is
  45. /// the ability to define account types larger than the max stack or heap size.
  46. /// When using borsh, the account has to be copied and deserialized into a new
  47. /// data structure and thus is constrained by stack and heap limits imposed by
  48. /// the BPF VM. With zero copy deserialization, all bytes from the account's
  49. /// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
  50. /// the data structure. No allocations or copies necessary. Hence the ability
  51. /// to get around stack and heap limitations.
  52. ///
  53. /// To facilitate this, all fields in an account must be constrained to be
  54. /// "plain old data", i.e., they must implement
  55. /// [`Pod`](../bytemuck/trait.Pod.html). Please review the
  56. /// [`safety`](../bytemuck/trait.Pod.html#safety)
  57. /// section before using.
  58. #[proc_macro_attribute]
  59. pub fn account(
  60. args: proc_macro::TokenStream,
  61. input: proc_macro::TokenStream,
  62. ) -> proc_macro::TokenStream {
  63. let mut namespace = "".to_string();
  64. let mut is_zero_copy = false;
  65. let args_str = args.to_string();
  66. let args: Vec<&str> = args_str.split(',').collect();
  67. if args.len() > 2 {
  68. panic!("Only two args are allowed to the account attribute.")
  69. }
  70. for arg in args {
  71. let ns = arg
  72. .to_string()
  73. .replace('\"', "")
  74. .chars()
  75. .filter(|c| !c.is_whitespace())
  76. .collect();
  77. if ns == "zero_copy" {
  78. is_zero_copy = true;
  79. } else {
  80. namespace = ns;
  81. }
  82. }
  83. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  84. let account_name = &account_strct.ident;
  85. let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
  86. let discriminator: proc_macro2::TokenStream = {
  87. let discriminator = anchor_common::header::create_discriminator(
  88. &account_name.to_string(),
  89. if namespace.is_empty() {
  90. None
  91. } else {
  92. Some(&namespace)
  93. },
  94. );
  95. format!("{:?}", discriminator).parse().unwrap()
  96. };
  97. let owner_impl = {
  98. if namespace.is_empty() {
  99. quote! {
  100. #[automatically_derived]
  101. impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
  102. fn owner() -> Pubkey {
  103. crate::ID
  104. }
  105. }
  106. }
  107. } else {
  108. quote! {}
  109. }
  110. };
  111. let disc_fn = {
  112. let len: proc_macro2::TokenStream = anchor_common::header::discriminator_len_str()
  113. .parse()
  114. .unwrap();
  115. quote! {
  116. fn discriminator() -> [u8; #len] {
  117. #discriminator
  118. }
  119. }
  120. };
  121. proc_macro::TokenStream::from({
  122. if is_zero_copy {
  123. quote! {
  124. #[zero_copy]
  125. #account_strct
  126. #[automatically_derived]
  127. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
  128. #[automatically_derived]
  129. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
  130. #[automatically_derived]
  131. impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
  132. #[automatically_derived]
  133. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  134. #disc_fn
  135. }
  136. // This trait is useful for clients deserializing accounts.
  137. // It's expected on-chain programs deserialize via zero-copy.
  138. #[automatically_derived]
  139. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  140. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  141. // Header is always 8 bytes.
  142. if buf.len() < anchor_lang::accounts::header::HEADER_LEN {
  143. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  144. }
  145. let given_disc = anchor_lang::accounts::header::read_discriminator(&buf);
  146. if &#discriminator != given_disc {
  147. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch.into());
  148. }
  149. Self::try_deserialize_unchecked(buf)
  150. }
  151. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  152. let data: &[u8] = &buf[8..];
  153. // Re-interpret raw bytes into the POD data structure.
  154. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  155. // Copy out the bytes into a new, owned data structure.
  156. Ok(*account)
  157. }
  158. }
  159. #owner_impl
  160. }
  161. } else {
  162. quote! {
  163. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  164. #account_strct
  165. #[automatically_derived]
  166. impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
  167. fn try_serialize(&self, buf: &mut [u8]) -> std::result::Result<(), ProgramError> {
  168. let dst = anchor_lang::accounts::header::read_data_mut(buf);
  169. let mut writer = std::io::Cursor::new(dst);
  170. AnchorSerialize::serialize(
  171. self,
  172. &mut writer
  173. )
  174. .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotSerialize)?;
  175. Ok(())
  176. }
  177. }
  178. #[automatically_derived]
  179. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  180. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  181. if buf.len() < #discriminator.len() {
  182. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  183. }
  184. let given_disc = anchor_lang::accounts::header::read_discriminator(&buf);
  185. if &#discriminator != given_disc {
  186. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch.into());
  187. }
  188. Self::try_deserialize_unchecked(buf)
  189. }
  190. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  191. let mut data: &[u8] = &buf[8..];
  192. AnchorDeserialize::deserialize(&mut data)
  193. .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
  194. }
  195. }
  196. #[automatically_derived]
  197. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  198. #disc_fn
  199. }
  200. #owner_impl
  201. }
  202. }
  203. })
  204. }
  205. #[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
  206. pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
  207. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  208. let account_name = &account_strct.ident;
  209. let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
  210. let fields = match &account_strct.fields {
  211. syn::Fields::Named(n) => n,
  212. _ => panic!("Fields must be named"),
  213. };
  214. let methods: Vec<proc_macro2::TokenStream> = fields
  215. .named
  216. .iter()
  217. .filter_map(|field: &syn::Field| {
  218. field
  219. .attrs
  220. .iter()
  221. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
  222. .map(|attr| {
  223. let mut tts = attr.tokens.clone().into_iter();
  224. let g_stream = match tts.next().expect("Must have a token group") {
  225. proc_macro2::TokenTree::Group(g) => g.stream(),
  226. _ => panic!("Invalid syntax"),
  227. };
  228. let accessor_ty = match g_stream.into_iter().next() {
  229. Some(token) => token,
  230. _ => panic!("Missing accessor type"),
  231. };
  232. let field_name = field.ident.as_ref().unwrap();
  233. let get_field: proc_macro2::TokenStream =
  234. format!("get_{}", field_name).parse().unwrap();
  235. let set_field: proc_macro2::TokenStream =
  236. format!("set_{}", field_name).parse().unwrap();
  237. quote! {
  238. pub fn #get_field(&self) -> #accessor_ty {
  239. anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
  240. }
  241. pub fn #set_field(&mut self, input: &#accessor_ty) {
  242. self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
  243. }
  244. }
  245. })
  246. })
  247. .collect();
  248. proc_macro::TokenStream::from(quote! {
  249. #[automatically_derived]
  250. impl #impl_gen #account_name #ty_gen #where_clause {
  251. #(#methods)*
  252. }
  253. })
  254. }
  255. /// A data structure that can be used as an internal field for a zero copy
  256. /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
  257. ///
  258. /// This is just a convenient alias for
  259. ///
  260. /// ```ignore
  261. /// #[derive(Copy, Clone)]
  262. /// #[repr(packed)]
  263. /// struct MyStruct {...}
  264. /// ```
  265. #[proc_macro_attribute]
  266. pub fn zero_copy(
  267. _args: proc_macro::TokenStream,
  268. item: proc_macro::TokenStream,
  269. ) -> proc_macro::TokenStream {
  270. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  271. // Takes the first repr. It's assumed that more than one are not on the
  272. // struct.
  273. let attr = account_strct
  274. .attrs
  275. .iter()
  276. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
  277. let repr = match attr {
  278. Some(_attr) => quote! {},
  279. None => quote! {#[repr(C)]},
  280. };
  281. proc_macro::TokenStream::from(quote! {
  282. #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
  283. #repr
  284. #account_strct
  285. })
  286. }
  287. /// Defines the program's ID. This should be used at the root of all Anchor
  288. /// based programs.
  289. #[proc_macro]
  290. pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  291. let id = parse_macro_input!(input as id::Id);
  292. proc_macro::TokenStream::from(quote! {#id})
  293. }