ソースを参照

feat(target_chains/starknet): pyth contract upgrade (#1592)

* feat(target_chains/starknet): pyth contract upgrade

* doc(target_chains/starknet): add comment about class hash for contract upgrade
Pavel Strakhov 1 年間 前
コミット
5f3188af2b

+ 8 - 1
.github/workflows/ci-starknet-tools.yml

@@ -17,6 +17,13 @@ jobs:
           toolchain: 1.78.0
           components: rustfmt, clippy
           override: true
+      - uses: actions/checkout@v3
+      - name: Install Scarb
+        uses: software-mansion/setup-scarb@v1
+        with:
+          tool-versions: target_chains/starknet/contracts/.tool-versions
+      - name: Install Starkli
+        run: curl https://get.starkli.sh | sh && . ~/.config/.starkli/env && starkliup -v $(awk '/starkli/{print $2}' target_chains/starknet/contracts/.tool-versions)
       - name: Check formatting
         run: cargo fmt --manifest-path ./target_chains/starknet/tools/test_vaas/Cargo.toml -- --check
       - name: Run clippy
@@ -25,7 +32,7 @@ jobs:
         run: cargo run --manifest-path ./target_chains/starknet/tools/test_vaas/Cargo.toml --bin generate_keypair
       - name: Check test data
         run: |
-          cargo run --manifest-path ./target_chains/starknet/tools/test_vaas/Cargo.toml --bin generate_test_data > /tmp/data.cairo
+          . ~/.config/.starkli/env && cargo run --manifest-path ./target_chains/starknet/tools/test_vaas/Cargo.toml --bin generate_test_data > /tmp/data.cairo
           if ! diff ./target_chains/starknet/contracts/tests/data.cairo /tmp/data.cairo; then
             >&2 echo "Re-run generate_test_data to update data.cairo"
             exit 1

+ 40 - 3
target_chains/starknet/contracts/src/pyth.cairo

@@ -3,7 +3,11 @@ mod interface;
 mod price_update;
 mod governance;
 
-pub use pyth::{Event, PriceFeedUpdateEvent, WormholeAddressSet, GovernanceDataSourceSet};
+mod fake_upgrades;
+
+pub use pyth::{
+    Event, PriceFeedUpdateEvent, WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded
+};
 pub use errors::{GetPriceUnsafeError, GovernanceActionError, UpdatePriceFeedsError};
 pub use interface::{IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Price};
 
@@ -16,10 +20,15 @@ mod pyth {
     use pyth::reader::{Reader, ReaderImpl};
     use pyth::byte_array::{ByteArray, ByteArrayImpl};
     use core::panic_with_felt252;
-    use core::starknet::{ContractAddress, get_caller_address, get_execution_info};
+    use core::starknet::{
+        ContractAddress, get_caller_address, get_execution_info, ClassHash, SyscallResultTrait,
+        get_contract_address,
+    };
+    use core::starknet::syscalls::replace_class_syscall;
     use pyth::wormhole::{IWormholeDispatcher, IWormholeDispatcherTrait, VerifiedVM};
     use super::{
-        DataSource, UpdatePriceFeedsError, GovernanceActionError, Price, GetPriceUnsafeError
+        DataSource, UpdatePriceFeedsError, GovernanceActionError, Price, GetPriceUnsafeError,
+        IPythDispatcher, IPythDispatcherTrait,
     };
     use super::governance;
     use super::governance::GovernancePayload;
@@ -31,6 +40,7 @@ mod pyth {
         PriceFeedUpdate: PriceFeedUpdateEvent,
         WormholeAddressSet: WormholeAddressSet,
         GovernanceDataSourceSet: GovernanceDataSourceSet,
+        ContractUpgraded: ContractUpgraded,
     }
 
     #[derive(Drop, PartialEq, starknet::Event)]
@@ -55,6 +65,11 @@ mod pyth {
         pub last_executed_governance_sequence: u64,
     }
 
+    #[derive(Drop, PartialEq, starknet::Event)]
+    pub struct ContractUpgraded {
+        pub new_class_hash: ClassHash,
+    }
+
     #[storage]
     struct Storage {
         wormhole_address: ContractAddress,
@@ -243,8 +258,18 @@ mod pyth {
                 GovernancePayload::AuthorizeGovernanceDataSourceTransfer(payload) => {
                     self.authorize_governance_transfer(payload.claim_vaa);
                 },
+                GovernancePayload::UpgradeContract(payload) => {
+                    if instruction.target_chain_id == 0 {
+                        panic_with_felt252(GovernanceActionError::InvalidGovernanceTarget.into());
+                    }
+                    self.upgrade_contract(payload.new_implementation);
+                }
             }
         }
+
+        fn pyth_upgradable_magic(self: @ContractState) -> u32 {
+            0x97a6f304
+        }
     }
 
     #[generate_trait]
@@ -385,6 +410,18 @@ mod pyth {
             };
             self.emit(event);
         }
+
+        fn upgrade_contract(ref self: ContractState, new_implementation: ClassHash) {
+            let contract_address = get_contract_address();
+            replace_class_syscall(new_implementation).unwrap_syscall();
+            // Dispatcher uses `call_contract_syscall` so it will call the new implementation.
+            let magic = IPythDispatcher { contract_address }.pyth_upgradable_magic();
+            if magic != 0x97a6f304 {
+                panic_with_felt252(GovernanceActionError::InvalidGovernanceMessage.into());
+            }
+            let event = ContractUpgraded { new_class_hash: new_implementation };
+            self.emit(event);
+        }
     }
 
     fn apply_decimal_expo(value: u64, expo: u64) -> u256 {

+ 105 - 0
target_chains/starknet/contracts/src/pyth/fake_upgrades.cairo

@@ -0,0 +1,105 @@
+// Only used for tests.
+
+#[starknet::contract]
+mod pyth_fake_upgrade1 {
+    use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
+    use pyth::byte_array::ByteArray;
+
+    #[storage]
+    struct Storage {}
+
+    #[constructor]
+    fn constructor(ref self: ContractState) {}
+
+    #[abi(embed_v0)]
+    impl PythImpl of IPyth<ContractState> {
+        fn get_price_unsafe(
+            self: @ContractState, price_id: u256
+        ) -> Result<Price, GetPriceUnsafeError> {
+            let price = Price { price: 42, conf: 2, expo: -5, publish_time: 101, };
+            Result::Ok(price)
+        }
+        fn get_ema_price_unsafe(
+            self: @ContractState, price_id: u256
+        ) -> Result<Price, GetPriceUnsafeError> {
+            panic!("unsupported")
+        }
+        fn set_data_sources(ref self: ContractState, sources: Array<DataSource>) {
+            panic!("unsupported")
+        }
+        fn set_fee(ref self: ContractState, single_update_fee: u256) {
+            panic!("unsupported")
+        }
+        fn update_price_feeds(ref self: ContractState, data: ByteArray) {
+            panic!("unsupported")
+        }
+        fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
+            panic!("unsupported")
+        }
+        fn pyth_upgradable_magic(self: @ContractState) -> u32 {
+            0x97a6f304
+        }
+    }
+}
+
+#[starknet::contract]
+mod pyth_fake_upgrade_wrong_magic {
+    use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
+    use pyth::byte_array::ByteArray;
+
+    #[storage]
+    struct Storage {}
+
+    #[constructor]
+    fn constructor(ref self: ContractState) {}
+
+    #[abi(embed_v0)]
+    impl PythImpl of IPyth<ContractState> {
+        fn get_price_unsafe(
+            self: @ContractState, price_id: u256
+        ) -> Result<Price, GetPriceUnsafeError> {
+            panic!("unsupported")
+        }
+        fn get_ema_price_unsafe(
+            self: @ContractState, price_id: u256
+        ) -> Result<Price, GetPriceUnsafeError> {
+            panic!("unsupported")
+        }
+        fn set_data_sources(ref self: ContractState, sources: Array<DataSource>) {
+            panic!("unsupported")
+        }
+        fn set_fee(ref self: ContractState, single_update_fee: u256) {
+            panic!("unsupported")
+        }
+        fn update_price_feeds(ref self: ContractState, data: ByteArray) {
+            panic!("unsupported")
+        }
+        fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
+            panic!("unsupported")
+        }
+        fn pyth_upgradable_magic(self: @ContractState) -> u32 {
+            606
+        }
+    }
+}
+
+#[starknet::interface]
+pub trait INotPyth<T> {
+    fn test1(ref self: T) -> u32;
+}
+
+#[starknet::contract]
+mod pyth_fake_upgrade_not_pyth {
+    #[storage]
+    struct Storage {}
+
+    #[constructor]
+    fn constructor(ref self: ContractState) {}
+
+    #[abi(embed_v0)]
+    impl NotPythImpl of super::INotPyth<ContractState> {
+        fn test1(ref self: ContractState) -> u32 {
+            42
+        }
+    }
+}

+ 28 - 6
target_chains/starknet/contracts/src/pyth/governance.cairo

@@ -4,7 +4,7 @@ use pyth::reader::{Reader, ReaderImpl};
 use pyth::byte_array::ByteArray;
 use pyth::pyth::errors::GovernanceActionError;
 use core::panic_with_felt252;
-use core::starknet::ContractAddress;
+use core::starknet::{ContractAddress, ClassHash};
 use super::DataSource;
 
 const MAGIC: u32 = 0x5054474d;
@@ -45,12 +45,13 @@ pub struct GovernanceInstruction {
 
 #[derive(Drop, Debug)]
 pub enum GovernancePayload {
-    SetFee: SetFee,
+    UpgradeContract: UpgradeContract,
+    AuthorizeGovernanceDataSourceTransfer: AuthorizeGovernanceDataSourceTransfer,
     SetDataSources: SetDataSources,
-    SetWormholeAddress: SetWormholeAddress,
+    SetFee: SetFee,
+    // SetValidPeriod is unsupported
     RequestGovernanceDataSourceTransfer: RequestGovernanceDataSourceTransfer,
-    AuthorizeGovernanceDataSourceTransfer: AuthorizeGovernanceDataSourceTransfer,
-// TODO: others
+    SetWormholeAddress: SetWormholeAddress,
 }
 
 #[derive(Drop, Debug)]
@@ -84,6 +85,15 @@ pub struct AuthorizeGovernanceDataSourceTransfer {
     pub claim_vaa: ByteArray,
 }
 
+#[derive(Drop, Debug)]
+pub struct UpgradeContract {
+    // Class hash of the new contract class. The contract class must already be deployed on the network
+    // (e.g. with `starkli declare`). Class hash is a Poseidon hash of all properties
+    // of the contract code, including entry points, ABI, and bytecode,
+    // so specifying a hash securely identifies the new implementation.
+    pub new_implementation: ClassHash,
+}
+
 pub fn parse_instruction(payload: ByteArray) -> GovernanceInstruction {
     let mut reader = ReaderImpl::new(payload);
     let magic = reader.read_u32();
@@ -102,7 +112,19 @@ pub fn parse_instruction(payload: ByteArray) -> GovernanceInstruction {
     let target_chain_id = reader.read_u16();
 
     let payload = match action {
-        GovernanceAction::UpgradeContract => { panic_with_felt252('unimplemented') },
+        GovernanceAction::UpgradeContract => {
+            let new_implementation: felt252 = reader
+                .read_u256()
+                .try_into()
+                .expect(GovernanceActionError::InvalidGovernanceMessage.into());
+            if new_implementation == 0 {
+                panic_with_felt252(GovernanceActionError::InvalidGovernanceMessage.into());
+            }
+            let new_implementation = new_implementation
+                .try_into()
+                .expect(GovernanceActionError::InvalidGovernanceMessage.into());
+            GovernancePayload::UpgradeContract(UpgradeContract { new_implementation })
+        },
         GovernanceAction::AuthorizeGovernanceDataSourceTransfer => {
             let len = reader.len();
             let claim_vaa = reader.read_byte_array(len);

+ 1 - 0
target_chains/starknet/contracts/src/pyth/interface.cairo

@@ -9,6 +9,7 @@ pub trait IPyth<T> {
     fn set_fee(ref self: T, single_update_fee: u256);
     fn update_price_feeds(ref self: T, data: ByteArray);
     fn execute_governance_instruction(ref self: T, data: ByteArray);
+    fn pyth_upgradable_magic(self: @T) -> u32;
 }
 
 #[derive(Drop, Debug, Clone, Copy, PartialEq, Hash, Default, Serde, starknet::Store)]

+ 52 - 0
target_chains/starknet/contracts/tests/data.cairo

@@ -401,6 +401,58 @@ pub fn pyth_set_fee_alt_emitter() -> ByteArray {
     ByteArrayImpl::new(array_try_into(bytes), 23)
 }
 
+// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
+pub fn pyth_upgrade_fake1() -> ByteArray {
+    let bytes = array![
+        1766847064779996629845663150320144587116923255693918326671650041367743158,
+        364132889311386805107013139684624946082752766829168028617992585410993000794,
+        51883035205100148844906587684528048568133099046687734614207273174883631104,
+        49565958604199796163020368,
+        148907253453589022322377848805870968387690459124203915663278968232930838042,
+        8990748118247398873,
+    ];
+    ByteArrayImpl::new(array_try_into(bytes), 8)
+}
+
+// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
+pub fn pyth_upgrade_not_pyth() -> ByteArray {
+    let bytes = array![
+        1766847064779994185568390976518139178339359117743780499979078006447412818,
+        312550937452923367391560946919832045570249370029901542796468563830775031789,
+        297548922588419398887374641748895591794744646787122275140580663536136486912,
+        49565958604199796163020368,
+        148907253453589022305803196061110108233921773465491227564264876752079119569,
+        6736708290019375278,
+    ];
+    ByteArrayImpl::new(array_try_into(bytes), 8)
+}
+
+// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
+pub fn pyth_upgrade_wrong_magic() -> ByteArray {
+    let bytes = array![
+        1766847064779993973755828929481286552054924338108588685006773817619868900,
+        115412576669831747089146670964350761640626878638240568653908102512904321557,
+        370636636445427985046380928790855735958458201351067908027881134703845048320,
+        49565958604199796163020368,
+        148907253453589022358052376969903205134363123861005618128296481878738034337,
+        2645198310775210562,
+    ];
+    ByteArrayImpl::new(array_try_into(bytes), 8)
+}
+
+// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
+pub fn pyth_upgrade_invalid_hash() -> ByteArray {
+    let bytes = array![
+        1766847064779994789591381079184882258862460741769249190705097785479185254,
+        41574146205389297059177705721481778703981276127215462116602633512315608382,
+        266498984494471565033413055222808266936531835027750145459398687214975057920,
+        49565958604199796163020368,
+        148907253453589022218037939353255655322518022029545083499057126097303896064,
+        505,
+    ];
+    ByteArrayImpl::new(array_try_into(bytes), 8)
+}
+
 // An update pulled from Hermes and re-signed by the test guardian #1.
 pub fn test_price_update1() -> ByteArray {
     let bytes = array![

+ 71 - 1
target_chains/starknet/contracts/tests/pyth.cairo

@@ -4,7 +4,7 @@ use snforge_std::{
 };
 use pyth::pyth::{
     IPythDispatcher, IPythDispatcherTrait, DataSource, Event as PythEvent, PriceFeedUpdateEvent,
-    WormholeAddressSet, GovernanceDataSourceSet,
+    WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded,
 };
 use pyth::byte_array::{ByteArray, ByteArrayImpl};
 use pyth::util::{array_try_into, UnwrapWithFelt252};
@@ -51,6 +51,9 @@ fn decode_event(mut event: Event) -> PythEvent {
             last_executed_governance_sequence: event.data.pop(),
         };
         PythEvent::GovernanceDataSourceSet(event)
+    } else if key0 == event_name_hash('ContractUpgraded') {
+        let event = ContractUpgraded { new_class_hash: event.data.pop() };
+        PythEvent::ContractUpgraded(event)
     } else {
         panic!("unrecognized event")
     };
@@ -407,6 +410,73 @@ fn test_rejects_old_emitter_after_transfer() {
     pyth.execute_governance_instruction(data::pyth_set_fee());
 }
 
+#[test]
+fn test_upgrade_works() {
+    let owner = 'owner'.try_into().unwrap();
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(owner, wormhole.contract_address, fee_contract.contract_address);
+
+    let class = declare("pyth_fake_upgrade1");
+
+    let mut spy = spy_events(SpyOn::One(pyth.contract_address));
+
+    pyth.execute_governance_instruction(data::pyth_upgrade_fake1());
+
+    spy.fetch_events();
+    assert!(spy.events.len() == 1);
+    let (from, event) = spy.events.pop_front().unwrap();
+    assert!(from == pyth.contract_address);
+    let event = decode_event(event);
+    let expected = ContractUpgraded { new_class_hash: class.class_hash };
+    assert!(event == PythEvent::ContractUpgraded(expected));
+
+    let last_price = pyth.get_price_unsafe(1234).unwrap_with_felt252();
+    assert!(last_price.price == 42);
+}
+
+#[test]
+#[should_panic]
+#[ignore] // TODO: unignore when snforge is updated
+fn test_upgrade_rejects_invalid_hash() {
+    let owner = 'owner'.try_into().unwrap();
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(owner, wormhole.contract_address, fee_contract.contract_address);
+
+    pyth.execute_governance_instruction(data::pyth_upgrade_invalid_hash());
+}
+
+#[test]
+#[should_panic]
+#[ignore] // TODO: unignore when snforge is updated
+fn test_upgrade_rejects_not_pyth() {
+    let owner = 'owner'.try_into().unwrap();
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(owner, wormhole.contract_address, fee_contract.contract_address);
+
+    declare("pyth_fake_upgrade_not_pyth");
+    pyth.execute_governance_instruction(data::pyth_upgrade_not_pyth());
+}
+
+#[test]
+#[should_panic(expected: ('invalid governance message',))]
+fn test_upgrade_rejects_wrong_magic() {
+    let owner = 'owner'.try_into().unwrap();
+    let user = 'user'.try_into().unwrap();
+    let wormhole = super::wormhole::deploy_with_test_guardian();
+    let fee_contract = deploy_fee_contract(user);
+    let pyth = deploy_default(owner, wormhole.contract_address, fee_contract.contract_address);
+
+    declare("pyth_fake_upgrade_wrong_magic");
+    pyth.execute_governance_instruction(data::pyth_upgrade_wrong_magic());
+}
+
+
 fn deploy_default(
     owner: ContractAddress, wormhole_address: ContractAddress, fee_contract_address: ContractAddress
 ) -> IPythDispatcher {

+ 62 - 0
target_chains/starknet/tools/test_vaas/src/bin/generate_test_data.rs

@@ -1,4 +1,7 @@
+use std::{path::Path, process::Command, str};
+
 use libsecp256k1::SecretKey;
+use primitive_types::U256;
 use test_vaas::{
     locate_vaa_in_price_update, print_as_cairo_fn, re_sign_price_update, serialize_vaa, u256_to_be,
     DataSource, EthAddress, GuardianSet, GuardianSetUpgrade,
@@ -290,6 +293,49 @@ fn main() {
         "A Pyth governance instruction to set fee with alternative emitter signed by the test guardian #1.",
     );
 
+    let contracts_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
+    let status = Command::new("scarb")
+        .arg("build")
+        .current_dir(&contracts_dir)
+        .output()
+        .unwrap()
+        .status;
+    assert!(status.success(), "scarb failed with {status:?}");
+
+    let upgrade_hashes = [
+        ("fake1", get_class_hash("fake_upgrade1", &contracts_dir)),
+        (
+            "not_pyth",
+            get_class_hash("fake_upgrade_not_pyth", &contracts_dir),
+        ),
+        (
+            "wrong_magic",
+            get_class_hash("fake_upgrade_wrong_magic", &contracts_dir),
+        ),
+        ("invalid_hash", 505.into()),
+    ];
+    for (name, hash) in upgrade_hashes {
+        let mut pyth_upgrade_payload = vec![80, 84, 71, 77, 1, 0, 234, 147];
+        pyth_upgrade_payload.extend_from_slice(&u256_to_be(hash));
+        let pyth_upgrade = serialize_vaa(guardians.sign_vaa(
+            &[0],
+            VaaBody {
+                timestamp: 1,
+                nonce: 2,
+                emitter_chain: 1,
+                emitter_address: u256_to_be(41.into()).into(),
+                sequence: 1.try_into().unwrap(),
+                consistency_level: 6,
+                payload: PayloadKind::Binary(pyth_upgrade_payload),
+            },
+        ));
+        print_as_cairo_fn(
+            &pyth_upgrade,
+            format!("pyth_upgrade_{name}"),
+            "A Pyth governance instruction to upgrade the contract signed by the test guardian #1.",
+        );
+    }
+
     let guardians2 = GuardianSet {
         set_index: 1,
         secrets: vec![SecretKey::parse_slice(&hex::decode(secret2).unwrap()).unwrap()],
@@ -330,3 +376,19 @@ fn main() {
         "An update pulled from Hermes and re-signed by the test guardian #2.",
     );
 }
+
+fn get_class_hash(name: &str, dir: &Path) -> U256 {
+    let output = Command::new("starkli")
+        .arg("class-hash")
+        .arg(format!("target/dev/pyth_pyth_{name}.contract_class.json"))
+        .current_dir(dir)
+        .output()
+        .unwrap();
+    assert!(
+        output.status.success(),
+        "starkli failed with {:?}",
+        output.status
+    );
+    let hash = str::from_utf8(&output.stdout).unwrap();
+    hash.trim().parse().unwrap()
+}