cpi.rs 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. use anchor_lang_idl::types::Idl;
  2. use heck::CamelCase;
  3. use quote::{format_ident, quote};
  4. use super::common::{convert_idl_type_to_syn_type, gen_accounts_common, gen_discriminator};
  5. pub fn gen_cpi_mod(idl: &Idl) -> proc_macro2::TokenStream {
  6. let cpi_instructions = gen_cpi_instructions(idl);
  7. let cpi_return_type = gen_cpi_return_type();
  8. let cpi_accounts_mod = gen_cpi_accounts_mod(idl);
  9. quote! {
  10. /// Cross program invocation (CPI) helpers.
  11. pub mod cpi {
  12. use super::*;
  13. #cpi_instructions
  14. #cpi_return_type
  15. #cpi_accounts_mod
  16. }
  17. }
  18. }
  19. fn gen_cpi_instructions(idl: &Idl) -> proc_macro2::TokenStream {
  20. let ixs = idl.instructions.iter().map(|ix| {
  21. let method_name = format_ident!("{}", ix.name);
  22. let accounts_ident = format_ident!("{}", ix.name.to_camel_case());
  23. let accounts_generic = if ix.accounts.is_empty() {
  24. quote!()
  25. } else {
  26. quote!(<'info>)
  27. };
  28. let args = ix.args.iter().map(|arg| {
  29. let name = format_ident!("{}", arg.name);
  30. let ty = convert_idl_type_to_syn_type(&arg.ty);
  31. quote! { #name: #ty }
  32. });
  33. let arg_value = if ix.args.is_empty() {
  34. quote! { #accounts_ident }
  35. } else {
  36. let fields= ix.args.iter().map(|arg| format_ident!("{}", arg.name));
  37. quote! {
  38. #accounts_ident {
  39. #(#fields),*
  40. }
  41. }
  42. };
  43. let discriminator = gen_discriminator(&ix.discriminator);
  44. let (ret_type, ret_value) = match ix.returns.as_ref() {
  45. Some(ty) => {
  46. let ty = convert_idl_type_to_syn_type(ty);
  47. (
  48. quote! { anchor_lang::Result<Return::<#ty>> },
  49. quote! { Ok(Return::<#ty> { phantom: std::marker::PhantomData }) },
  50. )
  51. },
  52. None => (
  53. quote! { anchor_lang::Result<()> },
  54. quote! { Ok(()) },
  55. )
  56. };
  57. quote! {
  58. pub fn #method_name<'a, 'b, 'c, 'info>(
  59. ctx: anchor_lang::context::CpiContext<'a, 'b, 'c, 'info, accounts::#accounts_ident #accounts_generic>,
  60. #(#args),*
  61. ) -> #ret_type {
  62. let ix = {
  63. let ix = internal::args::#arg_value;
  64. let mut data = Vec::with_capacity(256);
  65. data.extend_from_slice(&#discriminator);
  66. AnchorSerialize::serialize(&ix, &mut data)
  67. .map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotSerialize)?;
  68. let accounts = ctx.to_account_metas(None);
  69. anchor_lang::solana_program::instruction::Instruction {
  70. program_id: ctx.program.key(),
  71. accounts,
  72. data,
  73. }
  74. };
  75. let mut acc_infos = ctx.to_account_infos();
  76. anchor_lang::solana_program::program::invoke_signed(
  77. &ix,
  78. &acc_infos,
  79. ctx.signer_seeds,
  80. ).map_or_else(
  81. |e| Err(Into::into(e)),
  82. |_| { #ret_value }
  83. )
  84. }
  85. }
  86. });
  87. quote! {
  88. #(#ixs)*
  89. }
  90. }
  91. fn gen_cpi_return_type() -> proc_macro2::TokenStream {
  92. quote! {
  93. pub struct Return<T> {
  94. phantom: std::marker::PhantomData<T>
  95. }
  96. impl<T: AnchorDeserialize> Return<T> {
  97. pub fn get(&self) -> T {
  98. let (_key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
  99. T::try_from_slice(&data).unwrap()
  100. }
  101. }
  102. }
  103. }
  104. fn gen_cpi_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
  105. gen_accounts_common(idl, "cpi_client")
  106. }