Explorar el Código

Merge pull request #2840 from pyth-network/pyth-stylus-parse-updates

feat: stylus parse price feed updates fxn
Ayush Suresh hace 4 meses
padre
commit
3664ef3f9e

+ 4 - 0
target_chains/stylus/contracts/pyth-receiver/src/error.rs

@@ -17,6 +17,8 @@ pub enum PythReceiverError {
     InsufficientFee,
     InvalidEmitterAddress,
     TooManyUpdates,
+    PriceFeedNotFoundWithinRange,
+    NoFreshUpdate,
 }
 
 impl core::fmt::Debug for PythReceiverError {
@@ -43,6 +45,8 @@ impl From<PythReceiverError> for Vec<u8> {
             PythReceiverError::InsufficientFee => 13,
             PythReceiverError::InvalidEmitterAddress => 14,
             PythReceiverError::TooManyUpdates => 15,
+            PythReceiverError::PriceFeedNotFoundWithinRange => 16,
+            PythReceiverError::NoFreshUpdate => 17,
         }]
     }
 }

+ 10 - 0
target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs

@@ -342,4 +342,14 @@ mod test {
             multiple_updates_diff_vaa_results()[1]
         );
     }
+
+    #[motsu::test]
+    fn test_multiple_updates_same_id_updates_latest(
+        pyth_contract: Contract<PythReceiver>,
+        wormhole_contract: Contract<WormholeContract>,
+        alice: Address,
+    ) {
+        pyth_wormhole_init(&pyth_contract, &wormhole_contract, &alice);
+        alice.fund(U256::from(200));
+    }
 }

+ 210 - 93
target_chains/stylus/contracts/pyth-receiver/src/lib.rs

@@ -15,7 +15,7 @@ mod test_data;
 #[cfg(test)]
 use mock_instant::global::MockClock;
 
-use alloc::vec::Vec;
+use alloc::{collections::BTreeMap, vec::Vec};
 use stylus_sdk::{
     alloy_primitives::{Address, FixedBytes, I32, I64, U16, U256, U32, U64},
     call::Call,
@@ -97,7 +97,6 @@ impl PythReceiver {
         for (i, chain_id) in data_source_emitter_chain_ids.iter().enumerate() {
             let emitter_address = FixedBytes::<32>::from(data_source_emitter_addresses[i]);
 
-            // Create a new data source storage slot
             let mut data_source = self.valid_data_sources.grow();
             data_source.chain_id.set(U16::from(*chain_id));
             data_source.emitter_address.set(emitter_address);
@@ -178,7 +177,7 @@ impl PythReceiver {
         update_data: Vec<Vec<u8>>,
     ) -> Result<(), PythReceiverError> {
         for data in &update_data {
-            self.update_price_feeds_internal(data.clone())?;
+            self.update_price_feeds_internal(data.clone(), 0, 0, false)?;
         }
 
         let total_fee = self.get_update_fee(update_data)?;
@@ -193,17 +192,166 @@ impl PythReceiver {
 
     pub fn update_price_feeds_if_necessary(
         &mut self,
-        _update_data: Vec<Vec<u8>>,
-        _price_ids: Vec<[u8; 32]>,
-        _publish_times: Vec<u64>,
-    ) {
-        // dummy implementation
+        update_data: Vec<Vec<u8>>,
+        price_ids: Vec<[u8; 32]>,
+        publish_times: Vec<u64>,
+    ) -> Result<(), PythReceiverError> {
+        if (price_ids.len() != publish_times.len())
+            || (price_ids.is_empty() && publish_times.is_empty())
+        {
+            return Err(PythReceiverError::InvalidUpdateData);
+        }
+
+        for i in 0..price_ids.len() {
+            if self.latest_price_info_publish_time(price_ids[i]) < publish_times[i] {
+                self.update_price_feeds(update_data.clone())?;
+                return Ok(());
+            }
+        }
+
+        return Err(PythReceiverError::NoFreshUpdate);
+    }
+
+    fn latest_price_info_publish_time(&self, price_id: [u8; 32]) -> u64 {
+        let price_id_fb: FixedBytes<32> = FixedBytes::from(price_id);
+        let recent_price_info = self.latest_price_info.get(price_id_fb);
+        recent_price_info.publish_time.get().to::<u64>()
     }
 
     fn update_price_feeds_internal(
         &mut self,
         update_data: Vec<u8>,
-    ) -> Result<(), PythReceiverError> {
+        min_publish_time: u64,
+        max_publish_time: u64,
+        unique: bool,
+    ) -> Result<Vec<([u8; 32], PriceInfoReturn)>, PythReceiverError> {
+        let price_pairs = self.parse_price_feed_updates_internal(
+            update_data,
+            min_publish_time,
+            max_publish_time,
+            unique,
+        )?;
+
+        for (price_id, price_return) in price_pairs.clone() {
+            let price_id_fb: FixedBytes<32> = FixedBytes::from(price_id);
+            let mut recent_price_info = self.latest_price_info.setter(price_id_fb);
+
+            if recent_price_info.publish_time.get() < price_return.0
+                || recent_price_info.price.get() == I64::ZERO
+            {
+                recent_price_info.publish_time.set(price_return.0);
+                recent_price_info.expo.set(price_return.1);
+                recent_price_info.price.set(price_return.2);
+                recent_price_info.conf.set(price_return.3);
+                recent_price_info.ema_price.set(price_return.4);
+                recent_price_info.ema_conf.set(price_return.5);
+            }
+        }
+
+        Ok(price_pairs)
+    }
+
+    fn get_update_fee(&self, update_data: Vec<Vec<u8>>) -> Result<U256, PythReceiverError> {
+        let mut total_num_updates: u64 = 0;
+        for data in &update_data {
+            let update_data_array: &[u8] = &data;
+            let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array)
+                .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?;
+            match accumulator_update.proof {
+                Proof::WormholeMerkle { vaa: _, updates } => {
+                    let num_updates = u64::try_from(updates.len())
+                        .map_err(|_| PythReceiverError::TooManyUpdates)?;
+                    total_num_updates += num_updates;
+                }
+            }
+        }
+        Ok(U256::from(total_num_updates).saturating_mul(self.single_update_fee_in_wei.get())
+            + self.transaction_fee_in_wei.get())
+    }
+
+    pub fn get_twap_update_fee(&self, _update_data: Vec<Vec<u8>>) -> U256 {
+        U256::from(0u8)
+    }
+
+    pub fn parse_price_feed_updates(
+        &mut self,
+        update_data: Vec<u8>,
+        price_ids: Vec<[u8; 32]>,
+        min_publish_time: u64,
+        max_publish_time: u64,
+    ) -> Result<Vec<PriceInfoReturn>, PythReceiverError> {
+        let price_feeds = self.parse_price_feed_updates_with_config(
+            vec![update_data],
+            price_ids,
+            min_publish_time,
+            max_publish_time,
+            false,
+            false,
+            false,
+        );
+        price_feeds
+    }
+
+    pub fn parse_price_feed_updates_with_config(
+        &mut self,
+        update_data: Vec<Vec<u8>>,
+        price_ids: Vec<[u8; 32]>,
+        min_allowed_publish_time: u64,
+        max_allowed_publish_time: u64,
+        check_uniqueness: bool,
+        check_update_data_is_minimal: bool,
+        store_updates_if_fresh: bool,
+    ) -> Result<Vec<PriceInfoReturn>, PythReceiverError> {
+        let mut all_parsed_price_pairs = Vec::new();
+        for data in &update_data {
+            if store_updates_if_fresh {
+                all_parsed_price_pairs.extend(self.update_price_feeds_internal(
+                    data.clone(),
+                    min_allowed_publish_time,
+                    max_allowed_publish_time,
+                    check_uniqueness,
+                )?);
+            } else {
+                all_parsed_price_pairs.extend(self.parse_price_feed_updates_internal(
+                    data.clone(),
+                    min_allowed_publish_time,
+                    max_allowed_publish_time,
+                    check_uniqueness,
+                )?);
+            }
+        }
+
+        if check_update_data_is_minimal && all_parsed_price_pairs.len() != price_ids.len() {
+            return Err(PythReceiverError::InvalidUpdateData);
+        }
+
+        let mut result: Vec<PriceInfoReturn> = Vec::with_capacity(price_ids.len());
+        let mut price_map: BTreeMap<[u8; 32], PriceInfoReturn> = BTreeMap::new();
+
+        for (price_id, price_info) in all_parsed_price_pairs {
+            if !price_map.contains_key(&price_id) {
+                price_map.insert(price_id, price_info);
+            }
+        }
+
+        for price_id in price_ids {
+            if let Some(price_info) = price_map.get(&price_id) {
+                result.push(*price_info);
+            } else {
+                return Err(PythReceiverError::PriceFeedNotFoundWithinRange);
+            }
+        }
+
+        Ok(result)
+    }
+
+    fn parse_price_feed_updates_internal(
+        &mut self,
+        update_data: Vec<u8>,
+        min_allowed_publish_time: u64,
+        max_allowed_publish_time: u64,
+        check_uniqueness: bool,
+    ) -> Result<Vec<([u8; 32], PriceInfoReturn)>, PythReceiverError> {
         let update_data_array: &[u8] = &update_data;
         // Check the first 4 bytes of the update_data_array for the magic header
         if update_data_array.len() < 4 {
@@ -220,6 +368,8 @@ impl PythReceiver {
         let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array)
             .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?;
 
+        let mut price_feeds: BTreeMap<[u8; 32], PriceInfoReturn> = BTreeMap::new();
+
         match accumulator_update.proof {
             Proof::WormholeMerkle { vaa, updates } => {
                 let wormhole: IWormholeContract = IWormholeContract::new(self.wormhole.get());
@@ -228,10 +378,10 @@ impl PythReceiver {
                     .parse_and_verify_vm(config, Vec::from(vaa.clone()))
                     .map_err(|_| PythReceiverError::InvalidWormholeMessage)?;
 
-                let vaa = Vaa::read(&mut Vec::from(vaa.clone()).as_slice())
+                let vaa_obj = Vaa::read(&mut Vec::from(vaa.clone()).as_slice())
                     .map_err(|_| PythReceiverError::VaaVerificationFailed)?;
 
-                let cur_emitter_address: &[u8; 32] = vaa
+                let cur_emitter_address: &[u8; 32] = vaa_obj
                     .body
                     .emitter_address
                     .as_slice()
@@ -239,7 +389,7 @@ impl PythReceiver {
                     .map_err(|_| PythReceiverError::InvalidEmitterAddress)?;
 
                 let cur_data_source = DataSource {
-                    chain_id: U16::from(vaa.body.emitter_chain),
+                    chain_id: U16::from(vaa_obj.body.emitter_chain),
                     emitter_address: FixedBytes::from(cur_emitter_address),
                 };
 
@@ -247,7 +397,7 @@ impl PythReceiver {
                     return Err(PythReceiverError::InvalidWormholeMessage);
                 }
 
-                let root_digest: MerkleRoot<Keccak160> = parse_wormhole_proof(vaa)?;
+                let root_digest: MerkleRoot<Keccak160> = parse_wormhole_proof(vaa_obj)?;
 
                 for update in updates {
                     let message_vec = Vec::from(update.message);
@@ -262,33 +412,40 @@ impl PythReceiver {
 
                     match msg {
                         Message::PriceFeedMessage(price_feed_message) => {
-                            let price_id_fb: FixedBytes<32> =
-                                FixedBytes::from(price_feed_message.feed_id);
-                            let mut recent_price_info = self.latest_price_info.setter(price_id_fb);
+                            let publish_time = price_feed_message.publish_time;
 
-                            if recent_price_info.publish_time.get()
-                                < U64::from(price_feed_message.publish_time)
-                                || recent_price_info.price.get() == I64::ZERO
+                            if (min_allowed_publish_time > 0
+                                && publish_time < min_allowed_publish_time as i64)
+                                || (max_allowed_publish_time > 0
+                                    && publish_time > max_allowed_publish_time as i64)
                             {
-                                recent_price_info
-                                    .publish_time
-                                    .set(U64::from(price_feed_message.publish_time));
-                                recent_price_info.price.set(I64::from_le_bytes(
-                                    price_feed_message.price.to_le_bytes(),
-                                ));
-                                recent_price_info
-                                    .conf
-                                    .set(U64::from(price_feed_message.conf));
-                                recent_price_info.expo.set(I32::from_le_bytes(
-                                    price_feed_message.exponent.to_le_bytes(),
-                                ));
-                                recent_price_info.ema_price.set(I64::from_le_bytes(
-                                    price_feed_message.ema_price.to_le_bytes(),
-                                ));
-                                recent_price_info
-                                    .ema_conf
-                                    .set(U64::from(price_feed_message.ema_conf));
+                                return Err(PythReceiverError::PriceFeedNotFoundWithinRange);
                             }
+
+                            if check_uniqueness {
+                                let price_id_fb =
+                                    FixedBytes::<32>::from(price_feed_message.feed_id);
+                                let prev_price_info = self.latest_price_info.get(price_id_fb);
+                                let prev_publish_time =
+                                    prev_price_info.publish_time.get().to::<u64>();
+
+                                if prev_publish_time > 0
+                                    && min_allowed_publish_time <= prev_publish_time
+                                {
+                                    return Err(PythReceiverError::PriceFeedNotFoundWithinRange);
+                                }
+                            }
+
+                            let price_info_return = (
+                                U64::from(publish_time),
+                                I32::from_be_bytes(price_feed_message.exponent.to_be_bytes()),
+                                I64::from_be_bytes(price_feed_message.price.to_be_bytes()),
+                                U64::from(price_feed_message.conf),
+                                I64::from_be_bytes(price_feed_message.ema_price.to_be_bytes()),
+                                U64::from(price_feed_message.ema_conf),
+                            );
+
+                            price_feeds.insert(price_feed_message.feed_id, price_info_return);
                         }
                         _ => {
                             return Err(PythReceiverError::InvalidAccumulatorMessageType);
@@ -298,56 +455,7 @@ impl PythReceiver {
             }
         };
 
-        Ok(())
-    }
-
-    fn get_update_fee(&self, update_data: Vec<Vec<u8>>) -> Result<U256, PythReceiverError> {
-        let mut total_num_updates: u64 = 0;
-        for data in &update_data {
-            let update_data_array: &[u8] = &data;
-            let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array)
-                .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?;
-            match accumulator_update.proof {
-                Proof::WormholeMerkle { vaa: _, updates } => {
-                    let num_updates = u64::try_from(updates.len())
-                        .map_err(|_| PythReceiverError::TooManyUpdates)?;
-                    total_num_updates += num_updates;
-                }
-            }
-        }
-        Ok(self.get_total_fee(total_num_updates))
-    }
-
-    fn get_total_fee(&self, total_num_updates: u64) -> U256 {
-        U256::from(total_num_updates).saturating_mul(self.single_update_fee_in_wei.get())
-            + self.transaction_fee_in_wei.get()
-    }
-
-    pub fn get_twap_update_fee(&self, _update_data: Vec<Vec<u8>>) -> U256 {
-        U256::from(0u8)
-    }
-
-    pub fn parse_price_feed_updates(
-        &mut self,
-        _update_data: Vec<Vec<u8>>,
-        _price_ids: Vec<[u8; 32]>,
-        _min_publish_time: u64,
-        _max_publish_time: u64,
-    ) -> Vec<PriceInfoReturn> {
-        Vec::new()
-    }
-
-    pub fn parse_price_feed_updates_with_config(
-        &mut self,
-        _update_data: Vec<Vec<u8>>,
-        _price_ids: Vec<[u8; 32]>,
-        _min_allowed_publish_time: u64,
-        _max_allowed_publish_time: u64,
-        _check_uniqueness: bool,
-        _check_update_data_is_minimal: bool,
-        _store_updates_if_fresh: bool,
-    ) -> (Vec<PriceInfoReturn>, Vec<u64>) {
-        (Vec::new(), Vec::new())
+        Ok(price_feeds.into_iter().collect())
     }
 
     pub fn parse_twap_price_feed_updates(
@@ -360,12 +468,21 @@ impl PythReceiver {
 
     pub fn parse_price_feed_updates_unique(
         &mut self,
-        _update_data: Vec<Vec<u8>>,
-        _price_ids: Vec<[u8; 32]>,
-        _min_publish_time: u64,
-        _max_publish_time: u64,
-    ) -> Vec<PriceInfoReturn> {
-        Vec::new()
+        update_data: Vec<Vec<u8>>,
+        price_ids: Vec<[u8; 32]>,
+        min_publish_time: u64,
+        max_publish_time: u64,
+    ) -> Result<Vec<PriceInfoReturn>, PythReceiverError> {
+        let price_feeds = self.parse_price_feed_updates_with_config(
+            update_data,
+            price_ids,
+            min_publish_time,
+            max_publish_time,
+            true,
+            false,
+            false,
+        );
+        price_feeds
     }
 
     fn is_no_older_than(&self, publish_time: U64, max_age: u64) -> bool {

+ 8 - 8
target_chains/stylus/contracts/wormhole/src/lib.rs

@@ -500,7 +500,7 @@ mod tests {
     use core::str::FromStr;
     use k256::ecdsa::SigningKey;
     use stylus_sdk::alloy_primitives::keccak256;
-    
+
     #[cfg(test)]
     use base64::engine::general_purpose;
     #[cfg(test)]
@@ -543,7 +543,7 @@ mod tests {
             0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40,
         ]
     }
-    
+
     #[cfg(test)]
     fn current_guardians() -> Vec<Address> {
         vec![
@@ -634,7 +634,7 @@ mod tests {
         contract.initialize(guardians, 1, CHAIN_ID, GOVERNANCE_CHAIN_ID, governance_contract).unwrap();
         contract
     }
-    
+
     #[cfg(test)]
     fn deploy_with_current_mainnet_guardians() -> WormholeContract {
         let mut contract = WormholeContract::default();
@@ -802,7 +802,7 @@ mod tests {
     #[motsu::test]
     fn test_verification_multiple_guardian_sets() {
         let mut contract = deploy_with_current_mainnet_guardians();
-        
+
         let store_result = contract.store_gs(4, current_guardians(), 0);
         if let Err(_) = store_result {
             panic!("Error deploying multiple guardian sets");
@@ -816,7 +816,7 @@ mod tests {
     #[motsu::test]
     fn test_verification_incorrect_guardian_set() {
         let mut contract = deploy_with_current_mainnet_guardians();
-        
+
         let store_result = contract.store_gs(4, mock_guardian_set13(), 0);
         if let Err(_) = store_result {
             panic!("Error deploying guardian set");
@@ -1147,7 +1147,7 @@ mod tests {
         let mut contract = WormholeContract::default();
         let guardians = current_guardians();
         let governance_contract = Address::from_slice(&GOVERNANCE_CONTRACT.to_be_bytes::<32>()[12..32]);
-        
+
         let result = contract.initialize(guardians.clone(), 4, CHAIN_ID, GOVERNANCE_CHAIN_ID, governance_contract);
         assert!(result.is_ok(), "Contract initialization should succeed");
     }
@@ -1222,5 +1222,5 @@ mod tests {
         assert!(result2.is_ok());
     }
 
-    
-}
+
+}