Răsfoiți Sursa

Add padding in account deserialization (#1306)

Lucas Steuernagel 2 ani în urmă
părinte
comite
15d92e9169
4 a modificat fișierele cu 178 adăugiri și 11 ștergeri
  1. 10 1
      stdlib/solana_sdk.h
  2. 39 10
      tests/solana.rs
  3. 128 0
      tests/solana_tests/account_serialization.rs
  4. 1 0
      tests/solana_tests/mod.rs

+ 10 - 1
stdlib/solana_sdk.h

@@ -244,6 +244,8 @@ static uint64_t sol_deserialize(const uint8_t *input, SolParameters *params)
     {
         return ERROR_INVALID_ARGUMENT;
     }
+
+    uint64_t max_accounts = SOL_ARRAY_SIZE(params->ka);
     params->ka_num = *(uint64_t *)input;
     input += sizeof(uint64_t);
 
@@ -252,7 +254,7 @@ static uint64_t sol_deserialize(const uint8_t *input, SolParameters *params)
         uint8_t dup_info = input[0];
         input += sizeof(uint8_t);
 
-        if (i >= SOL_ARRAY_SIZE(params->ka))
+        if (i >= max_accounts)
         {
             if (dup_info == UINT8_MAX)
             {
@@ -270,6 +272,10 @@ static uint64_t sol_deserialize(const uint8_t *input, SolParameters *params)
                 input = (uint8_t *)(((uint64_t)input + 8 - 1) & ~(8 - 1)); // padding
                 input += sizeof(uint64_t);
             }
+            else
+            {
+                input += 7; // padding for the 64-bit alignment
+            }
             continue;
         }
         if (dup_info == UINT8_MAX)
@@ -337,6 +343,9 @@ static uint64_t sol_deserialize(const uint8_t *input, SolParameters *params)
     params->program_id = (SolPubkey *)input;
     input += sizeof(SolPubkey);
 
+    if (params->ka_num > max_accounts)
+        params->ka_num = max_accounts;
+
     return 0;
 }
 

+ 39 - 10
tests/solana.rs

@@ -256,6 +256,26 @@ struct AccountRef {
     length: usize,
 }
 
+enum SerializableAccount {
+    Unique(AccountMeta),
+    Duplicate(usize),
+}
+
+fn remove_duplicates(metas: &[AccountMeta]) -> Vec<SerializableAccount> {
+    let mut serializable_format: Vec<SerializableAccount> = Vec::new();
+    let mut inserted: HashMap<AccountMeta, usize> = HashMap::new();
+
+    for (idx, account) in metas.iter().enumerate() {
+        if let Some(idx) = inserted.get(account) {
+            serializable_format.push(SerializableAccount::Duplicate(*idx));
+        } else {
+            serializable_format.push(SerializableAccount::Unique(account.clone()));
+            inserted.insert(account.clone(), idx);
+        }
+    }
+    serializable_format
+}
+
 fn serialize_parameters(
     input: &[u8],
     metas: &[AccountMeta],
@@ -313,16 +333,25 @@ fn serialize_parameters(
         v.write_u64::<LittleEndian>(0).unwrap();
     }
 
+    let no_duplicates_meta = remove_duplicates(metas);
     // ka_num
-    v.write_u64::<LittleEndian>(metas.len() as u64).unwrap();
-
-    for account in metas {
-        serialize_account(
-            &mut v,
-            &mut refs,
-            account,
-            &vm.account_data[&account.pubkey.0],
-        );
+    v.write_u64::<LittleEndian>(no_duplicates_meta.len() as u64)
+        .unwrap();
+
+    for account_item in &no_duplicates_meta {
+        match account_item {
+            SerializableAccount::Unique(account) => {
+                serialize_account(
+                    &mut v,
+                    &mut refs,
+                    account,
+                    &vm.account_data[&account.pubkey.0],
+                );
+            }
+            SerializableAccount::Duplicate(idx) => {
+                v.write_u64::<LittleEndian>(*idx as u64).unwrap();
+            }
+        }
     }
 
     // calldata
@@ -970,7 +999,7 @@ impl Pubkey {
     }
 }
 
-#[derive(Debug, PartialEq, Eq, Clone)]
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
 pub struct AccountMeta {
     /// An account's public key
     pub pubkey: Pubkey,

+ 128 - 0
tests/solana_tests/account_serialization.rs

@@ -0,0 +1,128 @@
+// SPDX-License-Identifier: Apache-2.0
+
+use crate::borsh_encoding::BorshToken;
+use crate::{account_new, build_solidity, AccountMeta, AccountState, Pubkey};
+
+#[test]
+fn deserialize_duplicate_account() {
+    let mut vm = build_solidity(
+        r#"
+        contract Testing {
+    function check_deserialization(address my_address) public view {
+        assert(tx.accounts[1].key == tx.accounts[2].key);
+        assert(tx.accounts[1].is_signer == tx.accounts[2].is_signer);
+        assert(tx.accounts[1].is_writable == tx.accounts[2].is_writable);
+
+        assert(my_address == tx.program_id);
+    }
+}
+        "#,
+    );
+
+    vm.constructor(&[]);
+
+    let random_key = account_new();
+    vm.account_data.insert(
+        random_key,
+        AccountState {
+            data: vec![],
+            owner: None,
+            lamports: 0,
+        },
+    );
+
+    let other_key = account_new();
+    vm.account_data.insert(
+        other_key,
+        AccountState {
+            data: vec![],
+            owner: None,
+            lamports: 0,
+        },
+    );
+    let metas = vec![
+        AccountMeta {
+            pubkey: Pubkey(vm.stack[0].data),
+            is_writable: true,
+            is_signer: false,
+        },
+        AccountMeta {
+            pubkey: Pubkey(random_key),
+            is_signer: true,
+            is_writable: false,
+        },
+        AccountMeta {
+            pubkey: Pubkey(random_key),
+            is_signer: true,
+            is_writable: false,
+        },
+    ];
+
+    vm.function_metas(
+        "check_deserialization",
+        &metas,
+        &[BorshToken::Address(vm.stack[0].program)],
+    );
+}
+
+#[test]
+fn more_than_10_accounts() {
+    let mut vm = build_solidity(
+        r#"
+        contract Testing {
+    function check_deserialization(address my_address) public view {
+        // This assertion ensure the padding is correctly added when
+        // deserializing accounts
+        assert(my_address == tx.program_id);
+    }
+}
+        "#,
+    );
+
+    vm.constructor(&[]);
+
+    let mut metas: Vec<AccountMeta> = Vec::new();
+    metas.push(AccountMeta {
+        pubkey: Pubkey(vm.stack[0].data),
+        is_writable: true,
+        is_signer: false,
+    });
+    for i in 0..11 {
+        let account = account_new();
+        metas.push(AccountMeta {
+            pubkey: Pubkey(account),
+            is_writable: i % 2 == 0,
+            is_signer: i % 2 == 1,
+        });
+        vm.account_data.insert(
+            account,
+            AccountState {
+                data: vec![],
+                owner: None,
+                lamports: 0,
+            },
+        );
+    }
+
+    metas.push(metas[3].clone());
+    let account = account_new();
+    metas.push(AccountMeta {
+        pubkey: Pubkey(account),
+        is_signer: false,
+        is_writable: false,
+    });
+    vm.account_data.insert(
+        account,
+        AccountState {
+            data: vec![],
+            owner: None,
+            lamports: 0,
+        },
+    );
+
+    vm.function_metas(
+        "check_deserialization",
+        &metas,
+        &[BorshToken::Address(vm.stack[0].program)],
+    );
+}

+ 1 - 0
tests/solana_tests/mod.rs

@@ -5,6 +5,7 @@ mod abi_decode;
 mod abi_encode;
 mod accessor;
 mod account_info;
+mod account_serialization;
 mod arrays;
 mod balance;
 mod base58_encoding;