浏览代码

lang: Handle const array size casting and add const array size support for events (#1485)

Tom Linton 3 年之前
父节点
当前提交
2529b06c02

+ 1 - 0
CHANGELOG.md

@@ -35,6 +35,7 @@ incremented for features.
 ### Fixes
 
 * cli: Fix rust template ([#1488](https://github.com/project-serum/anchor/pull/1488)).
+* lang: Handle array sizes with variable sizes in events and array size casting in IDL parsing ([#1485](https://github.com/project-serum/anchor/pull/1485))
 
 ## [0.22.0] - 2022-02-20
 

+ 51 - 23
lang/syn/src/idl/file.rs

@@ -194,7 +194,7 @@ pub fn parse(
                     };
                     IdlEventField {
                         name: f.ident.clone().unwrap().to_string().to_mixed_case(),
-                        ty: parser::tts_to_string(&f.ty).parse().unwrap(),
+                        ty: to_idl_type(&ctx, f),
                         index,
                     }
                 })
@@ -409,16 +409,9 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                     .named
                     .iter()
                     .map(|f: &syn::Field| {
-                        let mut tts = proc_macro2::TokenStream::new();
-                        f.ty.to_tokens(&mut tts);
-                        // Handle array sizes that are constants
-                        let mut tts_string = tts.to_string();
-                        if tts_string.starts_with('[') {
-                            tts_string = resolve_variable_array_length(ctx, tts_string);
-                        }
                         Ok(IdlField {
                             name: f.ident.as_ref().unwrap().to_string().to_mixed_case(),
-                            ty: tts_string.parse()?,
+                            ty: to_idl_type(ctx, f),
                         })
                     })
                     .collect::<Result<Vec<IdlField>>>(),
@@ -442,7 +435,7 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                         syn::Fields::Unit => None,
                         syn::Fields::Unnamed(fields) => {
                             let fields: Vec<IdlType> =
-                                fields.unnamed.iter().map(to_idl_type).collect();
+                                fields.unnamed.iter().map(|f| to_idl_type(ctx, f)).collect();
                             Some(EnumFields::Tuple(fields))
                         }
                         syn::Fields::Named(fields) => {
@@ -451,7 +444,7 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
                                 .iter()
                                 .map(|f: &syn::Field| {
                                     let name = f.ident.as_ref().unwrap().to_string();
-                                    let ty = to_idl_type(f);
+                                    let ty = to_idl_type(ctx, f);
                                     IdlField { name, ty }
                                 })
                                 .collect();
@@ -470,11 +463,42 @@ fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
 }
 
 // Replace variable array lengths with values
-fn resolve_variable_array_length(ctx: &CrateContext, tts_string: String) -> String {
-    for constant in ctx.consts() {
-        if constant.ty.to_token_stream().to_string() == "usize"
-            && tts_string.contains(&constant.ident.to_string())
-        {
+fn resolve_variable_array_lengths(ctx: &CrateContext, mut tts_string: String) -> String {
+    for constant in ctx.consts().filter(|c| match *c.ty {
+        // Filter to only those consts that are of type usize or could be cast to usize
+        syn::Type::Path(ref p) => {
+            let segment = p.path.segments.last().unwrap();
+            matches!(
+                segment.ident.to_string().as_str(),
+                "usize"
+                    | "u8"
+                    | "u16"
+                    | "u32"
+                    | "u64"
+                    | "u128"
+                    | "isize"
+                    | "i8"
+                    | "i16"
+                    | "i32"
+                    | "i64"
+                    | "i128"
+            )
+        }
+        _ => false,
+    }) {
+        let mut check_string = tts_string.clone();
+        // Strip whitespace to handle accidental double whitespaces
+        check_string.retain(|c| !c.is_whitespace());
+        let size_string = format!("{}]", &constant.ident.to_string());
+        let cast_size_string = format!("{}asusize]", &constant.ident.to_string());
+        // Check for something to replace
+        let mut replacement_string = None;
+        if check_string.contains(cast_size_string.as_str()) {
+            replacement_string = Some(cast_size_string);
+        } else if check_string.contains(size_string.as_str()) {
+            replacement_string = Some(size_string);
+        }
+        if let Some(replacement_string) = replacement_string {
             // Check for the existence of consts existing elsewhere in the
             // crate which have the same name, are usize, and have a
             // different value. We can't know which was intended for the
@@ -487,19 +511,23 @@ fn resolve_variable_array_length(ctx: &CrateContext, tts_string: String) -> Stri
             }) {
                 panic!("Crate wide unique name required for array size const.");
             }
-            return tts_string.replace(
-                &constant.ident.to_string(),
-                &constant.expr.to_token_stream().to_string(),
+            // Replace the match, don't break because there might be multiple replacements to be
+            // made in the case of multidimensional arrays
+            tts_string = check_string.replace(
+                &replacement_string,
+                format!("{}]", &constant.expr.to_token_stream()).as_str(),
             );
         }
     }
     tts_string
 }
 
-fn to_idl_type(f: &syn::Field) -> IdlType {
-    let mut tts = proc_macro2::TokenStream::new();
-    f.ty.to_tokens(&mut tts);
-    tts.to_string().parse().unwrap()
+fn to_idl_type(ctx: &CrateContext, f: &syn::Field) -> IdlType {
+    let mut tts_string = parser::tts_to_string(&f.ty);
+    if tts_string.starts_with('[') {
+        tts_string = resolve_variable_array_lengths(ctx, tts_string);
+    }
+    tts_string.parse().unwrap()
 }
 
 fn idl_accounts(

+ 12 - 0
tests/misc/programs/misc/src/account.rs

@@ -9,6 +9,7 @@ macro_rules! size {
 }
 
 pub const MAX_SIZE: usize = 10;
+pub const MAX_SIZE_U8: u8 = 11;
 
 #[account]
 pub struct Data {
@@ -61,3 +62,14 @@ pub struct DataConstArraySize {
     pub data: [u8; MAX_SIZE], // 10
 }
 size!(DataConstArraySize, MAX_SIZE);
+
+#[account]
+pub struct DataConstCastArraySize {
+    pub data_one: [u8; MAX_SIZE as usize],
+    pub data_two: [u8; MAX_SIZE_U8 as usize],
+}
+
+#[account]
+pub struct DataMultidimensionalArrayConstSizes {
+    pub data: [[u8; MAX_SIZE_U8 as usize]; MAX_SIZE],
+}

+ 6 - 0
tests/misc/programs/misc/src/context.rs

@@ -385,6 +385,12 @@ pub struct TestConstArraySize<'info> {
     pub data: Account<'info, DataConstArraySize>,
 }
 
+#[derive(Accounts)]
+pub struct TestMultidimensionalArrayConstSizes<'info> {
+    #[account(zero)]
+    pub data: Account<'info, DataMultidimensionalArrayConstSizes>,
+}
+
 #[derive(Accounts)]
 pub struct NoRentExempt<'info> {
     /// CHECK:

+ 13 - 0
tests/misc/programs/misc/src/event.rs

@@ -1,5 +1,8 @@
 use anchor_lang::prelude::*;
 
+pub const MAX_EVENT_SIZE: usize = 10;
+pub const MAX_EVENT_SIZE_U8: u8 = 11;
+
 #[event]
 pub struct E1 {
     pub data: u32,
@@ -19,3 +22,13 @@ pub struct E3 {
 pub struct E4 {
     pub data: Pubkey,
 }
+
+#[event]
+pub struct E5 {
+    pub data: [u8; MAX_EVENT_SIZE],
+}
+
+#[event]
+pub struct E6 {
+    pub data: [u8; MAX_EVENT_SIZE_U8 as usize],
+}

+ 14 - 0
tests/misc/programs/misc/src/lib.rs

@@ -82,6 +82,12 @@ pub mod misc {
         emit!(E1 { data });
         emit!(E2 { data: 1234 });
         emit!(E3 { data: 9 });
+        emit!(E5 {
+            data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        });
+        emit!(E6 {
+            data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+        });
         Ok(())
     }
 
@@ -247,6 +253,14 @@ pub mod misc {
         Ok(())
     }
 
+    pub fn test_multidimensional_array_const_sizes(
+        ctx: Context<TestMultidimensionalArrayConstSizes>,
+        data: [[u8; 11]; 10],
+    ) -> Result<()> {
+        ctx.accounts.data.data = data;
+        Ok(())
+    }
+
     pub fn test_no_rent_exempt(ctx: Context<NoRentExempt>) -> Result<()> {
         Ok(())
     }

+ 33 - 1
tests/misc/tests/misc.js

@@ -164,6 +164,16 @@ describe("misc", () => {
     assert.ok(resp.events[1].data.data === 1234);
     assert.ok(resp.events[2].name === "E3");
     assert.ok(resp.events[2].data.data === 9);
+    assert.ok(resp.events[3].name === "E5");
+    assert.deepStrictEqual(
+      resp.events[3].data.data,
+      [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    );
+    assert.ok(resp.events[4].name === "E6");
+    assert.deepStrictEqual(
+      resp.events[4].data.data,
+      [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+    );
   });
 
   let dataI8;
@@ -1432,7 +1442,7 @@ describe("misc", () => {
   it("Can use multidimensional array", async () => {
     const array2d = new Array(10).fill(new Array(10).fill(99));
     const data = anchor.web3.Keypair.generate();
-    const tx = await program.rpc.testMultidimensionalArray(array2d, {
+    await program.rpc.testMultidimensionalArray(array2d, {
       accounts: {
         data: data.publicKey,
         rent: anchor.web3.SYSVAR_RENT_PUBKEY,
@@ -1448,6 +1458,28 @@ describe("misc", () => {
     assert.deepStrictEqual(dataAccount.data, array2d);
   });
 
+  it("Can use multidimensional array with const sizes", async () => {
+    const array2d = new Array(10).fill(new Array(11).fill(22));
+    const data = anchor.web3.Keypair.generate();
+    await program.rpc.testMultidimensionalArrayConstSizes(array2d, {
+      accounts: {
+        data: data.publicKey,
+        rent: anchor.web3.SYSVAR_RENT_PUBKEY,
+      },
+      signers: [data],
+      instructions: [
+        await program.account.dataMultidimensionalArrayConstSizes.createInstruction(
+          data
+        ),
+      ],
+    });
+    const dataAccount =
+      await program.account.dataMultidimensionalArrayConstSizes.fetch(
+        data.publicKey
+      );
+    assert.deepStrictEqual(dataAccount.data, array2d);
+  });
+
   it("allows non-rent exempt accounts", async () => {
     const data = anchor.web3.Keypair.generate();
     await program.rpc.initializeNoRentExempt({