lib.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. extern crate proc_macro;
  2. use quote::quote;
  3. use syn::parse_macro_input;
  4. mod id;
  5. /// A data structure representing a Solana account, implementing various traits:
  6. ///
  7. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  8. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  9. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  10. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  11. ///
  12. /// When implementing account serialization traits the first 8 bytes are
  13. /// reserved for a unique account discriminator, self described by the first 8
  14. /// bytes of the SHA256 of the account's Rust ident.
  15. ///
  16. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  17. /// check this discriminator. If it doesn't match, an invalid account was given,
  18. /// and the account deserialization will exit with an error.
  19. ///
  20. /// # Zero Copy Deserialization
  21. ///
  22. /// **WARNING**: Zero copy deserialization is an experimental feature. It's
  23. /// recommended to use it only when necessary, i.e., when you have extremely
  24. /// large accounts that cannot be Borsh deserialized without hitting stack or
  25. /// heap limits.
  26. ///
  27. /// ## Usage
  28. ///
  29. /// To enable zero-copy-deserialization, one can pass in the `zero_copy`
  30. /// argument to the macro as follows:
  31. ///
  32. /// ```ignore
  33. /// #[account(zero_copy)]
  34. /// ```
  35. ///
  36. /// This can be used to conveniently implement
  37. /// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
  38. /// with [`Loader`](./struct.Loader.html).
  39. ///
  40. /// Other than being more efficient, the most salient benefit this provides is
  41. /// the ability to define account types larger than the max stack or heap size.
  42. /// When using borsh, the account has to be copied and deserialized into a new
  43. /// data structure and thus is constrained by stack and heap limits imposed by
  44. /// the BPF VM. With zero copy deserialization, all bytes from the account's
  45. /// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
  46. /// the data structure. No allocations or copies necessary. Hence the ability
  47. /// to get around stack and heap limitations.
  48. ///
  49. /// To facilitate this, all fields in an account must be constrained to be
  50. /// "plain old data", i.e., they must implement
  51. /// [`Pod`](../bytemuck/trait.Pod.html). Please review the
  52. /// [`safety`](file:///home/armaniferrante/Documents/code/src/github.com/project-serum/anchor/target/doc/bytemuck/trait.Pod.html#safety)
  53. /// section before using.
  54. #[proc_macro_attribute]
  55. pub fn account(
  56. args: proc_macro::TokenStream,
  57. input: proc_macro::TokenStream,
  58. ) -> proc_macro::TokenStream {
  59. let mut namespace = "".to_string();
  60. let mut is_zero_copy = false;
  61. let args_str = args.to_string();
  62. let args: Vec<&str> = args_str.split(',').collect();
  63. if args.len() > 2 {
  64. panic!("Only two args are allowed to the account attribute.")
  65. }
  66. for arg in args {
  67. let ns = arg
  68. .to_string()
  69. .replace('\"', "")
  70. .chars()
  71. .filter(|c| !c.is_whitespace())
  72. .collect();
  73. if ns == "zero_copy" {
  74. is_zero_copy = true;
  75. } else {
  76. namespace = ns;
  77. }
  78. }
  79. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  80. let account_name = &account_strct.ident;
  81. let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
  82. let discriminator: proc_macro2::TokenStream = {
  83. // Namespace the discriminator to prevent collisions.
  84. let discriminator_preimage = {
  85. // For now, zero copy accounts can't be namespaced.
  86. if namespace.is_empty() {
  87. format!("account:{}", account_name)
  88. } else {
  89. format!("{}:{}", namespace, account_name)
  90. }
  91. };
  92. let mut discriminator = [0u8; 8];
  93. discriminator.copy_from_slice(
  94. &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
  95. );
  96. format!("{:?}", discriminator).parse().unwrap()
  97. };
  98. let owner_impl = {
  99. if namespace.is_empty() {
  100. quote! {
  101. #[automatically_derived]
  102. impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
  103. fn owner() -> Pubkey {
  104. crate::ID
  105. }
  106. }
  107. }
  108. } else {
  109. quote! {}
  110. }
  111. };
  112. proc_macro::TokenStream::from({
  113. if is_zero_copy {
  114. quote! {
  115. #[zero_copy]
  116. #account_strct
  117. #[automatically_derived]
  118. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
  119. #[automatically_derived]
  120. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
  121. #[automatically_derived]
  122. impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
  123. #[automatically_derived]
  124. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  125. fn discriminator() -> [u8; 8] {
  126. #discriminator
  127. }
  128. }
  129. // This trait is useful for clients deserializing accounts.
  130. // It's expected on-chain programs deserialize via zero-copy.
  131. #[automatically_derived]
  132. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  133. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  134. if buf.len() < #discriminator.len() {
  135. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
  136. }
  137. let given_disc = &buf[..8];
  138. if &#discriminator != given_disc {
  139. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
  140. }
  141. Self::try_deserialize_unchecked(buf)
  142. }
  143. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  144. let data: &[u8] = &buf[8..];
  145. // Re-interpret raw bytes into the POD data structure.
  146. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  147. // Copy out the bytes into a new, owned data structure.
  148. Ok(*account)
  149. }
  150. }
  151. #owner_impl
  152. }
  153. } else {
  154. quote! {
  155. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  156. #account_strct
  157. #[automatically_derived]
  158. impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
  159. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
  160. writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
  161. AnchorSerialize::serialize(
  162. self,
  163. writer
  164. )
  165. .map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
  166. Ok(())
  167. }
  168. }
  169. #[automatically_derived]
  170. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  171. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  172. if buf.len() < #discriminator.len() {
  173. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
  174. }
  175. let given_disc = &buf[..8];
  176. if &#discriminator != given_disc {
  177. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
  178. }
  179. Self::try_deserialize_unchecked(buf)
  180. }
  181. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  182. let mut data: &[u8] = &buf[8..];
  183. AnchorDeserialize::deserialize(&mut data)
  184. .map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotDeserialize.into())
  185. }
  186. }
  187. #[automatically_derived]
  188. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  189. fn discriminator() -> [u8; 8] {
  190. #discriminator
  191. }
  192. }
  193. #owner_impl
  194. }
  195. }
  196. })
  197. }
  198. #[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
  199. pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
  200. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  201. let account_name = &account_strct.ident;
  202. let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
  203. let fields = match &account_strct.fields {
  204. syn::Fields::Named(n) => n,
  205. _ => panic!("Fields must be named"),
  206. };
  207. let methods: Vec<proc_macro2::TokenStream> = fields
  208. .named
  209. .iter()
  210. .filter_map(|field: &syn::Field| {
  211. field
  212. .attrs
  213. .iter()
  214. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
  215. .map(|attr| {
  216. let mut tts = attr.tokens.clone().into_iter();
  217. let g_stream = match tts.next().expect("Must have a token group") {
  218. proc_macro2::TokenTree::Group(g) => g.stream(),
  219. _ => panic!("Invalid syntax"),
  220. };
  221. let accessor_ty = match g_stream.into_iter().next() {
  222. Some(token) => token,
  223. _ => panic!("Missing accessor type"),
  224. };
  225. let field_name = field.ident.as_ref().unwrap();
  226. let get_field: proc_macro2::TokenStream =
  227. format!("get_{}", field_name).parse().unwrap();
  228. let set_field: proc_macro2::TokenStream =
  229. format!("set_{}", field_name).parse().unwrap();
  230. quote! {
  231. pub fn #get_field(&self) -> #accessor_ty {
  232. anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
  233. }
  234. pub fn #set_field(&mut self, input: &#accessor_ty) {
  235. self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
  236. }
  237. }
  238. })
  239. })
  240. .collect();
  241. proc_macro::TokenStream::from(quote! {
  242. #[automatically_derived]
  243. impl #impl_gen #account_name #ty_gen #where_clause {
  244. #(#methods)*
  245. }
  246. })
  247. }
  248. /// A data structure that can be used as an internal field for a zero copy
  249. /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
  250. ///
  251. /// This is just a convenient alias for
  252. ///
  253. /// ```ignore
  254. /// #[derive(Copy, Clone)]
  255. /// #[repr(packed)]
  256. /// struct MyStruct {...}
  257. /// ```
  258. #[proc_macro_attribute]
  259. pub fn zero_copy(
  260. _args: proc_macro::TokenStream,
  261. item: proc_macro::TokenStream,
  262. ) -> proc_macro::TokenStream {
  263. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  264. proc_macro::TokenStream::from(quote! {
  265. #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
  266. #[repr(packed)]
  267. #account_strct
  268. })
  269. }
  270. /// Defines the program's ID. This should be used at the root of all Anchor
  271. /// based programs.
  272. #[proc_macro]
  273. pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  274. let id = parse_macro_input!(input as id::Id);
  275. proc_macro::TokenStream::from(quote! {#id})
  276. }