Просмотр исходного кода

refactor(target_chains/starknet): remove Result from reader (#1536)

Pavel Strakhov 1 год назад
Родитель
Сommit
308599714f

+ 3 - 10
target_chains/starknet/contracts/src/hash.cairo

@@ -63,16 +63,9 @@ pub impl HasherImpl of HasherTrait {
                 // reader.len() < 8
                 chunk_len = reader.len().try_into().expect(UNEXPECTED_OVERFLOW);
             }
-            match reader.read_num_bytes(chunk_len) {
-                Result::Ok(value) => {
-                    // chunk_len <= 8 so value must fit in u64.
-                    self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len);
-                },
-                Result::Err(err) => {
-                    result = Result::Err(err);
-                    break;
-                },
-            }
+            let value = reader.read_num_bytes(chunk_len);
+            // chunk_len <= 8 so value must fit in u64.
+            self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len);
         };
         result
     }

+ 3 - 10
target_chains/starknet/contracts/src/merkle_tree.cairo

@@ -46,20 +46,13 @@ pub fn read_and_verify_proof(
     let mut message_reader = ReaderImpl::new(message.clone());
     let mut current_hash = leaf_hash(message_reader.clone()).map_err()?;
 
-    let proof_size = reader.read_u8().map_err()?;
+    let proof_size = reader.read_u8();
     let mut i = 0;
 
     let mut result = Result::Ok(());
     while i < proof_size {
-        match reader.read_u160().map_err() {
-            Result::Ok(sibling_digest) => {
-                current_hash = node_hash(current_hash, sibling_digest);
-            },
-            Result::Err(err) => {
-                result = Result::Err(err);
-                break;
-            },
-        }
+        let sibling_digest = reader.read_u160();
+        current_hash = node_hash(current_hash, sibling_digest);
         i += 1;
     };
     result?;

+ 25 - 25
target_chains/starknet/contracts/src/pyth.cairo

@@ -357,21 +357,21 @@ mod pyth {
             ref self: ContractState, data: ByteArray
         ) -> Result<(), UpdatePriceFeedsError> {
             let mut reader = ReaderImpl::new(data);
-            let x = reader.read_u32().map_err()?;
+            let x = reader.read_u32();
             if x != ACCUMULATOR_MAGIC {
                 return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
             }
-            if reader.read_u8().map_err()? != MAJOR_VERSION {
+            if reader.read_u8() != MAJOR_VERSION {
                 return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
             }
-            if reader.read_u8().map_err()? < MINIMUM_ALLOWED_MINOR_VERSION {
+            if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION {
                 return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
             }
 
-            let trailing_header_size = reader.read_u8().map_err()?;
-            reader.skip(trailing_header_size).map_err()?;
+            let trailing_header_size = reader.read_u8();
+            reader.skip(trailing_header_size);
 
-            let update_type: Option<UpdateType> = reader.read_u8().map_err()?.try_into();
+            let update_type: Option<UpdateType> = reader.read_u8().try_into();
             match update_type {
                 Option::Some(v) => match v {
                     UpdateType::WormholeMerkle => {}
@@ -379,8 +379,8 @@ mod pyth {
                 Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
             };
 
-            let wh_proof_size = reader.read_u16().map_err()?;
-            let wh_proof = reader.read_byte_array(wh_proof_size.into()).map_err()?;
+            let wh_proof_size = reader.read_u16();
+            let wh_proof = reader.read_byte_array(wh_proof_size.into());
             let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
             let vm = wormhole.parse_and_verify_vm(wh_proof).map_err()?;
 
@@ -392,12 +392,12 @@ mod pyth {
             }
 
             let mut payload_reader = ReaderImpl::new(vm.payload);
-            let x = payload_reader.read_u32().map_err()?;
+            let x = payload_reader.read_u32();
             if x != ACCUMULATOR_WORMHOLE_MAGIC {
                 return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
             }
 
-            let update_type: Option<UpdateType> = payload_reader.read_u8().map_err()?.try_into();
+            let update_type: Option<UpdateType> = payload_reader.read_u8().try_into();
             match update_type {
                 Option::Some(v) => match v {
                     UpdateType::WormholeMerkle => {}
@@ -405,11 +405,11 @@ mod pyth {
                 Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
             };
 
-            let _slot = payload_reader.read_u64().map_err()?;
-            let _ring_size = payload_reader.read_u32().map_err()?;
-            let root_digest = payload_reader.read_u160().map_err()?;
+            let _slot = payload_reader.read_u64();
+            let _ring_size = payload_reader.read_u32();
+            let root_digest = payload_reader.read_u160();
 
-            let num_updates = reader.read_u8().map_err()?;
+            let num_updates = reader.read_u8();
 
             let total_fee = get_total_fee(ref self, num_updates);
             let fee_contract = IERC20CamelDispatcher {
@@ -451,12 +451,12 @@ mod pyth {
     fn read_and_verify_message(
         ref reader: Reader, root_digest: u256
     ) -> Result<PriceFeedMessage, UpdatePriceFeedsError> {
-        let message_size = reader.read_u16().map_err()?;
-        let message = reader.read_byte_array(message_size.into()).map_err()?;
+        let message_size = reader.read_u16();
+        let message = reader.read_byte_array(message_size.into());
         read_and_verify_proof(root_digest, @message, ref reader).map_err()?;
 
         let mut message_reader = ReaderImpl::new(message);
-        let message_type: Option<MessageType> = message_reader.read_u8().map_err()?.try_into();
+        let message_type: Option<MessageType> = message_reader.read_u8().try_into();
         match message_type {
             Option::Some(v) => match v {
                 MessageType::PriceFeed => {}
@@ -464,14 +464,14 @@ mod pyth {
             Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
         };
 
-        let price_id = message_reader.read_u256().map_err()?;
-        let price = u64_as_i64(message_reader.read_u64().map_err()?);
-        let conf = message_reader.read_u64().map_err()?;
-        let expo = u32_as_i32(message_reader.read_u32().map_err()?);
-        let publish_time = message_reader.read_u64().map_err()?;
-        let prev_publish_time = message_reader.read_u64().map_err()?;
-        let ema_price = u64_as_i64(message_reader.read_u64().map_err()?);
-        let ema_conf = message_reader.read_u64().map_err()?;
+        let price_id = message_reader.read_u256();
+        let price = u64_as_i64(message_reader.read_u64());
+        let conf = message_reader.read_u64();
+        let expo = u32_as_i32(message_reader.read_u32());
+        let publish_time = message_reader.read_u64();
+        let prev_publish_time = message_reader.read_u64();
+        let ema_price = u64_as_i64(message_reader.read_u64());
+        let ema_conf = message_reader.read_u64();
 
         let message = PriceFeedMessage {
             price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf,

+ 49 - 83
target_chains/starknet/contracts/src/reader.cairo

@@ -42,112 +42,77 @@ pub impl ReaderImpl of ReaderTrait {
     }
 
     /// Reads the specified number of bytes (up to 16) as a big endian unsigned integer.
-    fn read_num_bytes(ref self: Reader, num_bytes: u8) -> Result<u128, Error> {
+    fn read_num_bytes(ref self: Reader, num_bytes: u8) -> u128 {
         assert!(num_bytes <= 16, "Reader::read_num_bytes: num_bytes is too large");
         if num_bytes <= self.num_current_bytes {
-            let x = self.read_from_current(num_bytes);
-            return Result::Ok(x);
+            return self.read_from_current(num_bytes);
         }
         let num_low_bytes = num_bytes - self.num_current_bytes;
         let high = self.current;
-        self.fetch_next()?;
-        let low = self.read_num_bytes(num_low_bytes)?;
-        let value = if num_low_bytes == 16 {
+        self.fetch_next();
+        let low = self.read_num_bytes(num_low_bytes);
+        if num_low_bytes == 16 {
             low
         } else {
             high * one_shift_left_bytes_u128(num_low_bytes) + low
-        };
-        Result::Ok(value)
+        }
     }
 
-    fn read_u256(ref self: Reader) -> Result<u256, Error> {
-        let high = self.read_num_bytes(16)?;
-        let low = self.read_num_bytes(16)?;
+    fn read_u256(ref self: Reader) -> u256 {
+        let high = self.read_num_bytes(16);
+        let low = self.read_num_bytes(16);
         let value = u256 { high, low };
-        Result::Ok(value)
+        value
     }
-    fn read_u160(ref self: Reader) -> Result<u256, Error> {
-        let high = self.read_num_bytes(4)?;
-        let low = self.read_num_bytes(16)?;
-        let value = u256 { high, low };
-        Result::Ok(value)
+    fn read_u160(ref self: Reader) -> u256 {
+        let high = self.read_num_bytes(4);
+        let low = self.read_num_bytes(16);
+        u256 { high, low }
     }
-    fn read_u128(ref self: Reader) -> Result<u128, Error> {
+    fn read_u128(ref self: Reader) -> u128 {
         self.read_num_bytes(16)
     }
-    fn read_u64(ref self: Reader) -> Result<u64, Error> {
-        let value = self.read_num_bytes(8)?.try_into().expect(UNEXPECTED_OVERFLOW);
-        Result::Ok(value)
+    fn read_u64(ref self: Reader) -> u64 {
+        self.read_num_bytes(8).try_into().expect(UNEXPECTED_OVERFLOW)
     }
-    fn read_u32(ref self: Reader) -> Result<u32, Error> {
-        let value = self.read_num_bytes(4)?.try_into().expect(UNEXPECTED_OVERFLOW);
-        Result::Ok(value)
+    fn read_u32(ref self: Reader) -> u32 {
+        self.read_num_bytes(4).try_into().expect(UNEXPECTED_OVERFLOW)
     }
-    fn read_u16(ref self: Reader) -> Result<u16, Error> {
-        let value = self.read_num_bytes(2)?.try_into().expect(UNEXPECTED_OVERFLOW);
-        Result::Ok(value)
+    fn read_u16(ref self: Reader) -> u16 {
+        self.read_num_bytes(2).try_into().expect(UNEXPECTED_OVERFLOW)
     }
-    fn read_u8(ref self: Reader) -> Result<u8, Error> {
-        let value = self.read_num_bytes(1)?.try_into().expect(UNEXPECTED_OVERFLOW);
-        Result::Ok(value)
+    fn read_u8(ref self: Reader) -> u8 {
+        self.read_num_bytes(1).try_into().expect(UNEXPECTED_OVERFLOW)
     }
 
     // TODO: skip without calculating values
-    fn skip(ref self: Reader, mut num_bytes: u8) -> Result<(), Error> {
-        let mut result = Result::Ok(());
+    fn skip(ref self: Reader, mut num_bytes: u8) {
         while num_bytes > 0 {
             if num_bytes > 16 {
-                match self.read_num_bytes(16) {
-                    Result::Ok(_) => {},
-                    Result::Err(err) => {
-                        result = Result::Err(err);
-                        break;
-                    }
-                }
+                self.read_num_bytes(16);
                 num_bytes -= 16;
             } else {
-                match self.read_num_bytes(num_bytes) {
-                    Result::Ok(_) => {},
-                    Result::Err(err) => {
-                        result = Result::Err(err);
-                        break;
-                    }
-                }
-                break;
+                self.read_num_bytes(num_bytes);
             }
-        };
-        result
+        }
     }
 
     /// Reads the specified number of bytes as a new byte array.
-    fn read_byte_array(ref self: Reader, num_bytes: usize) -> Result<ByteArray, Error> {
+    fn read_byte_array(ref self: Reader, num_bytes: usize) -> ByteArray {
         let mut array: Array<bytes31> = array![];
-        let mut num_last_bytes = Option::None;
+        let mut num_last_bytes = 0;
         let mut num_remaining_bytes = num_bytes;
         loop {
-            let r = self.read_bytes_iteration(num_remaining_bytes, ref array);
-            match r {
-                Result::Ok((
-                    num_read, eof
-                )) => {
-                    num_remaining_bytes -= num_read;
-                    if eof {
-                        num_last_bytes = Option::Some(Result::Ok(num_read));
-                        break;
-                    }
-                },
-                Result::Err(err) => {
-                    num_last_bytes = Option::Some(Result::Err(err));
-                    break;
-                }
+            let (num_read, eof) = self.read_bytes_iteration(num_remaining_bytes, ref array);
+            num_remaining_bytes -= num_read;
+            if eof {
+                num_last_bytes = num_read;
+                break;
             }
         };
-        // `num_last_bytes` is always set to Some before break.
-        let num_last_bytes = num_last_bytes.unwrap()?;
         // num_last_bytes < 31
         let num_last_bytes = num_last_bytes.try_into().expect(UNEXPECTED_OVERFLOW);
-        let array = ByteArrayImpl::new(array, num_last_bytes);
-        Result::Ok(array)
+        ByteArrayImpl::new(array, num_last_bytes)
     }
 
     /// Returns number of remaining bytes to read.
@@ -179,7 +144,7 @@ impl ReaderPrivateImpl of ReaderPrivateTrait {
     /// Replenishes `self.current` and `self.num_current_bytes`.
     /// This should only be called when all bytes from `self.current` has been read.
     /// Returns `EOF` error if no more data is available.
-    fn fetch_next(ref self: Reader) -> Result<(), Error> {
+    fn fetch_next(ref self: Reader) {
         match self.next {
             Option::Some(next) => {
                 self.next = Option::None;
@@ -187,7 +152,10 @@ impl ReaderPrivateImpl of ReaderPrivateTrait {
                 self.num_current_bytes = 16;
             },
             Option::None => {
-                let (value, bytes) = self.array.pop_front().ok_or(Error::UnexpectedEndOfInput)?;
+                let (value, bytes) = self
+                    .array
+                    .pop_front()
+                    .expect(Error::UnexpectedEndOfInput.into());
                 let value: u256 = value.into();
                 if bytes > 16 {
                     self.current = value.high;
@@ -199,33 +167,31 @@ impl ReaderPrivateImpl of ReaderPrivateTrait {
                 }
             },
         }
-        Result::Ok(())
     }
 
     // Moved out from `read_bytes` because we cannot use `return` or `?` within a loop.
     fn read_bytes_iteration(
         ref self: Reader, num_bytes: usize, ref array: Array<bytes31>
-    ) -> Result<(usize, bool), Error> {
+    ) -> (usize, bool) {
         if num_bytes >= 31 {
-            let high = self.read_num_bytes(15)?;
-            let low = self.read_num_bytes(16)?;
+            let high = self.read_num_bytes(15);
+            let low = self.read_num_bytes(16);
             let value: felt252 = u256 { high, low }.try_into().expect(UNEXPECTED_OVERFLOW);
             array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
-            Result::Ok((31, false))
+            (31, false)
         } else if num_bytes > 16 {
             // num_bytes < 31
-            let high = self
-                .read_num_bytes((num_bytes - 16).try_into().expect(UNEXPECTED_OVERFLOW))?;
-            let low = self.read_num_bytes(16)?;
+            let high = self.read_num_bytes((num_bytes - 16).try_into().expect(UNEXPECTED_OVERFLOW));
+            let low = self.read_num_bytes(16);
             let value: felt252 = u256 { high, low }.try_into().expect(UNEXPECTED_OVERFLOW);
             array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
-            Result::Ok((num_bytes, true))
+            (num_bytes, true)
         } else {
             // bytes < 16
-            let low = self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW))?;
+            let low = self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW));
             let value: felt252 = low.try_into().expect(UNEXPECTED_OVERFLOW);
             array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
-            Result::Ok((num_bytes, true))
+            (num_bytes, true)
         }
     }
 }

+ 14 - 14
target_chains/starknet/contracts/src/wormhole.cairo

@@ -282,10 +282,10 @@ mod wormhole {
     }
 
     fn parse_signature(ref reader: Reader) -> Result<GuardianSignature, ParseAndVerifyVmError> {
-        let guardian_index = reader.read_u8().map_err()?;
-        let r = reader.read_u256().map_err()?;
-        let s = reader.read_u256().map_err()?;
-        let recovery_id = reader.read_u8().map_err()?;
+        let guardian_index = reader.read_u8();
+        let r = reader.read_u256();
+        let s = reader.read_u256();
+        let recovery_id = reader.read_u8();
         let y_parity = (recovery_id % 2) > 0;
         let signature = GuardianSignature {
             guardian_index, signature: Signature { r, s, y_parity }
@@ -295,13 +295,13 @@ mod wormhole {
 
     fn parse_vm(encoded_vm: ByteArray) -> Result<(VM, u256), ParseAndVerifyVmError> {
         let mut reader = ReaderImpl::new(encoded_vm);
-        let version = reader.read_u8().map_err()?;
+        let version = reader.read_u8();
         if version != 1 {
             return Result::Err(ParseAndVerifyVmError::VmVersionIncompatible);
         }
-        let guardian_set_index = reader.read_u32().map_err()?;
+        let guardian_set_index = reader.read_u32();
 
-        let sig_count = reader.read_u8().map_err()?;
+        let sig_count = reader.read_u8();
         let mut i = 0;
         let mut signatures = array![];
 
@@ -326,14 +326,14 @@ mod wormhole {
         hasher2.push_u256(body_hash1);
         let body_hash2 = hasher2.finalize();
 
-        let timestamp = reader.read_u32().map_err()?;
-        let nonce = reader.read_u32().map_err()?;
-        let emitter_chain_id = reader.read_u16().map_err()?;
-        let emitter_address = reader.read_u256().map_err()?;
-        let sequence = reader.read_u64().map_err()?;
-        let consistency_level = reader.read_u8().map_err()?;
+        let timestamp = reader.read_u32();
+        let nonce = reader.read_u32();
+        let emitter_chain_id = reader.read_u16();
+        let emitter_address = reader.read_u256();
+        let sequence = reader.read_u64();
+        let consistency_level = reader.read_u8();
         let payload_len = reader.len();
-        let payload = reader.read_byte_array(payload_len).map_err()?;
+        let payload = reader.read_byte_array(payload_len);
 
         let vm = VM {
             version,

+ 3 - 3
target_chains/starknet/contracts/tests/wormhole.cairo

@@ -29,9 +29,9 @@ fn test_parse_and_verify_vm_works() {
     assert!(vm.payload.len() == 37);
 
     let mut reader = ReaderImpl::new(vm.payload);
-    assert!(reader.read_u8().unwrap() == 65);
-    assert!(reader.read_u8().unwrap() == 85);
-    assert!(reader.read_u8().unwrap() == 87);
+    assert!(reader.read_u8() == 65);
+    assert!(reader.read_u8() == 85);
+    assert!(reader.read_u8() == 87);
 }
 
 #[test]