Эх сурвалжийг харах

feat(pyth): enhance TWAP processing to support multiple price feeds and improve validation

Daniel Chew 6 сар өмнө
parent
commit
3e5de9bb24

+ 110 - 63
target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

@@ -379,11 +379,11 @@ abstract contract Pyth is
         view
         returns (
             /// @return newOffset The next position in the update data after processing this TWAP update
-            /// @return twapPriceInfo The extracted time-weighted average price information
-            /// @return priceId The unique identifier for this price feed
+            /// @return priceInfos Array of extracted TWAP price information
+            /// @return priceIds Array of corresponding price feed IDs
             uint newOffset,
-            PythStructs.TwapPriceInfo memory twapPriceInfo,
-            bytes32 priceId
+            PythStructs.TwapPriceInfo[] memory twapPriceInfos,
+            bytes32[] memory priceIds
         )
     {
         UpdateType updateType;
@@ -417,12 +417,22 @@ abstract contract Pyth is
             revert PythErrors.InvalidUpdateData();
         }
 
-        // Extract start TWAP data with robust error checking
-        (offset, twapPriceInfo, priceId) = extractTwapPriceInfoFromMerkleProof(
-            digest,
-            encoded,
-            offset
-        );
+        // Initialize arrays to store all price infos and ids from this update
+        twapPriceInfos = new PythStructs.TwapPriceInfo[](numUpdates);
+        priceIds = new bytes32[](numUpdates);
+
+        // Extract each TWAP price info from the merkle proof
+        for (uint i = 0; i < numUpdates; i++) {
+            PythStructs.TwapPriceInfo memory twapPriceInfo;
+            bytes32 priceId;
+            (
+                offset,
+                twapPriceInfo,
+                priceId
+            ) = extractTwapPriceInfoFromMerkleProof(digest, encoded, offset);
+            twapPriceInfos[i] = twapPriceInfo;
+            priceIds[i] = priceId;
+        }
 
         if (offset != encoded.length) {
             revert PythErrors.InvalidTwapUpdateData();
@@ -439,72 +449,109 @@ abstract contract Pyth is
         override
         returns (PythStructs.TwapPriceFeed[] memory twapPriceFeeds)
     {
-        // TWAP requires pairs of updates (start and end points) for each price feed
-        // So updateData length must be exactly 2 * number of price feeds
-        if (updateData.length != priceIds.length * 2) {
+        // TWAP requires exactly 2 updates: one for the start point and one for the end point
+        if (updateData.length != 2) {
             revert PythErrors.InvalidUpdateData();
         }
 
         uint requiredFee = getUpdateFee(updateData);
         if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
 
-        unchecked {
-            twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
-            // Iterate over pairs of updates
-            for (uint i = 0; i < updateData.length; i += 2) {
+        // Process start update data
+        PythStructs.TwapPriceInfo[] memory startTwapPriceInfos;
+        bytes32[] memory startPriceIds;
+        {
+            uint offsetStart;
+            (
+                offsetStart,
+                startTwapPriceInfos,
+                startPriceIds
+            ) = processSingleTwapUpdate(updateData[0]);
+        }
+
+        // Process end update data
+        PythStructs.TwapPriceInfo[] memory endTwapPriceInfos;
+        bytes32[] memory endPriceIds;
+        {
+            uint offsetEnd;
+            (
+                offsetEnd,
+                endTwapPriceInfos,
+                endPriceIds
+            ) = processSingleTwapUpdate(updateData[1]);
+        }
+
+        // Verify that we have the same number of price feeds in start and end updates
+        if (startPriceIds.length != endPriceIds.length) {
+            revert PythErrors.InvalidTwapUpdateDataSet();
+        }
+
+        // Create a mapping to check that every startPriceId has a matching endPriceId
+        // This ensures price feed continuity between start and end points
+        bool[] memory endPriceIdMatched = new bool[](endPriceIds.length);
+        for (uint i = 0; i < startPriceIds.length; i++) {
+            bool foundMatch = false;
+            for (uint j = 0; j < endPriceIds.length; j++) {
                 if (
-                    (updateData[i].length > 4 &&
-                        UnsafeCalldataBytesLib.toUint32(updateData[i], 0) ==
-                        ACCUMULATOR_MAGIC) &&
-                    (updateData[i + 1].length > 4 &&
-                        UnsafeCalldataBytesLib.toUint32(updateData[i + 1], 0) ==
-                        ACCUMULATOR_MAGIC)
+                    startPriceIds[i] == endPriceIds[j] && !endPriceIdMatched[j]
                 ) {
-                    uint offsetStart;
-                    uint offsetEnd;
-                    bytes32 priceIdStart;
-                    bytes32 priceIdEnd;
-                    PythStructs.TwapPriceInfo memory twapPriceInfoStart;
-                    PythStructs.TwapPriceInfo memory twapPriceInfoEnd;
-                    (
-                        offsetStart,
-                        twapPriceInfoStart,
-                        priceIdStart
-                    ) = processSingleTwapUpdate(updateData[i]);
-                    (
-                        offsetEnd,
-                        twapPriceInfoEnd,
-                        priceIdEnd
-                    ) = processSingleTwapUpdate(updateData[i + 1]);
-
-                    if (priceIdStart != priceIdEnd)
-                        revert PythErrors.InvalidTwapUpdateDataSet();
-
-                    validateTwapPriceInfo(twapPriceInfoStart, twapPriceInfoEnd);
-
-                    uint k = findIndexOfPriceId(priceIds, priceIdStart);
-
-                    // If priceFeed[k].id != 0 then it means that there was a valid
-                    // update for priceIds[k] and we don't need to process this one.
-                    if (k == priceIds.length || twapPriceFeeds[k].id != 0) {
-                        continue;
-                    }
-
-                    twapPriceFeeds[k] = calculateTwap(
-                        priceIdStart,
-                        twapPriceInfoStart,
-                        twapPriceInfoEnd
-                    );
-                } else {
-                    revert PythErrors.InvalidUpdateData();
+                    endPriceIdMatched[j] = true;
+                    foundMatch = true;
+                    break;
                 }
             }
+            // If a price ID in start doesn't have a match in end, it's invalid
+            if (!foundMatch) {
+                revert PythErrors.InvalidTwapUpdateDataSet();
+            }
+        }
+
+        // Initialize the output array
+        twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
 
-            for (uint k = 0; k < priceIds.length; k++) {
-                if (twapPriceFeeds[k].id == 0) {
-                    revert PythErrors.PriceFeedNotFoundWithinRange();
+        // For each requested price ID, find matching start and end data points
+        for (uint i = 0; i < priceIds.length; i++) {
+            bytes32 requestedPriceId = priceIds[i];
+            int startIdx = -1;
+            int endIdx = -1;
+
+            // Find the index of this price ID in start and end arrays
+            for (uint j = 0; j < startPriceIds.length; j++) {
+                if (startPriceIds[j] == requestedPriceId) {
+                    startIdx = int(j);
+                    break;
                 }
             }
+
+            for (uint j = 0; j < endPriceIds.length; j++) {
+                if (endPriceIds[j] == requestedPriceId) {
+                    endIdx = int(j);
+                    break;
+                }
+            }
+
+            // If we found both start and end data for this price ID
+            if (startIdx >= 0 && endIdx >= 0) {
+                // Validate the pair of price infos
+                validateTwapPriceInfo(
+                    startTwapPriceInfos[uint(startIdx)],
+                    endTwapPriceInfos[uint(endIdx)]
+                );
+
+                // Calculate TWAP from these data points
+                twapPriceFeeds[i] = calculateTwap(
+                    requestedPriceId,
+                    startTwapPriceInfos[uint(startIdx)],
+                    endTwapPriceInfos[uint(endIdx)]
+                );
+            }
+        }
+
+        // Ensure all requested price IDs were found
+        for (uint k = 0; k < priceIds.length; k++) {
+            if (twapPriceFeeds[k].id == 0) {
+                revert PythErrors.PriceFeedNotFoundWithinRange();
+            }
         }
     }
 

+ 106 - 73
target_chains/ethereum/contracts/forge-test/Pyth.t.sol

@@ -164,44 +164,61 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
     ) public returns (bytes[] memory updateData, uint updateFee) {
         require(messages.length >= 2, "At least 2 messages required for TWAP");
 
-        // Create TWAP messages from regular price feed messages
-        // For TWAP calculation, we need cumulative values that increase over time
+        // Create arrays to hold the start and end TWAP messages for all price feeds
         TwapPriceFeedMessage[]
-            memory startTwapMessages = new TwapPriceFeedMessage[](1);
-        startTwapMessages[0].priceId = messages[0].priceId;
-        // For test purposes, we'll set cumulative values for start message
-        startTwapMessages[0].cumulativePrice = int128(messages[0].price) * 1000;
-        startTwapMessages[0].cumulativeConf = uint128(messages[0].conf) * 1000;
-        startTwapMessages[0].numDownSlots = 0; // No down slots for testing
-        startTwapMessages[0].expo = messages[0].expo;
-        startTwapMessages[0].publishTime = messages[0].publishTime;
-        startTwapMessages[0].prevPublishTime = messages[0].prevPublishTime;
-        startTwapMessages[0].publishSlot = 1000; // Start slot
-
+            memory startTwapMessages = new TwapPriceFeedMessage[](
+                messages.length / 2
+            );
         TwapPriceFeedMessage[]
-            memory endTwapMessages = new TwapPriceFeedMessage[](1);
-        endTwapMessages[0].priceId = messages[1].priceId;
-        // For end message, make sure cumulative values are higher than start
-        endTwapMessages[0].cumulativePrice =
-            int128(messages[1].price) *
-            1000 +
-            startTwapMessages[0].cumulativePrice;
-        endTwapMessages[0].cumulativeConf =
-            uint128(messages[1].conf) *
-            1000 +
-            startTwapMessages[0].cumulativeConf;
-        endTwapMessages[0].numDownSlots = 0; // No down slots for testing
-        endTwapMessages[0].expo = messages[1].expo;
-        endTwapMessages[0].publishTime = messages[1].publishTime;
-        endTwapMessages[0].prevPublishTime = messages[1].prevPublishTime;
-        endTwapMessages[0].publishSlot = 1100; // End slot (100 slots after start)
-
-        // Create the updateData array with exactly 2 elements as required by parseTwapPriceFeedUpdates
+            memory endTwapMessages = new TwapPriceFeedMessage[](
+                messages.length / 2
+            );
+
+        // Fill the arrays with all price feeds' start and end points
+        for (uint i = 0; i < messages.length / 2; i++) {
+            // Create start message for this price feed
+            startTwapMessages[i].priceId = messages[i * 2].priceId;
+            startTwapMessages[i].cumulativePrice =
+                int128(messages[i * 2].price) *
+                1000;
+            startTwapMessages[i].cumulativeConf =
+                uint128(messages[i * 2].conf) *
+                1000;
+            startTwapMessages[i].numDownSlots = 0; // No down slots for testing
+            startTwapMessages[i].expo = messages[i * 2].expo;
+            startTwapMessages[i].publishTime = messages[i * 2].publishTime;
+            startTwapMessages[i].prevPublishTime = messages[i * 2]
+                .prevPublishTime;
+            startTwapMessages[i].publishSlot = 1000; // Start slot
+
+            // Create end message for this price feed
+            endTwapMessages[i].priceId = messages[i * 2 + 1].priceId;
+            endTwapMessages[i].cumulativePrice =
+                int128(messages[i * 2 + 1].price) *
+                1000 +
+                startTwapMessages[i].cumulativePrice;
+            endTwapMessages[i].cumulativeConf =
+                uint128(messages[i * 2 + 1].conf) *
+                1000 +
+                startTwapMessages[i].cumulativeConf;
+            endTwapMessages[i].numDownSlots = 0; // No down slots for testing
+            endTwapMessages[i].expo = messages[i * 2 + 1].expo;
+            endTwapMessages[i].publishTime = messages[i * 2 + 1].publishTime;
+            endTwapMessages[i].prevPublishTime = messages[i * 2 + 1]
+                .prevPublishTime;
+            endTwapMessages[i].publishSlot = 1100; // End slot (100 slots after start)
+        }
+
+        // Create exactly 2 updateData entries as required by parseTwapPriceFeedUpdates
         updateData = new bytes[](2);
+
+        // First update contains all start points
         updateData[0] = generateWhMerkleTwapUpdateWithSource(
             startTwapMessages,
             config
         );
+
+        // Second update contains all end points
         updateData[1] = generateWhMerkleTwapUpdateWithSource(
             endTwapMessages,
             config
@@ -797,28 +814,23 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
         priceIds[0] = basePriceIds[0];
         priceIds[1] = basePriceIds[1];
 
-        // Create update data for both price feeds
-        bytes[] memory updateData = new bytes[](4); // 2 updates (start/end) for each price feed
+        // Create update data with both price feeds in the same updates
+        bytes[] memory updateData = new bytes[](2); // Just 2 updates (start/end) for both price feeds
 
-        // First price feed updates
+        // Combine both price feeds in the same messages
         TwapPriceFeedMessage[]
-            memory startMessages1 = new TwapPriceFeedMessage[](1);
-        TwapPriceFeedMessage[] memory endMessages1 = new TwapPriceFeedMessage[](
-            1
+            memory startMessages = new TwapPriceFeedMessage[](2);
+        TwapPriceFeedMessage[] memory endMessages = new TwapPriceFeedMessage[](
+            2
         );
-        startMessages1[0] = baseTwapStartMessages[0];
-        endMessages1[0] = baseTwapEndMessages[0];
 
-        // Second price feed updates
-        TwapPriceFeedMessage[]
-            memory startMessages2 = new TwapPriceFeedMessage[](1);
-        TwapPriceFeedMessage[] memory endMessages2 = new TwapPriceFeedMessage[](
-            1
-        );
-        startMessages2[0] = baseTwapStartMessages[1];
-        endMessages2[0] = baseTwapEndMessages[1];
+        // Add both price feeds to the start and end messages
+        startMessages[0] = baseTwapStartMessages[0];
+        startMessages[1] = baseTwapStartMessages[1];
+        endMessages[0] = baseTwapEndMessages[0];
+        endMessages[1] = baseTwapEndMessages[1];
 
-        // Generate Merkle updates for both price feeds
+        // Generate Merkle updates with both price feeds included
         MerkleUpdateConfig memory config = MerkleUpdateConfig(
             MERKLE_TREE_DEPTH,
             NUM_GUARDIAN_SIGNERS,
@@ -827,20 +839,13 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
             false
         );
 
+        // Create just 2 updates that contain both price feeds
         updateData[0] = generateWhMerkleTwapUpdateWithSource(
-            startMessages1,
+            startMessages,
             config
         );
         updateData[1] = generateWhMerkleTwapUpdateWithSource(
-            endMessages1,
-            config
-        );
-        updateData[2] = generateWhMerkleTwapUpdateWithSource(
-            startMessages2,
-            config
-        );
-        updateData[3] = generateWhMerkleTwapUpdateWithSource(
-            endMessages2,
+            endMessages,
             config
         );
 
@@ -882,12 +887,12 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
     function testParseTwapPriceFeedUpdatesRevertsWithMismatchedArrayLengths()
         public
     {
-        // Case 1: More updates than needed for price feeds
-        bytes32[] memory priceIds = new bytes32[](1); // One price feed
+        // Case 1: Too many updates (more than 2)
+        bytes32[] memory priceIds = new bytes32[](1);
         priceIds[0] = basePriceIds[0];
 
-        // Create 4 updates (should only be 2 for one price feed)
-        bytes[] memory updateData = new bytes[](4);
+        // Create 3 updates (should only be 2)
+        bytes[] memory updateData = new bytes[](3);
 
         TwapPriceFeedMessage[]
             memory startMessages = new TwapPriceFeedMessage[](1);
@@ -918,23 +923,50 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
             startMessages,
             config
         );
-        updateData[3] = generateWhMerkleTwapUpdateWithSource(
-            endMessages,
+
+        uint updateFee = pyth.getUpdateFee(updateData);
+
+        vm.expectRevert(PythErrors.InvalidUpdateData.selector);
+        pyth.parseTwapPriceFeedUpdates{value: updateFee}(updateData, priceIds);
+
+        // Case 2: Too few updates (less than 2)
+        updateData = new bytes[](1); // Only 1 update (should be 2)
+        updateData[0] = generateWhMerkleTwapUpdateWithSource(
+            startMessages,
             config
         );
 
-        uint updateFee = pyth.getUpdateFee(updateData);
+        updateFee = pyth.getUpdateFee(updateData);
 
         vm.expectRevert(PythErrors.InvalidUpdateData.selector);
         pyth.parseTwapPriceFeedUpdates{value: updateFee}(updateData, priceIds);
+    }
 
-        // Case 2: Fewer updates than needed for price feeds
-        priceIds = new bytes32[](2); // Two price feeds
-        priceIds[0] = basePriceIds[0];
-        priceIds[1] = basePriceIds[1];
+    function testParseTwapPriceFeedUpdatesWithRequestedButNotFoundPriceId()
+        public
+    {
+        // Create price IDs, including one that's not in the updates
+        bytes32[] memory priceIds = new bytes32[](2);
+        priceIds[0] = basePriceIds[0]; // This one exists in our updates
+        priceIds[1] = bytes32(uint256(999)); // This one doesn't exist in our updates
 
-        // Create only 2 updates (should be 4 for two price feeds)
-        updateData = new bytes[](2);
+        TwapPriceFeedMessage[]
+            memory startMessages = new TwapPriceFeedMessage[](1);
+        TwapPriceFeedMessage[] memory endMessages = new TwapPriceFeedMessage[](
+            1
+        );
+        startMessages[0] = baseTwapStartMessages[0]; // Only includes priceIds[0]
+        endMessages[0] = baseTwapEndMessages[0]; // Only includes priceIds[0]
+
+        MerkleUpdateConfig memory config = MerkleUpdateConfig(
+            MERKLE_TREE_DEPTH,
+            NUM_GUARDIAN_SIGNERS,
+            SOURCE_EMITTER_CHAIN_ID,
+            SOURCE_EMITTER_ADDRESS,
+            false
+        );
+
+        bytes[] memory updateData = new bytes[](2);
         updateData[0] = generateWhMerkleTwapUpdateWithSource(
             startMessages,
             config
@@ -944,9 +976,10 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils {
             config
         );
 
-        updateFee = pyth.getUpdateFee(updateData);
+        uint updateFee = pyth.getUpdateFee(updateData);
 
-        vm.expectRevert(PythErrors.InvalidUpdateData.selector);
+        // Should revert because one of the requested price IDs is not found in the updates
+        vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
         pyth.parseTwapPriceFeedUpdates{value: updateFee}(updateData, priceIds);
     }
 }