common.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. use anchor_idl::types::{
  2. Idl, IdlArrayLen, IdlDefinedFields, IdlField, IdlGenericArg, IdlRepr, IdlSerialization,
  3. IdlType, IdlTypeDef, IdlTypeDefGeneric, IdlTypeDefTy,
  4. };
  5. use quote::{format_ident, quote};
  6. /// This function should ideally return the absolute path to the declared program's id but because
  7. /// `proc_macro2::Span::call_site().source_file().path()` is behind an unstable feature flag, we
  8. /// are not able to reliably decide where the definition is.
  9. pub fn get_canonical_program_id() -> proc_macro2::TokenStream {
  10. quote! { super::__ID }
  11. }
  12. pub fn gen_docs(docs: &[String]) -> proc_macro2::TokenStream {
  13. let docs = docs
  14. .iter()
  15. .map(|doc| format!("{}{doc}", if doc.is_empty() { "" } else { " " }))
  16. .map(|doc| quote! { #[doc = #doc] });
  17. quote! { #(#docs)* }
  18. }
  19. pub fn gen_discriminator(disc: &[u8]) -> proc_macro2::TokenStream {
  20. quote! { [#(#disc), *] }
  21. }
  22. pub fn gen_accounts_common(idl: &Idl, prefix: &str) -> proc_macro2::TokenStream {
  23. let re_exports = idl
  24. .instructions
  25. .iter()
  26. .map(|ix| format_ident!("__{}_accounts_{}", prefix, ix.name))
  27. .map(|ident| quote! { pub use super::internal::#ident::*; });
  28. quote! {
  29. pub mod accounts {
  30. #(#re_exports)*
  31. }
  32. }
  33. }
  34. pub fn convert_idl_type_to_syn_type(ty: &IdlType) -> syn::Type {
  35. syn::parse_str(&convert_idl_type_to_str(ty)).unwrap()
  36. }
  37. // TODO: Impl `ToString` for `IdlType`
  38. pub fn convert_idl_type_to_str(ty: &IdlType) -> String {
  39. match ty {
  40. IdlType::Bool => "bool".into(),
  41. IdlType::U8 => "u8".into(),
  42. IdlType::I8 => "i8".into(),
  43. IdlType::U16 => "u16".into(),
  44. IdlType::I16 => "i16".into(),
  45. IdlType::U32 => "u32".into(),
  46. IdlType::I32 => "i32".into(),
  47. IdlType::F32 => "f32".into(),
  48. IdlType::U64 => "u64".into(),
  49. IdlType::I64 => "i64".into(),
  50. IdlType::F64 => "f64".into(),
  51. IdlType::U128 => "u128".into(),
  52. IdlType::I128 => "i128".into(),
  53. IdlType::U256 => "u256".into(),
  54. IdlType::I256 => "i256".into(),
  55. IdlType::Bytes => "bytes".into(),
  56. IdlType::String => "String".into(),
  57. IdlType::Pubkey => "Pubkey".into(),
  58. IdlType::Option(ty) => format!("Option<{}>", convert_idl_type_to_str(ty)),
  59. IdlType::Vec(ty) => format!("Vec<{}>", convert_idl_type_to_str(ty)),
  60. IdlType::Array(ty, len) => format!(
  61. "[{}; {}]",
  62. convert_idl_type_to_str(ty),
  63. match len {
  64. IdlArrayLen::Generic(len) => len.into(),
  65. IdlArrayLen::Value(len) => len.to_string(),
  66. }
  67. ),
  68. IdlType::Defined { name, generics } => generics
  69. .iter()
  70. .map(|generic| match generic {
  71. IdlGenericArg::Type { ty } => convert_idl_type_to_str(ty),
  72. IdlGenericArg::Const { value } => value.into(),
  73. })
  74. .reduce(|mut acc, cur| {
  75. if !acc.is_empty() {
  76. acc.push(',');
  77. }
  78. acc.push_str(&cur);
  79. acc
  80. })
  81. .map(|generics| format!("{name}<{generics}>"))
  82. .unwrap_or(name.into()),
  83. IdlType::Generic(ty) => ty.into(),
  84. }
  85. }
  86. pub fn convert_idl_type_def_to_ts(
  87. ty_def: &IdlTypeDef,
  88. ty_defs: &[IdlTypeDef],
  89. ) -> proc_macro2::TokenStream {
  90. let name = format_ident!("{}", ty_def.name);
  91. let docs = gen_docs(&ty_def.docs);
  92. let generics = {
  93. let generics = ty_def
  94. .generics
  95. .iter()
  96. .map(|generic| match generic {
  97. IdlTypeDefGeneric::Type { name } => format_ident!("{name}"),
  98. IdlTypeDefGeneric::Const { name, ty } => format_ident!("{name}: {ty}"),
  99. })
  100. .collect::<Vec<_>>();
  101. if generics.is_empty() {
  102. quote!()
  103. } else {
  104. quote!(<#(#generics,)*>)
  105. }
  106. };
  107. let attrs = {
  108. let debug_attr = quote!(#[derive(Debug)]);
  109. let default_attr = can_derive_default(ty_def, ty_defs)
  110. .then(|| quote!(#[derive(Default)]))
  111. .unwrap_or_default();
  112. let ser_attr = match &ty_def.serialization {
  113. IdlSerialization::Borsh => quote!(#[derive(AnchorSerialize, AnchorDeserialize)]),
  114. IdlSerialization::Bytemuck => quote!(#[zero_copy]),
  115. IdlSerialization::BytemuckUnsafe => quote!(#[zero_copy(unsafe)]),
  116. _ => unimplemented!("{:?}", ty_def.serialization),
  117. };
  118. let clone_attr = matches!(ty_def.serialization, IdlSerialization::Borsh)
  119. .then(|| quote!(#[derive(Clone)]))
  120. .unwrap_or_default();
  121. let copy_attr = matches!(ty_def.serialization, IdlSerialization::Borsh)
  122. .then(|| can_derive_copy(ty_def, ty_defs).then(|| quote!(#[derive(Copy)])))
  123. .flatten()
  124. .unwrap_or_default();
  125. quote! {
  126. #debug_attr
  127. #default_attr
  128. #ser_attr
  129. #clone_attr
  130. #copy_attr
  131. }
  132. };
  133. let repr = if let Some(repr) = &ty_def.repr {
  134. let kind = match repr {
  135. IdlRepr::Rust(_) => "Rust",
  136. IdlRepr::C(_) => "C",
  137. IdlRepr::Transparent => "transparent",
  138. };
  139. let kind = format_ident!("{kind}");
  140. let modifier = match repr {
  141. IdlRepr::Rust(modifier) | IdlRepr::C(modifier) => {
  142. let packed = modifier.packed.then(|| quote!(packed)).unwrap_or_default();
  143. let align = modifier
  144. .align
  145. .map(|align| quote!(align(#align)))
  146. .unwrap_or_default();
  147. if packed.is_empty() {
  148. align
  149. } else if align.is_empty() {
  150. packed
  151. } else {
  152. quote! { #packed, #align }
  153. }
  154. }
  155. _ => quote!(),
  156. };
  157. let modifier = if modifier.is_empty() {
  158. modifier
  159. } else {
  160. quote! { , #modifier }
  161. };
  162. quote! { #[repr(#kind #modifier)] }
  163. } else {
  164. quote!()
  165. };
  166. let ty = match &ty_def.ty {
  167. IdlTypeDefTy::Struct { fields } => {
  168. let declare_struct = quote! { pub struct #name #generics };
  169. handle_defined_fields(
  170. fields.as_ref(),
  171. || quote! { #declare_struct; },
  172. |fields| {
  173. let fields = fields.iter().map(|field| {
  174. let name = format_ident!("{}", field.name);
  175. let ty = convert_idl_type_to_syn_type(&field.ty);
  176. quote! { pub #name : #ty }
  177. });
  178. quote! {
  179. #declare_struct {
  180. #(#fields,)*
  181. }
  182. }
  183. },
  184. |tys| {
  185. let tys = tys.iter().map(convert_idl_type_to_syn_type);
  186. quote! {
  187. #declare_struct (#(#tys,)*);
  188. }
  189. },
  190. )
  191. }
  192. IdlTypeDefTy::Enum { variants } => {
  193. let variants = variants.iter().map(|variant| {
  194. let variant_name = format_ident!("{}", variant.name);
  195. handle_defined_fields(
  196. variant.fields.as_ref(),
  197. || quote! { #variant_name },
  198. |fields| {
  199. let fields = fields.iter().map(|field| {
  200. let name = format_ident!("{}", field.name);
  201. let ty = convert_idl_type_to_syn_type(&field.ty);
  202. quote! { #name : #ty }
  203. });
  204. quote! {
  205. #variant_name {
  206. #(#fields,)*
  207. }
  208. }
  209. },
  210. |tys| {
  211. let tys = tys.iter().map(convert_idl_type_to_syn_type);
  212. quote! {
  213. #variant_name (#(#tys,)*)
  214. }
  215. },
  216. )
  217. });
  218. quote! {
  219. pub enum #name #generics {
  220. #(#variants,)*
  221. }
  222. }
  223. }
  224. IdlTypeDefTy::Type { alias } => {
  225. let alias = convert_idl_type_to_syn_type(alias);
  226. quote! { pub type #name = #alias; }
  227. }
  228. };
  229. quote! {
  230. #docs
  231. #attrs
  232. #repr
  233. #ty
  234. }
  235. }
  236. fn can_derive_copy(ty_def: &IdlTypeDef, ty_defs: &[IdlTypeDef]) -> bool {
  237. match &ty_def.ty {
  238. IdlTypeDefTy::Struct { fields } => {
  239. can_derive_common(fields.as_ref(), ty_defs, can_derive_copy_ty)
  240. }
  241. IdlTypeDefTy::Enum { variants } => variants
  242. .iter()
  243. .all(|variant| can_derive_common(variant.fields.as_ref(), ty_defs, can_derive_copy_ty)),
  244. IdlTypeDefTy::Type { alias } => can_derive_copy_ty(alias, ty_defs),
  245. }
  246. }
  247. fn can_derive_default(ty_def: &IdlTypeDef, ty_defs: &[IdlTypeDef]) -> bool {
  248. match &ty_def.ty {
  249. IdlTypeDefTy::Struct { fields } => {
  250. can_derive_common(fields.as_ref(), ty_defs, can_derive_default_ty)
  251. }
  252. // TODO: Consider storing the default enum variant in IDL
  253. IdlTypeDefTy::Enum { .. } => false,
  254. IdlTypeDefTy::Type { alias } => can_derive_default_ty(alias, ty_defs),
  255. }
  256. }
  257. fn can_derive_copy_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
  258. match ty {
  259. IdlType::Option(inner) => can_derive_copy_ty(inner, ty_defs),
  260. IdlType::Array(inner, len) => {
  261. if !can_derive_copy_ty(inner, ty_defs) {
  262. return false;
  263. }
  264. match len {
  265. IdlArrayLen::Value(_) => true,
  266. IdlArrayLen::Generic(_) => false,
  267. }
  268. }
  269. IdlType::Defined { name, .. } => ty_defs
  270. .iter()
  271. .find(|ty_def| &ty_def.name == name)
  272. .map(|ty_def| can_derive_copy(ty_def, ty_defs))
  273. .expect("Type def must exist"),
  274. IdlType::Bytes | IdlType::String | IdlType::Vec(_) | IdlType::Generic(_) => false,
  275. _ => true,
  276. }
  277. }
  278. fn can_derive_default_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
  279. match ty {
  280. IdlType::Option(inner) => can_derive_default_ty(inner, ty_defs),
  281. IdlType::Vec(inner) => can_derive_default_ty(inner, ty_defs),
  282. IdlType::Array(inner, len) => {
  283. if !can_derive_default_ty(inner, ty_defs) {
  284. return false;
  285. }
  286. match len {
  287. IdlArrayLen::Value(len) => *len <= 32,
  288. IdlArrayLen::Generic(_) => false,
  289. }
  290. }
  291. IdlType::Defined { name, .. } => ty_defs
  292. .iter()
  293. .find(|ty_def| &ty_def.name == name)
  294. .map(|ty_def| can_derive_default(ty_def, ty_defs))
  295. .expect("Type def must exist"),
  296. IdlType::Generic(_) => false,
  297. _ => true,
  298. }
  299. }
  300. fn can_derive_common(
  301. fields: Option<&IdlDefinedFields>,
  302. ty_defs: &[IdlTypeDef],
  303. can_derive_ty: fn(&IdlType, &[IdlTypeDef]) -> bool,
  304. ) -> bool {
  305. handle_defined_fields(
  306. fields,
  307. || true,
  308. |fields| {
  309. fields
  310. .iter()
  311. .map(|field| &field.ty)
  312. .all(|ty| can_derive_ty(ty, ty_defs))
  313. },
  314. |tys| tys.iter().all(|ty| can_derive_ty(ty, ty_defs)),
  315. )
  316. }
  317. fn handle_defined_fields<R>(
  318. fields: Option<&IdlDefinedFields>,
  319. unit_cb: impl Fn() -> R,
  320. named_cb: impl Fn(&[IdlField]) -> R,
  321. tuple_cb: impl Fn(&[IdlType]) -> R,
  322. ) -> R {
  323. match fields {
  324. Some(fields) => match fields {
  325. IdlDefinedFields::Named(fields) => named_cb(fields),
  326. IdlDefinedFields::Tuple(tys) => tuple_cb(tys),
  327. },
  328. _ => unit_cb(),
  329. }
  330. }