Pārlūkot izejas kodu

feat(target_chains/starknet): add query_price_feed methods (#1627)

Pavel Strakhov 1 gadu atpakaļ
vecāks
revīzija
a66670265a

+ 35 - 0
target_chains/starknet/contracts/src/pyth.cairo

@@ -204,6 +204,41 @@ mod pyth {
             Result::Ok(price)
         }
 
+        fn query_price_feed_no_older_than(
+            self: @ContractState, price_id: u256, age: u64
+        ) -> Result<PriceFeed, GetPriceNoOlderThanError> {
+            let feed = self.query_price_feed_unsafe(price_id).map_err_into()?;
+            if !is_no_older_than(feed.price.publish_time, age) {
+                return Result::Err(GetPriceNoOlderThanError::StalePrice);
+            }
+            Result::Ok(feed)
+        }
+
+        fn query_price_feed_unsafe(
+            self: @ContractState, price_id: u256
+        ) -> Result<PriceFeed, GetPriceUnsafeError> {
+            let info = self.latest_price_info.read(price_id);
+            if info.publish_time == 0 {
+                return Result::Err(GetPriceUnsafeError::PriceFeedNotFound);
+            }
+            let feed = PriceFeed {
+                id: price_id,
+                price: Price {
+                    price: info.price,
+                    conf: info.conf,
+                    expo: info.expo,
+                    publish_time: info.publish_time,
+                },
+                ema_price: Price {
+                    price: info.ema_price,
+                    conf: info.ema_conf,
+                    expo: info.expo,
+                    publish_time: info.publish_time,
+                },
+            };
+            Result::Ok(feed)
+        }
+
         fn update_price_feeds(ref self: ContractState, data: ByteArray) {
             self.update_price_feeds_internal(data, array![], 0, 0, false);
         }

+ 4 - 0
target_chains/starknet/contracts/src/pyth/interface.cairo

@@ -11,6 +11,10 @@ pub trait IPyth<T> {
         self: @T, price_id: u256, age: u64
     ) -> Result<Price, GetPriceNoOlderThanError>;
     fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
+    fn query_price_feed_no_older_than(
+        self: @T, price_id: u256, age: u64
+    ) -> Result<PriceFeed, GetPriceNoOlderThanError>;
+    fn query_price_feed_unsafe(self: @T, price_id: u256) -> Result<PriceFeed, GetPriceUnsafeError>;
     fn update_price_feeds(ref self: T, data: ByteArray);
     fn update_price_feeds_if_necessary(
         ref self: T, update: ByteArray, required_publish_times: Array<PriceFeedPublishTime>

+ 30 - 1
target_chains/starknet/contracts/tests/pyth.cairo

@@ -5,7 +5,7 @@ use snforge_std::{
 use pyth::pyth::{
     IPythDispatcher, IPythDispatcherTrait, DataSource, Event as PythEvent, PriceFeedUpdated,
     WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded, DataSourcesSet, FeeSet,
-    PriceFeedPublishTime, GetPriceNoOlderThanError, Price, PriceFeed,
+    PriceFeedPublishTime, GetPriceNoOlderThanError, Price, PriceFeed, GetPriceUnsafeError,
 };
 use pyth::byte_array::{ByteArray, ByteArrayImpl};
 use pyth::util::{array_try_into, UnwrapWithFelt252};
@@ -131,6 +131,19 @@ fn update_price_feeds_works() {
     assert!(last_ema_price.conf == 4096812700);
     assert!(last_ema_price.expo == -8);
     assert!(last_ema_price.publish_time == 1712589206);
+
+    let feed = pyth
+        .query_price_feed_unsafe(0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43)
+        .unwrap_with_felt252();
+    assert!(feed.id == 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43);
+    assert!(feed.price.price == 7192002930010);
+    assert!(feed.price.conf == 3596501465);
+    assert!(feed.price.expo == -8);
+    assert!(feed.price.publish_time == 1712589206);
+    assert!(feed.ema_price.price == 7181868900000);
+    assert!(feed.ema_price.conf == 4096812700);
+    assert!(feed.ema_price.expo == -8);
+    assert!(feed.ema_price.publish_time == 1712589206);
 }
 
 #[test]
@@ -407,10 +420,18 @@ fn test_get_no_older_works() {
     let fee_contract = deploy_fee_contract(user);
     let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
     let price_id = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
+    let err = pyth.get_price_unsafe(price_id).unwrap_err();
+    assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
+    let err = pyth.get_ema_price_unsafe(price_id).unwrap_err();
+    assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
+    let err = pyth.query_price_feed_unsafe(price_id).unwrap_err();
+    assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
     let err = pyth.get_price_no_older_than(price_id, 100).unwrap_err();
     assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
     let err = pyth.get_ema_price_no_older_than(price_id, 100).unwrap_err();
     assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
+    let err = pyth.query_price_feed_no_older_than(price_id, 100).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
 
     start_prank(CheatTarget::One(fee_contract.contract_address), user.try_into().unwrap());
     fee_contract.approve(pyth.contract_address, 10000);
@@ -425,6 +446,8 @@ fn test_get_no_older_works() {
     assert!(err == GetPriceNoOlderThanError::StalePrice);
     let err = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_err();
     assert!(err == GetPriceNoOlderThanError::StalePrice);
+    let err = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_err();
+    assert!(err == GetPriceNoOlderThanError::StalePrice);
 
     start_warp(CheatTarget::One(pyth.contract_address), 1712589208);
     let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
@@ -433,6 +456,9 @@ fn test_get_no_older_works() {
     let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
     assert!(val.publish_time == 1712589206);
     assert!(val.price == 7181868900000);
+    let val = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.price.publish_time == 1712589206);
+    assert!(val.price.price == 7192002930010);
 
     start_warp(CheatTarget::One(pyth.contract_address), 1712589204);
     let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
@@ -441,6 +467,9 @@ fn test_get_no_older_works() {
     let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
     assert!(val.publish_time == 1712589206);
     assert!(val.price == 7181868900000);
+    let val = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_with_felt252();
+    assert!(val.price.publish_time == 1712589206);
+    assert!(val.price.price == 7192002930010);
 
     stop_warp(CheatTarget::One(pyth.contract_address));
 }