lib.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. extern crate proc_macro;
  2. use quote::quote;
  3. use syn::parse_macro_input;
  4. mod id;
  5. /// An attribute for a data structure representing a Solana account.
  6. ///
  7. /// `#[account]` generates trait implementations for the following traits:
  8. ///
  9. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  10. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  11. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  12. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  13. /// - [`Clone`](https://doc.rust-lang.org/std/clone/trait.Clone.html)
  14. /// - [`Discriminator`](./trait.Discriminator.html)
  15. /// - [`Owner`](./trait.Owner.html)
  16. ///
  17. /// When implementing account serialization traits the first 8 bytes are
  18. /// reserved for a unique account discriminator, self described by the first 8
  19. /// bytes of the SHA256 of the account's Rust ident.
  20. ///
  21. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  22. /// check this discriminator. If it doesn't match, an invalid account was given,
  23. /// and the account deserialization will exit with an error.
  24. ///
  25. /// # Zero Copy Deserialization
  26. ///
  27. /// **WARNING**: Zero copy deserialization is an experimental feature. It's
  28. /// recommended to use it only when necessary, i.e., when you have extremely
  29. /// large accounts that cannot be Borsh deserialized without hitting stack or
  30. /// heap limits.
  31. ///
  32. /// ## Usage
  33. ///
  34. /// To enable zero-copy-deserialization, one can pass in the `zero_copy`
  35. /// argument to the macro as follows:
  36. ///
  37. /// ```ignore
  38. /// #[account(zero_copy)]
  39. /// ```
  40. ///
  41. /// This can be used to conveniently implement
  42. /// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
  43. /// with [`AccountLoader`](./accounts/account_loader/struct.AccountLoader.html).
  44. ///
  45. /// Other than being more efficient, the most salient benefit this provides is
  46. /// the ability to define account types larger than the max stack or heap size.
  47. /// When using borsh, the account has to be copied and deserialized into a new
  48. /// data structure and thus is constrained by stack and heap limits imposed by
  49. /// the BPF VM. With zero copy deserialization, all bytes from the account's
  50. /// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
  51. /// the data structure. No allocations or copies necessary. Hence the ability
  52. /// to get around stack and heap limitations.
  53. ///
  54. /// To facilitate this, all fields in an account must be constrained to be
  55. /// "plain old data", i.e., they must implement
  56. /// [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html). Please review the
  57. /// [`safety`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html#safety)
  58. /// section before using.
  59. ///
  60. /// Using `zero_copy` requires adding the following to your `cargo.toml` file:
  61. /// `bytemuck = { version = "1.4.0", features = ["derive", "min_const_generics"]}`
  62. #[proc_macro_attribute]
  63. pub fn account(
  64. args: proc_macro::TokenStream,
  65. input: proc_macro::TokenStream,
  66. ) -> proc_macro::TokenStream {
  67. let mut namespace = "".to_string();
  68. let mut is_zero_copy = false;
  69. let mut unsafe_bytemuck = false;
  70. let args_str = args.to_string();
  71. let args: Vec<&str> = args_str.split(',').collect();
  72. if args.len() > 2 {
  73. panic!("Only two args are allowed to the account attribute.")
  74. }
  75. for arg in args {
  76. let ns = arg
  77. .to_string()
  78. .replace('\"', "")
  79. .chars()
  80. .filter(|c| !c.is_whitespace())
  81. .collect();
  82. if ns == "zero_copy" {
  83. is_zero_copy = true;
  84. unsafe_bytemuck = false;
  85. } else if ns == "zero_copy(unsafe)" {
  86. is_zero_copy = true;
  87. unsafe_bytemuck = true;
  88. } else {
  89. namespace = ns;
  90. }
  91. }
  92. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  93. let account_name = &account_strct.ident;
  94. let account_name_str = account_name.to_string();
  95. let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
  96. let discriminator: proc_macro2::TokenStream = {
  97. // Namespace the discriminator to prevent collisions.
  98. let discriminator_preimage = {
  99. // For now, zero copy accounts can't be namespaced.
  100. if namespace.is_empty() {
  101. format!("account:{account_name}")
  102. } else {
  103. format!("{namespace}:{account_name}")
  104. }
  105. };
  106. let mut discriminator = [0u8; 8];
  107. discriminator.copy_from_slice(
  108. &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
  109. );
  110. format!("{discriminator:?}").parse().unwrap()
  111. };
  112. let owner_impl = {
  113. if namespace.is_empty() {
  114. quote! {
  115. #[automatically_derived]
  116. impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
  117. fn owner() -> Pubkey {
  118. crate::ID
  119. }
  120. }
  121. }
  122. } else {
  123. quote! {}
  124. }
  125. };
  126. let unsafe_bytemuck_impl = {
  127. if unsafe_bytemuck {
  128. quote! {
  129. #[automatically_derived]
  130. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
  131. #[automatically_derived]
  132. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
  133. }
  134. } else {
  135. quote! {}
  136. }
  137. };
  138. let bytemuck_derives = {
  139. if !unsafe_bytemuck {
  140. quote! {
  141. #[zero_copy]
  142. }
  143. } else {
  144. quote! {
  145. #[zero_copy(unsafe)]
  146. }
  147. }
  148. };
  149. proc_macro::TokenStream::from({
  150. if is_zero_copy {
  151. quote! {
  152. #bytemuck_derives
  153. #account_strct
  154. #unsafe_bytemuck_impl
  155. #[automatically_derived]
  156. impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
  157. #[automatically_derived]
  158. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  159. const DISCRIMINATOR: &'static [u8] = &#discriminator;
  160. }
  161. // This trait is useful for clients deserializing accounts.
  162. // It's expected on-chain programs deserialize via zero-copy.
  163. #[automatically_derived]
  164. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  165. fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  166. if buf.len() < #discriminator.len() {
  167. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  168. }
  169. let given_disc = &buf[..#discriminator.len()];
  170. if &#discriminator != given_disc {
  171. return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
  172. }
  173. Self::try_deserialize_unchecked(buf)
  174. }
  175. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  176. let data: &[u8] = &buf[#discriminator.len()..];
  177. // Re-interpret raw bytes into the POD data structure.
  178. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  179. // Copy out the bytes into a new, owned data structure.
  180. Ok(*account)
  181. }
  182. }
  183. #owner_impl
  184. }
  185. } else {
  186. quote! {
  187. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  188. #account_strct
  189. #[automatically_derived]
  190. impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
  191. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
  192. if writer.write_all(&#discriminator).is_err() {
  193. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  194. }
  195. if AnchorSerialize::serialize(self, writer).is_err() {
  196. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  197. }
  198. Ok(())
  199. }
  200. }
  201. #[automatically_derived]
  202. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  203. fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  204. if buf.len() < #discriminator.len() {
  205. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  206. }
  207. let given_disc = &buf[..#discriminator.len()];
  208. if &#discriminator != given_disc {
  209. return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
  210. }
  211. Self::try_deserialize_unchecked(buf)
  212. }
  213. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  214. let mut data: &[u8] = &buf[#discriminator.len()..];
  215. AnchorDeserialize::deserialize(&mut data)
  216. .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
  217. }
  218. }
  219. #[automatically_derived]
  220. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  221. const DISCRIMINATOR: &'static [u8] = &#discriminator;
  222. }
  223. #owner_impl
  224. }
  225. }
  226. })
  227. }
  228. #[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
  229. pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
  230. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  231. let account_name = &account_strct.ident;
  232. let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
  233. let fields = match &account_strct.fields {
  234. syn::Fields::Named(n) => n,
  235. _ => panic!("Fields must be named"),
  236. };
  237. let methods: Vec<proc_macro2::TokenStream> = fields
  238. .named
  239. .iter()
  240. .filter_map(|field: &syn::Field| {
  241. field
  242. .attrs
  243. .iter()
  244. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
  245. .map(|attr| {
  246. let mut tts = attr.tokens.clone().into_iter();
  247. let g_stream = match tts.next().expect("Must have a token group") {
  248. proc_macro2::TokenTree::Group(g) => g.stream(),
  249. _ => panic!("Invalid syntax"),
  250. };
  251. let accessor_ty = match g_stream.into_iter().next() {
  252. Some(token) => token,
  253. _ => panic!("Missing accessor type"),
  254. };
  255. let field_name = field.ident.as_ref().unwrap();
  256. let get_field: proc_macro2::TokenStream =
  257. format!("get_{field_name}").parse().unwrap();
  258. let set_field: proc_macro2::TokenStream =
  259. format!("set_{field_name}").parse().unwrap();
  260. quote! {
  261. pub fn #get_field(&self) -> #accessor_ty {
  262. anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
  263. }
  264. pub fn #set_field(&mut self, input: &#accessor_ty) {
  265. self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
  266. }
  267. }
  268. })
  269. })
  270. .collect();
  271. proc_macro::TokenStream::from(quote! {
  272. #[automatically_derived]
  273. impl #impl_gen #account_name #ty_gen #where_clause {
  274. #(#methods)*
  275. }
  276. })
  277. }
  278. /// A data structure that can be used as an internal field for a zero copy
  279. /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
  280. ///
  281. /// `#[zero_copy]` is just a convenient alias for
  282. ///
  283. /// ```ignore
  284. /// #[derive(Copy, Clone)]
  285. /// #[derive(bytemuck::Zeroable)]
  286. /// #[derive(bytemuck::Pod)]
  287. /// #[repr(C)]
  288. /// struct MyStruct {...}
  289. /// ```
  290. #[proc_macro_attribute]
  291. pub fn zero_copy(
  292. args: proc_macro::TokenStream,
  293. item: proc_macro::TokenStream,
  294. ) -> proc_macro::TokenStream {
  295. let mut is_unsafe = false;
  296. for arg in args.into_iter() {
  297. match arg {
  298. proc_macro::TokenTree::Ident(ident) => {
  299. if ident.to_string() == "unsafe" {
  300. // `#[zero_copy(unsafe)]` maintains the old behaviour
  301. //
  302. // ```ignore
  303. // #[derive(Copy, Clone)]
  304. // #[repr(packed)]
  305. // struct MyStruct {...}
  306. // ```
  307. is_unsafe = true;
  308. } else {
  309. // TODO: how to return a compile error with a span (can't return prase error because expected type TokenStream)
  310. panic!("expected single ident `unsafe`");
  311. }
  312. }
  313. _ => {
  314. panic!("expected single ident `unsafe`");
  315. }
  316. }
  317. }
  318. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  319. // Takes the first repr. It's assumed that more than one are not on the
  320. // struct.
  321. let attr = account_strct
  322. .attrs
  323. .iter()
  324. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
  325. let repr = match attr {
  326. // Users might want to manually specify repr modifiers e.g. repr(C, packed)
  327. Some(_attr) => quote! {},
  328. None => {
  329. if is_unsafe {
  330. quote! {#[repr(packed)]}
  331. } else {
  332. quote! {#[repr(C)]}
  333. }
  334. }
  335. };
  336. let mut has_pod_attr = false;
  337. let mut has_zeroable_attr = false;
  338. for attr in account_strct.attrs.iter() {
  339. let token_string = attr.tokens.to_string();
  340. if token_string.contains("bytemuck :: Pod") {
  341. has_pod_attr = true;
  342. }
  343. if token_string.contains("bytemuck :: Zeroable") {
  344. has_zeroable_attr = true;
  345. }
  346. }
  347. // Once the Pod derive macro is expanded the compiler has to use the local crate's
  348. // bytemuck `::bytemuck::Pod` anyway, so we're no longer using the privately
  349. // exported anchor bytemuck `__private::bytemuck`, so that there won't be any
  350. // possible disparity between the anchor version and the local crate's version.
  351. let pod = if has_pod_attr || is_unsafe {
  352. quote! {}
  353. } else {
  354. quote! {#[derive(::bytemuck::Pod)]}
  355. };
  356. let zeroable = if has_zeroable_attr || is_unsafe {
  357. quote! {}
  358. } else {
  359. quote! {#[derive(::bytemuck::Zeroable)]}
  360. };
  361. let ret = quote! {
  362. #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
  363. #repr
  364. #pod
  365. #zeroable
  366. #account_strct
  367. };
  368. #[cfg(feature = "idl-build")]
  369. {
  370. let derive_unsafe = if is_unsafe {
  371. // Not a real proc-macro but exists in order to pass the serialization info
  372. quote! { #[derive(bytemuck::Unsafe)] }
  373. } else {
  374. quote! {}
  375. };
  376. let zc_struct = syn::parse2(quote! {
  377. #derive_unsafe
  378. #ret
  379. })
  380. .unwrap();
  381. let idl_build_impl = anchor_syn::idl::impl_idl_build_struct(&zc_struct);
  382. return proc_macro::TokenStream::from(quote! {
  383. #ret
  384. #idl_build_impl
  385. });
  386. }
  387. #[allow(unreachable_code)]
  388. proc_macro::TokenStream::from(ret)
  389. }
  390. /// Convenience macro to define a static public key.
  391. ///
  392. /// Input: a single literal base58 string representation of a Pubkey.
  393. #[proc_macro]
  394. pub fn pubkey(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  395. let pk = parse_macro_input!(input as id::Pubkey);
  396. proc_macro::TokenStream::from(quote! {#pk})
  397. }
  398. /// Defines the program's ID. This should be used at the root of all Anchor
  399. /// based programs.
  400. #[proc_macro]
  401. pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  402. #[cfg(feature = "idl-build")]
  403. let address = input.clone().to_string();
  404. let id = parse_macro_input!(input as id::Id);
  405. let ret = quote! { #id };
  406. #[cfg(feature = "idl-build")]
  407. {
  408. let idl_print = anchor_syn::idl::gen_idl_print_fn_address(address);
  409. return proc_macro::TokenStream::from(quote! {
  410. #ret
  411. #idl_print
  412. });
  413. }
  414. #[allow(unreachable_code)]
  415. proc_macro::TokenStream::from(ret)
  416. }