| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202 |
- // SPDX-License-Identifier: Apache 2
- pragma solidity ^0.8.0;
- import "forge-std/Test.sol";
- import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
- import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
- import "./utils/PulseTestUtils.t.sol";
- import "../contracts/pulse/PulseUpgradeable.sol";
- import "../contracts/pulse/IPulse.sol";
- import "../contracts/pulse/PulseState.sol";
- import "../contracts/pulse/PulseEvents.sol";
- import "../contracts/pulse/PulseErrors.sol";
- // Concrete implementation for testing
- contract ConcretePulseUpgradeable is PulseUpgradeable {}
- contract MockPulseConsumer is IPulseConsumer {
- address private _pulse;
- uint64 public lastSequenceNumber;
- PythStructs.PriceFeed[] private _lastPriceFeeds;
- constructor(address pulse) {
- _pulse = pulse;
- }
- function getPulse() internal view override returns (address) {
- return _pulse;
- }
- function pulseCallback(
- uint64 sequenceNumber,
- PythStructs.PriceFeed[] memory priceFeeds
- ) internal override {
- lastSequenceNumber = sequenceNumber;
- for (uint i = 0; i < priceFeeds.length; i++) {
- _lastPriceFeeds.push(priceFeeds[i]);
- }
- }
- function lastPriceFeeds()
- external
- view
- returns (PythStructs.PriceFeed[] memory)
- {
- return _lastPriceFeeds;
- }
- }
- contract FailingPulseConsumer is IPulseConsumer {
- address private _pulse;
- constructor(address pulse) {
- _pulse = pulse;
- }
- function getPulse() internal view override returns (address) {
- return _pulse;
- }
- function pulseCallback(
- uint64,
- PythStructs.PriceFeed[] memory
- ) internal pure override {
- revert("callback failed");
- }
- }
- contract CustomErrorPulseConsumer is IPulseConsumer {
- error CustomError(string message);
- address private _pulse;
- constructor(address pulse) {
- _pulse = pulse;
- }
- function getPulse() internal view override returns (address) {
- return _pulse;
- }
- function pulseCallback(
- uint64,
- PythStructs.PriceFeed[] memory
- ) internal pure override {
- revert CustomError("callback failed");
- }
- }
- // FIXME: this shouldn't be IPulseConsumer.
- contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils {
- ERC1967Proxy public proxy;
- PulseUpgradeable public pulse;
- MockPulseConsumer public consumer;
- address public owner;
- address public admin;
- address public pyth;
- address public defaultProvider;
- // Constants
- uint96 constant PYTH_FEE = 1 wei;
- uint96 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei;
- uint96 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei;
- uint96 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei;
- function setUp() public {
- owner = address(1);
- admin = address(2);
- pyth = address(3);
- defaultProvider = address(4);
- PulseUpgradeable _pulse = new ConcretePulseUpgradeable();
- proxy = new ERC1967Proxy(address(_pulse), "");
- pulse = PulseUpgradeable(address(proxy));
- pulse.initialize(
- owner,
- admin,
- PYTH_FEE,
- pyth,
- defaultProvider,
- false,
- 15
- );
- vm.prank(defaultProvider);
- pulse.registerProvider(
- DEFAULT_PROVIDER_BASE_FEE,
- DEFAULT_PROVIDER_FEE_PER_FEED,
- DEFAULT_PROVIDER_FEE_PER_GAS
- );
- consumer = new MockPulseConsumer(address(proxy));
- }
- // Helper function to calculate total fee
- // FIXME: I think this helper probably needs to take some arguments.
- function calculateTotalFee() internal view returns (uint96) {
- return
- pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds());
- }
- function testRequestPriceUpdate() public {
- // Set a realistic gas price
- vm.txGasPrice(30 gwei);
- bytes32[] memory priceIds = createPriceIds();
- uint64 publishTime = SafeCast.toUint64(block.timestamp);
- // Fund the consumer contract with enough ETH for higher gas price
- vm.deal(address(consumer), 1 ether);
- uint96 totalFee = calculateTotalFee();
- // Create the event data we expect to see
- bytes8[] memory expectedPriceIdPrefixes = new bytes8[](2);
- {
- bytes32 priceId0 = priceIds[0];
- bytes32 priceId1 = priceIds[1];
- bytes8 prefix0;
- bytes8 prefix1;
- assembly {
- prefix0 := priceId0
- prefix1 := priceId1
- }
- expectedPriceIdPrefixes[0] = prefix0;
- expectedPriceIdPrefixes[1] = prefix1;
- }
- PulseState.Request memory expectedRequest = PulseState.Request({
- sequenceNumber: 1,
- publishTime: publishTime,
- priceIdPrefixes: expectedPriceIdPrefixes,
- callbackGasLimit: uint32(CALLBACK_GAS_LIMIT),
- requester: address(consumer),
- provider: defaultProvider,
- fee: totalFee - PYTH_FEE
- });
- vm.expectEmit();
- emit PriceUpdateRequested(expectedRequest, priceIds);
- vm.prank(address(consumer));
- pulse.requestPriceUpdatesWithCallback{value: totalFee}(
- defaultProvider,
- publishTime,
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- // Additional assertions to verify event data was stored correctly
- PulseState.Request memory lastRequest = pulse.getRequest(1);
- assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber);
- assertEq(lastRequest.publishTime, expectedRequest.publishTime);
- assertEq(
- lastRequest.priceIdPrefixes.length,
- expectedRequest.priceIdPrefixes.length
- );
- for (uint8 i = 0; i < lastRequest.priceIdPrefixes.length; i++) {
- assertEq(
- lastRequest.priceIdPrefixes[i],
- expectedRequest.priceIdPrefixes[i]
- );
- }
- assertEq(
- lastRequest.callbackGasLimit,
- expectedRequest.callbackGasLimit
- );
- assertEq(
- lastRequest.requester,
- expectedRequest.requester,
- "Requester mismatch"
- );
- }
- function testRequestWithInsufficientFee() public {
- // Set a realistic gas price
- vm.txGasPrice(30 gwei);
- bytes32[] memory priceIds = createPriceIds();
- vm.deal(address(consumer), 1 ether);
- vm.prank(address(consumer));
- vm.expectRevert(InsufficientFee.selector);
- pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee
- defaultProvider,
- SafeCast.toUint64(block.timestamp),
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- }
- function testExecuteCallback() public {
- bytes32[] memory priceIds = createPriceIds();
- uint64 publishTime = SafeCast.toUint64(block.timestamp);
- // Fund the consumer contract
- vm.deal(address(consumer), 1 gwei);
- uint96 totalFee = calculateTotalFee();
- // Step 1: Make the request as consumer
- vm.prank(address(consumer));
- uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
- value: totalFee
- }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT);
- // Step 2: Create mock price feeds and setup Pyth response
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- // FIXME: this test doesn't ensure the Pyth fee is paid.
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- // Create arrays for expected event data
- int64[] memory expectedPrices = new int64[](2);
- expectedPrices[0] = MOCK_BTC_PRICE;
- expectedPrices[1] = MOCK_ETH_PRICE;
- uint64[] memory expectedConf = new uint64[](2);
- expectedConf[0] = MOCK_BTC_CONF;
- expectedConf[1] = MOCK_ETH_CONF;
- int32[] memory expectedExpos = new int32[](2);
- expectedExpos[0] = MOCK_PRICE_FEED_EXPO;
- expectedExpos[1] = MOCK_PRICE_FEED_EXPO;
- uint64[] memory expectedPublishTimes = new uint64[](2);
- expectedPublishTimes[0] = publishTime;
- expectedPublishTimes[1] = publishTime;
- // Expect the PriceUpdateExecuted event with all price data
- vm.expectEmit();
- emit PriceUpdateExecuted(
- sequenceNumber,
- defaultProvider,
- priceIds,
- expectedPrices,
- expectedConf,
- expectedExpos,
- expectedPublishTimes
- );
- // Create mock update data and execute callback
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- vm.prank(defaultProvider);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- // Verify callback was executed
- assertEq(consumer.lastSequenceNumber(), sequenceNumber);
- // Compare price feeds array length
- PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
- assertEq(lastFeeds.length, priceFeeds.length);
- // Compare each price feed
- for (uint i = 0; i < priceFeeds.length; i++) {
- assertEq(lastFeeds[i].id, priceFeeds[i].id);
- assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price);
- assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf);
- assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo);
- assertEq(
- lastFeeds[i].price.publishTime,
- priceFeeds[i].price.publishTime
- );
- }
- }
- function testExecuteCallbackFailure() public {
- FailingPulseConsumer failingConsumer = new FailingPulseConsumer(
- address(proxy)
- );
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(
- pulse,
- defaultProvider,
- address(failingConsumer)
- );
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- vm.expectEmit();
- emit PriceUpdateCallbackFailed(
- sequenceNumber,
- defaultProvider,
- priceIds,
- address(failingConsumer),
- "callback failed"
- );
- vm.prank(defaultProvider);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testExecuteCallbackCustomErrorFailure() public {
- CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(
- address(proxy)
- );
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(
- pulse,
- defaultProvider,
- address(failingConsumer)
- );
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- vm.expectEmit();
- emit PriceUpdateCallbackFailed(
- sequenceNumber,
- defaultProvider,
- priceIds,
- address(failingConsumer),
- "low-level error (possibly out of gas)"
- );
- vm.prank(defaultProvider);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testExecuteCallbackWithInsufficientGas() public {
- // Setup request with 1M gas limit
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
- // Setup mock data
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Try executing with only 100K gas when 1M is required
- vm.prank(defaultProvider);
- vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error
- pulse.executeCallback{gas: 100000}(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- ); // Will fail because gasleft() < callbackGasLimit
- }
- function testExecuteCallbackWithFutureTimestamp() public {
- // Setup request with future timestamp
- bytes32[] memory priceIds = createPriceIds();
- uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future
- vm.deal(address(consumer), 1 gwei);
- uint96 totalFee = calculateTotalFee();
- vm.prank(address(consumer));
- uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
- value: totalFee
- }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT);
- // Try to execute callback before the requested timestamp
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- futureTime // Mock price feeds with future timestamp
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- vm.prank(defaultProvider);
- // Should succeed because we're simulating receiving future-dated price updates
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- // Compare price feeds array length
- PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
- assertEq(lastFeeds.length, priceFeeds.length);
- // Compare each price feed publish time
- for (uint i = 0; i < priceFeeds.length; i++) {
- assertEq(
- lastFeeds[i].price.publishTime,
- priceFeeds[i].price.publishTime
- );
- }
- }
- function testRevertOnTooFarFutureTimestamp() public {
- bytes32[] memory priceIds = createPriceIds();
- uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute
- vm.deal(address(consumer), 1 gwei);
- uint96 totalFee = calculateTotalFee();
- vm.prank(address(consumer));
- vm.expectRevert("Too far in future");
- pulse.requestPriceUpdatesWithCallback{value: totalFee}(
- defaultProvider,
- farFutureTime,
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- }
- function testDoubleExecuteCallback() public {
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // First execution
- vm.prank(defaultProvider);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- // Second execution should fail
- vm.prank(defaultProvider);
- vm.expectRevert(NoSuchRequest.selector);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testGetFee() public {
- // Test with different gas limits to verify fee calculation
- uint32[] memory gasLimits = new uint32[](3);
- gasLimits[0] = 100_000;
- gasLimits[1] = 500_000;
- gasLimits[2] = 1_000_000;
- bytes32[] memory priceIds = createPriceIds();
- for (uint256 i = 0; i < gasLimits.length; i++) {
- uint32 gasLimit = gasLimits[i];
- uint96 expectedFee = SafeCast.toUint96(
- DEFAULT_PROVIDER_BASE_FEE +
- DEFAULT_PROVIDER_FEE_PER_FEED *
- priceIds.length +
- DEFAULT_PROVIDER_FEE_PER_GAS *
- gasLimit
- ) + PYTH_FEE;
- uint96 actualFee = pulse.getFee(
- defaultProvider,
- gasLimit,
- priceIds
- );
- assertEq(
- actualFee,
- expectedFee,
- "Fee calculation incorrect for gas limit"
- );
- }
- // Test with zero gas limit
- uint96 expectedMinFee = SafeCast.toUint96(
- PYTH_FEE +
- DEFAULT_PROVIDER_BASE_FEE +
- DEFAULT_PROVIDER_FEE_PER_FEED *
- priceIds.length
- );
- uint96 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds);
- assertEq(
- actualMinFee,
- expectedMinFee,
- "Minimum fee calculation incorrect"
- );
- }
- function testWithdrawFees() public {
- // Setup: Request price update to accrue some fees
- bytes32[] memory priceIds = createPriceIds();
- vm.deal(address(consumer), 1 gwei);
- vm.prank(address(consumer));
- pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
- defaultProvider,
- SafeCast.toUint64(block.timestamp),
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- // Get admin's balance before withdrawal
- uint256 adminBalanceBefore = admin.balance;
- uint128 accruedFees = pulse.getAccruedPythFees();
- // Withdraw fees as admin
- vm.prank(admin);
- pulse.withdrawFees(accruedFees);
- // Verify balances
- assertEq(
- admin.balance,
- adminBalanceBefore + accruedFees,
- "Admin balance should increase by withdrawn amount"
- );
- assertEq(
- pulse.getAccruedPythFees(),
- 0,
- "Contract should have no fees after withdrawal"
- );
- }
- function testWithdrawFeesUnauthorized() public {
- vm.prank(address(0xdead));
- vm.expectRevert("Only admin can withdraw fees");
- pulse.withdrawFees(1 ether);
- }
- function testWithdrawFeesInsufficientBalance() public {
- vm.prank(admin);
- vm.expectRevert("Insufficient balance");
- pulse.withdrawFees(1 ether);
- }
- function testSetAndWithdrawAsFeeManager() public {
- address feeManager = address(0x789);
- vm.prank(defaultProvider);
- pulse.setFeeManager(feeManager);
- // Setup: Request price update to accrue some fees
- bytes32[] memory priceIds = createPriceIds();
- vm.deal(address(consumer), 1 gwei);
- vm.prank(address(consumer));
- pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
- defaultProvider,
- SafeCast.toUint64(block.timestamp),
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- // Get provider's accrued fees instead of total fees
- PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
- defaultProvider
- );
- uint128 providerAccruedFees = providerInfo.accruedFeesInWei;
- uint256 managerBalanceBefore = feeManager.balance;
- vm.prank(feeManager);
- pulse.withdrawAsFeeManager(
- defaultProvider,
- uint96(providerAccruedFees)
- );
- assertEq(
- feeManager.balance,
- managerBalanceBefore + providerAccruedFees,
- "Fee manager balance should increase by withdrawn amount"
- );
- providerInfo = pulse.getProviderInfo(defaultProvider);
- assertEq(
- providerInfo.accruedFeesInWei,
- 0,
- "Provider should have no fees after withdrawal"
- );
- }
- function testSetFeeManagerUnauthorized() public {
- address feeManager = address(0x789);
- vm.prank(address(0xdead));
- vm.expectRevert("Provider not registered");
- pulse.setFeeManager(feeManager);
- }
- function testWithdrawAsFeeManagerUnauthorized() public {
- vm.prank(address(0xdead));
- vm.expectRevert("Only fee manager");
- pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
- }
- function testWithdrawAsFeeManagerInsufficientBalance() public {
- // Set up fee manager first
- address feeManager = address(0x789);
- vm.prank(defaultProvider);
- pulse.setFeeManager(feeManager);
- vm.prank(feeManager);
- vm.expectRevert("Insufficient balance");
- pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
- }
- // Add new test for invalid priceIds
- function testExecuteCallbackWithInvalidPriceIds() public {
- bytes32[] memory priceIds = createPriceIds();
- uint256 publishTime = block.timestamp;
- // Setup request
- (uint64 sequenceNumber, , ) = setupConsumerRequest(
- pulse,
- defaultProvider,
- address(consumer)
- );
- // Create different priceIds
- bytes32[] memory wrongPriceIds = new bytes32[](2);
- wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs
- wrongPriceIds[1] = bytes32(uint256(2));
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Should revert when trying to execute with wrong priceIds
- vm.prank(defaultProvider);
- // Extract first 8 bytes of the price ID for the error expectation
- bytes8 storedPriceIdPrefix;
- assembly {
- storedPriceIdPrefix := mload(add(priceIds, 32))
- }
- vm.expectRevert(
- abi.encodeWithSelector(
- InvalidPriceIds.selector,
- wrongPriceIds[0],
- storedPriceIdPrefix
- )
- );
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- wrongPriceIds
- );
- }
- function testRevertOnTooManyPriceIds() public {
- uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS());
- // Create array with MAX_PRICE_IDS + 1 price IDs
- bytes32[] memory priceIds = new bytes32[](maxPriceIds + 1);
- for (uint i = 0; i < priceIds.length; i++) {
- priceIds[i] = bytes32(uint256(i + 1));
- }
- vm.deal(address(consumer), 1 gwei);
- uint96 totalFee = calculateTotalFee();
- vm.prank(address(consumer));
- vm.expectRevert(
- abi.encodeWithSelector(
- TooManyPriceIds.selector,
- maxPriceIds + 1,
- maxPriceIds
- )
- );
- pulse.requestPriceUpdatesWithCallback{value: totalFee}(
- defaultProvider,
- SafeCast.toUint64(block.timestamp),
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- }
- function testProviderRegistration() public {
- address provider = address(0x123);
- uint96 providerFee = 1000;
- vm.prank(provider);
- pulse.registerProvider(providerFee, providerFee, providerFee);
- PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
- assertEq(info.feePerGasInWei, providerFee);
- assertTrue(info.isRegistered);
- }
- function testSetProviderFee() public {
- address provider = address(0x123);
- uint96 initialBaseFee = 1000;
- uint96 initialFeePerFeed = 2000;
- uint96 initialFeePerGas = 3000;
- uint96 newFeePerFeed = 4000;
- uint96 newBaseFee = 5000;
- uint96 newFeePerGas = 6000;
- vm.prank(provider);
- pulse.registerProvider(
- initialBaseFee,
- initialFeePerFeed,
- initialFeePerGas
- );
- vm.prank(provider);
- pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas);
- PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
- assertEq(info.baseFeeInWei, newBaseFee);
- assertEq(info.feePerFeedInWei, newFeePerFeed);
- assertEq(info.feePerGasInWei, newFeePerGas);
- }
- function testDefaultProvider() public {
- address provider = address(0x123);
- uint96 providerFee = 1000;
- vm.prank(provider);
- pulse.registerProvider(providerFee, providerFee, providerFee);
- vm.prank(admin);
- pulse.setDefaultProvider(provider);
- assertEq(pulse.getDefaultProvider(), provider);
- }
- function testRequestWithProvider() public {
- address provider = address(0x123);
- uint96 providerFee = 1000;
- vm.prank(provider);
- pulse.registerProvider(providerFee, providerFee, providerFee);
- bytes32[] memory priceIds = new bytes32[](1);
- priceIds[0] = bytes32(uint256(1));
- uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds);
- vm.deal(address(consumer), totalFee);
- vm.prank(address(consumer));
- uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
- value: totalFee
- }(
- provider,
- SafeCast.toUint64(block.timestamp),
- priceIds,
- CALLBACK_GAS_LIMIT
- );
- PulseState.Request memory req = pulse.getRequest(sequenceNumber);
- assertEq(req.provider, provider);
- }
- function testExclusivityPeriod() public {
- // Test initial value
- assertEq(
- pulse.getExclusivityPeriod(),
- 15,
- "Initial exclusivity period should be 15 seconds"
- );
- // Test setting new value
- vm.prank(admin);
- vm.expectEmit();
- emit ExclusivityPeriodUpdated(15, 30);
- pulse.setExclusivityPeriod(30);
- assertEq(
- pulse.getExclusivityPeriod(),
- 30,
- "Exclusivity period should be updated"
- );
- }
- function testSetExclusivityPeriodUnauthorized() public {
- vm.prank(address(0xdead));
- vm.expectRevert("Only admin can set exclusivity period");
- pulse.setExclusivityPeriod(30);
- }
- function testExecuteCallbackDuringExclusivity() public {
- // Register a second provider
- address secondProvider = address(0x456);
- vm.prank(secondProvider);
- pulse.registerProvider(
- DEFAULT_PROVIDER_BASE_FEE,
- DEFAULT_PROVIDER_FEE_PER_FEED,
- DEFAULT_PROVIDER_FEE_PER_GAS
- );
- // Setup request
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
- // Setup mock data
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Try to execute with second provider during exclusivity period
- vm.expectRevert("Only assigned provider during exclusivity period");
- pulse.executeCallback(
- secondProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- // Original provider should succeed
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testExecuteCallbackAfterExclusivity() public {
- // Register a second provider
- address secondProvider = address(0x456);
- vm.prank(secondProvider);
- pulse.registerProvider(
- DEFAULT_PROVIDER_BASE_FEE,
- DEFAULT_PROVIDER_FEE_PER_FEED,
- DEFAULT_PROVIDER_FEE_PER_GAS
- );
- // Setup request
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
- // Setup mock data
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Wait for exclusivity period to end
- vm.warp(block.timestamp + pulse.getExclusivityPeriod() + 1);
- // Second provider should now succeed
- vm.prank(secondProvider);
- pulse.executeCallback(
- defaultProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testExecuteCallbackWithCustomExclusivityPeriod() public {
- // Register a second provider
- address secondProvider = address(0x456);
- vm.prank(secondProvider);
- pulse.registerProvider(
- DEFAULT_PROVIDER_BASE_FEE,
- DEFAULT_PROVIDER_FEE_PER_FEED,
- DEFAULT_PROVIDER_FEE_PER_GAS
- );
- // Set custom exclusivity period
- vm.prank(admin);
- pulse.setExclusivityPeriod(30);
- // Setup request
- (
- uint64 sequenceNumber,
- bytes32[] memory priceIds,
- uint256 publishTime
- ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
- // Setup mock data
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Try at 29 seconds (should fail for second provider)
- vm.warp(block.timestamp + 29);
- vm.expectRevert("Only assigned provider during exclusivity period");
- pulse.executeCallback(
- secondProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- // Try at 31 seconds (should succeed for second provider)
- vm.warp(block.timestamp + 2);
- pulse.executeCallback(
- secondProvider,
- sequenceNumber,
- updateData,
- priceIds
- );
- }
- function testGetFirstActiveRequests() public {
- // Setup test data
- (
- bytes32[] memory priceIds,
- bytes[] memory updateData
- ) = setupTestData();
- createTestRequests(priceIds);
- completeRequests(updateData, priceIds);
- testRequestScenarios(priceIds, updateData);
- }
- function setupTestData()
- private
- pure
- returns (bytes32[] memory, bytes[] memory)
- {
- bytes32[] memory priceIds = new bytes32[](1);
- priceIds[0] = bytes32(uint256(1));
- bytes[] memory updateData = new bytes[](1);
- return (priceIds, updateData);
- }
- function createTestRequests(bytes32[] memory priceIds) private {
- uint64 publishTime = SafeCast.toUint64(block.timestamp);
- for (uint i = 0; i < 5; i++) {
- vm.deal(address(this), 1 ether);
- pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
- defaultProvider,
- publishTime,
- priceIds,
- 1000000
- );
- }
- }
- function completeRequests(
- bytes[] memory updateData,
- bytes32[] memory priceIds
- ) private {
- // Create mock price feeds and setup Pyth response
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- SafeCast.toUint64(block.timestamp)
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- updateData = createMockUpdateData(priceFeeds);
- vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
- vm.startPrank(defaultProvider);
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- 2,
- updateData,
- priceIds
- );
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- 4,
- updateData,
- priceIds
- );
- vm.stopPrank();
- }
- function testRequestScenarios(
- bytes32[] memory priceIds,
- bytes[] memory updateData
- ) private {
- // Test 1: Request more than available
- checkMoreThanAvailable();
- // Test 2: Request exact number
- checkExactNumber();
- // Test 3: Request fewer than available
- checkFewerThanAvailable();
- // Test 4: Request zero
- checkZeroRequest();
- // Test 5: Clear all and check empty
- clearAllRequests(updateData, priceIds);
- checkEmptyState();
- }
- // Split test scenarios into separate functions
- function checkMoreThanAvailable() private {
- (PulseState.Request[] memory requests, uint256 count) = pulse
- .getFirstActiveRequests(10);
- assertEq(count, 3, "Should find 3 active requests");
- assertEq(requests.length, 3, "Array should be resized to 3");
- assertEq(
- requests[0].sequenceNumber,
- 1,
- "First request should be oldest"
- );
- assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
- assertEq(requests[2].sequenceNumber, 5, "Third request should be #5");
- }
- function checkExactNumber() private {
- (PulseState.Request[] memory requests, uint256 count) = pulse
- .getFirstActiveRequests(3);
- assertEq(count, 3, "Should find 3 active requests");
- assertEq(requests.length, 3, "Array should match requested size");
- }
- function checkFewerThanAvailable() private {
- (PulseState.Request[] memory requests, uint256 count) = pulse
- .getFirstActiveRequests(2);
- assertEq(count, 2, "Should find 2 active requests");
- assertEq(requests.length, 2, "Array should match requested size");
- assertEq(
- requests[0].sequenceNumber,
- 1,
- "First request should be oldest"
- );
- assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
- }
- function checkZeroRequest() private {
- (PulseState.Request[] memory requests, uint256 count) = pulse
- .getFirstActiveRequests(0);
- assertEq(count, 0, "Should find 0 active requests");
- assertEq(requests.length, 0, "Array should be empty");
- }
- function clearAllRequests(
- bytes[] memory updateData,
- bytes32[] memory priceIds
- ) private {
- vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
- vm.startPrank(defaultProvider);
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- 1,
- updateData,
- priceIds
- );
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- 3,
- updateData,
- priceIds
- );
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- 5,
- updateData,
- priceIds
- );
- vm.stopPrank();
- }
- function checkEmptyState() private {
- (PulseState.Request[] memory requests, uint256 count) = pulse
- .getFirstActiveRequests(10);
- assertEq(count, 0, "Should find 0 active requests");
- assertEq(requests.length, 0, "Array should be empty");
- }
- function testGetFirstActiveRequestsGasUsage() public {
- // Setup test data
- bytes32[] memory priceIds = new bytes32[](1);
- priceIds[0] = bytes32(uint256(1));
- uint64 publishTime = SafeCast.toUint64(block.timestamp);
- uint256 callbackGasLimit = 1000000;
- // Create mock price feeds and setup Pyth response
- PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
- publishTime
- );
- mockParsePriceFeedUpdates(pyth, priceFeeds);
- bytes[] memory updateData = createMockUpdateData(priceFeeds);
- // Create 20 requests with some gaps
- for (uint i = 0; i < 20; i++) {
- vm.deal(address(this), 1 ether);
- pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
- defaultProvider,
- publishTime,
- priceIds,
- uint32(callbackGasLimit)
- );
- // Complete every third request to create gaps
- if (i % 3 == 0) {
- vm.deal(defaultProvider, 1 ether);
- vm.prank(defaultProvider);
- pulse.executeCallback{value: 1 ether}(
- defaultProvider,
- uint64(i + 1),
- updateData,
- priceIds
- );
- }
- }
- // Measure gas for different request counts
- uint256 gas1 = gasleft();
- pulse.getFirstActiveRequests(5);
- uint256 gas1Used = gas1 - gasleft();
- uint256 gas2 = gasleft();
- pulse.getFirstActiveRequests(10);
- uint256 gas2Used = gas2 - gasleft();
- // Log gas usage for analysis
- emit log_named_uint("Gas used for 5 requests", gas1Used);
- emit log_named_uint("Gas used for 10 requests", gas2Used);
- // Verify gas usage scales roughly linearly
- // Allow 10% margin for other factors
- assertApproxEqRel(
- gas2Used,
- gas1Used * 2,
- 0.1e18, // 10% tolerance
- "Gas usage should scale roughly linearly"
- );
- }
- function getPulse() internal view override returns (address) {
- return address(pulse);
- }
- // Mock implementation of pulseCallback
- function pulseCallback(
- uint64 sequenceNumber,
- PythStructs.PriceFeed[] memory priceFeeds
- ) internal override {
- // Just accept the callback, no need to do anything with the data
- // This prevents the revert we're seeing
- }
- }
|