Przeglądaj źródła

feat(tests): add tests for contract upgrade functionality in PythGovernance

Daniel Chew 7 miesięcy temu
rodzic
commit
d92749dfb7

+ 136 - 0
target_chains/ethereum/contracts/forge-test/PythGovernance.t.sol

@@ -3,6 +3,7 @@
 pragma solidity ^0.8.0;
 
 import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
+import "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol";
 import "forge-std/Test.sol";
 
 import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
@@ -24,6 +25,7 @@ import "../contracts/wormhole-receiver/ReceiverGovernanceStructs.sol";
 import "../contracts/wormhole-receiver/ReceiverStructs.sol";
 import "../contracts/wormhole-receiver/ReceiverGovernance.sol";
 import "../contracts/libraries/external/BytesLib.sol";
+import "../contracts/pyth/mock/MockUpgradeableProxy.sol";
 import "./utils/WormholeTestUtils.t.sol";
 import "./utils/PythTestUtils.t.sol";
 import "./utils/RandTestUtils.t.sol";
@@ -380,6 +382,140 @@ contract PythGovernanceTest is
         PythGovernance(address(pyth)).executeGovernanceInstruction(vaa2);
     }
 
+    function testUpgradeContractWithChainIdZeroIsInvalid() public {
+        // Deploy a new PythUpgradable contract
+        PythUpgradable newImplementation = new PythUpgradable();
+
+        // Create governance VAA with chain ID 0 (unset)
+        bytes memory data = abi.encodePacked(
+            MAGIC,
+            uint8(GovernanceModule.Target),
+            uint8(GovernanceAction.UpgradeContract),
+            uint16(0), // Chain ID 0 (unset)
+            address(newImplementation) // New implementation address
+        );
+
+        bytes memory vaa = encodeAndSignMessage(
+            data,
+            TEST_GOVERNANCE_CHAIN_ID,
+            TEST_GOVERNANCE_EMITTER,
+            1
+        );
+
+        // Should revert with InvalidGovernanceTarget
+        vm.expectRevert(PythErrors.InvalidGovernanceTarget.selector);
+        PythGovernance(address(pyth)).executeGovernanceInstruction(vaa);
+    }
+
+    // Helper function to get the second address from event data
+    function getSecondAddressFromEventData(
+        bytes memory data
+    ) internal pure returns (address) {
+        (, address secondAddr) = abi.decode(data, (address, address));
+        return secondAddr;
+    }
+
+    function testUpgradeContractShouldWork() public {
+        // Deploy a new PythUpgradable contract
+        PythUpgradable newImplementation = new PythUpgradable();
+
+        // Create governance VAA to upgrade the contract
+        bytes memory data = abi.encodePacked(
+            MAGIC,
+            uint8(GovernanceModule.Target),
+            uint8(GovernanceAction.UpgradeContract),
+            TARGET_CHAIN_ID, // Valid target chain ID
+            address(newImplementation) // New implementation address
+        );
+
+        bytes memory vaa = encodeAndSignMessage(
+            data,
+            TEST_GOVERNANCE_CHAIN_ID,
+            TEST_GOVERNANCE_EMITTER,
+            1
+        );
+
+        // Create a custom event checker for ContractUpgraded event
+        // Since we only care about the newImplementation parameter
+        vm.recordLogs();
+
+        // Execute the governance instruction
+        PythGovernance(address(pyth)).executeGovernanceInstruction(vaa);
+
+        // Get emitted logs and check the event parameters
+        Vm.Log[] memory entries = vm.getRecordedLogs();
+        bool foundUpgradeEvent = false;
+
+        for (uint i = 0; i < entries.length; i++) {
+            // The event signature for ContractUpgraded
+            bytes32 eventSignature = keccak256(
+                "ContractUpgraded(address,address)"
+            );
+
+            if (entries[i].topics[0] == eventSignature) {
+                // This is a ContractUpgraded event
+                // Get just the new implementation address using our helper
+                address recordedNewImplementation = getSecondAddressFromEventData(
+                        entries[i].data
+                    );
+
+                // Check newImplementation
+                assertEq(recordedNewImplementation, address(newImplementation));
+                foundUpgradeEvent = true;
+                break;
+            }
+        }
+
+        // Make sure we found the event
+        assertTrue(foundUpgradeEvent, "ContractUpgraded event not found");
+
+        // Verify the upgrade worked by checking the magic number
+        assertEq(
+            PythUpgradable(address(pyth)).pythUpgradableMagic(),
+            0x97a6f304
+        );
+
+        // Verify the implementation was upgraded to our new implementation
+        // Access implementation using the ERC1967 storage slot
+        address implAddr = address(
+            uint160(
+                uint256(
+                    vm.load(
+                        address(pyth),
+                        0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc // ERC1967 implementation slot
+                    )
+                )
+            )
+        );
+        assertEq(implAddr, address(newImplementation));
+    }
+
+    function testUpgradeContractToNonPythContractWontWork() public {
+        // Deploy a mock upgradeable proxy that isn't a proper Pyth implementation
+        MockUpgradeableProxy newImplementation = new MockUpgradeableProxy();
+
+        // Create governance VAA to upgrade to an invalid implementation
+        bytes memory data = abi.encodePacked(
+            MAGIC,
+            uint8(GovernanceModule.Target),
+            uint8(GovernanceAction.UpgradeContract),
+            TARGET_CHAIN_ID, // Valid target chain ID
+            address(newImplementation) // Invalid implementation address
+        );
+
+        bytes memory vaa = encodeAndSignMessage(
+            data,
+            TEST_GOVERNANCE_CHAIN_ID,
+            TEST_GOVERNANCE_EMITTER,
+            1
+        );
+
+        // Should revert with no specific error message because the mock implementation
+        // doesn't have the pythUpgradableMagic method
+        vm.expectRevert();
+        PythGovernance(address(pyth)).executeGovernanceInstruction(vaa);
+    }
+
     function testSetTransactionFee() public {
         // Set transaction fee to 1000 (1000 = 1 * 10^3)
         bytes memory setTransactionFeeMessage = abi.encodePacked(