Ver Fonte

refactor(target_chains/starknet): remove Result from merkle_tree and pyth setters (#1548)

* refactor(target_chains/starknet): remove Result from merkle_tree

* refactor(target_chains/starknet): remove Result from pyth contract setters
Pavel Strakhov há 1 ano atrás
pai
commit
42b64ac09f

+ 15 - 11
target_chains/starknet/contracts/src/merkle_tree.cairo

@@ -3,6 +3,7 @@ use super::reader::{Reader, ReaderImpl};
 use super::byte_array::ByteArray;
 use super::util::ONE_SHIFT_96;
 use core::cmp::{min, max};
+use core::panic_with_felt252;
 
 const MERKLE_LEAF_PREFIX: u8 = 0;
 const MERKLE_NODE_PREFIX: u8 = 1;
@@ -14,6 +15,15 @@ pub enum MerkleVerificationError {
     DigestMismatch,
 }
 
+impl MerkleVerificationErrorIntoFelt252 of Into<MerkleVerificationError, felt252> {
+    fn into(self: MerkleVerificationError) -> felt252 {
+        match self {
+            MerkleVerificationError::Reader(err) => err.into(),
+            MerkleVerificationError::DigestMismatch => 'digest mismatch',
+        }
+    }
+}
+
 #[generate_trait]
 impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrait<T> {
     fn map_err(self: Result<T, pyth::reader::Error>) -> Result<T, MerkleVerificationError> {
@@ -24,12 +34,11 @@ impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrai
     }
 }
 
-fn leaf_hash(mut reader: Reader) -> Result<u256, super::reader::Error> {
+fn leaf_hash(mut reader: Reader) -> u256 {
     let mut hasher = HasherImpl::new();
     hasher.push_u8(MERKLE_LEAF_PREFIX);
     hasher.push_reader(ref reader);
-    let hash = hasher.finalize() / ONE_SHIFT_96;
-    Result::Ok(hash)
+    hasher.finalize() / ONE_SHIFT_96
 }
 
 fn node_hash(a: u256, b: u256) -> u256 {
@@ -40,25 +49,20 @@ fn node_hash(a: u256, b: u256) -> u256 {
     hasher.finalize() / ONE_SHIFT_96
 }
 
-pub fn read_and_verify_proof(
-    root_digest: u256, message: @ByteArray, ref reader: Reader
-) -> Result<(), MerkleVerificationError> {
+pub fn read_and_verify_proof(root_digest: u256, message: @ByteArray, ref reader: Reader) {
     let mut message_reader = ReaderImpl::new(message.clone());
-    let mut current_hash = leaf_hash(message_reader.clone()).map_err()?;
+    let mut current_hash = leaf_hash(message_reader.clone());
 
     let proof_size = reader.read_u8();
     let mut i = 0;
 
-    let mut result = Result::Ok(());
     while i < proof_size {
         let sibling_digest = reader.read_u160();
         current_hash = node_hash(current_hash, sibling_digest);
         i += 1;
     };
-    result?;
 
     if root_digest != current_hash {
-        return Result::Err(MerkleVerificationError::DigestMismatch);
+        panic_with_felt252(MerkleVerificationError::DigestMismatch.into());
     }
-    Result::Ok(())
 }

+ 43 - 63
target_chains/starknet/contracts/src/pyth.cairo

@@ -9,11 +9,9 @@ pub use pyth::{Event, PriceFeedUpdateEvent};
 pub trait IPyth<T> {
     fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
     fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
-    fn set_data_sources(
-        ref self: T, sources: Array<DataSource>
-    ) -> Result<(), GovernanceActionError>;
-    fn set_fee(ref self: T, single_update_fee: u256) -> Result<(), GovernanceActionError>;
-    fn update_price_feeds(ref self: T, data: ByteArray) -> Result<(), UpdatePriceFeedsError>;
+    fn set_data_sources(ref self: T, sources: Array<DataSource>);
+    fn set_fee(ref self: T, single_update_fee: u256);
+    fn update_price_feeds(ref self: T, data: ByteArray);
 }
 
 #[derive(Copy, Drop, Debug, Serde, PartialEq)]
@@ -333,51 +331,44 @@ mod pyth {
             Result::Ok(price)
         }
 
-        fn set_data_sources(
-            ref self: ContractState, sources: Array<DataSource>
-        ) -> Result<(), GovernanceActionError> {
+        fn set_data_sources(ref self: ContractState, sources: Array<DataSource>) {
             if self.owner.read() != get_caller_address() {
-                return Result::Err(GovernanceActionError::AccessDenied);
+                panic_with_felt252(GovernanceActionError::AccessDenied.into());
             }
             write_data_sources(ref self, sources);
-            Result::Ok(())
         }
 
-        fn set_fee(
-            ref self: ContractState, single_update_fee: u256
-        ) -> Result<(), GovernanceActionError> {
+        fn set_fee(ref self: ContractState, single_update_fee: u256) {
             if self.owner.read() != get_caller_address() {
-                return Result::Err(GovernanceActionError::AccessDenied);
+                panic_with_felt252(GovernanceActionError::AccessDenied.into());
             }
             self.single_update_fee.write(single_update_fee);
-            Result::Ok(())
         }
 
-        fn update_price_feeds(
-            ref self: ContractState, data: ByteArray
-        ) -> Result<(), UpdatePriceFeedsError> {
+        fn update_price_feeds(ref self: ContractState, data: ByteArray) {
             let mut reader = ReaderImpl::new(data);
             let x = reader.read_u32();
             if x != ACCUMULATOR_MAGIC {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
             }
             if reader.read_u8() != MAJOR_VERSION {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
             }
             if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
             }
 
             let trailing_header_size = reader.read_u8();
             reader.skip(trailing_header_size);
 
-            let update_type: Option<UpdateType> = reader.read_u8().try_into();
+            let update_type: UpdateType = reader
+                .read_u8()
+                .try_into()
+                .expect(UpdatePriceFeedsError::InvalidUpdateData.into());
+
             match update_type {
-                Option::Some(v) => match v {
-                    UpdateType::WormholeMerkle => {}
-                },
-                Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
-            };
+                UpdateType::WormholeMerkle => {}
+            }
 
             let wh_proof_size = reader.read_u16();
             let wh_proof = reader.read_byte_array(wh_proof_size.into());
@@ -388,22 +379,23 @@ mod pyth {
                 emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address
             };
             if !self.is_valid_data_source.read(source) {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateDataSource);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateDataSource.into());
             }
 
             let mut payload_reader = ReaderImpl::new(vm.payload);
             let x = payload_reader.read_u32();
             if x != ACCUMULATOR_WORMHOLE_MAGIC {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
             }
 
-            let update_type: Option<UpdateType> = payload_reader.read_u8().try_into();
+            let update_type: UpdateType = payload_reader
+                .read_u8()
+                .try_into()
+                .expect(UpdatePriceFeedsError::InvalidUpdateData.into());
+
             match update_type {
-                Option::Some(v) => match v {
-                    UpdateType::WormholeMerkle => {}
-                },
-                Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
-            };
+                UpdateType::WormholeMerkle => {}
+            }
 
             let _slot = payload_reader.read_u64();
             let _ring_size = payload_reader.read_u32();
@@ -419,50 +411,39 @@ mod pyth {
             let caller = execution_info.caller_address;
             let contract = execution_info.contract_address;
             if fee_contract.allowance(caller, contract) < total_fee {
-                return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
+                panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
             }
             if !fee_contract.transferFrom(caller, contract, total_fee) {
-                return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
+                panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
             }
 
             let mut i = 0;
-            let mut result = Result::Ok(());
             while i < num_updates {
-                let r = read_and_verify_message(ref reader, root_digest);
-                match r {
-                    Result::Ok(message) => { update_latest_price_if_necessary(ref self, message); },
-                    Result::Err(err) => {
-                        result = Result::Err(err);
-                        break;
-                    }
-                }
+                let message = read_and_verify_message(ref reader, root_digest);
+                update_latest_price_if_necessary(ref self, message);
                 i += 1;
             };
-            result?;
 
             if reader.len() != 0 {
-                return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
+                panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
             }
-
-            Result::Ok(())
         }
     }
 
-    fn read_and_verify_message(
-        ref reader: Reader, root_digest: u256
-    ) -> Result<PriceFeedMessage, UpdatePriceFeedsError> {
+    fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage {
         let message_size = reader.read_u16();
         let message = reader.read_byte_array(message_size.into());
-        read_and_verify_proof(root_digest, @message, ref reader).map_err()?;
+        read_and_verify_proof(root_digest, @message, ref reader);
 
         let mut message_reader = ReaderImpl::new(message);
-        let message_type: Option<MessageType> = message_reader.read_u8().try_into();
+        let message_type: MessageType = message_reader
+            .read_u8()
+            .try_into()
+            .expect(UpdatePriceFeedsError::InvalidUpdateData.into());
+
         match message_type {
-            Option::Some(v) => match v {
-                MessageType::PriceFeed => {}
-            },
-            Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
-        };
+            MessageType::PriceFeed => {}
+        }
 
         let price_id = message_reader.read_u256();
         let price = u64_as_i64(message_reader.read_u64());
@@ -473,10 +454,9 @@ mod pyth {
         let ema_price = u64_as_i64(message_reader.read_u64());
         let ema_conf = message_reader.read_u64();
 
-        let message = PriceFeedMessage {
+        PriceFeedMessage {
             price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf,
-        };
-        Result::Ok(message)
+        }
     }
 
     fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) {

+ 1 - 1
target_chains/starknet/contracts/tests/pyth.cairo

@@ -55,7 +55,7 @@ fn update_price_feeds_works() {
     let mut spy = spy_events(SpyOn::One(pyth.contract_address));
 
     start_prank(CheatTarget::One(pyth.contract_address), user.try_into().unwrap());
-    pyth.update_price_feeds(good_update1()).unwrap_with_felt252();
+    pyth.update_price_feeds(good_update1());
     stop_prank(CheatTarget::One(pyth.contract_address));
 
     spy.fetch_events();