Переглянути джерело

feat(pulse): add getFirstActiveRequests function to retrieve active requests (#2371)

* feat(pulse): Add getLastActiveRequests function to retrieve active requests

* fix(pulse): Update gas usage comments for active requests function

* refactor(pulse): Rename getLastActiveRequests to getFirstActiveRequests and update related comments

* refactor(tests): Rename test functions for consistency in naming convention
Daniel Chew 9 місяців тому
батько
коміт
9b2b626ad5

+ 19 - 0
target_chains/ethereum/contracts/contracts/pulse/IPulse.sol

@@ -92,4 +92,23 @@ interface IPulse is PulseEvents {
     function setExclusivityPeriod(uint256 periodSeconds) external;
 
     function getExclusivityPeriod() external view returns (uint256);
+
+    /**
+     * @notice Gets the first N active requests
+     * @param count Maximum number of active requests to return
+     * @return requests Array of active requests, ordered from oldest to newest
+     * @return actualCount Number of active requests found (may be less than count)
+     * @dev Gas Usage: This function's gas cost scales linearly with the number of requests
+     *      between firstUnfulfilledSeq and currentSequenceNumber. Each iteration costs approximately:
+     *      - 2100 gas for cold storage reads, 100 gas for warm storage reads (SLOAD)
+     *      - Additional gas for array operations
+     *      The function starts from firstUnfulfilledSeq (all requests before this are fulfilled)
+     *      and scans forward until it finds enough active requests or reaches currentSequenceNumber.
+     */
+    function getFirstActiveRequests(
+        uint256 count
+    )
+        external
+        view
+        returns (PulseState.Request[] memory requests, uint256 actualCount);
 }

+ 43 - 1
target_chains/ethereum/contracts/contracts/pulse/Pulse.sol

@@ -164,6 +164,14 @@ abstract contract Pulse is IPulse, PulseState {
                 "low-level error (possibly out of gas)"
             );
         }
+
+        // After successful callback, update firstUnfulfilledSeq if needed
+        while (
+            _state.firstUnfulfilledSeq < _state.currentSequenceNumber &&
+            !isActive(findRequest(_state.firstUnfulfilledSeq))
+        ) {
+            _state.firstUnfulfilledSeq++;
+        }
     }
 
     function emitPriceUpdate(
@@ -293,7 +301,7 @@ abstract contract Pulse is IPulse, PulseState {
         }
     }
 
-    function isActive(Request storage req) internal view returns (bool) {
+    function isActive(Request memory req) internal pure returns (bool) {
         return req.sequenceNumber != 0;
     }
 
@@ -383,4 +391,38 @@ abstract contract Pulse is IPulse, PulseState {
     function getExclusivityPeriod() external view override returns (uint256) {
         return _state.exclusivityPeriodSeconds;
     }
+
+    function getFirstActiveRequests(
+        uint256 count
+    )
+        external
+        view
+        override
+        returns (Request[] memory requests, uint256 actualCount)
+    {
+        requests = new Request[](count);
+        actualCount = 0;
+
+        // Start from the first unfulfilled sequence and work forwards
+        uint64 currentSeq = _state.firstUnfulfilledSeq;
+
+        // Continue until we find enough active requests or reach current sequence
+        while (
+            actualCount < count && currentSeq < _state.currentSequenceNumber
+        ) {
+            Request memory req = findRequest(currentSeq);
+            if (isActive(req)) {
+                requests[actualCount] = req;
+                actualCount++;
+            }
+            currentSeq++;
+        }
+
+        // If we found fewer requests than asked for, resize the array
+        if (actualCount < count) {
+            assembly {
+                mstore(requests, actualCount)
+            }
+        }
+    }
 }

+ 1 - 0
target_chains/ethereum/contracts/contracts/pulse/PulseState.sol

@@ -37,6 +37,7 @@ contract PulseState {
         Request[NUM_REQUESTS] requests;
         mapping(bytes32 => Request) requestsOverflow;
         mapping(address => ProviderInfo) providers;
+        uint64 firstUnfulfilledSeq; // All sequences before this are fulfilled
     }
 
     State internal _state;

+ 204 - 1
target_chains/ethereum/contracts/forge-test/Pulse.t.sol

@@ -54,7 +54,7 @@ contract CustomErrorPulseConsumer is IPulseConsumer {
     }
 }
 
-contract PulseTest is Test, PulseEvents {
+contract PulseTest is Test, PulseEvents, IPulseConsumer {
     ERC1967Proxy public proxy;
     PulseUpgradeable public pulse;
     MockPulseConsumer public consumer;
@@ -876,4 +876,207 @@ contract PulseTest is Test, PulseEvents {
         vm.prank(secondProvider);
         pulse.executeCallback(sequenceNumber, updateData, priceIds);
     }
+
+    function testGetFirstActiveRequests() public {
+        // Setup test data
+        (
+            bytes32[] memory priceIds,
+            bytes[] memory updateData
+        ) = setupTestData();
+        createTestRequests(priceIds);
+        completeRequests(updateData, priceIds);
+
+        testRequestScenarios(priceIds, updateData);
+    }
+
+    function setupTestData()
+        private
+        pure
+        returns (bytes32[] memory, bytes[] memory)
+    {
+        bytes32[] memory priceIds = new bytes32[](1);
+        priceIds[0] = bytes32(uint256(1));
+
+        bytes[] memory updateData = new bytes[](1);
+        return (priceIds, updateData);
+    }
+
+    function createTestRequests(bytes32[] memory priceIds) private {
+        uint256 publishTime = block.timestamp;
+        for (uint i = 0; i < 5; i++) {
+            vm.deal(address(this), 1 ether);
+            pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
+                publishTime,
+                priceIds,
+                1000000
+            );
+        }
+    }
+
+    function completeRequests(
+        bytes[] memory updateData,
+        bytes32[] memory priceIds
+    ) private {
+        // Create mock price feeds and setup Pyth response
+        PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
+            block.timestamp
+        );
+        mockParsePriceFeedUpdates(priceFeeds);
+        updateData = createMockUpdateData(priceFeeds);
+
+        vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
+        vm.startPrank(defaultProvider);
+        pulse.executeCallback{value: 1 ether}(2, updateData, priceIds);
+        pulse.executeCallback{value: 1 ether}(4, updateData, priceIds);
+        vm.stopPrank();
+    }
+
+    function testRequestScenarios(
+        bytes32[] memory priceIds,
+        bytes[] memory updateData
+    ) private {
+        // Test 1: Request more than available
+        checkMoreThanAvailable();
+
+        // Test 2: Request exact number
+        checkExactNumber();
+
+        // Test 3: Request fewer than available
+        checkFewerThanAvailable();
+
+        // Test 4: Request zero
+        checkZeroRequest();
+
+        // Test 5: Clear all and check empty
+        clearAllRequests(updateData, priceIds);
+        checkEmptyState();
+    }
+
+    // Split test scenarios into separate functions
+    function checkMoreThanAvailable() private {
+        (PulseState.Request[] memory requests, uint256 count) = pulse
+            .getFirstActiveRequests(10);
+        assertEq(count, 3, "Should find 3 active requests");
+        assertEq(requests.length, 3, "Array should be resized to 3");
+        assertEq(
+            requests[0].sequenceNumber,
+            1,
+            "First request should be oldest"
+        );
+        assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
+        assertEq(requests[2].sequenceNumber, 5, "Third request should be #5");
+    }
+
+    function checkExactNumber() private {
+        (PulseState.Request[] memory requests, uint256 count) = pulse
+            .getFirstActiveRequests(3);
+        assertEq(count, 3, "Should find 3 active requests");
+        assertEq(requests.length, 3, "Array should match requested size");
+    }
+
+    function checkFewerThanAvailable() private {
+        (PulseState.Request[] memory requests, uint256 count) = pulse
+            .getFirstActiveRequests(2);
+        assertEq(count, 2, "Should find 2 active requests");
+        assertEq(requests.length, 2, "Array should match requested size");
+        assertEq(
+            requests[0].sequenceNumber,
+            1,
+            "First request should be oldest"
+        );
+        assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
+    }
+
+    function checkZeroRequest() private {
+        (PulseState.Request[] memory requests, uint256 count) = pulse
+            .getFirstActiveRequests(0);
+        assertEq(count, 0, "Should find 0 active requests");
+        assertEq(requests.length, 0, "Array should be empty");
+    }
+
+    function clearAllRequests(
+        bytes[] memory updateData,
+        bytes32[] memory priceIds
+    ) private {
+        vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
+        vm.startPrank(defaultProvider);
+        pulse.executeCallback{value: 1 ether}(1, updateData, priceIds);
+        pulse.executeCallback{value: 1 ether}(3, updateData, priceIds);
+        pulse.executeCallback{value: 1 ether}(5, updateData, priceIds);
+        vm.stopPrank();
+    }
+
+    function checkEmptyState() private {
+        (PulseState.Request[] memory requests, uint256 count) = pulse
+            .getFirstActiveRequests(10);
+        assertEq(count, 0, "Should find 0 active requests");
+        assertEq(requests.length, 0, "Array should be empty");
+    }
+
+    function testGetFirstActiveRequestsGasUsage() public {
+        // Setup test data
+        bytes32[] memory priceIds = new bytes32[](1);
+        priceIds[0] = bytes32(uint256(1));
+        uint256 publishTime = block.timestamp;
+        uint256 callbackGasLimit = 1000000;
+
+        // Create mock price feeds and setup Pyth response
+        PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
+            publishTime
+        );
+        mockParsePriceFeedUpdates(priceFeeds);
+        bytes[] memory updateData = createMockUpdateData(priceFeeds);
+
+        // Create 20 requests with some gaps
+        for (uint i = 0; i < 20; i++) {
+            vm.deal(address(this), 1 ether);
+            pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
+                publishTime,
+                priceIds,
+                callbackGasLimit
+            );
+
+            // Complete every third request to create gaps
+            if (i % 3 == 0) {
+                vm.deal(defaultProvider, 1 ether);
+                vm.prank(defaultProvider);
+                pulse.executeCallback{value: 1 ether}(
+                    uint64(i + 1),
+                    updateData,
+                    priceIds
+                );
+            }
+        }
+
+        // Measure gas for different request counts
+        uint256 gas1 = gasleft();
+        pulse.getFirstActiveRequests(5);
+        uint256 gas1Used = gas1 - gasleft();
+
+        uint256 gas2 = gasleft();
+        pulse.getFirstActiveRequests(10);
+        uint256 gas2Used = gas2 - gasleft();
+
+        // Log gas usage for analysis
+        emit log_named_uint("Gas used for 5 requests", gas1Used);
+        emit log_named_uint("Gas used for 10 requests", gas2Used);
+
+        // Verify gas usage scales roughly linearly
+        // Allow 10% margin for other factors
+        assertApproxEqRel(
+            gas2Used,
+            gas1Used * 2,
+            0.1e18, // 10% tolerance
+            "Gas usage should scale roughly linearly"
+        );
+    }
+
+    // Mock implementation of pulseCallback
+    function pulseCallback(
+        uint64 sequenceNumber,
+        PythStructs.PriceFeed[] memory priceFeeds
+    ) external override {
+        // Just accept the callback, no need to do anything with the data
+        // This prevents the revert we're seeing
+    }
 }