lib.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. extern crate proc_macro;
  2. use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
  3. use quote::{quote, ToTokens};
  4. use syn::{
  5. parenthesized,
  6. parse::{Parse, ParseStream},
  7. parse_macro_input,
  8. token::{Comma, Paren},
  9. Ident, LitStr,
  10. };
  11. mod id;
  12. #[cfg(feature = "lazy-account")]
  13. mod lazy;
  14. /// An attribute for a data structure representing a Solana account.
  15. ///
  16. /// `#[account]` generates trait implementations for the following traits:
  17. ///
  18. /// - [`AccountSerialize`](./trait.AccountSerialize.html)
  19. /// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
  20. /// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
  21. /// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
  22. /// - [`Clone`](https://doc.rust-lang.org/std/clone/trait.Clone.html)
  23. /// - [`Discriminator`](./trait.Discriminator.html)
  24. /// - [`Owner`](./trait.Owner.html)
  25. ///
  26. /// When implementing account serialization traits the first 8 bytes are
  27. /// reserved for a unique account discriminator by default, self described by
  28. /// the first 8 bytes of the SHA256 of the account's Rust ident. This is unless
  29. /// the discriminator was overridden with the `discriminator` argument (see
  30. /// [Arguments](#arguments)).
  31. ///
  32. /// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
  33. /// check this discriminator. If it doesn't match, an invalid account was given,
  34. /// and the account deserialization will exit with an error.
  35. ///
  36. /// # Arguments
  37. ///
  38. /// - `discriminator`: Override the default 8-byte discriminator
  39. ///
  40. /// **Usage:** `discriminator = <CONST_EXPR>`
  41. ///
  42. /// All constant expressions are supported.
  43. ///
  44. /// **Examples:**
  45. ///
  46. /// - `discriminator = 1` (shortcut for `[1]`)
  47. /// - `discriminator = [1, 2, 3, 4]`
  48. /// - `discriminator = b"hi"`
  49. /// - `discriminator = MY_DISC`
  50. /// - `discriminator = get_disc(...)`
  51. ///
  52. /// # Zero Copy Deserialization
  53. ///
  54. /// **WARNING**: Zero copy deserialization is an experimental feature. It's
  55. /// recommended to use it only when necessary, i.e., when you have extremely
  56. /// large accounts that cannot be Borsh deserialized without hitting stack or
  57. /// heap limits.
  58. ///
  59. /// ## Usage
  60. ///
  61. /// To enable zero-copy-deserialization, one can pass in the `zero_copy`
  62. /// argument to the macro as follows:
  63. ///
  64. /// ```ignore
  65. /// #[account(zero_copy)]
  66. /// ```
  67. ///
  68. /// This can be used to conveniently implement
  69. /// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
  70. /// with [`AccountLoader`](./accounts/account_loader/struct.AccountLoader.html).
  71. ///
  72. /// Other than being more efficient, the most salient benefit this provides is
  73. /// the ability to define account types larger than the max stack or heap size.
  74. /// When using borsh, the account has to be copied and deserialized into a new
  75. /// data structure and thus is constrained by stack and heap limits imposed by
  76. /// the BPF VM. With zero copy deserialization, all bytes from the account's
  77. /// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
  78. /// the data structure. No allocations or copies necessary. Hence the ability
  79. /// to get around stack and heap limitations.
  80. ///
  81. /// To facilitate this, all fields in an account must be constrained to be
  82. /// "plain old data", i.e., they must implement
  83. /// [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html). Please review the
  84. /// [`safety`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html#safety)
  85. /// section before using.
  86. ///
  87. /// Using `zero_copy` requires adding the following dependency to your `Cargo.toml` file:
  88. ///
  89. /// ```toml
  90. /// bytemuck = { version = "1.17", features = ["derive", "min_const_generics"] }
  91. /// ```
  92. #[proc_macro_attribute]
  93. pub fn account(
  94. args: proc_macro::TokenStream,
  95. input: proc_macro::TokenStream,
  96. ) -> proc_macro::TokenStream {
  97. let args = parse_macro_input!(args as AccountArgs);
  98. let namespace = args.namespace.unwrap_or_default();
  99. let is_zero_copy = args.zero_copy.is_some();
  100. let unsafe_bytemuck = args.zero_copy.unwrap_or_default();
  101. let account_strct = parse_macro_input!(input as syn::ItemStruct);
  102. let account_name = &account_strct.ident;
  103. let account_name_str = account_name.to_string();
  104. let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
  105. let discriminator = args
  106. .overrides
  107. .and_then(|ov| ov.discriminator)
  108. .unwrap_or_else(|| {
  109. // Namespace the discriminator to prevent collisions.
  110. let namespace = if namespace.is_empty() {
  111. "account"
  112. } else {
  113. &namespace
  114. };
  115. gen_discriminator(namespace, account_name)
  116. });
  117. let disc = if account_strct.generics.lt_token.is_some() {
  118. quote! { #account_name::#type_gen::DISCRIMINATOR }
  119. } else {
  120. quote! { #account_name::DISCRIMINATOR }
  121. };
  122. let owner_impl = {
  123. if namespace.is_empty() {
  124. quote! {
  125. #[automatically_derived]
  126. impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
  127. fn owner() -> Pubkey {
  128. crate::ID
  129. }
  130. }
  131. }
  132. } else {
  133. quote! {}
  134. }
  135. };
  136. let unsafe_bytemuck_impl = {
  137. if unsafe_bytemuck {
  138. quote! {
  139. #[automatically_derived]
  140. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
  141. #[automatically_derived]
  142. unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
  143. }
  144. } else {
  145. quote! {}
  146. }
  147. };
  148. let bytemuck_derives = {
  149. if !unsafe_bytemuck {
  150. quote! {
  151. #[zero_copy]
  152. }
  153. } else {
  154. quote! {
  155. #[zero_copy(unsafe)]
  156. }
  157. }
  158. };
  159. proc_macro::TokenStream::from({
  160. if is_zero_copy {
  161. quote! {
  162. #bytemuck_derives
  163. #account_strct
  164. #unsafe_bytemuck_impl
  165. #[automatically_derived]
  166. impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
  167. #[automatically_derived]
  168. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  169. const DISCRIMINATOR: &'static [u8] = #discriminator;
  170. }
  171. // This trait is useful for clients deserializing accounts.
  172. // It's expected on-chain programs deserialize via zero-copy.
  173. #[automatically_derived]
  174. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  175. fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  176. if buf.len() < #disc.len() {
  177. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  178. }
  179. let given_disc = &buf[..#disc.len()];
  180. if #disc != given_disc {
  181. return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
  182. }
  183. Self::try_deserialize_unchecked(buf)
  184. }
  185. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  186. let data: &[u8] = &buf[#disc.len()..];
  187. // Re-interpret raw bytes into the POD data structure.
  188. let account = anchor_lang::__private::bytemuck::from_bytes(data);
  189. // Copy out the bytes into a new, owned data structure.
  190. Ok(*account)
  191. }
  192. }
  193. #owner_impl
  194. }
  195. } else {
  196. let lazy = {
  197. #[cfg(feature = "lazy-account")]
  198. match namespace.is_empty().then(|| lazy::gen_lazy(&account_strct)) {
  199. Some(Ok(lazy)) => lazy,
  200. // If lazy codegen fails for whatever reason, return empty tokenstream which
  201. // will make the account unusable with `LazyAccount<T>`
  202. _ => Default::default(),
  203. }
  204. #[cfg(not(feature = "lazy-account"))]
  205. proc_macro2::TokenStream::default()
  206. };
  207. quote! {
  208. #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
  209. #account_strct
  210. #[automatically_derived]
  211. impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
  212. fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
  213. if writer.write_all(#disc).is_err() {
  214. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  215. }
  216. if AnchorSerialize::serialize(self, writer).is_err() {
  217. return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
  218. }
  219. Ok(())
  220. }
  221. }
  222. #[automatically_derived]
  223. impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
  224. fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  225. if buf.len() < #disc.len() {
  226. return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
  227. }
  228. let given_disc = &buf[..#disc.len()];
  229. if #disc != given_disc {
  230. return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
  231. }
  232. Self::try_deserialize_unchecked(buf)
  233. }
  234. fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
  235. let mut data: &[u8] = &buf[#disc.len()..];
  236. AnchorDeserialize::deserialize(&mut data)
  237. .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
  238. }
  239. }
  240. #[automatically_derived]
  241. impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
  242. const DISCRIMINATOR: &'static [u8] = #discriminator;
  243. }
  244. #owner_impl
  245. #lazy
  246. }
  247. }
  248. })
  249. }
  250. #[derive(Debug, Default)]
  251. struct AccountArgs {
  252. /// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
  253. zero_copy: Option<bool>,
  254. /// Account namespace override, `account` if not specified
  255. namespace: Option<String>,
  256. /// Named overrides
  257. overrides: Option<Overrides>,
  258. }
  259. impl Parse for AccountArgs {
  260. fn parse(input: ParseStream) -> syn::Result<Self> {
  261. let mut parsed = Self::default();
  262. let args = input.parse_terminated::<_, Comma>(AccountArg::parse)?;
  263. for arg in args {
  264. match arg {
  265. AccountArg::ZeroCopy { is_unsafe } => {
  266. parsed.zero_copy.replace(is_unsafe);
  267. }
  268. AccountArg::Namespace(ns) => {
  269. parsed.namespace.replace(ns);
  270. }
  271. AccountArg::Overrides(ov) => {
  272. parsed.overrides.replace(ov);
  273. }
  274. }
  275. }
  276. Ok(parsed)
  277. }
  278. }
  279. enum AccountArg {
  280. ZeroCopy { is_unsafe: bool },
  281. Namespace(String),
  282. Overrides(Overrides),
  283. }
  284. impl Parse for AccountArg {
  285. fn parse(input: ParseStream) -> syn::Result<Self> {
  286. // Namespace
  287. if let Ok(ns) = input.parse::<LitStr>() {
  288. return Ok(Self::Namespace(
  289. ns.to_token_stream().to_string().replace('\"', ""),
  290. ));
  291. }
  292. // Zero copy
  293. if input.fork().parse::<Ident>()? == "zero_copy" {
  294. input.parse::<Ident>()?;
  295. let is_unsafe = if input.peek(Paren) {
  296. let content;
  297. parenthesized!(content in input);
  298. let content = content.parse::<proc_macro2::TokenStream>()?;
  299. if content.to_string().as_str().trim() != "unsafe" {
  300. return Err(syn::Error::new(
  301. syn::spanned::Spanned::span(&content),
  302. "Expected `unsafe`",
  303. ));
  304. }
  305. true
  306. } else {
  307. false
  308. };
  309. return Ok(Self::ZeroCopy { is_unsafe });
  310. };
  311. // Overrides
  312. input.parse::<Overrides>().map(Self::Overrides)
  313. }
  314. }
  315. #[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
  316. pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
  317. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  318. let account_name = &account_strct.ident;
  319. let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
  320. let fields = match &account_strct.fields {
  321. syn::Fields::Named(n) => n,
  322. _ => panic!("Fields must be named"),
  323. };
  324. let methods: Vec<proc_macro2::TokenStream> = fields
  325. .named
  326. .iter()
  327. .filter_map(|field: &syn::Field| {
  328. field
  329. .attrs
  330. .iter()
  331. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
  332. .map(|attr| {
  333. let mut tts = attr.tokens.clone().into_iter();
  334. let g_stream = match tts.next().expect("Must have a token group") {
  335. proc_macro2::TokenTree::Group(g) => g.stream(),
  336. _ => panic!("Invalid syntax"),
  337. };
  338. let accessor_ty = match g_stream.into_iter().next() {
  339. Some(token) => token,
  340. _ => panic!("Missing accessor type"),
  341. };
  342. let field_name = field.ident.as_ref().unwrap();
  343. let get_field: proc_macro2::TokenStream =
  344. format!("get_{field_name}").parse().unwrap();
  345. let set_field: proc_macro2::TokenStream =
  346. format!("set_{field_name}").parse().unwrap();
  347. quote! {
  348. pub fn #get_field(&self) -> #accessor_ty {
  349. anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
  350. }
  351. pub fn #set_field(&mut self, input: &#accessor_ty) {
  352. self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
  353. }
  354. }
  355. })
  356. })
  357. .collect();
  358. proc_macro::TokenStream::from(quote! {
  359. #[automatically_derived]
  360. impl #impl_gen #account_name #ty_gen #where_clause {
  361. #(#methods)*
  362. }
  363. })
  364. }
  365. /// A data structure that can be used as an internal field for a zero copy
  366. /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
  367. ///
  368. /// `#[zero_copy]` is just a convenient alias for
  369. ///
  370. /// ```ignore
  371. /// #[derive(Copy, Clone)]
  372. /// #[derive(bytemuck::Zeroable)]
  373. /// #[derive(bytemuck::Pod)]
  374. /// #[repr(C)]
  375. /// struct MyStruct {...}
  376. /// ```
  377. #[proc_macro_attribute]
  378. pub fn zero_copy(
  379. args: proc_macro::TokenStream,
  380. item: proc_macro::TokenStream,
  381. ) -> proc_macro::TokenStream {
  382. let mut is_unsafe = false;
  383. for arg in args.into_iter() {
  384. match arg {
  385. proc_macro::TokenTree::Ident(ident) => {
  386. if ident.to_string() == "unsafe" {
  387. // `#[zero_copy(unsafe)]` maintains the old behaviour
  388. //
  389. // ```ignore
  390. // #[derive(Copy, Clone)]
  391. // #[repr(packed)]
  392. // struct MyStruct {...}
  393. // ```
  394. is_unsafe = true;
  395. } else {
  396. // TODO: how to return a compile error with a span (can't return prase error because expected type TokenStream)
  397. panic!("expected single ident `unsafe`");
  398. }
  399. }
  400. _ => {
  401. panic!("expected single ident `unsafe`");
  402. }
  403. }
  404. }
  405. let account_strct = parse_macro_input!(item as syn::ItemStruct);
  406. // Takes the first repr. It's assumed that more than one are not on the
  407. // struct.
  408. let attr = account_strct
  409. .attrs
  410. .iter()
  411. .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
  412. let repr = match attr {
  413. // Users might want to manually specify repr modifiers e.g. repr(C, packed)
  414. Some(_attr) => quote! {},
  415. None => {
  416. if is_unsafe {
  417. quote! {#[repr(Rust, packed)]}
  418. } else {
  419. quote! {#[repr(C)]}
  420. }
  421. }
  422. };
  423. let mut has_pod_attr = false;
  424. let mut has_zeroable_attr = false;
  425. for attr in account_strct.attrs.iter() {
  426. let token_string = attr.tokens.to_string();
  427. if token_string.contains("bytemuck :: Pod") {
  428. has_pod_attr = true;
  429. }
  430. if token_string.contains("bytemuck :: Zeroable") {
  431. has_zeroable_attr = true;
  432. }
  433. }
  434. // Once the Pod derive macro is expanded the compiler has to use the local crate's
  435. // bytemuck `::bytemuck::Pod` anyway, so we're no longer using the privately
  436. // exported anchor bytemuck `__private::bytemuck`, so that there won't be any
  437. // possible disparity between the anchor version and the local crate's version.
  438. let pod = if has_pod_attr || is_unsafe {
  439. quote! {}
  440. } else {
  441. quote! {#[derive(::bytemuck::Pod)]}
  442. };
  443. let zeroable = if has_zeroable_attr || is_unsafe {
  444. quote! {}
  445. } else {
  446. quote! {#[derive(::bytemuck::Zeroable)]}
  447. };
  448. let ret = quote! {
  449. #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
  450. #repr
  451. #pod
  452. #zeroable
  453. #account_strct
  454. };
  455. #[cfg(feature = "idl-build")]
  456. {
  457. let derive_unsafe = if is_unsafe {
  458. // Not a real proc-macro but exists in order to pass the serialization info
  459. quote! { #[derive(bytemuck::Unsafe)] }
  460. } else {
  461. quote! {}
  462. };
  463. let zc_struct = syn::parse2(quote! {
  464. #derive_unsafe
  465. #ret
  466. })
  467. .unwrap();
  468. let idl_build_impl = anchor_syn::idl::impl_idl_build_struct(&zc_struct);
  469. return proc_macro::TokenStream::from(quote! {
  470. #ret
  471. #idl_build_impl
  472. });
  473. }
  474. #[allow(unreachable_code)]
  475. proc_macro::TokenStream::from(ret)
  476. }
  477. /// Convenience macro to define a static public key.
  478. ///
  479. /// Input: a single literal base58 string representation of a Pubkey.
  480. #[proc_macro]
  481. pub fn pubkey(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  482. let pk = parse_macro_input!(input as id::Pubkey);
  483. proc_macro::TokenStream::from(quote! {#pk})
  484. }
  485. /// Defines the program's ID. This should be used at the root of all Anchor
  486. /// based programs.
  487. #[proc_macro]
  488. pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
  489. #[cfg(feature = "idl-build")]
  490. let address = input.clone().to_string();
  491. let id = parse_macro_input!(input as id::Id);
  492. let ret = quote! { #id };
  493. #[cfg(feature = "idl-build")]
  494. {
  495. let idl_print = anchor_syn::idl::gen_idl_print_fn_address(address);
  496. return proc_macro::TokenStream::from(quote! {
  497. #ret
  498. #idl_print
  499. });
  500. }
  501. #[allow(unreachable_code)]
  502. proc_macro::TokenStream::from(ret)
  503. }