Parcourir la source

[ethereum] - charge updateFee per number of updates (#878)

* feat(ethereum): charge update fee per numUpdates for accumulator updates

* refactor(ethereum): refactor, add benchmarks for getUpdateFee

* refactor(ethereum): add back parseWormholeMerkleHeaderNumUpdates

* refactor: increment totalNumUdpates by 1 for batch prices

* test(ethereum): add test for checking getUpdateFee for accumulator, clean up unused code
swimricky il y a 2 ans
Parent
commit
19b77e2c84

+ 53 - 23
target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

@@ -71,24 +71,26 @@ abstract contract Pyth is
     function updatePriceFeeds(
         bytes[] calldata updateData
     ) public payable override {
-        // TODO: Is this fee model still good for accumulator?
-        uint requiredFee = getUpdateFee(updateData);
-        if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
-
+        uint totalNumUpdates = 0;
         for (uint i = 0; i < updateData.length; ) {
             if (
                 updateData[i].length > 4 &&
                 UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
             ) {
-                updatePriceInfosFromAccumulatorUpdate(updateData[i]);
+                totalNumUpdates += updatePriceInfosFromAccumulatorUpdate(
+                    updateData[i]
+                );
             } else {
                 updatePriceBatchFromVm(updateData[i]);
+                totalNumUpdates += 1;
             }
 
             unchecked {
                 i++;
             }
         }
+        uint requiredFee = getTotalFee(totalNumUpdates);
+        if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
     }
 
     /// This method is deprecated, please use the `getUpdateFee(bytes[])` instead.
@@ -101,7 +103,28 @@ abstract contract Pyth is
     function getUpdateFee(
         bytes[] calldata updateData
     ) public view override returns (uint feeAmount) {
-        return singleUpdateFeeInWei() * updateData.length;
+        uint totalNumUpdates = 0;
+        for (uint i = 0; i < updateData.length; i++) {
+            if (
+                updateData[i].length > 4 &&
+                UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
+            ) {
+                (
+                    uint offset,
+                    UpdateType updateType
+                ) = extractUpdateTypeFromAccumulatorHeader(updateData[i]);
+                if (updateType != UpdateType.WormholeMerkle) {
+                    revert PythErrors.InvalidUpdateData();
+                }
+                totalNumUpdates += parseWormholeMerkleHeaderNumUpdates(
+                    updateData[i],
+                    offset
+                );
+            } else {
+                totalNumUpdates += 1;
+            }
+        }
+        return getTotalFee(totalNumUpdates);
     }
 
     function verifyPythVM(
@@ -425,12 +448,7 @@ abstract contract Pyth is
         returns (PythStructs.PriceFeed[] memory priceFeeds)
     {
         unchecked {
-            {
-                uint requiredFee = getUpdateFee(updateData);
-                if (msg.value < requiredFee)
-                    revert PythErrors.InsufficientFee();
-            }
-
+            uint totalNumUpdates = 0;
             priceFeeds = new PythStructs.PriceFeed[](priceIds.length);
             for (uint i = 0; i < updateData.length; i++) {
                 if (
@@ -438,7 +456,6 @@ abstract contract Pyth is
                     UnsafeBytesLib.toUint32(updateData[i], 0) ==
                     ACCUMULATOR_MAGIC
                 ) {
-                    bytes memory accumulatorUpdate = updateData[i];
                     uint offset;
                     {
                         UpdateType updateType;
@@ -446,31 +463,30 @@ abstract contract Pyth is
                             offset,
                             updateType
                         ) = extractUpdateTypeFromAccumulatorHeader(
-                            accumulatorUpdate
+                            updateData[i]
                         );
 
                         if (updateType != UpdateType.WormholeMerkle) {
                             revert PythErrors.InvalidUpdateData();
                         }
                     }
+
                     bytes20 digest;
                     uint8 numUpdates;
-                    bytes memory encoded = UnsafeBytesLib.slice(
-                        accumulatorUpdate,
-                        offset,
-                        accumulatorUpdate.length - offset
-                    );
-
+                    bytes memory encoded;
                     (
                         offset,
                         digest,
-                        numUpdates
-                    ) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
+                        numUpdates,
+                        encoded
+                    ) = extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
+                        updateData[i],
+                        offset
+                    );
 
                     for (uint j = 0; j < numUpdates; j++) {
                         PythInternalStructs.PriceInfo memory info;
                         bytes32 priceId;
-
                         (
                             offset,
                             info,
@@ -509,6 +525,7 @@ abstract contract Pyth is
                             }
                         }
                     }
+                    totalNumUpdates += numUpdates;
                     if (offset != encoded.length)
                         revert PythErrors.InvalidUpdateData();
                 } else {
@@ -583,6 +600,7 @@ abstract contract Pyth is
 
                         index += attestationSize;
                     }
+                    totalNumUpdates += 1;
                 }
             }
 
@@ -591,9 +609,21 @@ abstract contract Pyth is
                     revert PythErrors.PriceFeedNotFoundWithinRange();
                 }
             }
+
+            {
+                uint requiredFee = getTotalFee(totalNumUpdates);
+                if (msg.value < requiredFee)
+                    revert PythErrors.InsufficientFee();
+            }
         }
     }
 
+    function getTotalFee(
+        uint totalNumUpdates
+    ) private view returns (uint requiredFee) {
+        return totalNumUpdates * singleUpdateFeeInWei();
+    }
+
     function findIndexOfPriceId(
         bytes32[] calldata priceIds,
         bytes32 targetPriceId

+ 53 - 59
target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol

@@ -107,9 +107,24 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
         }
     }
 
-    function extractWormholeMerkleHeaderDigestAndNumUpdates(
-        bytes memory encoded
-    ) internal view returns (uint offset, bytes20 digest, uint8 numUpdates) {
+    function extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
+        bytes calldata accumulatorUpdate,
+        uint encodedOffset
+    )
+        internal
+        view
+        returns (
+            uint offset,
+            bytes20 digest,
+            uint8 numUpdates,
+            bytes memory encoded
+        )
+    {
+        encoded = UnsafeBytesLib.slice(
+            accumulatorUpdate,
+            encodedOffset,
+            accumulatorUpdate.length - encodedOffset
+        );
         unchecked {
             offset = 0;
 
@@ -170,6 +185,19 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
         }
     }
 
+    function parseWormholeMerkleHeaderNumUpdates(
+        bytes calldata wormholeMerkleUpdate,
+        uint offset
+    ) internal view returns (uint8 numUpdates) {
+        uint16 whProofSize = UnsafeBytesLib.toUint16(
+            wormholeMerkleUpdate,
+            offset
+        );
+        offset += 2;
+        offset += whProofSize;
+        numUpdates = UnsafeBytesLib.toUint8(wormholeMerkleUpdate, offset);
+    }
+
     function extractPriceInfoFromMerkleProof(
         bytes20 digest,
         bytes memory encoded,
@@ -185,62 +213,28 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
     {
         unchecked {
             bytes memory encodedMessage;
-            (endOffset, encodedMessage) = extractMessageFromProof(
-                encoded,
-                offset,
-                digest
-            );
-
-            (priceInfo, priceId) = extractPriceInfoAndIdFromPriceFeedMessage(
-                encodedMessage
-            );
-
-            return (endOffset, priceInfo, priceId);
-        }
-    }
-
-    function extractMessageFromProof(
-        bytes memory encodedProof,
-        uint offset,
-        bytes20 merkleRoot
-    ) private pure returns (uint endOffset, bytes memory encodedMessage) {
-        unchecked {
-            uint16 messageSize = UnsafeBytesLib.toUint16(encodedProof, offset);
+            uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset);
             offset += 2;
 
-            encodedMessage = UnsafeBytesLib.slice(
-                encodedProof,
-                offset,
-                messageSize
-            );
+            encodedMessage = UnsafeBytesLib.slice(encoded, offset, messageSize);
             offset += messageSize;
 
             bool valid;
             (valid, endOffset) = MerkleTree.isProofValid(
-                encodedProof,
+                encoded,
                 offset,
-                merkleRoot,
+                digest,
                 encodedMessage
             );
             if (!valid) {
                 revert PythErrors.InvalidUpdateData();
             }
-        }
-    }
 
-    function extractPriceInfoAndIdFromPriceFeedMessage(
-        bytes memory encodedMessage
-    )
-        private
-        pure
-        returns (PythInternalStructs.PriceInfo memory info, bytes32 priceId)
-    {
-        unchecked {
             MessageType messageType = MessageType(
                 UnsafeBytesLib.toUint8(encodedMessage, 0)
             );
             if (messageType == MessageType.PriceFeed) {
-                (info, priceId) = parsePriceFeedMessage(
+                (priceInfo, priceId) = parsePriceFeedMessage(
                     UnsafeBytesLib.slice(
                         encodedMessage,
                         1,
@@ -250,6 +244,8 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
             } else {
                 revert PythErrors.InvalidUpdateData();
             }
+
+            return (endOffset, priceInfo, priceId);
         }
     }
 
@@ -315,32 +311,30 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
 
     function updatePriceInfosFromAccumulatorUpdate(
         bytes calldata accumulatorUpdate
-    ) internal {
+    ) internal returns (uint8 numUpdates) {
         (
-            uint offset,
+            uint encodedOffset,
             UpdateType updateType
         ) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate);
 
         if (updateType != UpdateType.WormholeMerkle) {
             revert PythErrors.InvalidUpdateData();
         }
-        updatePriceInfosFromWormholeMerkle(
-            UnsafeBytesLib.slice(
-                accumulatorUpdate,
-                offset,
-                accumulatorUpdate.length - offset
-            )
+
+        uint offset;
+        bytes20 digest;
+        bytes memory encoded;
+        (
+            offset,
+            digest,
+            numUpdates,
+            encoded
+        ) = extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
+            accumulatorUpdate,
+            encodedOffset
         );
-    }
 
-    function updatePriceInfosFromWormholeMerkle(bytes memory encoded) private {
         unchecked {
-            (
-                uint offset,
-                bytes20 digest,
-                uint8 numUpdates
-            ) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
-
             for (uint i = 0; i < numUpdates; i++) {
                 PythInternalStructs.PriceInfo memory priceInfo;
                 bytes32 priceId;
@@ -360,7 +354,7 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
                     );
                 }
             }
-            if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
         }
+        if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
     }
 }

+ 27 - 4
target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol

@@ -340,9 +340,12 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
         ids[0] = priceIds[4];
         ids[1] = priceIds[2];
         ids[2] = priceIds[0];
-        pyth.parsePriceFeedUpdates{
-            value: freshPricesWhMerkleUpdateFee[numIds - 1]
-        }(freshPricesWhMerkleUpdateData[4], ids, 0, 50);
+        pyth.parsePriceFeedUpdates{value: freshPricesWhMerkleUpdateFee[4]}( // updateFee based on number of priceFeeds in updateData
+            freshPricesWhMerkleUpdateData[4],
+            ids,
+            0,
+            50
+        );
     }
 
     function testBenchmarkParsePriceFeedUpdatesWhMerkleForOnePriceFeedNotWithinRange()
@@ -391,7 +394,27 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
         pyth.getEmaPrice(priceIds[0]);
     }
 
-    function testBenchmarkGetUpdateFee() public view {
+    function testBenchmarkGetUpdateFeeWhBatch() public view {
         pyth.getUpdateFee(freshPricesWhBatchUpdateData);
     }
+
+    function testBenchmarkGetUpdateFeeWhMerkle1() public view {
+        pyth.getUpdateFee(freshPricesWhMerkleUpdateData[0]);
+    }
+
+    function testBenchmarkGetUpdateFeeWhMerkle2() public view {
+        pyth.getUpdateFee(freshPricesWhMerkleUpdateData[1]);
+    }
+
+    function testBenchmarkGetUpdateFeeWhMerkle3() public view {
+        pyth.getUpdateFee(freshPricesWhMerkleUpdateData[2]);
+    }
+
+    function testBenchmarkGetUpdateFeeWhMerkle4() public view {
+        pyth.getUpdateFee(freshPricesWhMerkleUpdateData[3]);
+    }
+
+    function testBenchmarkGetUpdateFeeWhMerkle5() public view {
+        pyth.getUpdateFee(freshPricesWhMerkleUpdateData[4]);
+    }
 }

+ 35 - 1
target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol

@@ -520,7 +520,10 @@ contract PythWormholeMerkleAccumulatorTest is
             );
         }
 
-        updateFee = pyth.getUpdateFee(updateData);
+        // manually calculate the fee
+        // so this helper method doesn't trigger the error.
+        // updateFee = pyth.getUpdateFee(numPriceFeeds);
+        updateFee = singleUpdateFeeInWei() * numPriceFeeds;
     }
 
     function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadMagic()
@@ -913,6 +916,37 @@ contract PythWormholeMerkleAccumulatorTest is
         );
     }
 
+    function testGetUpdateFeeWorksForWhMerkle() public {
+        uint numPriceFeeds = (getRandUint() % 10) + 1;
+        PriceFeedMessage[]
+            memory priceFeedMessages = generateRandomPriceFeedMessage(
+                numPriceFeeds
+            );
+        (bytes[] memory updateData, ) = createWormholeMerkleUpdateData(
+            priceFeedMessages
+        );
+
+        uint updateFee = pyth.getUpdateFee(updateData);
+        assertEq(updateFee, SINGLE_UPDATE_FEE_IN_WEI * numPriceFeeds);
+    }
+
+    function testGetUpdateFeeWorksForWhMerkleBasedOnNumUpdates() public {
+        uint numPriceFeeds = (getRandUint() % 10) + 1;
+        PriceFeedMessage[]
+            memory priceFeedMessages = generateRandomPriceFeedMessage(
+                numPriceFeeds
+            );
+        // Set the priceId of the second message to be the same as the first.
+        priceFeedMessages[1].priceId = priceFeedMessages[0].priceId;
+        (bytes[] memory updateData, ) = createWormholeMerkleUpdateData(
+            priceFeedMessages
+        );
+
+        uint updateFee = pyth.getUpdateFee(updateData);
+        // updateFee should still be based on numUpdates not distinct number of priceIds
+        assertEq(updateFee, SINGLE_UPDATE_FEE_IN_WEI * numPriceFeeds);
+    }
+
     //TODO: add some tests of forward compatibility.
     // I.e., create a message where each part that can be expanded in size is expanded and make sure that parsing still works
 }

+ 6 - 1
target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol

@@ -24,6 +24,7 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
     uint16 constant GOVERNANCE_EMITTER_CHAIN_ID = 0x1;
     bytes32 constant GOVERNANCE_EMITTER_ADDRESS =
         0x0000000000000000000000000000000000000000000000000000000000000011;
+    uint constant SINGLE_UPDATE_FEE_IN_WEI = 1;
 
     function setUpPyth(address wormhole) public returns (address) {
         PythUpgradable implementation = new PythUpgradable();
@@ -47,12 +48,16 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
             GOVERNANCE_EMITTER_ADDRESS,
             0, // Initial governance sequence
             60, // Valid time period in seconds
-            1 // single update fee in wei
+            SINGLE_UPDATE_FEE_IN_WEI // single update fee in wei
         );
 
         return address(pyth);
     }
 
+    function singleUpdateFeeInWei() public view returns (uint) {
+        return SINGLE_UPDATE_FEE_IN_WEI;
+    }
+
     /// Utilities to help generating price attestations and VAAs for them
 
     enum PriceAttestationStatus {

+ 1 - 1
target_chains/ethereum/contracts/foundry.toml

@@ -1,7 +1,7 @@
 [profile.default]
 solc_version = '0.8.4'
 optimizer = true
-optimizer_runs = 5000
+optimizer_runs = 2000
 src = 'contracts'
 # We put the tests into the forge-test directory (instead of test) so that
 # truffle doesn't try to build them