Explorar o código

feat(target_chains/starknet): add get_update_fee method (#1613)

Pavel Strakhov hai 1 ano
pai
achega
7a5ac7d968

+ 15 - 3
target_chains/starknet/contracts/src/pyth.cairo

@@ -15,7 +15,7 @@ pub use interface::{IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Pr
 #[starknet::contract]
 mod pyth {
     use super::price_update::{
-        PriceInfo, PriceFeedMessage, read_and_verify_message, read_header_and_wormhole_proof,
+        PriceInfo, PriceFeedMessage, read_and_verify_message, read_and_verify_header,
         parse_wormhole_proof
     };
     use pyth::reader::{Reader, ReaderImpl};
@@ -177,7 +177,10 @@ mod pyth {
 
         fn update_price_feeds(ref self: ContractState, data: ByteArray) {
             let mut reader = ReaderImpl::new(data);
-            let wormhole_proof = read_header_and_wormhole_proof(ref reader);
+            read_and_verify_header(ref reader);
+            let wormhole_proof_size = reader.read_u16();
+            let wormhole_proof = reader.read_byte_array(wormhole_proof_size.into());
+
             let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
             let vm = wormhole.parse_and_verify_vm(wormhole_proof);
 
@@ -217,6 +220,15 @@ mod pyth {
             }
         }
 
+        fn get_update_fee(self: @ContractState, data: ByteArray) -> u256 {
+            let mut reader = ReaderImpl::new(data);
+            read_and_verify_header(ref reader);
+            let wormhole_proof_size = reader.read_u16();
+            reader.skip(wormhole_proof_size.into());
+            let num_updates = reader.read_u8();
+            self.get_total_fee(num_updates)
+        }
+
         fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
             let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
             let vm = wormhole.parse_and_verify_vm(data.clone());
@@ -323,7 +335,7 @@ mod pyth {
             }
         }
 
-        fn get_total_fee(ref self: ContractState, num_updates: u8) -> u256 {
+        fn get_total_fee(self: @ContractState, num_updates: u8) -> u256 {
             self.single_update_fee.read() * num_updates.into()
         }
 

+ 12 - 26
target_chains/starknet/contracts/src/pyth/fake_upgrades.cairo

@@ -1,8 +1,16 @@
+use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
+
 // Only used for tests.
 
+#[starknet::interface]
+pub trait IFakePyth<T> {
+    fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
+    fn pyth_upgradable_magic(self: @T) -> u32;
+}
+
 #[starknet::contract]
 mod pyth_fake_upgrade1 {
-    use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
+    use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
     use pyth::byte_array::ByteArray;
 
     #[storage]
@@ -12,24 +20,13 @@ mod pyth_fake_upgrade1 {
     fn constructor(ref self: ContractState) {}
 
     #[abi(embed_v0)]
-    impl PythImpl of IPyth<ContractState> {
+    impl PythImpl of super::IFakePyth<ContractState> {
         fn get_price_unsafe(
             self: @ContractState, price_id: u256
         ) -> Result<Price, GetPriceUnsafeError> {
             let price = Price { price: 42, conf: 2, expo: -5, publish_time: 101, };
             Result::Ok(price)
         }
-        fn get_ema_price_unsafe(
-            self: @ContractState, price_id: u256
-        ) -> Result<Price, GetPriceUnsafeError> {
-            panic!("unsupported")
-        }
-        fn update_price_feeds(ref self: ContractState, data: ByteArray) {
-            panic!("unsupported")
-        }
-        fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
-            panic!("unsupported")
-        }
         fn pyth_upgradable_magic(self: @ContractState) -> u32 {
             0x97a6f304
         }
@@ -38,7 +35,7 @@ mod pyth_fake_upgrade1 {
 
 #[starknet::contract]
 mod pyth_fake_upgrade_wrong_magic {
-    use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
+    use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
     use pyth::byte_array::ByteArray;
 
     #[storage]
@@ -48,23 +45,12 @@ mod pyth_fake_upgrade_wrong_magic {
     fn constructor(ref self: ContractState) {}
 
     #[abi(embed_v0)]
-    impl PythImpl of IPyth<ContractState> {
+    impl PythImpl of super::IFakePyth<ContractState> {
         fn get_price_unsafe(
             self: @ContractState, price_id: u256
         ) -> Result<Price, GetPriceUnsafeError> {
             panic!("unsupported")
         }
-        fn get_ema_price_unsafe(
-            self: @ContractState, price_id: u256
-        ) -> Result<Price, GetPriceUnsafeError> {
-            panic!("unsupported")
-        }
-        fn update_price_feeds(ref self: ContractState, data: ByteArray) {
-            panic!("unsupported")
-        }
-        fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
-            panic!("unsupported")
-        }
         fn pyth_upgradable_magic(self: @ContractState) -> u32 {
             606
         }

+ 1 - 0
target_chains/starknet/contracts/src/pyth/interface.cairo

@@ -6,6 +6,7 @@ pub trait IPyth<T> {
     fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
     fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
     fn update_price_feeds(ref self: T, data: ByteArray);
+    fn get_update_fee(self: @T, data: ByteArray) -> u256;
     fn execute_governance_instruction(ref self: T, data: ByteArray);
     fn pyth_upgradable_magic(self: @T) -> u32;
 }

+ 2 - 5
target_chains/starknet/contracts/src/pyth/price_update.cairo

@@ -64,7 +64,7 @@ pub struct PriceFeedMessage {
     pub ema_conf: u64,
 }
 
-pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
+pub fn read_and_verify_header(ref reader: Reader) {
     if reader.read_u32() != ACCUMULATOR_MAGIC {
         panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
     }
@@ -76,7 +76,7 @@ pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
     }
 
     let trailing_header_size = reader.read_u8();
-    reader.skip(trailing_header_size);
+    reader.skip(trailing_header_size.into());
 
     let update_type: UpdateType = reader
         .read_u8()
@@ -86,9 +86,6 @@ pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
     match update_type {
         UpdateType::WormholeMerkle => {}
     }
-
-    let wormhole_proof_size = reader.read_u16();
-    reader.read_byte_array(wormhole_proof_size.into())
 }
 
 pub fn parse_wormhole_proof(payload: ByteArray) -> u256 {

+ 4 - 2
target_chains/starknet/contracts/src/reader.cairo

@@ -86,13 +86,15 @@ pub impl ReaderImpl of ReaderTrait {
     }
 
     // TODO: skip without calculating values
-    fn skip(ref self: Reader, mut num_bytes: u8) {
+    fn skip(ref self: Reader, mut num_bytes: usize) {
         while num_bytes > 0 {
             if num_bytes > 16 {
                 self.read_num_bytes(16);
                 num_bytes -= 16;
             } else {
-                self.read_num_bytes(num_bytes);
+                // num_bytes <= 16 so it shouldn't overflow.
+                self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW));
+                break;
             }
         }
     }

+ 10 - 10
target_chains/starknet/contracts/tests/data.cairo

@@ -404,12 +404,12 @@ pub fn pyth_set_fee_alt_emitter() -> ByteArray {
 // A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
 pub fn pyth_upgrade_fake1() -> ByteArray {
     let bytes = array![
-        1766847064779997312831656888004304663648863693096357069129843988620764542,
-        372087717591229137403366610731035855939366039700111817924253553748324215495,
-        182456855699527413949626507253841174034899423702925950617111971827806109696,
+        1766847064779994791169214817472264547450542145364282319310439743685771618,
+        175385590228001769706203572954671062839210335359545531991708252078677402742,
+        338282801975945534678621806212670914146735662234331326855531973960850735104,
         49565958604199796163020368,
-        148907253453589022235416579439991212386300560409198472807590534281503440988,
-        7311947531350894019,
+        148907253453589022320407306335457538262203456299261498528172020674942501293,
+        9624434269354675143,
     ];
     ByteArrayImpl::new(array_try_into(bytes), 8)
 }
@@ -430,12 +430,12 @@ pub fn pyth_upgrade_not_pyth() -> ByteArray {
 // A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
 pub fn pyth_upgrade_wrong_magic() -> ByteArray {
     let bytes = array![
-        1766847064779991597204876565434227784957683976813807912095437759207426783,
-        116073945271795196915694593374349818132109707660825728503969190995475470190,
-        402267564262237040559156170656516235895865250461329916178126910059500797952,
+        1766847064779993581380818181711092803131812037068363180730038764700119064,
+        43179698701133869693008541869474965453366967663087320291846878688486859828,
+        257191826617037171240065659464096594985467828231875472974396182656981139456,
         49565958604199796163020368,
-        148907253453589022397792005599092877068906138702361966208625267621388965397,
-        10856656060318424790,
+        148907253453589022340563264373887392414227070562033595690783947835630084766,
+        5698494087895763928,
     ];
     ByteArrayImpl::new(array_try_into(bytes), 8)
 }

+ 10 - 1
target_chains/starknet/contracts/tests/pyth.cairo

@@ -89,8 +89,11 @@ fn update_price_feeds_works() {
     let fee_contract = deploy_fee_contract(user);
     let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
 
+    let fee = pyth.get_update_fee(data::good_update1());
+    assert!(fee == 1000);
+
     start_prank(CheatTarget::One(fee_contract.contract_address), user.try_into().unwrap());
-    fee_contract.approve(pyth.contract_address, 10000);
+    fee_contract.approve(pyth.contract_address, fee);
     stop_prank(CheatTarget::One(fee_contract.contract_address));
 
     let mut spy = spy_events(SpyOn::One(pyth.contract_address));
@@ -136,6 +139,9 @@ fn test_governance_set_fee_works() {
     let fee_contract = deploy_fee_contract(user);
     let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
 
+    let fee1 = pyth.get_update_fee(data::test_price_update1());
+    assert!(fee1 == 1000);
+
     start_prank(CheatTarget::One(fee_contract.contract_address), user);
     fee_contract.approve(pyth.contract_address, 10000);
     stop_prank(CheatTarget::One(fee_contract.contract_address));
@@ -164,6 +170,9 @@ fn test_governance_set_fee_works() {
     let expected = FeeSet { old_fee: 1000, new_fee: 4200, };
     assert!(event == PythEvent::FeeSet(expected));
 
+    let fee2 = pyth.get_update_fee(data::test_price_update2());
+    assert!(fee2 == 4200);
+
     start_prank(CheatTarget::One(pyth.contract_address), user);
     pyth.update_price_feeds(data::test_price_update2());
     stop_prank(CheatTarget::One(pyth.contract_address));