common.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. use anchor_lang_idl::types::{
  2. Idl, IdlArrayLen, IdlDefinedFields, IdlField, IdlGenericArg, IdlRepr, IdlSerialization,
  3. IdlType, IdlTypeDef, IdlTypeDefGeneric, IdlTypeDefTy,
  4. };
  5. use proc_macro2::Literal;
  6. use quote::{format_ident, quote};
  7. /// This function should ideally return the absolute path to the declared program's id but because
  8. /// `proc_macro2::Span::call_site().source_file().path()` is behind an unstable feature flag, we
  9. /// are not able to reliably decide where the definition is.
  10. pub fn get_canonical_program_id() -> proc_macro2::TokenStream {
  11. quote! { super::__ID }
  12. }
  13. pub fn gen_docs(docs: &[String]) -> proc_macro2::TokenStream {
  14. let docs = docs
  15. .iter()
  16. .map(|doc| format!("{}{doc}", if doc.is_empty() { "" } else { " " }))
  17. .map(|doc| quote! { #[doc = #doc] });
  18. quote! { #(#docs)* }
  19. }
  20. pub fn gen_discriminator(disc: &[u8]) -> proc_macro2::TokenStream {
  21. quote! { [#(#disc), *] }
  22. }
  23. pub fn gen_accounts_common(idl: &Idl, prefix: &str) -> proc_macro2::TokenStream {
  24. let re_exports = idl
  25. .instructions
  26. .iter()
  27. .map(|ix| format_ident!("__{}_accounts_{}", prefix, ix.name))
  28. .map(|ident| quote! { pub use super::internal::#ident::*; });
  29. quote! {
  30. pub mod accounts {
  31. #(#re_exports)*
  32. }
  33. }
  34. }
  35. pub fn convert_idl_type_to_syn_type(ty: &IdlType) -> syn::Type {
  36. syn::parse_str(&convert_idl_type_to_str(ty)).unwrap()
  37. }
  38. // TODO: Impl `ToString` for `IdlType`
  39. pub fn convert_idl_type_to_str(ty: &IdlType) -> String {
  40. match ty {
  41. IdlType::Bool => "bool".into(),
  42. IdlType::U8 => "u8".into(),
  43. IdlType::I8 => "i8".into(),
  44. IdlType::U16 => "u16".into(),
  45. IdlType::I16 => "i16".into(),
  46. IdlType::U32 => "u32".into(),
  47. IdlType::I32 => "i32".into(),
  48. IdlType::F32 => "f32".into(),
  49. IdlType::U64 => "u64".into(),
  50. IdlType::I64 => "i64".into(),
  51. IdlType::F64 => "f64".into(),
  52. IdlType::U128 => "u128".into(),
  53. IdlType::I128 => "i128".into(),
  54. IdlType::U256 => "u256".into(),
  55. IdlType::I256 => "i256".into(),
  56. IdlType::Bytes => "Vec<u8>".into(),
  57. IdlType::String => "String".into(),
  58. IdlType::Pubkey => "Pubkey".into(),
  59. IdlType::Option(ty) => format!("Option<{}>", convert_idl_type_to_str(ty)),
  60. IdlType::Vec(ty) => format!("Vec<{}>", convert_idl_type_to_str(ty)),
  61. IdlType::Array(ty, len) => format!(
  62. "[{}; {}]",
  63. convert_idl_type_to_str(ty),
  64. match len {
  65. IdlArrayLen::Generic(len) => len.into(),
  66. IdlArrayLen::Value(len) => len.to_string(),
  67. }
  68. ),
  69. IdlType::Defined { name, generics } => generics
  70. .iter()
  71. .map(|generic| match generic {
  72. IdlGenericArg::Type { ty } => convert_idl_type_to_str(ty),
  73. IdlGenericArg::Const { value } => value.into(),
  74. })
  75. .reduce(|mut acc, cur| {
  76. if !acc.is_empty() {
  77. acc.push(',');
  78. }
  79. acc.push_str(&cur);
  80. acc
  81. })
  82. .map(|generics| format!("{name}<{generics}>"))
  83. .unwrap_or(name.into()),
  84. IdlType::Generic(ty) => ty.into(),
  85. _ => unimplemented!("{ty:?}"),
  86. }
  87. }
  88. pub fn convert_idl_type_def_to_ts(
  89. ty_def: &IdlTypeDef,
  90. ty_defs: &[IdlTypeDef],
  91. ) -> proc_macro2::TokenStream {
  92. let name = format_ident!("{}", ty_def.name);
  93. let docs = gen_docs(&ty_def.docs);
  94. let generics = {
  95. let generics = ty_def
  96. .generics
  97. .iter()
  98. .map(|generic| match generic {
  99. IdlTypeDefGeneric::Type { name } => {
  100. let name = format_ident!("{}", name);
  101. quote! { #name }
  102. }
  103. IdlTypeDefGeneric::Const { name, ty } => {
  104. let name = format_ident!("{}", name);
  105. let ty = format_ident!("{}", ty);
  106. quote! { const #name: #ty }
  107. }
  108. })
  109. .collect::<Vec<_>>();
  110. if generics.is_empty() {
  111. quote!()
  112. } else {
  113. quote!(<#(#generics,)*>)
  114. }
  115. };
  116. let attrs = {
  117. let debug_attr = quote!(#[derive(Debug)]);
  118. let default_attr = can_derive_default(ty_def, ty_defs)
  119. .then(|| quote!(#[derive(Default)]))
  120. .unwrap_or_default();
  121. let ser_attr = match &ty_def.serialization {
  122. IdlSerialization::Borsh => quote!(#[derive(AnchorSerialize, AnchorDeserialize)]),
  123. IdlSerialization::Bytemuck => quote!(#[zero_copy]),
  124. IdlSerialization::BytemuckUnsafe => quote!(#[zero_copy(unsafe)]),
  125. _ => unimplemented!("{:?}", ty_def.serialization),
  126. };
  127. let clone_attr = matches!(ty_def.serialization, IdlSerialization::Borsh)
  128. .then(|| quote!(#[derive(Clone)]))
  129. .unwrap_or_default();
  130. let copy_attr = matches!(ty_def.serialization, IdlSerialization::Borsh)
  131. .then(|| can_derive_copy(ty_def, ty_defs).then(|| quote!(#[derive(Copy)])))
  132. .flatten()
  133. .unwrap_or_default();
  134. quote! {
  135. #debug_attr
  136. #default_attr
  137. #ser_attr
  138. #clone_attr
  139. #copy_attr
  140. }
  141. };
  142. let repr = if let Some(repr) = &ty_def.repr {
  143. let kind = match repr {
  144. IdlRepr::Rust(_) => "Rust",
  145. IdlRepr::C(_) => "C",
  146. IdlRepr::Transparent => "transparent",
  147. _ => unimplemented!("{repr:?}"),
  148. };
  149. let kind = format_ident!("{kind}");
  150. let modifier = match repr {
  151. IdlRepr::Rust(modifier) | IdlRepr::C(modifier) => {
  152. let packed = modifier.packed.then(|| quote!(packed)).unwrap_or_default();
  153. let align = modifier
  154. .align
  155. .map(Literal::usize_unsuffixed)
  156. .map(|align| quote!(align(#align)))
  157. .unwrap_or_default();
  158. if packed.is_empty() {
  159. align
  160. } else if align.is_empty() {
  161. packed
  162. } else {
  163. quote! { #packed, #align }
  164. }
  165. }
  166. _ => quote!(),
  167. };
  168. let modifier = if modifier.is_empty() {
  169. modifier
  170. } else {
  171. quote! { , #modifier }
  172. };
  173. quote! { #[repr(#kind #modifier)] }
  174. } else {
  175. quote!()
  176. };
  177. let ty = match &ty_def.ty {
  178. IdlTypeDefTy::Struct { fields } => {
  179. let declare_struct = quote! { pub struct #name #generics };
  180. handle_defined_fields(
  181. fields.as_ref(),
  182. || quote! { #declare_struct; },
  183. |fields| {
  184. let fields = fields.iter().map(|field| {
  185. let name = format_ident!("{}", field.name);
  186. let ty = convert_idl_type_to_syn_type(&field.ty);
  187. quote! { pub #name : #ty }
  188. });
  189. quote! {
  190. #declare_struct {
  191. #(#fields,)*
  192. }
  193. }
  194. },
  195. |tys| {
  196. let tys = tys
  197. .iter()
  198. .map(convert_idl_type_to_syn_type)
  199. .map(|ty| quote! { pub #ty });
  200. quote! {
  201. #declare_struct (#(#tys,)*);
  202. }
  203. },
  204. )
  205. }
  206. IdlTypeDefTy::Enum { variants } => {
  207. let variants = variants.iter().map(|variant| {
  208. let variant_name = format_ident!("{}", variant.name);
  209. handle_defined_fields(
  210. variant.fields.as_ref(),
  211. || quote! { #variant_name },
  212. |fields| {
  213. let fields = fields.iter().map(|field| {
  214. let name = format_ident!("{}", field.name);
  215. let ty = convert_idl_type_to_syn_type(&field.ty);
  216. quote! { #name : #ty }
  217. });
  218. quote! {
  219. #variant_name {
  220. #(#fields,)*
  221. }
  222. }
  223. },
  224. |tys| {
  225. let tys = tys.iter().map(convert_idl_type_to_syn_type);
  226. quote! {
  227. #variant_name (#(#tys,)*)
  228. }
  229. },
  230. )
  231. });
  232. quote! {
  233. pub enum #name #generics {
  234. #(#variants,)*
  235. }
  236. }
  237. }
  238. IdlTypeDefTy::Type { alias } => {
  239. let alias = convert_idl_type_to_syn_type(alias);
  240. quote! { pub type #name = #alias; }
  241. }
  242. };
  243. quote! {
  244. #docs
  245. #attrs
  246. #repr
  247. #ty
  248. }
  249. }
  250. fn can_derive_copy(ty_def: &IdlTypeDef, ty_defs: &[IdlTypeDef]) -> bool {
  251. match &ty_def.ty {
  252. IdlTypeDefTy::Struct { fields } => {
  253. can_derive_common(fields.as_ref(), ty_defs, can_derive_copy_ty)
  254. }
  255. IdlTypeDefTy::Enum { variants } => variants
  256. .iter()
  257. .all(|variant| can_derive_common(variant.fields.as_ref(), ty_defs, can_derive_copy_ty)),
  258. IdlTypeDefTy::Type { alias } => can_derive_copy_ty(alias, ty_defs),
  259. }
  260. }
  261. fn can_derive_default(ty_def: &IdlTypeDef, ty_defs: &[IdlTypeDef]) -> bool {
  262. match &ty_def.ty {
  263. IdlTypeDefTy::Struct { fields } => {
  264. can_derive_common(fields.as_ref(), ty_defs, can_derive_default_ty)
  265. }
  266. // TODO: Consider storing the default enum variant in IDL
  267. IdlTypeDefTy::Enum { .. } => false,
  268. IdlTypeDefTy::Type { alias } => can_derive_default_ty(alias, ty_defs),
  269. }
  270. }
  271. fn can_derive_copy_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
  272. match ty {
  273. IdlType::Option(inner) => can_derive_copy_ty(inner, ty_defs),
  274. IdlType::Array(inner, len) => {
  275. if !can_derive_copy_ty(inner, ty_defs) {
  276. return false;
  277. }
  278. match len {
  279. IdlArrayLen::Value(_) => true,
  280. IdlArrayLen::Generic(_) => false,
  281. }
  282. }
  283. IdlType::Defined { name, .. } => ty_defs
  284. .iter()
  285. .find(|ty_def| &ty_def.name == name)
  286. .map(|ty_def| can_derive_copy(ty_def, ty_defs))
  287. .expect("Type def must exist"),
  288. IdlType::Bytes | IdlType::String | IdlType::Vec(_) | IdlType::Generic(_) => false,
  289. _ => true,
  290. }
  291. }
  292. fn can_derive_default_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
  293. match ty {
  294. IdlType::Option(inner) => can_derive_default_ty(inner, ty_defs),
  295. IdlType::Vec(inner) => can_derive_default_ty(inner, ty_defs),
  296. IdlType::Array(inner, len) => {
  297. if !can_derive_default_ty(inner, ty_defs) {
  298. return false;
  299. }
  300. match len {
  301. IdlArrayLen::Value(len) => *len <= 32,
  302. IdlArrayLen::Generic(_) => false,
  303. }
  304. }
  305. IdlType::Defined { name, .. } => ty_defs
  306. .iter()
  307. .find(|ty_def| &ty_def.name == name)
  308. .map(|ty_def| can_derive_default(ty_def, ty_defs))
  309. .expect("Type def must exist"),
  310. IdlType::Generic(_) => false,
  311. _ => true,
  312. }
  313. }
  314. fn can_derive_common(
  315. fields: Option<&IdlDefinedFields>,
  316. ty_defs: &[IdlTypeDef],
  317. can_derive_ty: fn(&IdlType, &[IdlTypeDef]) -> bool,
  318. ) -> bool {
  319. handle_defined_fields(
  320. fields,
  321. || true,
  322. |fields| {
  323. fields
  324. .iter()
  325. .map(|field| &field.ty)
  326. .all(|ty| can_derive_ty(ty, ty_defs))
  327. },
  328. |tys| tys.iter().all(|ty| can_derive_ty(ty, ty_defs)),
  329. )
  330. }
  331. fn handle_defined_fields<R>(
  332. fields: Option<&IdlDefinedFields>,
  333. unit_cb: impl Fn() -> R,
  334. named_cb: impl Fn(&[IdlField]) -> R,
  335. tuple_cb: impl Fn(&[IdlType]) -> R,
  336. ) -> R {
  337. match fields {
  338. Some(fields) => match fields {
  339. IdlDefinedFields::Named(fields) => named_cb(fields),
  340. IdlDefinedFields::Tuple(tys) => tuple_cb(tys),
  341. },
  342. _ => unit_cb(),
  343. }
  344. }