Pulse.t.sol 38 KB


  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "forge-std/Test.sol";
  4. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  5. import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
  6. import "./utils/PulseTestUtils.t.sol";
  7. import "../contracts/pulse/PulseUpgradeable.sol";
  8. import "../contracts/pulse/IPulse.sol";
  9. import "../contracts/pulse/PulseState.sol";
  10. import "../contracts/pulse/PulseEvents.sol";
  11. import "../contracts/pulse/PulseErrors.sol";
  12. contract MockPulseConsumer is IPulseConsumer {
  13. address private _pulse;
  14. uint64 public lastSequenceNumber;
  15. PythStructs.PriceFeed[] private _lastPriceFeeds;
  16. constructor(address pulse) {
  17. _pulse = pulse;
  18. }
  19. function getPulse() internal view override returns (address) {
  20. return _pulse;
  21. }
  22. function pulseCallback(
  23. uint64 sequenceNumber,
  24. PythStructs.PriceFeed[] memory priceFeeds
  25. ) internal override {
  26. lastSequenceNumber = sequenceNumber;
  27. for (uint i = 0; i < priceFeeds.length; i++) {
  28. _lastPriceFeeds.push(priceFeeds[i]);
  29. }
  30. }
  31. function lastPriceFeeds()
  32. external
  33. view
  34. returns (PythStructs.PriceFeed[] memory)
  35. {
  36. return _lastPriceFeeds;
  37. }
  38. }
  39. contract FailingPulseConsumer is IPulseConsumer {
  40. address private _pulse;
  41. constructor(address pulse) {
  42. _pulse = pulse;
  43. }
  44. function getPulse() internal view override returns (address) {
  45. return _pulse;
  46. }
  47. function pulseCallback(
  48. uint64,
  49. PythStructs.PriceFeed[] memory
  50. ) internal pure override {
  51. revert("callback failed");
  52. }
  53. }
  54. contract CustomErrorPulseConsumer is IPulseConsumer {
  55. error CustomError(string message);
  56. address private _pulse;
  57. constructor(address pulse) {
  58. _pulse = pulse;
  59. }
  60. function getPulse() internal view override returns (address) {
  61. return _pulse;
  62. }
  63. function pulseCallback(
  64. uint64,
  65. PythStructs.PriceFeed[] memory
  66. ) internal pure override {
  67. revert CustomError("callback failed");
  68. }
  69. }
  70. // FIXME: this shouldn't be IPulseConsumer.
  71. contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils {
  72. ERC1967Proxy public proxy;
  73. PulseUpgradeable public pulse;
  74. MockPulseConsumer public consumer;
  75. address public owner;
  76. address public admin;
  77. address public pyth;
  78. address public defaultProvider;
  79. // Constants
  80. uint96 constant PYTH_FEE = 1 wei;
  81. uint96 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei;
  82. uint96 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei;
  83. uint96 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei;
  84. function setUp() public {
  85. owner = address(1);
  86. admin = address(2);
  87. pyth = address(3);
  88. defaultProvider = address(4);
  89. PulseUpgradeable _pulse = new PulseUpgradeable();
  90. proxy = new ERC1967Proxy(address(_pulse), "");
  91. pulse = PulseUpgradeable(address(proxy));
  92. pulse.initialize(
  93. owner,
  94. admin,
  95. PYTH_FEE,
  96. pyth,
  97. defaultProvider,
  98. false,
  99. 15
  100. );
  101. vm.prank(defaultProvider);
  102. pulse.registerProvider(
  103. DEFAULT_PROVIDER_BASE_FEE,
  104. DEFAULT_PROVIDER_FEE_PER_FEED,
  105. DEFAULT_PROVIDER_FEE_PER_GAS
  106. );
  107. consumer = new MockPulseConsumer(address(proxy));
  108. }
  109. // Helper function to calculate total fee
  110. // FIXME: I think this helper probably needs to take some arguments.
  111. function calculateTotalFee() internal view returns (uint96) {
  112. return
  113. pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds());
  114. }
  115. function testRequestPriceUpdate() public {
  116. // Set a realistic gas price
  117. vm.txGasPrice(30 gwei);
  118. bytes32[] memory priceIds = createPriceIds();
  119. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  120. // Fund the consumer contract with enough ETH for higher gas price
  121. vm.deal(address(consumer), 1 ether);
  122. uint96 totalFee = calculateTotalFee();
  123. // Create the event data we expect to see
  124. bytes8[] memory expectedPriceIdPrefixes = new bytes8[](2);
  125. {
  126. bytes32 priceId0 = priceIds[0];
  127. bytes32 priceId1 = priceIds[1];
  128. bytes8 prefix0;
  129. bytes8 prefix1;
  130. assembly {
  131. prefix0 := priceId0
  132. prefix1 := priceId1
  133. }
  134. expectedPriceIdPrefixes[0] = prefix0;
  135. expectedPriceIdPrefixes[1] = prefix1;
  136. }
  137. PulseState.Request memory expectedRequest = PulseState.Request({
  138. sequenceNumber: 1,
  139. publishTime: publishTime,
  140. priceIdPrefixes: expectedPriceIdPrefixes,
  141. callbackGasLimit: uint32(CALLBACK_GAS_LIMIT),
  142. requester: address(consumer),
  143. provider: defaultProvider,
  144. fee: totalFee - PYTH_FEE
  145. });
  146. vm.expectEmit();
  147. emit PriceUpdateRequested(expectedRequest, priceIds);
  148. vm.prank(address(consumer));
  149. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  150. defaultProvider,
  151. publishTime,
  152. priceIds,
  153. CALLBACK_GAS_LIMIT
  154. );
  155. // Additional assertions to verify event data was stored correctly
  156. PulseState.Request memory lastRequest = pulse.getRequest(1);
  157. assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber);
  158. assertEq(lastRequest.publishTime, expectedRequest.publishTime);
  159. assertEq(
  160. lastRequest.priceIdPrefixes.length,
  161. expectedRequest.priceIdPrefixes.length
  162. );
  163. for (uint8 i = 0; i < lastRequest.priceIdPrefixes.length; i++) {
  164. assertEq(
  165. lastRequest.priceIdPrefixes[i],
  166. expectedRequest.priceIdPrefixes[i]
  167. );
  168. }
  169. assertEq(
  170. lastRequest.callbackGasLimit,
  171. expectedRequest.callbackGasLimit
  172. );
  173. assertEq(
  174. lastRequest.requester,
  175. expectedRequest.requester,
  176. "Requester mismatch"
  177. );
  178. }
  179. function testRequestWithInsufficientFee() public {
  180. // Set a realistic gas price
  181. vm.txGasPrice(30 gwei);
  182. bytes32[] memory priceIds = createPriceIds();
  183. vm.deal(address(consumer), 1 ether);
  184. vm.prank(address(consumer));
  185. vm.expectRevert(InsufficientFee.selector);
  186. pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee
  187. defaultProvider,
  188. SafeCast.toUint64(block.timestamp),
  189. priceIds,
  190. CALLBACK_GAS_LIMIT
  191. );
  192. }
  193. function testExecuteCallback() public {
  194. bytes32[] memory priceIds = createPriceIds();
  195. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  196. // Fund the consumer contract
  197. vm.deal(address(consumer), 1 gwei);
  198. uint96 totalFee = calculateTotalFee();
  199. // Step 1: Make the request as consumer
  200. vm.prank(address(consumer));
  201. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  202. value: totalFee
  203. }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT);
  204. // Step 2: Create mock price feeds and setup Pyth response
  205. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  206. publishTime
  207. );
  208. // FIXME: this test doesn't ensure the Pyth fee is paid.
  209. mockParsePriceFeedUpdates(pyth, priceFeeds);
  210. // Create arrays for expected event data
  211. int64[] memory expectedPrices = new int64[](2);
  212. expectedPrices[0] = MOCK_BTC_PRICE;
  213. expectedPrices[1] = MOCK_ETH_PRICE;
  214. uint64[] memory expectedConf = new uint64[](2);
  215. expectedConf[0] = MOCK_BTC_CONF;
  216. expectedConf[1] = MOCK_ETH_CONF;
  217. int32[] memory expectedExpos = new int32[](2);
  218. expectedExpos[0] = MOCK_PRICE_FEED_EXPO;
  219. expectedExpos[1] = MOCK_PRICE_FEED_EXPO;
  220. uint64[] memory expectedPublishTimes = new uint64[](2);
  221. expectedPublishTimes[0] = publishTime;
  222. expectedPublishTimes[1] = publishTime;
  223. // Expect the PriceUpdateExecuted event with all price data
  224. vm.expectEmit();
  225. emit PriceUpdateExecuted(
  226. sequenceNumber,
  227. defaultProvider,
  228. priceIds,
  229. expectedPrices,
  230. expectedConf,
  231. expectedExpos,
  232. expectedPublishTimes
  233. );
  234. // Create mock update data and execute callback
  235. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  236. vm.prank(defaultProvider);
  237. pulse.executeCallback(
  238. defaultProvider,
  239. sequenceNumber,
  240. updateData,
  241. priceIds
  242. );
  243. // Verify callback was executed
  244. assertEq(consumer.lastSequenceNumber(), sequenceNumber);
  245. // Compare price feeds array length
  246. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  247. assertEq(lastFeeds.length, priceFeeds.length);
  248. // Compare each price feed
  249. for (uint i = 0; i < priceFeeds.length; i++) {
  250. assertEq(lastFeeds[i].id, priceFeeds[i].id);
  251. assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price);
  252. assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf);
  253. assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo);
  254. assertEq(
  255. lastFeeds[i].price.publishTime,
  256. priceFeeds[i].price.publishTime
  257. );
  258. }
  259. }
  260. function testExecuteCallbackFailure() public {
  261. FailingPulseConsumer failingConsumer = new FailingPulseConsumer(
  262. address(proxy)
  263. );
  264. (
  265. uint64 sequenceNumber,
  266. bytes32[] memory priceIds,
  267. uint256 publishTime
  268. ) = setupConsumerRequest(
  269. pulse,
  270. defaultProvider,
  271. address(failingConsumer)
  272. );
  273. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  274. publishTime
  275. );
  276. mockParsePriceFeedUpdates(pyth, priceFeeds);
  277. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  278. vm.expectEmit();
  279. emit PriceUpdateCallbackFailed(
  280. sequenceNumber,
  281. defaultProvider,
  282. priceIds,
  283. address(failingConsumer),
  284. "callback failed"
  285. );
  286. vm.prank(defaultProvider);
  287. pulse.executeCallback(
  288. defaultProvider,
  289. sequenceNumber,
  290. updateData,
  291. priceIds
  292. );
  293. }
  294. function testExecuteCallbackCustomErrorFailure() public {
  295. CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(
  296. address(proxy)
  297. );
  298. (
  299. uint64 sequenceNumber,
  300. bytes32[] memory priceIds,
  301. uint256 publishTime
  302. ) = setupConsumerRequest(
  303. pulse,
  304. defaultProvider,
  305. address(failingConsumer)
  306. );
  307. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  308. publishTime
  309. );
  310. mockParsePriceFeedUpdates(pyth, priceFeeds);
  311. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  312. vm.expectEmit();
  313. emit PriceUpdateCallbackFailed(
  314. sequenceNumber,
  315. defaultProvider,
  316. priceIds,
  317. address(failingConsumer),
  318. "low-level error (possibly out of gas)"
  319. );
  320. vm.prank(defaultProvider);
  321. pulse.executeCallback(
  322. defaultProvider,
  323. sequenceNumber,
  324. updateData,
  325. priceIds
  326. );
  327. }
  328. function testExecuteCallbackWithInsufficientGas() public {
  329. // Setup request with 1M gas limit
  330. (
  331. uint64 sequenceNumber,
  332. bytes32[] memory priceIds,
  333. uint256 publishTime
  334. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  335. // Setup mock data
  336. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  337. publishTime
  338. );
  339. mockParsePriceFeedUpdates(pyth, priceFeeds);
  340. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  341. // Try executing with only 100K gas when 1M is required
  342. vm.prank(defaultProvider);
  343. vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error
  344. pulse.executeCallback{gas: 100000}(
  345. defaultProvider,
  346. sequenceNumber,
  347. updateData,
  348. priceIds
  349. ); // Will fail because gasleft() < callbackGasLimit
  350. }
  351. function testExecuteCallbackWithFutureTimestamp() public {
  352. // Setup request with future timestamp
  353. bytes32[] memory priceIds = createPriceIds();
  354. uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future
  355. vm.deal(address(consumer), 1 gwei);
  356. uint96 totalFee = calculateTotalFee();
  357. vm.prank(address(consumer));
  358. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  359. value: totalFee
  360. }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT);
  361. // Try to execute callback before the requested timestamp
  362. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  363. futureTime // Mock price feeds with future timestamp
  364. );
  365. mockParsePriceFeedUpdates(pyth, priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices
  366. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  367. vm.prank(defaultProvider);
  368. // Should succeed because we're simulating receiving future-dated price updates
  369. pulse.executeCallback(
  370. defaultProvider,
  371. sequenceNumber,
  372. updateData,
  373. priceIds
  374. );
  375. // Compare price feeds array length
  376. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  377. assertEq(lastFeeds.length, priceFeeds.length);
  378. // Compare each price feed publish time
  379. for (uint i = 0; i < priceFeeds.length; i++) {
  380. assertEq(
  381. lastFeeds[i].price.publishTime,
  382. priceFeeds[i].price.publishTime
  383. );
  384. }
  385. }
  386. function testRevertOnTooFarFutureTimestamp() public {
  387. bytes32[] memory priceIds = createPriceIds();
  388. uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute
  389. vm.deal(address(consumer), 1 gwei);
  390. uint96 totalFee = calculateTotalFee();
  391. vm.prank(address(consumer));
  392. vm.expectRevert("Too far in future");
  393. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  394. defaultProvider,
  395. farFutureTime,
  396. priceIds,
  397. CALLBACK_GAS_LIMIT
  398. );
  399. }
  400. function testDoubleExecuteCallback() public {
  401. (
  402. uint64 sequenceNumber,
  403. bytes32[] memory priceIds,
  404. uint256 publishTime
  405. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  406. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  407. publishTime
  408. );
  409. mockParsePriceFeedUpdates(pyth, priceFeeds);
  410. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  411. // First execution
  412. vm.prank(defaultProvider);
  413. pulse.executeCallback(
  414. defaultProvider,
  415. sequenceNumber,
  416. updateData,
  417. priceIds
  418. );
  419. // Second execution should fail
  420. vm.prank(defaultProvider);
  421. vm.expectRevert(NoSuchRequest.selector);
  422. pulse.executeCallback(
  423. defaultProvider,
  424. sequenceNumber,
  425. updateData,
  426. priceIds
  427. );
  428. }
  429. function testGetFee() public {
  430. // Test with different gas limits to verify fee calculation
  431. uint32[] memory gasLimits = new uint32[](3);
  432. gasLimits[0] = 100_000;
  433. gasLimits[1] = 500_000;
  434. gasLimits[2] = 1_000_000;
  435. bytes32[] memory priceIds = createPriceIds();
  436. for (uint256 i = 0; i < gasLimits.length; i++) {
  437. uint32 gasLimit = gasLimits[i];
  438. uint96 expectedFee = SafeCast.toUint96(
  439. DEFAULT_PROVIDER_BASE_FEE +
  440. DEFAULT_PROVIDER_FEE_PER_FEED *
  441. priceIds.length +
  442. DEFAULT_PROVIDER_FEE_PER_GAS *
  443. gasLimit
  444. ) + PYTH_FEE;
  445. uint96 actualFee = pulse.getFee(
  446. defaultProvider,
  447. gasLimit,
  448. priceIds
  449. );
  450. assertEq(
  451. actualFee,
  452. expectedFee,
  453. "Fee calculation incorrect for gas limit"
  454. );
  455. }
  456. // Test with zero gas limit
  457. uint96 expectedMinFee = SafeCast.toUint96(
  458. PYTH_FEE +
  459. DEFAULT_PROVIDER_BASE_FEE +
  460. DEFAULT_PROVIDER_FEE_PER_FEED *
  461. priceIds.length
  462. );
  463. uint96 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds);
  464. assertEq(
  465. actualMinFee,
  466. expectedMinFee,
  467. "Minimum fee calculation incorrect"
  468. );
  469. }
  470. function testWithdrawFees() public {
  471. // Setup: Request price update to accrue some fees
  472. bytes32[] memory priceIds = createPriceIds();
  473. vm.deal(address(consumer), 1 gwei);
  474. vm.prank(address(consumer));
  475. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  476. defaultProvider,
  477. SafeCast.toUint64(block.timestamp),
  478. priceIds,
  479. CALLBACK_GAS_LIMIT
  480. );
  481. // Get admin's balance before withdrawal
  482. uint256 adminBalanceBefore = admin.balance;
  483. uint128 accruedFees = pulse.getAccruedPythFees();
  484. // Withdraw fees as admin
  485. vm.prank(admin);
  486. pulse.withdrawFees(accruedFees);
  487. // Verify balances
  488. assertEq(
  489. admin.balance,
  490. adminBalanceBefore + accruedFees,
  491. "Admin balance should increase by withdrawn amount"
  492. );
  493. assertEq(
  494. pulse.getAccruedPythFees(),
  495. 0,
  496. "Contract should have no fees after withdrawal"
  497. );
  498. }
  499. function testWithdrawFeesUnauthorized() public {
  500. vm.prank(address(0xdead));
  501. vm.expectRevert("Only admin can withdraw fees");
  502. pulse.withdrawFees(1 ether);
  503. }
  504. function testWithdrawFeesInsufficientBalance() public {
  505. vm.prank(admin);
  506. vm.expectRevert("Insufficient balance");
  507. pulse.withdrawFees(1 ether);
  508. }
  509. function testSetAndWithdrawAsFeeManager() public {
  510. address feeManager = address(0x789);
  511. vm.prank(defaultProvider);
  512. pulse.setFeeManager(feeManager);
  513. // Setup: Request price update to accrue some fees
  514. bytes32[] memory priceIds = createPriceIds();
  515. vm.deal(address(consumer), 1 gwei);
  516. vm.prank(address(consumer));
  517. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  518. defaultProvider,
  519. SafeCast.toUint64(block.timestamp),
  520. priceIds,
  521. CALLBACK_GAS_LIMIT
  522. );
  523. // Get provider's accrued fees instead of total fees
  524. PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
  525. defaultProvider
  526. );
  527. uint128 providerAccruedFees = providerInfo.accruedFeesInWei;
  528. uint256 managerBalanceBefore = feeManager.balance;
  529. vm.prank(feeManager);
  530. pulse.withdrawAsFeeManager(
  531. defaultProvider,
  532. uint96(providerAccruedFees)
  533. );
  534. assertEq(
  535. feeManager.balance,
  536. managerBalanceBefore + providerAccruedFees,
  537. "Fee manager balance should increase by withdrawn amount"
  538. );
  539. providerInfo = pulse.getProviderInfo(defaultProvider);
  540. assertEq(
  541. providerInfo.accruedFeesInWei,
  542. 0,
  543. "Provider should have no fees after withdrawal"
  544. );
  545. }
  546. function testSetFeeManagerUnauthorized() public {
  547. address feeManager = address(0x789);
  548. vm.prank(address(0xdead));
  549. vm.expectRevert("Provider not registered");
  550. pulse.setFeeManager(feeManager);
  551. }
  552. function testWithdrawAsFeeManagerUnauthorized() public {
  553. vm.prank(address(0xdead));
  554. vm.expectRevert("Only fee manager");
  555. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  556. }
  557. function testWithdrawAsFeeManagerInsufficientBalance() public {
  558. // Set up fee manager first
  559. address feeManager = address(0x789);
  560. vm.prank(defaultProvider);
  561. pulse.setFeeManager(feeManager);
  562. vm.prank(feeManager);
  563. vm.expectRevert("Insufficient balance");
  564. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  565. }
  566. // Add new test for invalid priceIds
  567. function testExecuteCallbackWithInvalidPriceIds() public {
  568. bytes32[] memory priceIds = createPriceIds();
  569. uint256 publishTime = block.timestamp;
  570. // Setup request
  571. (uint64 sequenceNumber, , ) = setupConsumerRequest(
  572. pulse,
  573. defaultProvider,
  574. address(consumer)
  575. );
  576. // Create different priceIds
  577. bytes32[] memory wrongPriceIds = new bytes32[](2);
  578. wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs
  579. wrongPriceIds[1] = bytes32(uint256(2));
  580. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  581. publishTime
  582. );
  583. mockParsePriceFeedUpdates(pyth, priceFeeds);
  584. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  585. // Should revert when trying to execute with wrong priceIds
  586. vm.prank(defaultProvider);
  587. // Extract first 8 bytes of the price ID for the error expectation
  588. bytes8 storedPriceIdPrefix;
  589. assembly {
  590. storedPriceIdPrefix := mload(add(priceIds, 32))
  591. }
  592. vm.expectRevert(
  593. abi.encodeWithSelector(
  594. InvalidPriceIds.selector,
  595. wrongPriceIds[0],
  596. storedPriceIdPrefix
  597. )
  598. );
  599. pulse.executeCallback(
  600. defaultProvider,
  601. sequenceNumber,
  602. updateData,
  603. wrongPriceIds
  604. );
  605. }
  606. function testRevertOnTooManyPriceIds() public {
  607. uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS());
  608. // Create array with MAX_PRICE_IDS + 1 price IDs
  609. bytes32[] memory priceIds = new bytes32[](maxPriceIds + 1);
  610. for (uint i = 0; i < priceIds.length; i++) {
  611. priceIds[i] = bytes32(uint256(i + 1));
  612. }
  613. vm.deal(address(consumer), 1 gwei);
  614. uint96 totalFee = calculateTotalFee();
  615. vm.prank(address(consumer));
  616. vm.expectRevert(
  617. abi.encodeWithSelector(
  618. TooManyPriceIds.selector,
  619. maxPriceIds + 1,
  620. maxPriceIds
  621. )
  622. );
  623. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  624. defaultProvider,
  625. SafeCast.toUint64(block.timestamp),
  626. priceIds,
  627. CALLBACK_GAS_LIMIT
  628. );
  629. }
  630. function testProviderRegistration() public {
  631. address provider = address(0x123);
  632. uint96 providerFee = 1000;
  633. vm.prank(provider);
  634. pulse.registerProvider(providerFee, providerFee, providerFee);
  635. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  636. assertEq(info.feePerGasInWei, providerFee);
  637. assertTrue(info.isRegistered);
  638. }
  639. function testSetProviderFee() public {
  640. address provider = address(0x123);
  641. uint96 initialBaseFee = 1000;
  642. uint96 initialFeePerFeed = 2000;
  643. uint96 initialFeePerGas = 3000;
  644. uint96 newFeePerFeed = 4000;
  645. uint96 newBaseFee = 5000;
  646. uint96 newFeePerGas = 6000;
  647. vm.prank(provider);
  648. pulse.registerProvider(
  649. initialBaseFee,
  650. initialFeePerFeed,
  651. initialFeePerGas
  652. );
  653. vm.prank(provider);
  654. pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas);
  655. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  656. assertEq(info.baseFeeInWei, newBaseFee);
  657. assertEq(info.feePerFeedInWei, newFeePerFeed);
  658. assertEq(info.feePerGasInWei, newFeePerGas);
  659. }
  660. function testDefaultProvider() public {
  661. address provider = address(0x123);
  662. uint96 providerFee = 1000;
  663. vm.prank(provider);
  664. pulse.registerProvider(providerFee, providerFee, providerFee);
  665. vm.prank(admin);
  666. pulse.setDefaultProvider(provider);
  667. assertEq(pulse.getDefaultProvider(), provider);
  668. }
  669. function testRequestWithProvider() public {
  670. address provider = address(0x123);
  671. uint96 providerFee = 1000;
  672. vm.prank(provider);
  673. pulse.registerProvider(providerFee, providerFee, providerFee);
  674. bytes32[] memory priceIds = new bytes32[](1);
  675. priceIds[0] = bytes32(uint256(1));
  676. uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds);
  677. vm.deal(address(consumer), totalFee);
  678. vm.prank(address(consumer));
  679. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  680. value: totalFee
  681. }(
  682. provider,
  683. SafeCast.toUint64(block.timestamp),
  684. priceIds,
  685. CALLBACK_GAS_LIMIT
  686. );
  687. PulseState.Request memory req = pulse.getRequest(sequenceNumber);
  688. assertEq(req.provider, provider);
  689. }
  690. function testExclusivityPeriod() public {
  691. // Test initial value
  692. assertEq(
  693. pulse.getExclusivityPeriod(),
  694. 15,
  695. "Initial exclusivity period should be 15 seconds"
  696. );
  697. // Test setting new value
  698. vm.prank(admin);
  699. vm.expectEmit();
  700. emit ExclusivityPeriodUpdated(15, 30);
  701. pulse.setExclusivityPeriod(30);
  702. assertEq(
  703. pulse.getExclusivityPeriod(),
  704. 30,
  705. "Exclusivity period should be updated"
  706. );
  707. }
  708. function testSetExclusivityPeriodUnauthorized() public {
  709. vm.prank(address(0xdead));
  710. vm.expectRevert("Only admin can set exclusivity period");
  711. pulse.setExclusivityPeriod(30);
  712. }
  713. function testExecuteCallbackDuringExclusivity() public {
  714. // Register a second provider
  715. address secondProvider = address(0x456);
  716. vm.prank(secondProvider);
  717. pulse.registerProvider(
  718. DEFAULT_PROVIDER_BASE_FEE,
  719. DEFAULT_PROVIDER_FEE_PER_FEED,
  720. DEFAULT_PROVIDER_FEE_PER_GAS
  721. );
  722. // Setup request
  723. (
  724. uint64 sequenceNumber,
  725. bytes32[] memory priceIds,
  726. uint256 publishTime
  727. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  728. // Setup mock data
  729. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  730. publishTime
  731. );
  732. mockParsePriceFeedUpdates(pyth, priceFeeds);
  733. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  734. // Try to execute with second provider during exclusivity period
  735. vm.expectRevert("Only assigned provider during exclusivity period");
  736. pulse.executeCallback(
  737. secondProvider,
  738. sequenceNumber,
  739. updateData,
  740. priceIds
  741. );
  742. // Original provider should succeed
  743. pulse.executeCallback(
  744. defaultProvider,
  745. sequenceNumber,
  746. updateData,
  747. priceIds
  748. );
  749. }
  750. function testExecuteCallbackAfterExclusivity() public {
  751. // Register a second provider
  752. address secondProvider = address(0x456);
  753. vm.prank(secondProvider);
  754. pulse.registerProvider(
  755. DEFAULT_PROVIDER_BASE_FEE,
  756. DEFAULT_PROVIDER_FEE_PER_FEED,
  757. DEFAULT_PROVIDER_FEE_PER_GAS
  758. );
  759. // Setup request
  760. (
  761. uint64 sequenceNumber,
  762. bytes32[] memory priceIds,
  763. uint256 publishTime
  764. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  765. // Setup mock data
  766. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  767. publishTime
  768. );
  769. mockParsePriceFeedUpdates(pyth, priceFeeds);
  770. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  771. // Wait for exclusivity period to end
  772. vm.warp(block.timestamp + pulse.getExclusivityPeriod() + 1);
  773. // Second provider should now succeed
  774. vm.prank(secondProvider);
  775. pulse.executeCallback(
  776. defaultProvider,
  777. sequenceNumber,
  778. updateData,
  779. priceIds
  780. );
  781. }
  782. function testExecuteCallbackWithCustomExclusivityPeriod() public {
  783. // Register a second provider
  784. address secondProvider = address(0x456);
  785. vm.prank(secondProvider);
  786. pulse.registerProvider(
  787. DEFAULT_PROVIDER_BASE_FEE,
  788. DEFAULT_PROVIDER_FEE_PER_FEED,
  789. DEFAULT_PROVIDER_FEE_PER_GAS
  790. );
  791. // Set custom exclusivity period
  792. vm.prank(admin);
  793. pulse.setExclusivityPeriod(30);
  794. // Setup request
  795. (
  796. uint64 sequenceNumber,
  797. bytes32[] memory priceIds,
  798. uint256 publishTime
  799. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  800. // Setup mock data
  801. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  802. publishTime
  803. );
  804. mockParsePriceFeedUpdates(pyth, priceFeeds);
  805. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  806. // Try at 29 seconds (should fail for second provider)
  807. vm.warp(block.timestamp + 29);
  808. vm.expectRevert("Only assigned provider during exclusivity period");
  809. pulse.executeCallback(
  810. secondProvider,
  811. sequenceNumber,
  812. updateData,
  813. priceIds
  814. );
  815. // Try at 31 seconds (should succeed for second provider)
  816. vm.warp(block.timestamp + 2);
  817. pulse.executeCallback(
  818. secondProvider,
  819. sequenceNumber,
  820. updateData,
  821. priceIds
  822. );
  823. }
  824. function testGetFirstActiveRequests() public {
  825. // Setup test data
  826. (
  827. bytes32[] memory priceIds,
  828. bytes[] memory updateData
  829. ) = setupTestData();
  830. createTestRequests(priceIds);
  831. completeRequests(updateData, priceIds);
  832. testRequestScenarios(priceIds, updateData);
  833. }
  834. function setupTestData()
  835. private
  836. pure
  837. returns (bytes32[] memory, bytes[] memory)
  838. {
  839. bytes32[] memory priceIds = new bytes32[](1);
  840. priceIds[0] = bytes32(uint256(1));
  841. bytes[] memory updateData = new bytes[](1);
  842. return (priceIds, updateData);
  843. }
  844. function createTestRequests(bytes32[] memory priceIds) private {
  845. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  846. for (uint i = 0; i < 5; i++) {
  847. vm.deal(address(this), 1 ether);
  848. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  849. defaultProvider,
  850. publishTime,
  851. priceIds,
  852. 1000000
  853. );
  854. }
  855. }
  856. function completeRequests(
  857. bytes[] memory updateData,
  858. bytes32[] memory priceIds
  859. ) private {
  860. // Create mock price feeds and setup Pyth response
  861. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  862. SafeCast.toUint64(block.timestamp)
  863. );
  864. mockParsePriceFeedUpdates(pyth, priceFeeds);
  865. updateData = createMockUpdateData(priceFeeds);
  866. vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
  867. vm.startPrank(defaultProvider);
  868. pulse.executeCallback{value: 1 ether}(
  869. defaultProvider,
  870. 2,
  871. updateData,
  872. priceIds
  873. );
  874. pulse.executeCallback{value: 1 ether}(
  875. defaultProvider,
  876. 4,
  877. updateData,
  878. priceIds
  879. );
  880. vm.stopPrank();
  881. }
  882. function testRequestScenarios(
  883. bytes32[] memory priceIds,
  884. bytes[] memory updateData
  885. ) private {
  886. // Test 1: Request more than available
  887. checkMoreThanAvailable();
  888. // Test 2: Request exact number
  889. checkExactNumber();
  890. // Test 3: Request fewer than available
  891. checkFewerThanAvailable();
  892. // Test 4: Request zero
  893. checkZeroRequest();
  894. // Test 5: Clear all and check empty
  895. clearAllRequests(updateData, priceIds);
  896. checkEmptyState();
  897. }
  898. // Split test scenarios into separate functions
  899. function checkMoreThanAvailable() private {
  900. (PulseState.Request[] memory requests, uint256 count) = pulse
  901. .getFirstActiveRequests(10);
  902. assertEq(count, 3, "Should find 3 active requests");
  903. assertEq(requests.length, 3, "Array should be resized to 3");
  904. assertEq(
  905. requests[0].sequenceNumber,
  906. 1,
  907. "First request should be oldest"
  908. );
  909. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  910. assertEq(requests[2].sequenceNumber, 5, "Third request should be #5");
  911. }
  912. function checkExactNumber() private {
  913. (PulseState.Request[] memory requests, uint256 count) = pulse
  914. .getFirstActiveRequests(3);
  915. assertEq(count, 3, "Should find 3 active requests");
  916. assertEq(requests.length, 3, "Array should match requested size");
  917. }
  918. function checkFewerThanAvailable() private {
  919. (PulseState.Request[] memory requests, uint256 count) = pulse
  920. .getFirstActiveRequests(2);
  921. assertEq(count, 2, "Should find 2 active requests");
  922. assertEq(requests.length, 2, "Array should match requested size");
  923. assertEq(
  924. requests[0].sequenceNumber,
  925. 1,
  926. "First request should be oldest"
  927. );
  928. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  929. }
  930. function checkZeroRequest() private {
  931. (PulseState.Request[] memory requests, uint256 count) = pulse
  932. .getFirstActiveRequests(0);
  933. assertEq(count, 0, "Should find 0 active requests");
  934. assertEq(requests.length, 0, "Array should be empty");
  935. }
  936. function clearAllRequests(
  937. bytes[] memory updateData,
  938. bytes32[] memory priceIds
  939. ) private {
  940. vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
  941. vm.startPrank(defaultProvider);
  942. pulse.executeCallback{value: 1 ether}(
  943. defaultProvider,
  944. 1,
  945. updateData,
  946. priceIds
  947. );
  948. pulse.executeCallback{value: 1 ether}(
  949. defaultProvider,
  950. 3,
  951. updateData,
  952. priceIds
  953. );
  954. pulse.executeCallback{value: 1 ether}(
  955. defaultProvider,
  956. 5,
  957. updateData,
  958. priceIds
  959. );
  960. vm.stopPrank();
  961. }
  962. function checkEmptyState() private {
  963. (PulseState.Request[] memory requests, uint256 count) = pulse
  964. .getFirstActiveRequests(10);
  965. assertEq(count, 0, "Should find 0 active requests");
  966. assertEq(requests.length, 0, "Array should be empty");
  967. }
  968. function testGetFirstActiveRequestsGasUsage() public {
  969. // Setup test data
  970. bytes32[] memory priceIds = new bytes32[](1);
  971. priceIds[0] = bytes32(uint256(1));
  972. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  973. uint256 callbackGasLimit = 1000000;
  974. // Create mock price feeds and setup Pyth response
  975. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  976. publishTime
  977. );
  978. mockParsePriceFeedUpdates(pyth, priceFeeds);
  979. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  980. // Create 20 requests with some gaps
  981. for (uint i = 0; i < 20; i++) {
  982. vm.deal(address(this), 1 ether);
  983. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  984. defaultProvider,
  985. publishTime,
  986. priceIds,
  987. uint32(callbackGasLimit)
  988. );
  989. // Complete every third request to create gaps
  990. if (i % 3 == 0) {
  991. vm.deal(defaultProvider, 1 ether);
  992. vm.prank(defaultProvider);
  993. pulse.executeCallback{value: 1 ether}(
  994. defaultProvider,
  995. uint64(i + 1),
  996. updateData,
  997. priceIds
  998. );
  999. }
  1000. }
  1001. // Measure gas for different request counts
  1002. uint256 gas1 = gasleft();
  1003. pulse.getFirstActiveRequests(5);
  1004. uint256 gas1Used = gas1 - gasleft();
  1005. uint256 gas2 = gasleft();
  1006. pulse.getFirstActiveRequests(10);
  1007. uint256 gas2Used = gas2 - gasleft();
  1008. // Log gas usage for analysis
  1009. emit log_named_uint("Gas used for 5 requests", gas1Used);
  1010. emit log_named_uint("Gas used for 10 requests", gas2Used);
  1011. // Verify gas usage scales roughly linearly
  1012. // Allow 10% margin for other factors
  1013. assertApproxEqRel(
  1014. gas2Used,
  1015. gas1Used * 2,
  1016. 0.1e18, // 10% tolerance
  1017. "Gas usage should scale roughly linearly"
  1018. );
  1019. }
  1020. function getPulse() internal view override returns (address) {
  1021. return address(pulse);
  1022. }
  1023. // Mock implementation of pulseCallback
  1024. function pulseCallback(
  1025. uint64 sequenceNumber,
  1026. PythStructs.PriceFeed[] memory priceFeeds
  1027. ) internal override {
  1028. // Just accept the callback, no need to do anything with the data
  1029. // This prevents the revert we're seeing
  1030. }
  1031. }