Bläddra i källkod

removing use of 'as' and cleaning byte conversion of parse_vm

Ayush Suresh 5 månader sedan
förälder
incheckning
44fbd5413b
1 ändrade filer med 84 tillägg och 43 borttagningar
  1. 84 43
      target_chains/stylus/contracts/wormhole/src/lib.rs

+ 84 - 43
target_chains/stylus/contracts/wormhole/src/lib.rs

@@ -57,6 +57,7 @@ sol! {
     error InsufficientSignatures();
     error InvalidGuardianIndex();
     error InvalidAddressLength();
+    error InvalidSignatures();
     error VerifyVAAError();
 }
 
@@ -82,6 +83,7 @@ impl_debug_for_sol_error!(
     NotInitialized,
     InvalidInput,
     InsufficientSignatures,
+    InvalidSignatures,
     InvalidGuardianIndex,
     InvalidAddressLength,
     VerifyVAAError
@@ -100,6 +102,7 @@ pub enum WormholeError {
     NotInitialized(NotInitialized),
     InvalidInput(InvalidInput),
     InsufficientSignatures(InsufficientSignatures),
+    InvalidSignatures(InvalidSignatures),
     InvalidGuardianIndex(InvalidGuardianIndex),
     InvalidAddressLength(InvalidAddressLength),
     VerifyVAAError(VerifyVAAError),
@@ -130,6 +133,7 @@ pub struct WormholeContract {
     guardian_keys: StorageMap<U256, StorageAddress>,
 }
 
+#[public]
 impl WormholeContract {
     pub fn initialize(
         &mut self,
@@ -229,6 +233,8 @@ impl WormholeContract {
         Self::parse_vm_static(encoded_vaa)
     }
 
+    // Parsing a Wormhole VAA according to the structure defined 
+    // by https://wormhole.com/docs/protocol/infrastructure/vaas/
     fn parse_vm_static(encoded_vaa: &[u8]) -> Result<VAA, WormholeError> {
         if encoded_vaa.len() < 6 {
             return Err(WormholeError::InvalidVAAFormat(InvalidVAAFormat {}));
@@ -243,18 +249,23 @@ impl WormholeContract {
             return Err(WormholeError::InvalidVAAFormat(InvalidVAAFormat {}));
         }
 
-        let guardian_set_index = u32::from_be_bytes([
-            encoded_vaa[cursor],
-            encoded_vaa[cursor + 1],
-            encoded_vaa[cursor + 2],
-            encoded_vaa[cursor + 3],
-        ]);
+        let guardian_set_index_bytes: [u8; 4] = encoded_vaa[cursor..cursor + 4]
+            .try_into()
+            .map_err(|_| WormholeError::InvalidVAAFormat(InvalidVAAFormat {}))?;
+
+        let guardian_set_index = u32::from_be_bytes(guardian_set_index_bytes);
+
         cursor += 4;
 
         let len_signatures = encoded_vaa[cursor];
         cursor += 1;
 
         let mut signatures = Vec::new();
+
+        if len_signatures > 19 {
+            return Err(WormholeError::InvalidVAAFormat(InvalidVAAFormat {}));
+        }
+
         for _ in 0..len_signatures {
             if cursor + 66 > encoded_vaa.len() {
                 return Err(WormholeError::InvalidVAAFormat(InvalidVAAFormat {}));
@@ -277,20 +288,18 @@ impl WormholeContract {
             return Err(WormholeError::InvalidVAAFormat(InvalidVAAFormat {}));
         }
 
-        let timestamp = u32::from_be_bytes([
-            encoded_vaa[cursor],
-            encoded_vaa[cursor + 1],
-            encoded_vaa[cursor + 2],
-            encoded_vaa[cursor + 3],
-        ]);
+        let timestamp_bytes: [u8; 4] = encoded_vaa[cursor..cursor + 4]
+            .try_into()
+            .map_err(|_| WormholeError::InvalidVAAFormat(InvalidVAAFormat {}))?;
+
+        let timestamp = u32::from_be_bytes(timestamp_bytes);
         cursor += 4;
 
-        let nonce = u32::from_be_bytes([
-            encoded_vaa[cursor],
-            encoded_vaa[cursor + 1],
-            encoded_vaa[cursor + 2],
-            encoded_vaa[cursor + 3],
-        ]);
+        let nonce_bytes: [u8; 4] = encoded_vaa[cursor..cursor + 4]
+            .try_into()
+            .map_err(|_| WormholeError::InvalidVAAFormat(InvalidVAAFormat {}))?;
+
+        let nonce = u32::from_be_bytes(nonce_bytes);
         cursor += 4;
 
         let emitter_chain_id = u16::from_be_bytes([
@@ -303,16 +312,12 @@ impl WormholeContract {
         emitter_address_bytes.copy_from_slice(&encoded_vaa[cursor..cursor + 32]);
         cursor += 32;
 
-        let sequence = u64::from_be_bytes([
-            encoded_vaa[cursor],
-            encoded_vaa[cursor + 1],
-            encoded_vaa[cursor + 2],
-            encoded_vaa[cursor + 3],
-            encoded_vaa[cursor + 4],
-            encoded_vaa[cursor + 5],
-            encoded_vaa[cursor + 6],
-            encoded_vaa[cursor + 7],
-        ]);
+        let sequence_bytes: [u8; 8] = encoded_vaa[cursor..cursor + 8]
+            .try_into()
+            .map_err(|_| WormholeError::InvalidVAAFormat(InvalidVAAFormat {}))?;
+
+        let sequence = u64::from_be_bytes(sequence_bytes);
+        
         cursor += 8;
 
         let consistency_level = encoded_vaa[cursor];
@@ -345,10 +350,13 @@ impl WormholeContract {
             && guardian_set.expiration_time > 0 {
                 return Err(WormholeError::GuardianSetExpired(GuardianSetExpired {}))
         }
+        
+        let num_guardians : u32 = guardian_set.keys.len().try_into().map_err(|_| WormholeError::InvalidInput(InvalidInput {}))?;
 
-        let required_signatures = Self::quorum(guardian_set.keys.len() as u32);
+        let required_signatures = Self::quorum(num_guardians);
+        let num_signatures : u32 = vaa.signatures.len().try_into().map_err(|_| WormholeError::InvalidInput(InvalidInput {}))?;
 
-        if vaa.signatures.len() < required_signatures as usize {
+        if num_signatures < required_signatures {
             return Err(WormholeError::InsufficientSignatures(InsufficientSignatures {}));
         }
 
@@ -362,11 +370,16 @@ impl WormholeContract {
             }
             last_guardian_index = Some(signature.guardian_index);
 
-            if signature.guardian_index as usize >= guardian_set.keys.len() {
+            let index: usize = signature
+                .guardian_index
+                .try_into()
+                .map_err(|_| WormholeError::InvalidGuardianIndex(InvalidGuardianIndex {}))?;
+
+            if index >= guardian_set.keys.len() {
                 return Err(WormholeError::InvalidGuardianIndex(InvalidGuardianIndex {}));
             }
 
-            let guardian_address = guardian_set.keys[signature.guardian_index as usize];
+            let guardian_address = guardian_set.keys[index];
             let hashed_vaa_hash: FixedBytes<32> = FixedBytes::from(keccak256(vaa.hash));
 
             match self.verify_signature(&hashed_vaa_hash, &signature.signature, guardian_address) {
@@ -404,7 +417,9 @@ impl WormholeContract {
         self.guardian_set_expiry.setter(U256::from(set_index)).set(U256::from(expiration_time));
 
         for (i, guardian) in guardians.iter().enumerate() {
-            let key = self.compute_guardian_key(set_index, i as u8);
+            let i_u8: u8 = i.try_into()
+                .map_err(|_| WormholeError::InvalidGuardianIndex(InvalidGuardianIndex {}))?;
+            let key = self.compute_guardian_key(set_index, i_u8);
             self.guardian_keys.setter(key).set(*guardian);
         }
 
@@ -424,7 +439,10 @@ impl WormholeContract {
 
         let secp = Secp256k1::new();
 
-        let recid = RecoveryId::try_from(signature[64] as i32)
+        let recid_raw_i32: i32 = signature[64].try_into()
+            .map_err(|_| WormholeError::InvalidGuardianIndex(InvalidGuardianIndex {}))?;
+
+        let recid = RecoveryId::try_from(recid_raw_i32)
             .map_err(|_| WormholeError::InvalidSignature(InvalidSignature {}))?;
         let recoverable_sig = RecoverableSignature::from_compact(&signature[..64], recid)
             .map_err(|_| WormholeError::InvalidSignature(InvalidSignature {}))?;
@@ -453,7 +471,13 @@ impl WormholeContract {
         let mut keys = Vec::new();
         let size_u32: u32 = size.try_into().unwrap_or(0);
         for i in 0..size_u32 {
-            let key = self.compute_guardian_key(index, i as u8);
+            let i_u8: u8 = match i.try_into() {
+                Ok(val) => val,
+                Err(_) => {
+                    return None;
+                }
+            };
+            let key = self.compute_guardian_key(index, i_u8);
             let guardian_address = self.guardian_keys.getter(key).get();
             keys.push(guardian_address);
         }
@@ -705,7 +729,7 @@ mod tests {
         }
     }
 
-    fn create_valid_guardian_signature(guardian_index: u8, hash: &FixedBytes<32>) -> GuardianSignature {
+    fn create_valid_guardian_signature(guardian_index: u8, hash: &FixedBytes<32>) -> Result<GuardianSignature, WormholeError> {
         // Select a test guardian secret key
         let secret: SecretKey = match guardian_index {
             0 => test_guardian_secret1(),
@@ -724,18 +748,23 @@ mod tests {
 
         // Build Ethereum-compatible 65-byte signature (r || s || v)
         let mut signature_bytes = [0u8; 65];
-        signature_bytes[..64].copy_from_slice(&sig_bytes);
+        signature_bytes
+            .get_mut(..64)
+            .ok_or(WormholeError::InvalidInput(InvalidInput {}))?
+            .copy_from_slice(&sig_bytes);
+
         signature_bytes[64] = match recovery_id {
             RecoveryId::Zero => 27,
             RecoveryId::One => 28,
             RecoveryId::Two => 29,
             RecoveryId::Three => 30,
+            _ => return Err(WormholeError::InvalidInput(InvalidInput {})),
         };
 
-        GuardianSignature {
+        Ok(GuardianSignature {
             guardian_index,
             signature: FixedBytes::from(signature_bytes),
-        }
+        })
     }
 
     fn create_guardian_signature(guardian_index: u8) -> GuardianSignature {
@@ -932,7 +961,13 @@ mod tests {
         let _contract = deploy_with_mainnet_guardians();
 
         for i in 0..10 {
-            let corrupted_data = corrupted_vaa(vec![1, 0, 0, 1, 0, 0], i, i as u8, (i * 2) as u8);
+            let i_u8: u8 = match i.try_into() {
+                Ok(val) => val,
+                Err(_) => {
+                    panic!("Invalid index");
+                }
+            };
+            let corrupted_data = corrupted_vaa(vec![1, 0, 0, 1, 0, 0], i, i_u8, (i_u8 * 2));
             let result = WormholeContract::parse_vm_static(&corrupted_data);
             assert!(result.is_err());
         }
@@ -943,8 +978,14 @@ mod tests {
         let contract = deploy_with_mainnet_guardians();
 
         for i in 0..5 {
+            let i_u8: u8 = match i.try_into() {
+                Ok(val) => val,
+                Err(_) => {
+                    panic!("Invalid index");
+                }
+            };
             let base_vaa = vec![1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
-            let corrupted_data = corrupted_vaa(base_vaa, i, i as u8, (i * 3) as u8);
+            let corrupted_data = corrupted_vaa(base_vaa, i, i_u8, (i_u8 * 3));
             let result = WormholeContract::parse_vm_static(&corrupted_data);
             assert!(result.is_err());
         }
@@ -1032,8 +1073,8 @@ mod tests {
             version: 1,
             guardian_set_index: 0,
             signatures: vec![
-                create_valid_guardian_signature(0, &hash),
-                create_valid_guardian_signature(1, &hash),
+                create_valid_guardian_signature(0, &hash).unwrap(),
+                create_valid_guardian_signature(1, &hash).unwrap(),
             ],
             timestamp: 0,
             nonce: 0,