Explorar el Código

feat(target_chains/starknet): update if necessary and get no older than (#1614)

* feat(target_chains/starknet): add update_price_feeds_if_necessary()

* feat(target_chains/starknet): add get_price_no_older_than

* feat(target_chains/starknet): add get_ema_price_no_older_than()

* fix(target_chains/starknet): panic if there is no fresh update
Pavel Strakhov hace 1 año
padre
commit
b0cb32f8a4

+ 64 - 4
target_chains/starknet/contracts/src/pyth.cairo

@@ -9,11 +9,17 @@ pub use pyth::{
     Event, PriceFeedUpdated, WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded,
     DataSourcesSet, FeeSet,
 };
-pub use errors::{GetPriceUnsafeError, GovernanceActionError, UpdatePriceFeedsError};
-pub use interface::{IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Price};
+pub use errors::{
+    GetPriceUnsafeError, GovernanceActionError, UpdatePriceFeedsError, GetPriceNoOlderThanError,
+    UpdatePriceFeedsIfNecessaryError,
+};
+pub use interface::{
+    IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Price, PriceFeedPublishTime
+};
 
 #[starknet::contract]
 mod pyth {
+    use pyth::pyth::interface::IPyth;
     use super::price_update::{
         PriceInfo, PriceFeedMessage, read_and_verify_message, read_and_verify_header,
         parse_wormhole_proof
@@ -23,17 +29,19 @@ mod pyth {
     use core::panic_with_felt252;
     use core::starknet::{
         ContractAddress, get_caller_address, get_execution_info, ClassHash, SyscallResultTrait,
-        get_contract_address,
+        get_contract_address, get_block_timestamp,
     };
     use core::starknet::syscalls::replace_class_syscall;
     use pyth::wormhole::{IWormholeDispatcher, IWormholeDispatcherTrait, VerifiedVM};
     use super::{
         DataSource, UpdatePriceFeedsError, GovernanceActionError, Price, GetPriceUnsafeError,
-        IPythDispatcher, IPythDispatcherTrait,
+        IPythDispatcher, IPythDispatcherTrait, PriceFeedPublishTime, GetPriceNoOlderThanError,
+        UpdatePriceFeedsIfNecessaryError,
     };
     use super::governance;
     use super::governance::GovernancePayload;
     use openzeppelin::token::erc20::interface::{IERC20CamelDispatcherTrait, IERC20CamelDispatcher};
+    use pyth::util::ResultMapErrInto;
 
     #[event]
     #[derive(Drop, PartialEq, starknet::Event)]
@@ -143,6 +151,16 @@ mod pyth {
 
     #[abi(embed_v0)]
     impl PythImpl of super::IPyth<ContractState> {
+        fn get_price_no_older_than(
+            self: @ContractState, price_id: u256, age: u64
+        ) -> Result<Price, GetPriceNoOlderThanError> {
+            let info = self.get_price_unsafe(price_id).map_err_into()?;
+            if !is_no_older_than(info.publish_time, age) {
+                return Result::Err(GetPriceNoOlderThanError::StalePrice);
+            }
+            Result::Ok(info)
+        }
+
         fn get_price_unsafe(
             self: @ContractState, price_id: u256
         ) -> Result<Price, GetPriceUnsafeError> {
@@ -159,6 +177,16 @@ mod pyth {
             Result::Ok(price)
         }
 
+        fn get_ema_price_no_older_than(
+            self: @ContractState, price_id: u256, age: u64
+        ) -> Result<Price, GetPriceNoOlderThanError> {
+            let info = self.get_ema_price_unsafe(price_id).map_err_into()?;
+            if !is_no_older_than(info.publish_time, age) {
+                return Result::Err(GetPriceNoOlderThanError::StalePrice);
+            }
+            Result::Ok(info)
+        }
+
         fn get_ema_price_unsafe(
             self: @ContractState, price_id: u256
         ) -> Result<Price, GetPriceUnsafeError> {
@@ -229,6 +257,28 @@ mod pyth {
             self.get_total_fee(num_updates)
         }
 
+        fn update_price_feeds_if_necessary(
+            ref self: ContractState,
+            update: ByteArray,
+            required_publish_times: Array<PriceFeedPublishTime>
+        ) {
+            let mut i = 0;
+            let mut found = false;
+            while i < required_publish_times.len() {
+                let item = required_publish_times.at(i);
+                let latest_time = self.latest_price_info.read(*item.price_id).publish_time;
+                if latest_time < *item.publish_time {
+                    self.update_price_feeds(update);
+                    found = true;
+                    break;
+                }
+                i += 1;
+            };
+            if !found {
+                panic_with_felt252(UpdatePriceFeedsIfNecessaryError::NoFreshUpdate.into());
+            }
+        }
+
         fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
             let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
             let vm = wormhole.parse_and_verify_vm(data.clone());
@@ -451,4 +501,14 @@ mod pyth {
         };
         output
     }
+
+    fn is_no_older_than(publish_time: u64, age: u64) -> bool {
+        let current = get_block_timestamp();
+        let actual_age = if current >= publish_time {
+            current - publish_time
+        } else {
+            0
+        };
+        actual_age <= age
+    }
 }

+ 42 - 0
target_chains/starknet/contracts/src/pyth/errors.cairo

@@ -11,6 +11,31 @@ impl GetPriceUnsafeErrorIntoFelt252 of Into<GetPriceUnsafeError, felt252> {
     }
 }
 
+#[derive(Copy, Drop, Debug, Serde, PartialEq)]
+pub enum GetPriceNoOlderThanError {
+    PriceFeedNotFound,
+    StalePrice,
+}
+
+impl GetPriceNoOlderThanErrorIntoFelt252 of Into<GetPriceNoOlderThanError, felt252> {
+    fn into(self: GetPriceNoOlderThanError) -> felt252 {
+        match self {
+            GetPriceNoOlderThanError::PriceFeedNotFound => 'price feed not found',
+            GetPriceNoOlderThanError::StalePrice => 'stale price',
+        }
+    }
+}
+
+impl GetPriceUnsafeErrorIntoGetPriceNoOlderThanError of Into<
+    GetPriceUnsafeError, GetPriceNoOlderThanError
+> {
+    fn into(self: GetPriceUnsafeError) -> GetPriceNoOlderThanError {
+        match self {
+            GetPriceUnsafeError::PriceFeedNotFound => GetPriceNoOlderThanError::PriceFeedNotFound,
+        }
+    }
+}
+
 #[derive(Copy, Drop, Debug, Serde, PartialEq)]
 pub enum GovernanceActionError {
     AccessDenied,
@@ -56,3 +81,20 @@ impl UpdatePriceFeedsErrorIntoFelt252 of Into<UpdatePriceFeedsError, felt252> {
         }
     }
 }
+
+#[derive(Copy, Drop, Debug, Serde, PartialEq)]
+pub enum UpdatePriceFeedsIfNecessaryError {
+    Update: UpdatePriceFeedsError,
+    NoFreshUpdate,
+}
+
+impl UpdatePriceFeedsIfNecessaryErrorIntoFelt252 of Into<
+    UpdatePriceFeedsIfNecessaryError, felt252
+> {
+    fn into(self: UpdatePriceFeedsIfNecessaryError) -> felt252 {
+        match self {
+            UpdatePriceFeedsIfNecessaryError::Update(err) => err.into(),
+            UpdatePriceFeedsIfNecessaryError::NoFreshUpdate => 'no fresh update',
+        }
+    }
+}

+ 16 - 1
target_chains/starknet/contracts/src/pyth/interface.cairo

@@ -1,11 +1,20 @@
-use super::GetPriceUnsafeError;
+use super::{GetPriceUnsafeError, GetPriceNoOlderThanError};
 use pyth::byte_array::ByteArray;
 
 #[starknet::interface]
 pub trait IPyth<T> {
+    fn get_price_no_older_than(
+        self: @T, price_id: u256, age: u64
+    ) -> Result<Price, GetPriceNoOlderThanError>;
     fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
+    fn get_ema_price_no_older_than(
+        self: @T, price_id: u256, age: u64
+    ) -> Result<Price, GetPriceNoOlderThanError>;
     fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
     fn update_price_feeds(ref self: T, data: ByteArray);
+    fn update_price_feeds_if_necessary(
+        ref self: T, update: ByteArray, required_publish_times: Array<PriceFeedPublishTime>
+    );
     fn get_update_fee(self: @T, data: ByteArray) -> u256;
     fn execute_governance_instruction(ref self: T, data: ByteArray);
     fn pyth_upgradable_magic(self: @T) -> u32;
@@ -24,3 +33,9 @@ pub struct Price {
     pub expo: i32,
     pub publish_time: u64,
 }
+
+#[derive(Drop, Clone, Serde)]
+pub struct PriceFeedPublishTime {
+    pub price_id: u256,
+    pub publish_time: u64,
+}

+ 13 - 0
target_chains/starknet/contracts/src/util.cairo

@@ -145,6 +145,19 @@ pub fn array_try_into<T, U, +TryInto<T, U>, +Drop<T>, +Drop<U>>(mut input: Array
     output
 }
 
+pub trait ResultMapErrInto<T, E1, E2> {
+    fn map_err_into(self: Result<T, E1>) -> Result<T, E2>;
+}
+
+impl ResultMapErrIntoImpl<T, E1, E2, +Into<E1, E2>> of ResultMapErrInto<T, E1, E2> {
+    fn map_err_into(self: Result<T, E1>) -> Result<T, E2> {
+        match self {
+            Result::Ok(v) => Result::Ok(v),
+            Result::Err(err) => Result::Err(err.into()),
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::{u64_as_i64, u32_as_i32};

+ 140 - 1
target_chains/starknet/contracts/tests/pyth.cairo

@@ -1,10 +1,11 @@
 use snforge_std::{
     declare, ContractClassTrait, start_prank, stop_prank, CheatTarget, spy_events, SpyOn, EventSpy,
-    EventFetcher, event_name_hash, Event
+    EventFetcher, event_name_hash, Event, start_warp, stop_warp
 };
 use pyth::pyth::{
     IPythDispatcher, IPythDispatcherTrait, DataSource, Event as PythEvent, PriceFeedUpdated,
     WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded, DataSourcesSet, FeeSet,
+    PriceFeedPublishTime, GetPriceNoOlderThanError,
 };
 use pyth::byte_array::{ByteArray, ByteArrayImpl};
 use pyth::util::{array_try_into, UnwrapWithFelt252};
@@ -132,6 +133,144 @@ fn update_price_feeds_works() {
     assert!(last_ema_price.publish_time == 1712589206);
 }
 
+#[test]
+fn test_update_if_necessary_works() {
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
+
+    start_prank(CheatTarget::One(fee_contract.contract_address), user);
+    fee_contract.approve(pyth.contract_address, 10000);
+    stop_prank(CheatTarget::One(fee_contract.contract_address));
+
+    let mut spy = spy_events(SpyOn::One(pyth.contract_address));
+
+    let price_id = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
+    assert!(pyth.get_price_unsafe(price_id).is_err());
+
+    start_prank(CheatTarget::One(pyth.contract_address), user);
+    let times = array![PriceFeedPublishTime { price_id, publish_time: 1715769470 }];
+    pyth.update_price_feeds_if_necessary(data::test_price_update1(), times);
+
+    let last_price = pyth.get_price_unsafe(price_id).unwrap_with_felt252();
+    assert!(last_price.price == 6281060000000);
+    assert!(last_price.publish_time == 1715769470);
+
+    spy.fetch_events();
+    assert!(spy.events.len() == 1);
+
+    let times = array![PriceFeedPublishTime { price_id, publish_time: 1715769475 }];
+    pyth.update_price_feeds_if_necessary(data::test_price_update2(), times);
+
+    let last_price = pyth.get_price_unsafe(price_id).unwrap_with_felt252();
+    assert!(last_price.price == 6281522520745);
+    assert!(last_price.publish_time == 1715769475);
+
+    spy.fetch_events();
+    assert!(spy.events.len() == 2);
+
+    stop_prank(CheatTarget::One(pyth.contract_address));
+}
+
+#[test]
+#[should_panic(expected: ('no fresh update',))]
+fn test_update_if_necessary_rejects_empty() {
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
+
+    start_prank(CheatTarget::One(fee_contract.contract_address), user);
+    fee_contract.approve(pyth.contract_address, 10000);
+    stop_prank(CheatTarget::One(fee_contract.contract_address));
+
+    start_prank(CheatTarget::One(pyth.contract_address), user);
+    pyth.update_price_feeds_if_necessary(data::test_price_update1(), array![]);
+    stop_prank(CheatTarget::One(pyth.contract_address));
+}
+
+#[test]
+#[should_panic(expected: ('no fresh update',))]
+fn test_update_if_necessary_rejects_no_fresh() {
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
+
+    start_prank(CheatTarget::One(fee_contract.contract_address), user);
+    fee_contract.approve(pyth.contract_address, 10000);
+    stop_prank(CheatTarget::One(fee_contract.contract_address));
+
+    let mut spy = spy_events(SpyOn::One(pyth.contract_address));
+
+    start_prank(CheatTarget::One(pyth.contract_address), user);
+    pyth.update_price_feeds_if_necessary(data::test_price_update1(), array![]);
+
+    let price_id = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
+    assert!(pyth.get_price_unsafe(price_id).is_err());
+    spy.fetch_events();
+    assert!(spy.events.len() == 0);
+
+    let times = array![PriceFeedPublishTime { price_id, publish_time: 1715769470 }];
+    pyth.update_price_feeds_if_necessary(data::test_price_update1(), times);
+
+    let last_price = pyth.get_price_unsafe(price_id).unwrap_with_felt252();
+    assert!(last_price.price == 6281060000000);
+    assert!(last_price.publish_time == 1715769470);
+
+    spy.fetch_events();
+    assert!(spy.events.len() == 1);
+
+    let times = array![PriceFeedPublishTime { price_id, publish_time: 1715769470 }];
+    pyth.update_price_feeds_if_necessary(data::test_price_update2(), times);
+}
+
+#[test]
+fn test_get_no_older_works() {
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_mainnet_guardians();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
+    let price_id = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
+    let err = pyth.get_price_no_older_than(price_id, 100).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
+    let err = pyth.get_ema_price_no_older_than(price_id, 100).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
+
+    start_prank(CheatTarget::One(fee_contract.contract_address), user.try_into().unwrap());
+    fee_contract.approve(pyth.contract_address, 10000);
+    stop_prank(CheatTarget::One(fee_contract.contract_address));
+
+    start_prank(CheatTarget::One(pyth.contract_address), user.try_into().unwrap());
+    pyth.update_price_feeds(data::good_update1());
+    stop_prank(CheatTarget::One(pyth.contract_address));
+
+    start_warp(CheatTarget::One(pyth.contract_address), 1712589210);
+    let err = pyth.get_price_no_older_than(price_id, 3).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::StalePrice);
+    let err = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::StalePrice);
+
+    start_warp(CheatTarget::One(pyth.contract_address), 1712589208);
+    let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.publish_time == 1712589206);
+    assert!(val.price == 7192002930010);
+    let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.publish_time == 1712589206);
+    assert!(val.price == 7181868900000);
+
+    start_warp(CheatTarget::One(pyth.contract_address), 1712589204);
+    let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.publish_time == 1712589206);
+    assert!(val.price == 7192002930010);
+    let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.publish_time == 1712589206);
+    assert!(val.price == 7181868900000);
+
+    stop_warp(CheatTarget::One(pyth.contract_address));
+}
+
 #[test]
 fn test_governance_set_fee_works() {
     let user = 'user'.try_into().unwrap();