Pulse.t.sol 37 KB


  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "forge-std/Test.sol";
  4. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  5. import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
  6. import "./utils/PulseTestUtils.t.sol";
  7. import "../contracts/pulse/PulseUpgradeable.sol";
  8. import "../contracts/pulse/IPulse.sol";
  9. import "../contracts/pulse/PulseState.sol";
  10. import "../contracts/pulse/PulseEvents.sol";
  11. import "../contracts/pulse/PulseErrors.sol";
  12. contract MockPulseConsumer is IPulseConsumer {
  13. address private _pulse;
  14. uint64 public lastSequenceNumber;
  15. PythStructs.PriceFeed[] private _lastPriceFeeds;
  16. constructor(address pulse) {
  17. _pulse = pulse;
  18. }
  19. function getPulse() internal view override returns (address) {
  20. return _pulse;
  21. }
  22. function pulseCallback(
  23. uint64 sequenceNumber,
  24. PythStructs.PriceFeed[] memory priceFeeds
  25. ) internal override {
  26. lastSequenceNumber = sequenceNumber;
  27. for (uint i = 0; i < priceFeeds.length; i++) {
  28. _lastPriceFeeds.push(priceFeeds[i]);
  29. }
  30. }
  31. function lastPriceFeeds()
  32. external
  33. view
  34. returns (PythStructs.PriceFeed[] memory)
  35. {
  36. return _lastPriceFeeds;
  37. }
  38. }
  39. contract FailingPulseConsumer is IPulseConsumer {
  40. address private _pulse;
  41. constructor(address pulse) {
  42. _pulse = pulse;
  43. }
  44. function getPulse() internal view override returns (address) {
  45. return _pulse;
  46. }
  47. function pulseCallback(
  48. uint64,
  49. PythStructs.PriceFeed[] memory
  50. ) internal pure override {
  51. revert("callback failed");
  52. }
  53. }
  54. contract CustomErrorPulseConsumer is IPulseConsumer {
  55. error CustomError(string message);
  56. address private _pulse;
  57. constructor(address pulse) {
  58. _pulse = pulse;
  59. }
  60. function getPulse() internal view override returns (address) {
  61. return _pulse;
  62. }
  63. function pulseCallback(
  64. uint64,
  65. PythStructs.PriceFeed[] memory
  66. ) internal pure override {
  67. revert CustomError("callback failed");
  68. }
  69. }
  70. // FIXME: this shouldn't be IPulseConsumer.
  71. contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils {
  72. ERC1967Proxy public proxy;
  73. PulseUpgradeable public pulse;
  74. MockPulseConsumer public consumer;
  75. address public owner;
  76. address public admin;
  77. address public pyth;
  78. address public defaultProvider;
  79. // Constants
  80. uint128 constant PYTH_FEE = 1 wei;
  81. uint128 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei;
  82. uint128 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei;
  83. uint128 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei;
  84. function setUp() public {
  85. owner = address(1);
  86. admin = address(2);
  87. pyth = address(3);
  88. defaultProvider = address(4);
  89. PulseUpgradeable _pulse = new PulseUpgradeable();
  90. proxy = new ERC1967Proxy(address(_pulse), "");
  91. pulse = PulseUpgradeable(address(proxy));
  92. pulse.initialize(
  93. owner,
  94. admin,
  95. PYTH_FEE,
  96. pyth,
  97. defaultProvider,
  98. false,
  99. 15
  100. );
  101. vm.prank(defaultProvider);
  102. pulse.registerProvider(
  103. DEFAULT_PROVIDER_BASE_FEE,
  104. DEFAULT_PROVIDER_FEE_PER_FEED,
  105. DEFAULT_PROVIDER_FEE_PER_GAS
  106. );
  107. consumer = new MockPulseConsumer(address(proxy));
  108. }
  109. // Helper function to calculate total fee
  110. // FIXME: I think this helper probably needs to take some arguments.
  111. function calculateTotalFee() internal view returns (uint128) {
  112. return
  113. pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds());
  114. }
  115. function testRequestPriceUpdate() public {
  116. // Set a realistic gas price
  117. vm.txGasPrice(30 gwei);
  118. bytes32[] memory priceIds = createPriceIds();
  119. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  120. // Fund the consumer contract with enough ETH for higher gas price
  121. vm.deal(address(consumer), 1 ether);
  122. uint128 totalFee = calculateTotalFee();
  123. // Create the event data we expect to see
  124. PulseState.Request memory expectedRequest = PulseState.Request({
  125. sequenceNumber: 1,
  126. publishTime: publishTime,
  127. priceIds: [
  128. priceIds[0],
  129. priceIds[1],
  130. bytes32(0), // Fill remaining slots with zero
  131. bytes32(0),
  132. bytes32(0),
  133. bytes32(0),
  134. bytes32(0),
  135. bytes32(0),
  136. bytes32(0),
  137. bytes32(0)
  138. ],
  139. numPriceIds: 2,
  140. callbackGasLimit: CALLBACK_GAS_LIMIT,
  141. requester: address(consumer),
  142. provider: defaultProvider,
  143. fee: totalFee - PYTH_FEE
  144. });
  145. vm.expectEmit();
  146. emit PriceUpdateRequested(expectedRequest, priceIds);
  147. vm.prank(address(consumer));
  148. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  149. defaultProvider,
  150. publishTime,
  151. priceIds,
  152. CALLBACK_GAS_LIMIT
  153. );
  154. // Additional assertions to verify event data was stored correctly
  155. PulseState.Request memory lastRequest = pulse.getRequest(1);
  156. assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber);
  157. assertEq(lastRequest.publishTime, expectedRequest.publishTime);
  158. assertEq(lastRequest.numPriceIds, expectedRequest.numPriceIds);
  159. for (uint8 i = 0; i < lastRequest.numPriceIds; i++) {
  160. assertEq(lastRequest.priceIds[i], expectedRequest.priceIds[i]);
  161. }
  162. assertEq(
  163. lastRequest.callbackGasLimit,
  164. expectedRequest.callbackGasLimit
  165. );
  166. assertEq(
  167. lastRequest.requester,
  168. expectedRequest.requester,
  169. "Requester mismatch"
  170. );
  171. }
  172. function testRequestWithInsufficientFee() public {
  173. // Set a realistic gas price
  174. vm.txGasPrice(30 gwei);
  175. bytes32[] memory priceIds = createPriceIds();
  176. vm.deal(address(consumer), 1 ether);
  177. vm.prank(address(consumer));
  178. vm.expectRevert(InsufficientFee.selector);
  179. pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee
  180. defaultProvider,
  181. SafeCast.toUint64(block.timestamp),
  182. priceIds,
  183. CALLBACK_GAS_LIMIT
  184. );
  185. }
  186. function testExecuteCallback() public {
  187. bytes32[] memory priceIds = createPriceIds();
  188. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  189. // Fund the consumer contract
  190. vm.deal(address(consumer), 1 gwei);
  191. uint128 totalFee = calculateTotalFee();
  192. // Step 1: Make the request as consumer
  193. vm.prank(address(consumer));
  194. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  195. value: totalFee
  196. }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT);
  197. // Step 2: Create mock price feeds and setup Pyth response
  198. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  199. publishTime
  200. );
  201. // FIXME: this test doesn't ensure the Pyth fee is paid.
  202. mockParsePriceFeedUpdates(pyth, priceFeeds);
  203. // Create arrays for expected event data
  204. int64[] memory expectedPrices = new int64[](2);
  205. expectedPrices[0] = MOCK_BTC_PRICE;
  206. expectedPrices[1] = MOCK_ETH_PRICE;
  207. uint64[] memory expectedConf = new uint64[](2);
  208. expectedConf[0] = MOCK_BTC_CONF;
  209. expectedConf[1] = MOCK_ETH_CONF;
  210. int32[] memory expectedExpos = new int32[](2);
  211. expectedExpos[0] = MOCK_PRICE_FEED_EXPO;
  212. expectedExpos[1] = MOCK_PRICE_FEED_EXPO;
  213. uint64[] memory expectedPublishTimes = new uint64[](2);
  214. expectedPublishTimes[0] = publishTime;
  215. expectedPublishTimes[1] = publishTime;
  216. // Expect the PriceUpdateExecuted event with all price data
  217. vm.expectEmit();
  218. emit PriceUpdateExecuted(
  219. sequenceNumber,
  220. defaultProvider,
  221. priceIds,
  222. expectedPrices,
  223. expectedConf,
  224. expectedExpos,
  225. expectedPublishTimes
  226. );
  227. // Create mock update data and execute callback
  228. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  229. vm.prank(defaultProvider);
  230. pulse.executeCallback(
  231. defaultProvider,
  232. sequenceNumber,
  233. updateData,
  234. priceIds
  235. );
  236. // Verify callback was executed
  237. assertEq(consumer.lastSequenceNumber(), sequenceNumber);
  238. // Compare price feeds array length
  239. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  240. assertEq(lastFeeds.length, priceFeeds.length);
  241. // Compare each price feed
  242. for (uint i = 0; i < priceFeeds.length; i++) {
  243. assertEq(lastFeeds[i].id, priceFeeds[i].id);
  244. assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price);
  245. assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf);
  246. assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo);
  247. assertEq(
  248. lastFeeds[i].price.publishTime,
  249. priceFeeds[i].price.publishTime
  250. );
  251. }
  252. }
  253. function testExecuteCallbackFailure() public {
  254. FailingPulseConsumer failingConsumer = new FailingPulseConsumer(
  255. address(proxy)
  256. );
  257. (
  258. uint64 sequenceNumber,
  259. bytes32[] memory priceIds,
  260. uint256 publishTime
  261. ) = setupConsumerRequest(
  262. pulse,
  263. defaultProvider,
  264. address(failingConsumer)
  265. );
  266. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  267. publishTime
  268. );
  269. mockParsePriceFeedUpdates(pyth, priceFeeds);
  270. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  271. vm.expectEmit();
  272. emit PriceUpdateCallbackFailed(
  273. sequenceNumber,
  274. defaultProvider,
  275. priceIds,
  276. address(failingConsumer),
  277. "callback failed"
  278. );
  279. vm.prank(defaultProvider);
  280. pulse.executeCallback(
  281. defaultProvider,
  282. sequenceNumber,
  283. updateData,
  284. priceIds
  285. );
  286. }
  287. function testExecuteCallbackCustomErrorFailure() public {
  288. CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(
  289. address(proxy)
  290. );
  291. (
  292. uint64 sequenceNumber,
  293. bytes32[] memory priceIds,
  294. uint256 publishTime
  295. ) = setupConsumerRequest(
  296. pulse,
  297. defaultProvider,
  298. address(failingConsumer)
  299. );
  300. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  301. publishTime
  302. );
  303. mockParsePriceFeedUpdates(pyth, priceFeeds);
  304. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  305. vm.expectEmit();
  306. emit PriceUpdateCallbackFailed(
  307. sequenceNumber,
  308. defaultProvider,
  309. priceIds,
  310. address(failingConsumer),
  311. "low-level error (possibly out of gas)"
  312. );
  313. vm.prank(defaultProvider);
  314. pulse.executeCallback(
  315. defaultProvider,
  316. sequenceNumber,
  317. updateData,
  318. priceIds
  319. );
  320. }
  321. function testExecuteCallbackWithInsufficientGas() public {
  322. // Setup request with 1M gas limit
  323. (
  324. uint64 sequenceNumber,
  325. bytes32[] memory priceIds,
  326. uint256 publishTime
  327. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  328. // Setup mock data
  329. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  330. publishTime
  331. );
  332. mockParsePriceFeedUpdates(pyth, priceFeeds);
  333. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  334. // Try executing with only 100K gas when 1M is required
  335. vm.prank(defaultProvider);
  336. vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error
  337. pulse.executeCallback{gas: 100000}(
  338. defaultProvider,
  339. sequenceNumber,
  340. updateData,
  341. priceIds
  342. ); // Will fail because gasleft() < callbackGasLimit
  343. }
  344. function testExecuteCallbackWithFutureTimestamp() public {
  345. // Setup request with future timestamp
  346. bytes32[] memory priceIds = createPriceIds();
  347. uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future
  348. vm.deal(address(consumer), 1 gwei);
  349. uint128 totalFee = calculateTotalFee();
  350. vm.prank(address(consumer));
  351. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  352. value: totalFee
  353. }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT);
  354. // Try to execute callback before the requested timestamp
  355. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  356. futureTime // Mock price feeds with future timestamp
  357. );
  358. mockParsePriceFeedUpdates(pyth, priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices
  359. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  360. vm.prank(defaultProvider);
  361. // Should succeed because we're simulating receiving future-dated price updates
  362. pulse.executeCallback(
  363. defaultProvider,
  364. sequenceNumber,
  365. updateData,
  366. priceIds
  367. );
  368. // Compare price feeds array length
  369. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  370. assertEq(lastFeeds.length, priceFeeds.length);
  371. // Compare each price feed publish time
  372. for (uint i = 0; i < priceFeeds.length; i++) {
  373. assertEq(
  374. lastFeeds[i].price.publishTime,
  375. priceFeeds[i].price.publishTime
  376. );
  377. }
  378. }
  379. function testRevertOnTooFarFutureTimestamp() public {
  380. bytes32[] memory priceIds = createPriceIds();
  381. uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute
  382. vm.deal(address(consumer), 1 gwei);
  383. uint128 totalFee = calculateTotalFee();
  384. vm.prank(address(consumer));
  385. vm.expectRevert("Too far in future");
  386. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  387. defaultProvider,
  388. farFutureTime,
  389. priceIds,
  390. CALLBACK_GAS_LIMIT
  391. );
  392. }
  393. function testDoubleExecuteCallback() public {
  394. (
  395. uint64 sequenceNumber,
  396. bytes32[] memory priceIds,
  397. uint256 publishTime
  398. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  399. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  400. publishTime
  401. );
  402. mockParsePriceFeedUpdates(pyth, priceFeeds);
  403. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  404. // First execution
  405. vm.prank(defaultProvider);
  406. pulse.executeCallback(
  407. defaultProvider,
  408. sequenceNumber,
  409. updateData,
  410. priceIds
  411. );
  412. // Second execution should fail
  413. vm.prank(defaultProvider);
  414. vm.expectRevert(NoSuchRequest.selector);
  415. pulse.executeCallback(
  416. defaultProvider,
  417. sequenceNumber,
  418. updateData,
  419. priceIds
  420. );
  421. }
  422. function testGetFee() public {
  423. // Test with different gas limits to verify fee calculation
  424. uint256[] memory gasLimits = new uint256[](3);
  425. gasLimits[0] = 100_000;
  426. gasLimits[1] = 500_000;
  427. gasLimits[2] = 1_000_000;
  428. bytes32[] memory priceIds = createPriceIds();
  429. for (uint256 i = 0; i < gasLimits.length; i++) {
  430. uint256 gasLimit = gasLimits[i];
  431. uint128 expectedFee = SafeCast.toUint128(
  432. DEFAULT_PROVIDER_BASE_FEE +
  433. DEFAULT_PROVIDER_FEE_PER_FEED *
  434. priceIds.length +
  435. DEFAULT_PROVIDER_FEE_PER_GAS *
  436. gasLimit
  437. ) + PYTH_FEE;
  438. uint128 actualFee = pulse.getFee(
  439. defaultProvider,
  440. gasLimit,
  441. priceIds
  442. );
  443. assertEq(
  444. actualFee,
  445. expectedFee,
  446. "Fee calculation incorrect for gas limit"
  447. );
  448. }
  449. // Test with zero gas limit
  450. uint128 expectedMinFee = SafeCast.toUint128(
  451. PYTH_FEE +
  452. DEFAULT_PROVIDER_BASE_FEE +
  453. DEFAULT_PROVIDER_FEE_PER_FEED *
  454. priceIds.length
  455. );
  456. uint128 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds);
  457. assertEq(
  458. actualMinFee,
  459. expectedMinFee,
  460. "Minimum fee calculation incorrect"
  461. );
  462. }
  463. function testWithdrawFees() public {
  464. // Setup: Request price update to accrue some fees
  465. bytes32[] memory priceIds = createPriceIds();
  466. vm.deal(address(consumer), 1 gwei);
  467. vm.prank(address(consumer));
  468. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  469. defaultProvider,
  470. SafeCast.toUint64(block.timestamp),
  471. priceIds,
  472. CALLBACK_GAS_LIMIT
  473. );
  474. // Get admin's balance before withdrawal
  475. uint256 adminBalanceBefore = admin.balance;
  476. uint128 accruedFees = pulse.getAccruedPythFees();
  477. // Withdraw fees as admin
  478. vm.prank(admin);
  479. pulse.withdrawFees(accruedFees);
  480. // Verify balances
  481. assertEq(
  482. admin.balance,
  483. adminBalanceBefore + accruedFees,
  484. "Admin balance should increase by withdrawn amount"
  485. );
  486. assertEq(
  487. pulse.getAccruedPythFees(),
  488. 0,
  489. "Contract should have no fees after withdrawal"
  490. );
  491. }
  492. function testWithdrawFeesUnauthorized() public {
  493. vm.prank(address(0xdead));
  494. vm.expectRevert("Only admin can withdraw fees");
  495. pulse.withdrawFees(1 ether);
  496. }
  497. function testWithdrawFeesInsufficientBalance() public {
  498. vm.prank(admin);
  499. vm.expectRevert("Insufficient balance");
  500. pulse.withdrawFees(1 ether);
  501. }
  502. function testSetAndWithdrawAsFeeManager() public {
  503. address feeManager = address(0x789);
  504. vm.prank(defaultProvider);
  505. pulse.setFeeManager(feeManager);
  506. // Setup: Request price update to accrue some fees
  507. bytes32[] memory priceIds = createPriceIds();
  508. vm.deal(address(consumer), 1 gwei);
  509. vm.prank(address(consumer));
  510. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  511. defaultProvider,
  512. SafeCast.toUint64(block.timestamp),
  513. priceIds,
  514. CALLBACK_GAS_LIMIT
  515. );
  516. // Get provider's accrued fees instead of total fees
  517. PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
  518. defaultProvider
  519. );
  520. uint128 providerAccruedFees = providerInfo.accruedFeesInWei;
  521. uint256 managerBalanceBefore = feeManager.balance;
  522. vm.prank(feeManager);
  523. pulse.withdrawAsFeeManager(defaultProvider, providerAccruedFees);
  524. assertEq(
  525. feeManager.balance,
  526. managerBalanceBefore + providerAccruedFees,
  527. "Fee manager balance should increase by withdrawn amount"
  528. );
  529. providerInfo = pulse.getProviderInfo(defaultProvider);
  530. assertEq(
  531. providerInfo.accruedFeesInWei,
  532. 0,
  533. "Provider should have no fees after withdrawal"
  534. );
  535. }
  536. function testSetFeeManagerUnauthorized() public {
  537. address feeManager = address(0x789);
  538. vm.prank(address(0xdead));
  539. vm.expectRevert("Provider not registered");
  540. pulse.setFeeManager(feeManager);
  541. }
  542. function testWithdrawAsFeeManagerUnauthorized() public {
  543. vm.prank(address(0xdead));
  544. vm.expectRevert("Only fee manager");
  545. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  546. }
  547. function testWithdrawAsFeeManagerInsufficientBalance() public {
  548. // Set up fee manager first
  549. address feeManager = address(0x789);
  550. vm.prank(defaultProvider);
  551. pulse.setFeeManager(feeManager);
  552. vm.prank(feeManager);
  553. vm.expectRevert("Insufficient balance");
  554. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  555. }
  556. // Add new test for invalid priceIds
  557. function testExecuteCallbackWithInvalidPriceIds() public {
  558. bytes32[] memory priceIds = createPriceIds();
  559. uint256 publishTime = block.timestamp;
  560. // Setup request
  561. (uint64 sequenceNumber, , ) = setupConsumerRequest(
  562. pulse,
  563. defaultProvider,
  564. address(consumer)
  565. );
  566. // Create different priceIds
  567. bytes32[] memory wrongPriceIds = new bytes32[](2);
  568. wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs
  569. wrongPriceIds[1] = bytes32(uint256(2));
  570. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  571. publishTime
  572. );
  573. mockParsePriceFeedUpdates(pyth, priceFeeds);
  574. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  575. // Should revert when trying to execute with wrong priceIds
  576. vm.prank(defaultProvider);
  577. vm.expectRevert(
  578. abi.encodeWithSelector(
  579. InvalidPriceIds.selector,
  580. wrongPriceIds[0],
  581. priceIds[0]
  582. )
  583. );
  584. pulse.executeCallback(
  585. defaultProvider,
  586. sequenceNumber,
  587. updateData,
  588. wrongPriceIds
  589. );
  590. }
  591. function testRevertOnTooManyPriceIds() public {
  592. uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS());
  593. // Create array with MAX_PRICE_IDS + 1 price IDs
  594. bytes32[] memory priceIds = new bytes32[](maxPriceIds + 1);
  595. for (uint i = 0; i < priceIds.length; i++) {
  596. priceIds[i] = bytes32(uint256(i + 1));
  597. }
  598. vm.deal(address(consumer), 1 gwei);
  599. uint128 totalFee = calculateTotalFee();
  600. vm.prank(address(consumer));
  601. vm.expectRevert(
  602. abi.encodeWithSelector(
  603. TooManyPriceIds.selector,
  604. maxPriceIds + 1,
  605. maxPriceIds
  606. )
  607. );
  608. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  609. defaultProvider,
  610. SafeCast.toUint64(block.timestamp),
  611. priceIds,
  612. CALLBACK_GAS_LIMIT
  613. );
  614. }
  615. function testProviderRegistration() public {
  616. address provider = address(0x123);
  617. uint128 providerFee = 1000;
  618. vm.prank(provider);
  619. pulse.registerProvider(providerFee, providerFee, providerFee);
  620. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  621. assertEq(info.feePerGasInWei, providerFee);
  622. assertTrue(info.isRegistered);
  623. }
  624. function testSetProviderFee() public {
  625. address provider = address(0x123);
  626. uint128 initialBaseFee = 1000;
  627. uint128 initialFeePerFeed = 2000;
  628. uint128 initialFeePerGas = 3000;
  629. uint128 newFeePerFeed = 4000;
  630. uint128 newBaseFee = 5000;
  631. uint128 newFeePerGas = 6000;
  632. vm.prank(provider);
  633. pulse.registerProvider(
  634. initialBaseFee,
  635. initialFeePerFeed,
  636. initialFeePerGas
  637. );
  638. vm.prank(provider);
  639. pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas);
  640. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  641. assertEq(info.baseFeeInWei, newBaseFee);
  642. assertEq(info.feePerFeedInWei, newFeePerFeed);
  643. assertEq(info.feePerGasInWei, newFeePerGas);
  644. }
  645. function testDefaultProvider() public {
  646. address provider = address(0x123);
  647. uint128 providerFee = 1000;
  648. vm.prank(provider);
  649. pulse.registerProvider(providerFee, providerFee, providerFee);
  650. vm.prank(admin);
  651. pulse.setDefaultProvider(provider);
  652. assertEq(pulse.getDefaultProvider(), provider);
  653. }
  654. function testRequestWithProvider() public {
  655. address provider = address(0x123);
  656. uint128 providerFee = 1000;
  657. vm.prank(provider);
  658. pulse.registerProvider(providerFee, providerFee, providerFee);
  659. bytes32[] memory priceIds = new bytes32[](1);
  660. priceIds[0] = bytes32(uint256(1));
  661. uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds);
  662. vm.deal(address(consumer), totalFee);
  663. vm.prank(address(consumer));
  664. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  665. value: totalFee
  666. }(
  667. provider,
  668. SafeCast.toUint64(block.timestamp),
  669. priceIds,
  670. CALLBACK_GAS_LIMIT
  671. );
  672. PulseState.Request memory req = pulse.getRequest(sequenceNumber);
  673. assertEq(req.provider, provider);
  674. }
  675. function testExclusivityPeriod() public {
  676. // Test initial value
  677. assertEq(
  678. pulse.getExclusivityPeriod(),
  679. 15,
  680. "Initial exclusivity period should be 15 seconds"
  681. );
  682. // Test setting new value
  683. vm.prank(admin);
  684. vm.expectEmit();
  685. emit ExclusivityPeriodUpdated(15, 30);
  686. pulse.setExclusivityPeriod(30);
  687. assertEq(
  688. pulse.getExclusivityPeriod(),
  689. 30,
  690. "Exclusivity period should be updated"
  691. );
  692. }
  693. function testSetExclusivityPeriodUnauthorized() public {
  694. vm.prank(address(0xdead));
  695. vm.expectRevert("Only admin can set exclusivity period");
  696. pulse.setExclusivityPeriod(30);
  697. }
  698. function testExecuteCallbackDuringExclusivity() public {
  699. // Register a second provider
  700. address secondProvider = address(0x456);
  701. vm.prank(secondProvider);
  702. pulse.registerProvider(
  703. DEFAULT_PROVIDER_BASE_FEE,
  704. DEFAULT_PROVIDER_FEE_PER_FEED,
  705. DEFAULT_PROVIDER_FEE_PER_GAS
  706. );
  707. // Setup request
  708. (
  709. uint64 sequenceNumber,
  710. bytes32[] memory priceIds,
  711. uint256 publishTime
  712. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  713. // Setup mock data
  714. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  715. publishTime
  716. );
  717. mockParsePriceFeedUpdates(pyth, priceFeeds);
  718. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  719. // Try to execute with second provider during exclusivity period
  720. vm.expectRevert("Only assigned provider during exclusivity period");
  721. pulse.executeCallback(
  722. secondProvider,
  723. sequenceNumber,
  724. updateData,
  725. priceIds
  726. );
  727. // Original provider should succeed
  728. pulse.executeCallback(
  729. defaultProvider,
  730. sequenceNumber,
  731. updateData,
  732. priceIds
  733. );
  734. }
  735. function testExecuteCallbackAfterExclusivity() public {
  736. // Register a second provider
  737. address secondProvider = address(0x456);
  738. vm.prank(secondProvider);
  739. pulse.registerProvider(
  740. DEFAULT_PROVIDER_BASE_FEE,
  741. DEFAULT_PROVIDER_FEE_PER_FEED,
  742. DEFAULT_PROVIDER_FEE_PER_GAS
  743. );
  744. // Setup request
  745. (
  746. uint64 sequenceNumber,
  747. bytes32[] memory priceIds,
  748. uint256 publishTime
  749. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  750. // Setup mock data
  751. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  752. publishTime
  753. );
  754. mockParsePriceFeedUpdates(pyth, priceFeeds);
  755. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  756. // Wait for exclusivity period to end
  757. vm.warp(block.timestamp + pulse.getExclusivityPeriod() + 1);
  758. // Second provider should now succeed
  759. vm.prank(secondProvider);
  760. pulse.executeCallback(
  761. defaultProvider,
  762. sequenceNumber,
  763. updateData,
  764. priceIds
  765. );
  766. }
  767. function testExecuteCallbackWithCustomExclusivityPeriod() public {
  768. // Register a second provider
  769. address secondProvider = address(0x456);
  770. vm.prank(secondProvider);
  771. pulse.registerProvider(
  772. DEFAULT_PROVIDER_BASE_FEE,
  773. DEFAULT_PROVIDER_FEE_PER_FEED,
  774. DEFAULT_PROVIDER_FEE_PER_GAS
  775. );
  776. // Set custom exclusivity period
  777. vm.prank(admin);
  778. pulse.setExclusivityPeriod(30);
  779. // Setup request
  780. (
  781. uint64 sequenceNumber,
  782. bytes32[] memory priceIds,
  783. uint256 publishTime
  784. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  785. // Setup mock data
  786. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  787. publishTime
  788. );
  789. mockParsePriceFeedUpdates(pyth, priceFeeds);
  790. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  791. // Try at 29 seconds (should fail for second provider)
  792. vm.warp(block.timestamp + 29);
  793. vm.expectRevert("Only assigned provider during exclusivity period");
  794. pulse.executeCallback(
  795. secondProvider,
  796. sequenceNumber,
  797. updateData,
  798. priceIds
  799. );
  800. // Try at 31 seconds (should succeed for second provider)
  801. vm.warp(block.timestamp + 2);
  802. pulse.executeCallback(
  803. secondProvider,
  804. sequenceNumber,
  805. updateData,
  806. priceIds
  807. );
  808. }
  809. function testGetFirstActiveRequests() public {
  810. // Setup test data
  811. (
  812. bytes32[] memory priceIds,
  813. bytes[] memory updateData
  814. ) = setupTestData();
  815. createTestRequests(priceIds);
  816. completeRequests(updateData, priceIds);
  817. testRequestScenarios(priceIds, updateData);
  818. }
  819. function setupTestData()
  820. private
  821. pure
  822. returns (bytes32[] memory, bytes[] memory)
  823. {
  824. bytes32[] memory priceIds = new bytes32[](1);
  825. priceIds[0] = bytes32(uint256(1));
  826. bytes[] memory updateData = new bytes[](1);
  827. return (priceIds, updateData);
  828. }
  829. function createTestRequests(bytes32[] memory priceIds) private {
  830. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  831. for (uint i = 0; i < 5; i++) {
  832. vm.deal(address(this), 1 ether);
  833. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  834. defaultProvider,
  835. publishTime,
  836. priceIds,
  837. 1000000
  838. );
  839. }
  840. }
  841. function completeRequests(
  842. bytes[] memory updateData,
  843. bytes32[] memory priceIds
  844. ) private {
  845. // Create mock price feeds and setup Pyth response
  846. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  847. SafeCast.toUint64(block.timestamp)
  848. );
  849. mockParsePriceFeedUpdates(pyth, priceFeeds);
  850. updateData = createMockUpdateData(priceFeeds);
  851. vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
  852. vm.startPrank(defaultProvider);
  853. pulse.executeCallback{value: 1 ether}(
  854. defaultProvider,
  855. 2,
  856. updateData,
  857. priceIds
  858. );
  859. pulse.executeCallback{value: 1 ether}(
  860. defaultProvider,
  861. 4,
  862. updateData,
  863. priceIds
  864. );
  865. vm.stopPrank();
  866. }
  867. function testRequestScenarios(
  868. bytes32[] memory priceIds,
  869. bytes[] memory updateData
  870. ) private {
  871. // Test 1: Request more than available
  872. checkMoreThanAvailable();
  873. // Test 2: Request exact number
  874. checkExactNumber();
  875. // Test 3: Request fewer than available
  876. checkFewerThanAvailable();
  877. // Test 4: Request zero
  878. checkZeroRequest();
  879. // Test 5: Clear all and check empty
  880. clearAllRequests(updateData, priceIds);
  881. checkEmptyState();
  882. }
  883. // Split test scenarios into separate functions
  884. function checkMoreThanAvailable() private {
  885. (PulseState.Request[] memory requests, uint256 count) = pulse
  886. .getFirstActiveRequests(10);
  887. assertEq(count, 3, "Should find 3 active requests");
  888. assertEq(requests.length, 3, "Array should be resized to 3");
  889. assertEq(
  890. requests[0].sequenceNumber,
  891. 1,
  892. "First request should be oldest"
  893. );
  894. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  895. assertEq(requests[2].sequenceNumber, 5, "Third request should be #5");
  896. }
  897. function checkExactNumber() private {
  898. (PulseState.Request[] memory requests, uint256 count) = pulse
  899. .getFirstActiveRequests(3);
  900. assertEq(count, 3, "Should find 3 active requests");
  901. assertEq(requests.length, 3, "Array should match requested size");
  902. }
  903. function checkFewerThanAvailable() private {
  904. (PulseState.Request[] memory requests, uint256 count) = pulse
  905. .getFirstActiveRequests(2);
  906. assertEq(count, 2, "Should find 2 active requests");
  907. assertEq(requests.length, 2, "Array should match requested size");
  908. assertEq(
  909. requests[0].sequenceNumber,
  910. 1,
  911. "First request should be oldest"
  912. );
  913. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  914. }
  915. function checkZeroRequest() private {
  916. (PulseState.Request[] memory requests, uint256 count) = pulse
  917. .getFirstActiveRequests(0);
  918. assertEq(count, 0, "Should find 0 active requests");
  919. assertEq(requests.length, 0, "Array should be empty");
  920. }
  921. function clearAllRequests(
  922. bytes[] memory updateData,
  923. bytes32[] memory priceIds
  924. ) private {
  925. vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
  926. vm.startPrank(defaultProvider);
  927. pulse.executeCallback{value: 1 ether}(
  928. defaultProvider,
  929. 1,
  930. updateData,
  931. priceIds
  932. );
  933. pulse.executeCallback{value: 1 ether}(
  934. defaultProvider,
  935. 3,
  936. updateData,
  937. priceIds
  938. );
  939. pulse.executeCallback{value: 1 ether}(
  940. defaultProvider,
  941. 5,
  942. updateData,
  943. priceIds
  944. );
  945. vm.stopPrank();
  946. }
  947. function checkEmptyState() private {
  948. (PulseState.Request[] memory requests, uint256 count) = pulse
  949. .getFirstActiveRequests(10);
  950. assertEq(count, 0, "Should find 0 active requests");
  951. assertEq(requests.length, 0, "Array should be empty");
  952. }
  953. function testGetFirstActiveRequestsGasUsage() public {
  954. // Setup test data
  955. bytes32[] memory priceIds = new bytes32[](1);
  956. priceIds[0] = bytes32(uint256(1));
  957. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  958. uint256 callbackGasLimit = 1000000;
  959. // Create mock price feeds and setup Pyth response
  960. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  961. publishTime
  962. );
  963. mockParsePriceFeedUpdates(pyth, priceFeeds);
  964. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  965. // Create 20 requests with some gaps
  966. for (uint i = 0; i < 20; i++) {
  967. vm.deal(address(this), 1 ether);
  968. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  969. defaultProvider,
  970. publishTime,
  971. priceIds,
  972. callbackGasLimit
  973. );
  974. // Complete every third request to create gaps
  975. if (i % 3 == 0) {
  976. vm.deal(defaultProvider, 1 ether);
  977. vm.prank(defaultProvider);
  978. pulse.executeCallback{value: 1 ether}(
  979. defaultProvider,
  980. uint64(i + 1),
  981. updateData,
  982. priceIds
  983. );
  984. }
  985. }
  986. // Measure gas for different request counts
  987. uint256 gas1 = gasleft();
  988. pulse.getFirstActiveRequests(5);
  989. uint256 gas1Used = gas1 - gasleft();
  990. uint256 gas2 = gasleft();
  991. pulse.getFirstActiveRequests(10);
  992. uint256 gas2Used = gas2 - gasleft();
  993. // Log gas usage for analysis
  994. emit log_named_uint("Gas used for 5 requests", gas1Used);
  995. emit log_named_uint("Gas used for 10 requests", gas2Used);
  996. // Verify gas usage scales roughly linearly
  997. // Allow 10% margin for other factors
  998. assertApproxEqRel(
  999. gas2Used,
  1000. gas1Used * 2,
  1001. 0.1e18, // 10% tolerance
  1002. "Gas usage should scale roughly linearly"
  1003. );
  1004. }
  1005. function getPulse() internal view override returns (address) {
  1006. return address(pulse);
  1007. }
  1008. // Mock implementation of pulseCallback
  1009. function pulseCallback(
  1010. uint64 sequenceNumber,
  1011. PythStructs.PriceFeed[] memory priceFeeds
  1012. ) internal override {
  1013. // Just accept the callback, no need to do anything with the data
  1014. // This prevents the revert we're seeing
  1015. }
  1016. }