소스 검색

Add checkUpdateDataIsMinimal to ParseConfig and implement strict parse mode

Co-Authored-By: Tejas Badadare <tejas@dourolabs.xyz>
Devin AI 6 달 전
부모
커밋
708724d7a4

+ 1 - 1
target_chains/ethereum/contracts/contracts/pulse/Scheduler.sol

@@ -293,7 +293,7 @@ abstract contract Scheduler is IScheduler, SchedulerState {
         (
             PythStructs.PriceFeed[] memory priceFeeds,
             uint64[] memory slots
-        ) = pyth.parsePriceFeedUpdatesWithSlots{value: pythFee}(
+        ) = pyth.parsePriceFeedUpdatesWithSlotsStrict{value: pythFee}(
                 updateData,
                 priceIds,
                 curTime > PAST_TIMESTAMP_MAX_VALIDITY_PERIOD

+ 64 - 2
target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

@@ -314,6 +314,17 @@ abstract contract Pyth is
             }
         }
 
+        // In minimal update data mode, revert if we have more updates than price IDs
+        if (config.checkUpdateDataIsMinimal) {
+            uint64 totalUpdatesAcrossBlobs = 0;
+            for (uint i = 0; i < updateData.length; i++) {
+                totalUpdatesAcrossBlobs += getTotalUpdatesInBlob(updateData[i]);
+            }
+            if (totalUpdatesAcrossBlobs > priceIds.length) {
+                revert PythErrors.InvalidArgument();
+            }
+        }
+
         // Check all price feeds were found
         for (uint k = 0; k < priceIds.length; k++) {
             if (context.priceFeeds[k].id == 0) {
@@ -341,6 +352,7 @@ abstract contract Pyth is
             PythInternalStructs.ParseConfig(
                 minPublishTime,
                 maxPublishTime,
+                false,
                 false
             )
         );
@@ -367,11 +379,39 @@ abstract contract Pyth is
                 PythInternalStructs.ParseConfig(
                     minPublishTime,
                     maxPublishTime,
+                    false,
                     false
                 )
             );
     }
 
+    function parsePriceFeedUpdatesWithSlotsStrict(
+        bytes[] calldata updateData,
+        bytes32[] calldata priceIds,
+        uint64 minPublishTime,
+        uint64 maxPublishTime
+    )
+        external
+        payable
+        override
+        returns (
+            PythStructs.PriceFeed[] memory priceFeeds,
+            uint64[] memory slots
+        )
+    {
+        return
+            parsePriceFeedUpdatesInternal(
+                updateData,
+                priceIds,
+                PythInternalStructs.ParseConfig(
+                    minPublishTime,
+                    maxPublishTime,
+                    false,
+                    true
+                )
+            );
+    }
+
     function processSingleTwapUpdate(
         bytes calldata updateData
     )
@@ -550,7 +590,8 @@ abstract contract Pyth is
             PythInternalStructs.ParseConfig(
                 minPublishTime,
                 maxPublishTime,
-                true
+                true,
+                false
             )
         );
     }
@@ -624,7 +665,7 @@ abstract contract Pyth is
     }
 
     function version() public pure returns (string memory) {
-        return "1.4.4-alpha.5";
+        return "1.4.4-alpha.6";
     }
 
     /// @notice Calculates TWAP from two price points
@@ -676,4 +717,25 @@ abstract contract Pyth is
 
         return twapPriceFeed;
     }
+
+    /// @dev Helper function to count the total number of updates in a single blob
+    function getTotalUpdatesInBlob(bytes calldata singleUpdateData) internal view returns (uint64) {
+        if (
+            singleUpdateData.length <= 4 ||
+            UnsafeCalldataBytesLib.toUint32(singleUpdateData, 0) !=
+            ACCUMULATOR_MAGIC
+        ) {
+            revert PythErrors.InvalidUpdateData();
+        }
+
+        uint offset;
+        UpdateType updateType;
+        (offset, updateType) = extractUpdateTypeFromAccumulatorHeader(singleUpdateData);
+
+        if (updateType != UpdateType.WormholeMerkle) {
+            revert PythErrors.InvalidUpdateData();
+        }
+
+        return parseWormholeMerkleHeaderNumUpdates(singleUpdateData, offset);
+    }
 }

+ 1 - 0
target_chains/ethereum/contracts/contracts/pyth/PythInternalStructs.sol

@@ -13,6 +13,7 @@ contract PythInternalStructs {
         uint64 minPublishTime;
         uint64 maxPublishTime;
         bool checkUniqueness;
+        bool checkUpdateDataIsMinimal;
     }
 
     /// Internal struct to hold parameters for update processing

+ 17 - 17
target_chains/ethereum/contracts/forge-test/PulseScheduler.t.sol

@@ -248,7 +248,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             numInitialFeeds
         );
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, initialPriceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, initialPriceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(initialPriceFeeds);
 
         vm.prank(pusher);
@@ -822,7 +822,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             priceIds.length
         );
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds1, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds1, slots);
         bytes[] memory updateData1 = createMockUpdateData(priceFeeds1);
 
         // Perform first update
@@ -873,7 +873,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             priceFeeds2[i].emaPrice.publishTime = publishTime2;
         }
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds2, slots); // Mock for the second call
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds2, slots); // Mock for the second call
         bytes[] memory updateData2 = createMockUpdateData(priceFeeds2);
 
         // Perform second update
@@ -934,7 +934,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             );
 
         uint256 mockPythFee = MOCK_PYTH_FEE_PER_FEED * params.priceIds.length;
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         // Get state before
@@ -1019,7 +1019,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             priceIds.length
         );
         uint256 mockPythFee = MOCK_PYTH_FEE_PER_FEED * priceIds.length;
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         // Calculate minimum keeper fee (overhead + feed-specific fee)
@@ -1078,7 +1078,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds1;
         uint64[] memory slots1;
         (priceFeeds1, slots1) = createMockPriceFeedsWithSlots(publishTime1, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds1, slots1);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds1, slots1);
         bytes[] memory updateData1 = createMockUpdateData(priceFeeds1);
         vm.prank(pusher);
         scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds);
@@ -1089,7 +1089,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds2;
         uint64[] memory slots2;
         (priceFeeds2, slots2) = createMockPriceFeedsWithSlots(publishTime2, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds2, slots2);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds2, slots2);
         bytes[] memory updateData2 = createMockUpdateData(priceFeeds2);
 
         // Expect revert because heartbeat condition is not met
@@ -1126,7 +1126,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds1;
         uint64[] memory slots;
         (priceFeeds1, slots) = createMockPriceFeedsWithSlots(publishTime1, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds1, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds1, slots);
         bytes[] memory updateData1 = createMockUpdateData(priceFeeds1);
         vm.prank(pusher);
         scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds);
@@ -1152,7 +1152,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             priceFeeds2[i].price.publishTime = publishTime2;
         }
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds2, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds2, slots);
         bytes[] memory updateData2 = createMockUpdateData(priceFeeds2);
 
         // Expect revert because deviation condition is not met
@@ -1178,7 +1178,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds1;
         uint64[] memory slots1;
         (priceFeeds1, slots1) = createMockPriceFeedsWithSlots(publishTime1, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds1, slots1);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds1, slots1);
         bytes[] memory updateData1 = createMockUpdateData(priceFeeds1);
 
         vm.prank(pusher);
@@ -1190,7 +1190,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         uint64[] memory slots2;
         (priceFeeds2, slots2) = createMockPriceFeedsWithSlots(publishTime2, 2);
         // Mock Pyth response to return feeds with the older timestamp
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds2, slots2);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds2, slots2);
         bytes[] memory updateData2 = createMockUpdateData(priceFeeds2);
 
         // Expect revert with TimestampOlderThanLastUpdate (checked in _validateShouldUpdatePrices)
@@ -1231,7 +1231,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         slots[1] = 200; // Different slot
 
         // Mock Pyth response to return these feeds with mismatched slots
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         // Expect revert with PriceSlotMismatch error
@@ -1346,7 +1346,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds;
         uint64[] memory slots;
         (priceFeeds, slots) = createMockPriceFeedsWithSlots(publishTime, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         vm.prank(pusher);
@@ -1388,7 +1388,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds;
         uint64[] memory slots;
         (priceFeeds, slots) = createMockPriceFeedsWithSlots(publishTime, 3);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         vm.prank(pusher);
@@ -1443,7 +1443,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
         PythStructs.PriceFeed[] memory priceFeeds;
         uint64[] memory slots;
         (priceFeeds, slots) = createMockPriceFeedsWithSlots(publishTime, 2);
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
         bytes32[] memory priceIds = params.priceIds;
 
@@ -1488,7 +1488,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             publishTime,
             priceIds.length
         );
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         vm.prank(pusher);
@@ -1555,7 +1555,7 @@ contract SchedulerTest is Test, SchedulerEvents, PulseSchedulerTestUtils {
             priceFeeds[i].emaPrice.expo = priceFeeds[i].price.expo;
         }
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         vm.prank(pusher);

+ 2 - 2
target_chains/ethereum/contracts/forge-test/PulseSchedulerGasBenchmark.t.sol

@@ -70,7 +70,7 @@ contract PulseSchedulerGasBenchmark is Test, PulseSchedulerTestUtils {
         );
 
         // Mock Pyth response for the benchmark
-        mockParsePriceFeedUpdatesWithSlots(pyth, newPriceFeeds, newSlots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, newPriceFeeds, newSlots);
 
         // Actual benchmark: Measure gas for updating price feeds
         uint256 startGas = gasleft();
@@ -128,7 +128,7 @@ contract PulseSchedulerGasBenchmark is Test, PulseSchedulerTestUtils {
             numFeeds
         );
 
-        mockParsePriceFeedUpdatesWithSlots(pyth, priceFeeds, slots);
+        mockParsePriceFeedUpdatesWithSlotsStrict(pyth, priceFeeds, slots);
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         // Update the price feeds. We should have enough balance to cover the update

+ 33 - 0
target_chains/ethereum/contracts/forge-test/Pyth.t.sol

@@ -392,6 +392,39 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
         );
     }
 
+    function testParsePriceFeedUpdatesWithSlotsStrictRevertsWithExcessUpdateData()
+        public
+    {
+        // Create a price update with more price updates than requested price IDs
+        uint numPriceIds = 2;
+        uint numMessages = numPriceIds + 1; // One more than the number of price IDs
+        
+        (
+            bytes32[] memory priceIds,
+            PriceFeedMessage[] memory messages
+        ) = generateRandomPriceMessages(numMessages);
+        
+        // Only use a subset of the price IDs to trigger the strict check
+        bytes32[] memory requestedPriceIds = new bytes32[](numPriceIds);
+        for (uint i = 0; i < numPriceIds; i++) {
+            requestedPriceIds[i] = priceIds[i];
+        }
+        
+        (
+            bytes[] memory updateData,
+            uint updateFee
+        ) = createBatchedUpdateDataFromMessages(messages);
+        
+        // Should revert in strict mode
+        vm.expectRevert(PythErrors.InvalidArgument.selector);
+        pyth.parsePriceFeedUpdatesWithSlotsStrict{value: updateFee}(
+            updateData,
+            requestedPriceIds,
+            0,
+            MAX_UINT64
+        );
+    }
+
     function testParsePriceFeedUpdatesRevertsIfUpdateSourceChainIsInvalid()
         public
     {

+ 3 - 3
target_chains/ethereum/contracts/forge-test/utils/MockPriceFeedTestUtils.sol

@@ -169,8 +169,8 @@ abstract contract MockPriceFeedTestUtils is Test {
         );
     }
 
-    // Helper function to mock Pyth response with slots
-    function mockParsePriceFeedUpdatesWithSlots(
+    // Helper function to mock Pyth response with slots (strict mode)
+    function mockParsePriceFeedUpdatesWithSlotsStrict(
         address pyth,
         PythStructs.PriceFeed[] memory priceFeeds,
         uint64[] memory slots
@@ -187,7 +187,7 @@ abstract contract MockPriceFeedTestUtils is Test {
             pyth,
             expectedFee,
             abi.encodeWithSelector(
-                IPyth.parsePriceFeedUpdatesWithSlots.selector
+                IPyth.parsePriceFeedUpdatesWithSlotsStrict.selector
             ),
             abi.encode(priceFeeds, slots)
         );

+ 21 - 0
target_chains/ethereum/sdk/solidity/IPyth.sol

@@ -185,4 +185,25 @@ interface IPyth is IPythEvents {
             PythStructs.PriceFeed[] memory priceFeeds,
             uint64[] memory slots
         );
+        
+    /// @dev Same as `parsePriceFeedUpdatesWithSlots`, but with minimal update data check enabled.
+    /// When checkUpdateDataIsMinimal is true, the function will revert if the number of updates exceeds the length of priceIds.
+    /// @param updateData Array of price update data.
+    /// @param priceIds Array of price ids.
+    /// @param minPublishTime minimum acceptable publishTime for the given `priceIds`.
+    /// @param maxPublishTime maximum acceptable publishTime for the given `priceIds`.
+    /// @return priceFeeds Array of the price feeds corresponding to the given `priceIds` (with the same order).
+    /// @return slots Array of the Pythnet slot corresponding to the given `priceIds` (with the same order).
+    function parsePriceFeedUpdatesWithSlotsStrict(
+        bytes[] calldata updateData,
+        bytes32[] calldata priceIds,
+        uint64 minPublishTime,
+        uint64 maxPublishTime
+    )
+        external
+        payable
+        returns (
+            PythStructs.PriceFeed[] memory priceFeeds,
+            uint64[] memory slots
+        );
 }