Browse Source

Merge pull request #2929 from pyth-network/correcting-fee-calc-order

fix(stylus) - fixed order of fee calculation
Ayush Suresh 3 tháng trước cách đây
mục cha
commit
314c44733a

+ 29 - 0
target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs

@@ -8,6 +8,7 @@ mod test {
     use motsu::prelude::*;
     use pythnet_sdk::wire::v1::{AccumulatorUpdateData, Proof};
     use std::time::Duration;
+    use stylus_sdk::types::AddressVM;
     use wormhole_contract::WormholeContract;
 
     const PYTHNET_CHAIN_ID: u16 = 26;
@@ -118,8 +119,12 @@ mod test {
         let result = pyth_contract
             .sender_and_value(alice, update_fee)
             .update_price_feeds(update_data);
+
         assert!(result.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         let price_result = pyth_contract
             .sender(alice)
             .get_price_unsafe(ban_usd_feed_id());
@@ -169,11 +174,17 @@ mod test {
             .update_price_feeds(update_data1);
         assert!(result1.is_ok());
 
+        assert_eq!(alice.balance(), update_fee2);
+        assert_eq!(pyth_contract.balance(), update_fee1);
+
         let result2 = pyth_contract
             .sender_and_value(alice, update_fee2)
             .update_price_feeds(update_data2);
         assert!(result2.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee1 + update_fee2);
+
         let price_result = pyth_contract
             .sender(alice)
             .get_price_unsafe(ban_usd_feed_id());
@@ -243,6 +254,9 @@ mod test {
             .update_price_feeds(update_data);
         assert!(result.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         let price_result = pyth_contract
             .sender(alice)
             .get_price_no_older_than(btc_usd_feed_id(), u64::MAX);
@@ -269,6 +283,9 @@ mod test {
             .update_price_feeds(update_data);
         assert!(result.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         let price_result = pyth_contract
             .sender(alice)
             .get_price_no_older_than(btc_usd_feed_id(), 1);
@@ -298,6 +315,9 @@ mod test {
             .update_price_feeds(update_data);
         assert!(result.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         let first_price_result = pyth_contract
             .sender(alice)
             .get_price_unsafe(ban_usd_feed_id());
@@ -339,6 +359,9 @@ mod test {
             .update_price_feeds(update_data);
         assert!(result.is_ok());
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         assert!(pyth_contract
             .sender(alice)
             .price_feed_exists(ban_usd_feed_id()));
@@ -380,6 +403,9 @@ mod test {
             .sender_and_value(alice, update_fee)
             .update_price_feeds(update_data);
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         assert!(result.is_ok());
 
         let price_result = pyth_contract
@@ -407,6 +433,9 @@ mod test {
             .sender_and_value(alice, update_fee)
             .update_price_feeds(update_data);
 
+        assert_eq!(alice.balance(), U256::ZERO);
+        assert_eq!(pyth_contract.balance(), update_fee);
+
         assert!(result.is_ok());
 
         let price_result1 = pyth_contract

+ 5 - 5
target_chains/stylus/contracts/pyth-receiver/src/lib.rs

@@ -219,17 +219,17 @@ impl PythReceiver {
         &mut self,
         update_data: Vec<Vec<u8>>,
     ) -> Result<(), PythReceiverError> {
-        for data in &update_data {
-            self.update_price_feeds_internal(data.clone(), 0, 0, false)?;
-        }
-
-        let total_fee = self.get_update_fee(update_data)?;
+        let total_fee = self.get_update_fee(update_data.clone())?;
 
         let value = self.vm().msg_value();
 
         if value < total_fee {
             return Err(PythReceiverError::InsufficientFee);
         }
+
+        for data in &update_data {
+            self.update_price_feeds_internal(data.clone(), 0, 0, false)?;
+        }
         Ok(())
     }