lib.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. extern crate proc_macro;
  2. use quote::quote;
  3. use syn::parse_macro_input;
  4. /// A data structure representing a Solana account, implementing various traits:
  5. ///
  6. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  7. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  8. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  9. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  10. ///
  11. /// When implementing account serialization traits the first 8 bytes are
  12. /// reserved for a unique account discriminator, self described by the first 8
  13. /// bytes of the SHA256 of the account's Rust ident.
  14. ///
  15. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  16. /// check this discriminator. If it doesn't match, an invalid account was given,
  17. /// and the account deserialization will exit with an error.
  18. ///
  19. /// # Zero Copy Deserialization
  20. ///
  21. /// **WARNING**: Zero copy deserialization is an experimental feature. It's
  22. /// recommended to use it only when necessary, i.e., when you have extremely
  23. /// large accounts that cannot be Borsh deserialized without hitting stack or
  24. /// heap limits.
  25. ///
  26. /// ## Usage
  27. ///
  28. /// To enable zero-copy-deserialization, one can pass in the `zero_copy`
  29. /// argument to the macro as follows:
  30. ///
  31. /// ```ignore
  32. /// #[account(zero_copy)]
  33. /// ```
  34. ///
  35. /// This can be used to conveniently implement
  36. /// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
  37. /// with [`Loader`](./struct.Loader.html).
  38. ///
  39. /// Other than being more efficient, the most salient benefit this provides is
  40. /// the ability to define account types larger than the max stack or heap size.
  41. /// When using borsh, the account has to be copied and deserialized into a new
  42. /// data structure and thus is constrained by stack and heap limits imposed by
  43. /// the BPF VM. With zero copy deserialization, all bytes from the account's
  44. /// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
  45. /// the data structure. No allocations or copies necessary. Hence the ability
  46. /// to get around stack and heap limitations.
  47. ///
  48. /// To facilitate this, all fields in an account must be constrained to be
  49. /// "plain old data", i.e., they must implement
  50. /// [`Pod`](../bytemuck/trait.Pod.html). Please review the
  51. /// [`safety`](file:///home/armaniferrante/Documents/code/src/github.com/project-serum/anchor/target/doc/bytemuck/trait.Pod.html#safety)
  52. /// section before using.
  53. #[proc_macro_attribute]
  54. pub fn account(
  55. args: proc_macro::TokenStream,
  56. input: proc_macro::TokenStream,
  57. ) -> proc_macro::TokenStream {
  58. let mut namespace = "".to_string();
  59. let mut is_zero_copy = false;
  60. if args.to_string().split(",").collect::<Vec<_>>().len() > 2 {
  61. panic!("Only two args are allowed to the account attribute.")
  62. }
  63. for arg in args.to_string().split(",") {
  64. let ns = arg
  65. .to_string()
  66. .replace("\"", "")
  67. .chars()
  68. .filter(|c| !c.is_whitespace())
  69. .collect();
  70. if ns == "zero_copy" {
  71. is_zero_copy = true;
  72. } else {
  73. namespace = ns;
  74. }
  75. }
  76. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  77. let account_name = &account_strct.ident;
  78. let discriminator: proc_macro2::TokenStream = {
  79. // Namespace the discriminator to prevent collisions.
  80. let discriminator_preimage = {
  81. // For now, zero copy accounts can't be namespaced.
  82. if namespace.is_empty() {
  83. format!("account:{}", account_name.to_string())
  84. } else {
  85. format!("{}:{}", namespace, account_name.to_string())
  86. }
  87. };
  88. let mut discriminator = [0u8; 8];
  89. discriminator.copy_from_slice(
  90. &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
  91. );
  92. format!("{:?}", discriminator).parse().unwrap()
  93. };
  94. proc_macro::TokenStream::from({
  95. if is_zero_copy {
  96. quote! {
  97. #[zero_copy]
  98. #account_strct
  99. unsafe impl anchor_lang::__private::bytemuck::Pod for #account_name {}
  100. unsafe impl anchor_lang::__private::bytemuck::Zeroable for #account_name {}
  101. impl anchor_lang::ZeroCopy for #account_name {}
  102. impl anchor_lang::Discriminator for #account_name {
  103. fn discriminator() -> [u8; 8] {
  104. #discriminator
  105. }
  106. }
  107. // This trait is useful for clients deserializing accounts.
  108. // It's expected on-chain programs deserialize via zero-copy.
  109. impl anchor_lang::AccountDeserialize for #account_name {
  110. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  111. if buf.len() < #discriminator.len() {
  112. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
  113. }
  114. let given_disc = &buf[..8];
  115. if &#discriminator != given_disc {
  116. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
  117. }
  118. Self::try_deserialize_unchecked(buf)
  119. }
  120. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  121. let data: &[u8] = &buf[8..];
  122. // Re-interpret raw bytes into the POD data structure.
  123. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  124. // Copy out the bytes into a new, owned data structure.
  125. Ok(*account)
  126. }
  127. }
  128. }
  129. } else {
  130. quote! {
  131. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  132. #account_strct
  133. impl anchor_lang::AccountSerialize for #account_name {
  134. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
  135. writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
  136. AnchorSerialize::serialize(
  137. self,
  138. writer
  139. )
  140. .map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
  141. Ok(())
  142. }
  143. }
  144. impl anchor_lang::AccountDeserialize for #account_name {
  145. fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  146. if buf.len() < #discriminator.len() {
  147. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
  148. }
  149. let given_disc = &buf[..8];
  150. if &#discriminator != given_disc {
  151. return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
  152. }
  153. Self::try_deserialize_unchecked(buf)
  154. }
  155. fn try_deserialize_unchecked(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
  156. let mut data: &[u8] = &buf[8..];
  157. AnchorDeserialize::deserialize(&mut data)
  158. .map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotDeserialize.into())
  159. }
  160. }
  161. impl anchor_lang::Discriminator for #account_name {
  162. fn discriminator() -> [u8; 8] {
  163. #discriminator
  164. }
  165. }
  166. }
  167. }
  168. })
  169. }
  170. /// Extends the `#[account]` attribute to allow one to create associated
  171. /// accounts. This includes a `Default` implementation, which means all fields
  172. /// in an `#[associated]` struct must implement `Default` and an
  173. /// `anchor_lang::Bump` trait implementation, which allows the account to be
  174. /// used as a program derived address.
  175. ///
  176. /// # Zero Copy Deserialization
  177. ///
  178. /// Similar to the `#[account]` attribute one can enable zero copy
  179. /// deserialization by using the `zero_copy` argument:
  180. ///
  181. /// ```ignore
  182. /// #[associated(zero_copy)]
  183. /// ```
  184. ///
  185. /// For more, see the [`account`](./attr.account.html) attribute.
  186. #[proc_macro_attribute]
  187. pub fn associated(
  188. args: proc_macro::TokenStream,
  189. input: proc_macro::TokenStream,
  190. ) -> proc_macro::TokenStream {
  191. let mut account_strct = parse_macro_input!(input as syn::ItemStruct);
  192. let account_name = &account_strct.ident;
  193. // Add a `__nonce: u8` field to the struct to hold the bump seed for
  194. // the program dervied address.
  195. match &mut account_strct.fields {
  196. syn::Fields::Named(fields) => {
  197. let mut segments = syn::punctuated::Punctuated::new();
  198. segments.push(syn::PathSegment {
  199. ident: syn::Ident::new("u8", proc_macro2::Span::call_site()),
  200. arguments: syn::PathArguments::None,
  201. });
  202. fields.named.push(syn::Field {
  203. attrs: Vec::new(),
  204. vis: syn::Visibility::Inherited,
  205. ident: Some(syn::Ident::new("__nonce", proc_macro2::Span::call_site())),
  206. colon_token: Some(syn::token::Colon {
  207. spans: [proc_macro2::Span::call_site()],
  208. }),
  209. ty: syn::Type::Path(syn::TypePath {
  210. qself: None,
  211. path: syn::Path {
  212. leading_colon: None,
  213. segments,
  214. },
  215. }),
  216. });
  217. }
  218. _ => panic!("Fields must be named"),
  219. }
  220. let args: proc_macro2::TokenStream = args.into();
  221. proc_macro::TokenStream::from(quote! {
  222. #[anchor_lang::account(#args)]
  223. #[derive(Default)]
  224. #account_strct
  225. impl anchor_lang::Bump for #account_name {
  226. fn seed(&self) -> u8 {
  227. self.__nonce
  228. }
  229. }
  230. })
  231. }
  232. #[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
  233. pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
  234. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  235. let account_name = &account_strct.ident;
  236. let fields = match &account_strct.fields {
  237. syn::Fields::Named(n) => n,
  238. _ => panic!("Fields must be named"),
  239. };
  240. let methods: Vec<proc_macro2::TokenStream> = fields
  241. .named
  242. .iter()
  243. .filter_map(|field: &syn::Field| {
  244. field
  245. .attrs
  246. .iter()
  247. .filter(|attr| {
  248. let name = anchor_syn::parser::tts_to_string(&attr.path);
  249. if name != "accessor" {
  250. return false;
  251. }
  252. return true;
  253. })
  254. .next()
  255. .map(|attr| {
  256. let mut tts = attr.tokens.clone().into_iter();
  257. let g_stream = match tts.next().expect("Must have a token group") {
  258. proc_macro2::TokenTree::Group(g) => g.stream(),
  259. _ => panic!("Invalid syntax"),
  260. };
  261. let accessor_ty = match g_stream.into_iter().next() {
  262. Some(token) => token,
  263. _ => panic!("Missing accessor type"),
  264. };
  265. let field_name = field.ident.as_ref().unwrap();
  266. let get_field: proc_macro2::TokenStream =
  267. format!("get_{}", field_name.to_string()).parse().unwrap();
  268. let set_field: proc_macro2::TokenStream =
  269. format!("set_{}", field_name.to_string()).parse().unwrap();
  270. quote! {
  271. pub fn #get_field(&self) -> #accessor_ty {
  272. anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
  273. }
  274. pub fn #set_field(&mut self, input: &#accessor_ty) {
  275. self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
  276. }
  277. }
  278. })
  279. })
  280. .collect();
  281. proc_macro::TokenStream::from(quote! {
  282. impl #account_name {
  283. #(#methods)*
  284. }
  285. })
  286. }
  287. /// A data structure that can be used as an internal field for a zero copy
  288. /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
  289. ///
  290. /// This is just a convenient alias for
  291. ///
  292. /// ```ignore
  293. /// #[derive(Copy, Clone)]
  294. /// #[repr(packed)]
  295. /// struct MyStruct {...}
  296. /// ```
  297. #[proc_macro_attribute]
  298. pub fn zero_copy(
  299. _args: proc_macro::TokenStream,
  300. item: proc_macro::TokenStream,
  301. ) -> proc_macro::TokenStream {
  302. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  303. proc_macro::TokenStream::from(quote! {
  304. #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
  305. #[repr(packed)]
  306. #account_strct
  307. })
  308. }