ソースを参照

Handle arrays with const as size in endpoint function args (#1631)

skrrb 3 年 前
コミット
58af625736

+ 1 - 0
CHANGELOG.md

@@ -20,6 +20,7 @@ incremented for features.
 * lang: Add new `AccountSysvarMismatch` error code and test cases for sysvars ([#1535](https://github.com/project-serum/anchor/pull/1535)).
 * lang: Replace `std::io::Cursor` with a custom `Write` impl that uses the Solana mem syscalls ([#1589](https://github.com/project-serum/anchor/pull/1589)).
 * lang: Add `require_neq`, `require_keys_neq`, `require_gt`, and `require_gte` comparison macros ([#1622](https://github.com/project-serum/anchor/pull/1622)).
+* lang: Handle arrays with const as size in instruction data ([#1623](https://github.com/project-serum/anchor/issues/1623).
 * spl: Add support for revoke instruction ([#1493](https://github.com/project-serum/anchor/pull/1493)).
 * ts: Add provider parameter to `Spl.token` factory method ([#1597](https://github.com/project-serum/anchor/pull/1597)).
 

+ 13 - 15
lang/syn/src/idl/file.rs

@@ -156,14 +156,9 @@ pub fn parse(
             let args = ix
                 .args
                 .iter()
-                .map(|arg| {
-                    let mut tts = proc_macro2::TokenStream::new();
-                    arg.raw_arg.ty.to_tokens(&mut tts);
-                    let ty = tts.to_string().parse().unwrap();
-                    IdlField {
-                        name: arg.name.to_string().to_mixed_case(),
-                        ty,
-                    }
+                .map(|arg| IdlField {
+                    name: arg.name.to_string().to_mixed_case(),
+                    ty: to_idl_type(&ctx, &arg.raw_arg.ty),
                 })
                 .collect::<Vec<_>>();
             // todo: don't unwrap
@@ -194,7 +189,7 @@ pub fn parse(
                     };
                     IdlEventField {
                         name: f.ident.clone().unwrap().to_string().to_mixed_case(),
-                        ty: to_idl_type(&ctx, f),
+                        ty: to_idl_type(&ctx, &f.ty),
                         index,
                     }
                 })
@@ -411,7 +406,7 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                     .map(|f: &syn::Field| {
                         Ok(IdlField {
                             name: f.ident.as_ref().unwrap().to_string().to_mixed_case(),
-                            ty: to_idl_type(ctx, f),
+                            ty: to_idl_type(ctx, &f.ty),
                         })
                     })
                     .collect::<Result<Vec<IdlField>>>(),
@@ -434,8 +429,11 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                     let fields = match &variant.fields {
                         syn::Fields::Unit => None,
                         syn::Fields::Unnamed(fields) => {
-                            let fields: Vec<IdlType> =
-                                fields.unnamed.iter().map(|f| to_idl_type(ctx, f)).collect();
+                            let fields: Vec<IdlType> = fields
+                                .unnamed
+                                .iter()
+                                .map(|f| to_idl_type(ctx, &f.ty))
+                                .collect();
                             Some(EnumFields::Tuple(fields))
                         }
                         syn::Fields::Named(fields) => {
@@ -444,7 +442,7 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                                 .iter()
                                 .map(|f: &syn::Field| {
                                     let name = f.ident.as_ref().unwrap().to_string();
-                                    let ty = to_idl_type(ctx, f);
+                                    let ty = to_idl_type(ctx, &f.ty);
                                     IdlField { name, ty }
                                 })
                                 .collect();
@@ -522,8 +520,8 @@ fn resolve_variable_array_lengths(ctx: &CrateContext, mut tts_string: String) ->
     tts_string
 }
 
-fn to_idl_type(ctx: &CrateContext, f: &syn::Field) -> IdlType {
-    let mut tts_string = parser::tts_to_string(&f.ty);
+fn to_idl_type(ctx: &CrateContext, ty: &syn::Type) -> IdlType {
+    let mut tts_string = parser::tts_to_string(&ty);
     if tts_string.starts_with('[') {
         tts_string = resolve_variable_array_lengths(ctx, tts_string);
     }

+ 6 - 0
tests/misc/programs/misc/src/context.rs

@@ -385,6 +385,12 @@ pub struct TestConstArraySize<'info> {
     pub data: Account<'info, DataConstArraySize>,
 }
 
+#[derive(Accounts)]
+pub struct TestConstIxDataSize<'info> {
+    #[account(zero)]
+    pub data: Account<'info, DataConstArraySize>,
+}
+
 #[derive(Accounts)]
 pub struct TestMultidimensionalArrayConstSizes<'info> {
     #[account(zero)]

+ 9 - 0
tests/misc/programs/misc/src/lib.rs

@@ -1,6 +1,7 @@
 //! Misc example is a catchall program for testing unrelated features.
 //! It's not too instructive/coherent by itself, so please see other examples.
 
+use account::MAX_SIZE;
 use anchor_lang::prelude::*;
 use context::*;
 use event::*;
@@ -106,6 +107,14 @@ pub mod misc {
         Ok(())
     }
 
+    pub fn test_const_ix_data_size(
+        ctx: Context<TestConstIxDataSize>,
+        data: [u8; MAX_SIZE],
+    ) -> Result<()> {
+        ctx.accounts.data.data = data;
+        Ok(())
+    }
+
     pub fn test_close(_ctx: Context<TestClose>) -> Result<()> {
         Ok(())
     }

+ 19 - 0
tests/misc/tests/misc.js

@@ -924,6 +924,25 @@ describe("misc", () => {
     assert.deepStrictEqual(dataAccount.data, [99, ...new Array(9).fill(0)]);
   });
 
+  it("Can use const for instruction data size", async () => {
+    const data = anchor.web3.Keypair.generate();
+    const dataArray = [99, ...new Array(9).fill(0)];
+    const tx = await program.rpc.testConstIxDataSize(dataArray, {
+      accounts: {
+        data: data.publicKey,
+        rent: anchor.web3.SYSVAR_RENT_PUBKEY,
+      },
+      signers: [data],
+      instructions: [
+        await program.account.dataConstArraySize.createInstruction(data),
+      ],
+    });
+    const dataAccount = await program.account.dataConstArraySize.fetch(
+      data.publicKey
+    );
+    assert.deepStrictEqual(dataAccount.data, dataArray);
+  });
+
   it("Should include BASE const in IDL", async () => {
     assert(
       miscIdl.constants.find(