Pulse.t.sol 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "forge-std/Test.sol";
  4. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  5. import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
  6. import "./utils/PulseTestUtils.t.sol";
  7. import "../contracts/pulse/PulseUpgradeable.sol";
  8. import "../contracts/pulse/IPulse.sol";
  9. import "../contracts/pulse/PulseState.sol";
  10. import "../contracts/pulse/PulseEvents.sol";
  11. import "../contracts/pulse/PulseErrors.sol";
  12. // Concrete implementation for testing
  13. contract ConcretePulseUpgradeable is PulseUpgradeable {}
  14. contract MockPulseConsumer is IPulseConsumer {
  15. address private _pulse;
  16. uint64 public lastSequenceNumber;
  17. PythStructs.PriceFeed[] private _lastPriceFeeds;
  18. constructor(address pulse) {
  19. _pulse = pulse;
  20. }
  21. function getPulse() internal view override returns (address) {
  22. return _pulse;
  23. }
  24. function pulseCallback(
  25. uint64 sequenceNumber,
  26. PythStructs.PriceFeed[] memory priceFeeds
  27. ) internal override {
  28. lastSequenceNumber = sequenceNumber;
  29. for (uint i = 0; i < priceFeeds.length; i++) {
  30. _lastPriceFeeds.push(priceFeeds[i]);
  31. }
  32. }
  33. function lastPriceFeeds()
  34. external
  35. view
  36. returns (PythStructs.PriceFeed[] memory)
  37. {
  38. return _lastPriceFeeds;
  39. }
  40. }
  41. contract FailingPulseConsumer is IPulseConsumer {
  42. address private _pulse;
  43. constructor(address pulse) {
  44. _pulse = pulse;
  45. }
  46. function getPulse() internal view override returns (address) {
  47. return _pulse;
  48. }
  49. function pulseCallback(
  50. uint64,
  51. PythStructs.PriceFeed[] memory
  52. ) internal pure override {
  53. revert("callback failed");
  54. }
  55. }
  56. contract CustomErrorPulseConsumer is IPulseConsumer {
  57. error CustomError(string message);
  58. address private _pulse;
  59. constructor(address pulse) {
  60. _pulse = pulse;
  61. }
  62. function getPulse() internal view override returns (address) {
  63. return _pulse;
  64. }
  65. function pulseCallback(
  66. uint64,
  67. PythStructs.PriceFeed[] memory
  68. ) internal pure override {
  69. revert CustomError("callback failed");
  70. }
  71. }
  72. // FIXME: this shouldn't be IPulseConsumer.
  73. contract PulseTest is Test, PulseEvents, IPulseConsumer, PulseTestUtils {
  74. ERC1967Proxy public proxy;
  75. PulseUpgradeable public pulse;
  76. MockPulseConsumer public consumer;
  77. address public owner;
  78. address public admin;
  79. address public pyth;
  80. address public defaultProvider;
  81. // Constants
  82. uint96 constant PYTH_FEE = 1 wei;
  83. uint96 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei;
  84. uint96 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei;
  85. uint96 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei;
  86. function setUp() public {
  87. owner = address(1);
  88. admin = address(2);
  89. pyth = address(3);
  90. defaultProvider = address(4);
  91. PulseUpgradeable _pulse = new ConcretePulseUpgradeable();
  92. proxy = new ERC1967Proxy(address(_pulse), "");
  93. pulse = PulseUpgradeable(address(proxy));
  94. pulse.initialize(
  95. owner,
  96. admin,
  97. PYTH_FEE,
  98. pyth,
  99. defaultProvider,
  100. false,
  101. 15
  102. );
  103. vm.prank(defaultProvider);
  104. pulse.registerProvider(
  105. DEFAULT_PROVIDER_BASE_FEE,
  106. DEFAULT_PROVIDER_FEE_PER_FEED,
  107. DEFAULT_PROVIDER_FEE_PER_GAS
  108. );
  109. consumer = new MockPulseConsumer(address(proxy));
  110. }
  111. // Helper function to calculate total fee
  112. // FIXME: I think this helper probably needs to take some arguments.
  113. function calculateTotalFee() internal view returns (uint96) {
  114. return
  115. pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds());
  116. }
  117. function testRequestPriceUpdate() public {
  118. // Set a realistic gas price
  119. vm.txGasPrice(30 gwei);
  120. bytes32[] memory priceIds = createPriceIds();
  121. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  122. // Fund the consumer contract with enough ETH for higher gas price
  123. vm.deal(address(consumer), 1 ether);
  124. uint96 totalFee = calculateTotalFee();
  125. // Create the event data we expect to see
  126. bytes8[] memory expectedPriceIdPrefixes = new bytes8[](2);
  127. {
  128. bytes32 priceId0 = priceIds[0];
  129. bytes32 priceId1 = priceIds[1];
  130. bytes8 prefix0;
  131. bytes8 prefix1;
  132. assembly {
  133. prefix0 := priceId0
  134. prefix1 := priceId1
  135. }
  136. expectedPriceIdPrefixes[0] = prefix0;
  137. expectedPriceIdPrefixes[1] = prefix1;
  138. }
  139. PulseState.Request memory expectedRequest = PulseState.Request({
  140. sequenceNumber: 1,
  141. publishTime: publishTime,
  142. priceIdPrefixes: expectedPriceIdPrefixes,
  143. callbackGasLimit: uint32(CALLBACK_GAS_LIMIT),
  144. requester: address(consumer),
  145. provider: defaultProvider,
  146. fee: totalFee - PYTH_FEE
  147. });
  148. vm.expectEmit();
  149. emit PriceUpdateRequested(expectedRequest, priceIds);
  150. vm.prank(address(consumer));
  151. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  152. defaultProvider,
  153. publishTime,
  154. priceIds,
  155. CALLBACK_GAS_LIMIT
  156. );
  157. // Additional assertions to verify event data was stored correctly
  158. PulseState.Request memory lastRequest = pulse.getRequest(1);
  159. assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber);
  160. assertEq(lastRequest.publishTime, expectedRequest.publishTime);
  161. assertEq(
  162. lastRequest.priceIdPrefixes.length,
  163. expectedRequest.priceIdPrefixes.length
  164. );
  165. for (uint8 i = 0; i < lastRequest.priceIdPrefixes.length; i++) {
  166. assertEq(
  167. lastRequest.priceIdPrefixes[i],
  168. expectedRequest.priceIdPrefixes[i]
  169. );
  170. }
  171. assertEq(
  172. lastRequest.callbackGasLimit,
  173. expectedRequest.callbackGasLimit
  174. );
  175. assertEq(
  176. lastRequest.requester,
  177. expectedRequest.requester,
  178. "Requester mismatch"
  179. );
  180. }
  181. function testRequestWithInsufficientFee() public {
  182. // Set a realistic gas price
  183. vm.txGasPrice(30 gwei);
  184. bytes32[] memory priceIds = createPriceIds();
  185. vm.deal(address(consumer), 1 ether);
  186. vm.prank(address(consumer));
  187. vm.expectRevert(InsufficientFee.selector);
  188. pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee
  189. defaultProvider,
  190. SafeCast.toUint64(block.timestamp),
  191. priceIds,
  192. CALLBACK_GAS_LIMIT
  193. );
  194. }
  195. function testExecuteCallback() public {
  196. bytes32[] memory priceIds = createPriceIds();
  197. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  198. // Fund the consumer contract
  199. vm.deal(address(consumer), 1 gwei);
  200. uint96 totalFee = calculateTotalFee();
  201. // Step 1: Make the request as consumer
  202. vm.prank(address(consumer));
  203. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  204. value: totalFee
  205. }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT);
  206. // Step 2: Create mock price feeds and setup Pyth response
  207. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  208. publishTime
  209. );
  210. // FIXME: this test doesn't ensure the Pyth fee is paid.
  211. mockParsePriceFeedUpdates(pyth, priceFeeds);
  212. // Create arrays for expected event data
  213. int64[] memory expectedPrices = new int64[](2);
  214. expectedPrices[0] = MOCK_BTC_PRICE;
  215. expectedPrices[1] = MOCK_ETH_PRICE;
  216. uint64[] memory expectedConf = new uint64[](2);
  217. expectedConf[0] = MOCK_BTC_CONF;
  218. expectedConf[1] = MOCK_ETH_CONF;
  219. int32[] memory expectedExpos = new int32[](2);
  220. expectedExpos[0] = MOCK_PRICE_FEED_EXPO;
  221. expectedExpos[1] = MOCK_PRICE_FEED_EXPO;
  222. uint64[] memory expectedPublishTimes = new uint64[](2);
  223. expectedPublishTimes[0] = publishTime;
  224. expectedPublishTimes[1] = publishTime;
  225. // Expect the PriceUpdateExecuted event with all price data
  226. vm.expectEmit();
  227. emit PriceUpdateExecuted(
  228. sequenceNumber,
  229. defaultProvider,
  230. priceIds,
  231. expectedPrices,
  232. expectedConf,
  233. expectedExpos,
  234. expectedPublishTimes
  235. );
  236. // Create mock update data and execute callback
  237. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  238. vm.prank(defaultProvider);
  239. pulse.executeCallback(
  240. defaultProvider,
  241. sequenceNumber,
  242. updateData,
  243. priceIds
  244. );
  245. // Verify callback was executed
  246. assertEq(consumer.lastSequenceNumber(), sequenceNumber);
  247. // Compare price feeds array length
  248. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  249. assertEq(lastFeeds.length, priceFeeds.length);
  250. // Compare each price feed
  251. for (uint i = 0; i < priceFeeds.length; i++) {
  252. assertEq(lastFeeds[i].id, priceFeeds[i].id);
  253. assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price);
  254. assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf);
  255. assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo);
  256. assertEq(
  257. lastFeeds[i].price.publishTime,
  258. priceFeeds[i].price.publishTime
  259. );
  260. }
  261. }
  262. function testExecuteCallbackFailure() public {
  263. FailingPulseConsumer failingConsumer = new FailingPulseConsumer(
  264. address(proxy)
  265. );
  266. (
  267. uint64 sequenceNumber,
  268. bytes32[] memory priceIds,
  269. uint256 publishTime
  270. ) = setupConsumerRequest(
  271. pulse,
  272. defaultProvider,
  273. address(failingConsumer)
  274. );
  275. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  276. publishTime
  277. );
  278. mockParsePriceFeedUpdates(pyth, priceFeeds);
  279. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  280. vm.expectEmit();
  281. emit PriceUpdateCallbackFailed(
  282. sequenceNumber,
  283. defaultProvider,
  284. priceIds,
  285. address(failingConsumer),
  286. "callback failed"
  287. );
  288. vm.prank(defaultProvider);
  289. pulse.executeCallback(
  290. defaultProvider,
  291. sequenceNumber,
  292. updateData,
  293. priceIds
  294. );
  295. }
  296. function testExecuteCallbackCustomErrorFailure() public {
  297. CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(
  298. address(proxy)
  299. );
  300. (
  301. uint64 sequenceNumber,
  302. bytes32[] memory priceIds,
  303. uint256 publishTime
  304. ) = setupConsumerRequest(
  305. pulse,
  306. defaultProvider,
  307. address(failingConsumer)
  308. );
  309. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  310. publishTime
  311. );
  312. mockParsePriceFeedUpdates(pyth, priceFeeds);
  313. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  314. vm.expectEmit();
  315. emit PriceUpdateCallbackFailed(
  316. sequenceNumber,
  317. defaultProvider,
  318. priceIds,
  319. address(failingConsumer),
  320. "low-level error (possibly out of gas)"
  321. );
  322. vm.prank(defaultProvider);
  323. pulse.executeCallback(
  324. defaultProvider,
  325. sequenceNumber,
  326. updateData,
  327. priceIds
  328. );
  329. }
  330. function testExecuteCallbackWithInsufficientGas() public {
  331. // Setup request with 1M gas limit
  332. (
  333. uint64 sequenceNumber,
  334. bytes32[] memory priceIds,
  335. uint256 publishTime
  336. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  337. // Setup mock data
  338. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  339. publishTime
  340. );
  341. mockParsePriceFeedUpdates(pyth, priceFeeds);
  342. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  343. // Try executing with only 100K gas when 1M is required
  344. vm.prank(defaultProvider);
  345. vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error
  346. pulse.executeCallback{gas: 100000}(
  347. defaultProvider,
  348. sequenceNumber,
  349. updateData,
  350. priceIds
  351. ); // Will fail because gasleft() < callbackGasLimit
  352. }
  353. function testExecuteCallbackWithFutureTimestamp() public {
  354. // Setup request with future timestamp
  355. bytes32[] memory priceIds = createPriceIds();
  356. uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future
  357. vm.deal(address(consumer), 1 gwei);
  358. uint96 totalFee = calculateTotalFee();
  359. vm.prank(address(consumer));
  360. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  361. value: totalFee
  362. }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT);
  363. // Try to execute callback before the requested timestamp
  364. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  365. futureTime // Mock price feeds with future timestamp
  366. );
  367. mockParsePriceFeedUpdates(pyth, priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices
  368. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  369. vm.prank(defaultProvider);
  370. // Should succeed because we're simulating receiving future-dated price updates
  371. pulse.executeCallback(
  372. defaultProvider,
  373. sequenceNumber,
  374. updateData,
  375. priceIds
  376. );
  377. // Compare price feeds array length
  378. PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
  379. assertEq(lastFeeds.length, priceFeeds.length);
  380. // Compare each price feed publish time
  381. for (uint i = 0; i < priceFeeds.length; i++) {
  382. assertEq(
  383. lastFeeds[i].price.publishTime,
  384. priceFeeds[i].price.publishTime
  385. );
  386. }
  387. }
  388. function testRevertOnTooFarFutureTimestamp() public {
  389. bytes32[] memory priceIds = createPriceIds();
  390. uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute
  391. vm.deal(address(consumer), 1 gwei);
  392. uint96 totalFee = calculateTotalFee();
  393. vm.prank(address(consumer));
  394. vm.expectRevert("Too far in future");
  395. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  396. defaultProvider,
  397. farFutureTime,
  398. priceIds,
  399. CALLBACK_GAS_LIMIT
  400. );
  401. }
  402. function testDoubleExecuteCallback() public {
  403. (
  404. uint64 sequenceNumber,
  405. bytes32[] memory priceIds,
  406. uint256 publishTime
  407. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  408. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  409. publishTime
  410. );
  411. mockParsePriceFeedUpdates(pyth, priceFeeds);
  412. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  413. // First execution
  414. vm.prank(defaultProvider);
  415. pulse.executeCallback(
  416. defaultProvider,
  417. sequenceNumber,
  418. updateData,
  419. priceIds
  420. );
  421. // Second execution should fail
  422. vm.prank(defaultProvider);
  423. vm.expectRevert(NoSuchRequest.selector);
  424. pulse.executeCallback(
  425. defaultProvider,
  426. sequenceNumber,
  427. updateData,
  428. priceIds
  429. );
  430. }
  431. function testGetFee() public {
  432. // Test with different gas limits to verify fee calculation
  433. uint32[] memory gasLimits = new uint32[](3);
  434. gasLimits[0] = 100_000;
  435. gasLimits[1] = 500_000;
  436. gasLimits[2] = 1_000_000;
  437. bytes32[] memory priceIds = createPriceIds();
  438. for (uint256 i = 0; i < gasLimits.length; i++) {
  439. uint32 gasLimit = gasLimits[i];
  440. uint96 expectedFee = SafeCast.toUint96(
  441. DEFAULT_PROVIDER_BASE_FEE +
  442. DEFAULT_PROVIDER_FEE_PER_FEED *
  443. priceIds.length +
  444. DEFAULT_PROVIDER_FEE_PER_GAS *
  445. gasLimit
  446. ) + PYTH_FEE;
  447. uint96 actualFee = pulse.getFee(
  448. defaultProvider,
  449. gasLimit,
  450. priceIds
  451. );
  452. assertEq(
  453. actualFee,
  454. expectedFee,
  455. "Fee calculation incorrect for gas limit"
  456. );
  457. }
  458. // Test with zero gas limit
  459. uint96 expectedMinFee = SafeCast.toUint96(
  460. PYTH_FEE +
  461. DEFAULT_PROVIDER_BASE_FEE +
  462. DEFAULT_PROVIDER_FEE_PER_FEED *
  463. priceIds.length
  464. );
  465. uint96 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds);
  466. assertEq(
  467. actualMinFee,
  468. expectedMinFee,
  469. "Minimum fee calculation incorrect"
  470. );
  471. }
  472. function testWithdrawFees() public {
  473. // Setup: Request price update to accrue some fees
  474. bytes32[] memory priceIds = createPriceIds();
  475. vm.deal(address(consumer), 1 gwei);
  476. vm.prank(address(consumer));
  477. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  478. defaultProvider,
  479. SafeCast.toUint64(block.timestamp),
  480. priceIds,
  481. CALLBACK_GAS_LIMIT
  482. );
  483. // Get admin's balance before withdrawal
  484. uint256 adminBalanceBefore = admin.balance;
  485. uint128 accruedFees = pulse.getAccruedPythFees();
  486. // Withdraw fees as admin
  487. vm.prank(admin);
  488. pulse.withdrawFees(accruedFees);
  489. // Verify balances
  490. assertEq(
  491. admin.balance,
  492. adminBalanceBefore + accruedFees,
  493. "Admin balance should increase by withdrawn amount"
  494. );
  495. assertEq(
  496. pulse.getAccruedPythFees(),
  497. 0,
  498. "Contract should have no fees after withdrawal"
  499. );
  500. }
  501. function testWithdrawFeesUnauthorized() public {
  502. vm.prank(address(0xdead));
  503. vm.expectRevert("Only admin can withdraw fees");
  504. pulse.withdrawFees(1 ether);
  505. }
  506. function testWithdrawFeesInsufficientBalance() public {
  507. vm.prank(admin);
  508. vm.expectRevert("Insufficient balance");
  509. pulse.withdrawFees(1 ether);
  510. }
  511. function testSetAndWithdrawAsFeeManager() public {
  512. address feeManager = address(0x789);
  513. vm.prank(defaultProvider);
  514. pulse.setFeeManager(feeManager);
  515. // Setup: Request price update to accrue some fees
  516. bytes32[] memory priceIds = createPriceIds();
  517. vm.deal(address(consumer), 1 gwei);
  518. vm.prank(address(consumer));
  519. pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
  520. defaultProvider,
  521. SafeCast.toUint64(block.timestamp),
  522. priceIds,
  523. CALLBACK_GAS_LIMIT
  524. );
  525. // Get provider's accrued fees instead of total fees
  526. PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
  527. defaultProvider
  528. );
  529. uint128 providerAccruedFees = providerInfo.accruedFeesInWei;
  530. uint256 managerBalanceBefore = feeManager.balance;
  531. vm.prank(feeManager);
  532. pulse.withdrawAsFeeManager(
  533. defaultProvider,
  534. uint96(providerAccruedFees)
  535. );
  536. assertEq(
  537. feeManager.balance,
  538. managerBalanceBefore + providerAccruedFees,
  539. "Fee manager balance should increase by withdrawn amount"
  540. );
  541. providerInfo = pulse.getProviderInfo(defaultProvider);
  542. assertEq(
  543. providerInfo.accruedFeesInWei,
  544. 0,
  545. "Provider should have no fees after withdrawal"
  546. );
  547. }
  548. function testSetFeeManagerUnauthorized() public {
  549. address feeManager = address(0x789);
  550. vm.prank(address(0xdead));
  551. vm.expectRevert("Provider not registered");
  552. pulse.setFeeManager(feeManager);
  553. }
  554. function testWithdrawAsFeeManagerUnauthorized() public {
  555. vm.prank(address(0xdead));
  556. vm.expectRevert("Only fee manager");
  557. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  558. }
  559. function testWithdrawAsFeeManagerInsufficientBalance() public {
  560. // Set up fee manager first
  561. address feeManager = address(0x789);
  562. vm.prank(defaultProvider);
  563. pulse.setFeeManager(feeManager);
  564. vm.prank(feeManager);
  565. vm.expectRevert("Insufficient balance");
  566. pulse.withdrawAsFeeManager(defaultProvider, 1 ether);
  567. }
  568. // Add new test for invalid priceIds
  569. function testExecuteCallbackWithInvalidPriceIds() public {
  570. bytes32[] memory priceIds = createPriceIds();
  571. uint256 publishTime = block.timestamp;
  572. // Setup request
  573. (uint64 sequenceNumber, , ) = setupConsumerRequest(
  574. pulse,
  575. defaultProvider,
  576. address(consumer)
  577. );
  578. // Create different priceIds
  579. bytes32[] memory wrongPriceIds = new bytes32[](2);
  580. wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs
  581. wrongPriceIds[1] = bytes32(uint256(2));
  582. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  583. publishTime
  584. );
  585. mockParsePriceFeedUpdates(pyth, priceFeeds);
  586. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  587. // Should revert when trying to execute with wrong priceIds
  588. vm.prank(defaultProvider);
  589. // Extract first 8 bytes of the price ID for the error expectation
  590. bytes8 storedPriceIdPrefix;
  591. assembly {
  592. storedPriceIdPrefix := mload(add(priceIds, 32))
  593. }
  594. vm.expectRevert(
  595. abi.encodeWithSelector(
  596. InvalidPriceIds.selector,
  597. wrongPriceIds[0],
  598. storedPriceIdPrefix
  599. )
  600. );
  601. pulse.executeCallback(
  602. defaultProvider,
  603. sequenceNumber,
  604. updateData,
  605. wrongPriceIds
  606. );
  607. }
  608. function testRevertOnTooManyPriceIds() public {
  609. uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS());
  610. // Create array with MAX_PRICE_IDS + 1 price IDs
  611. bytes32[] memory priceIds = new bytes32[](maxPriceIds + 1);
  612. for (uint i = 0; i < priceIds.length; i++) {
  613. priceIds[i] = bytes32(uint256(i + 1));
  614. }
  615. vm.deal(address(consumer), 1 gwei);
  616. uint96 totalFee = calculateTotalFee();
  617. vm.prank(address(consumer));
  618. vm.expectRevert(
  619. abi.encodeWithSelector(
  620. TooManyPriceIds.selector,
  621. maxPriceIds + 1,
  622. maxPriceIds
  623. )
  624. );
  625. pulse.requestPriceUpdatesWithCallback{value: totalFee}(
  626. defaultProvider,
  627. SafeCast.toUint64(block.timestamp),
  628. priceIds,
  629. CALLBACK_GAS_LIMIT
  630. );
  631. }
  632. function testProviderRegistration() public {
  633. address provider = address(0x123);
  634. uint96 providerFee = 1000;
  635. vm.prank(provider);
  636. pulse.registerProvider(providerFee, providerFee, providerFee);
  637. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  638. assertEq(info.feePerGasInWei, providerFee);
  639. assertTrue(info.isRegistered);
  640. }
  641. function testSetProviderFee() public {
  642. address provider = address(0x123);
  643. uint96 initialBaseFee = 1000;
  644. uint96 initialFeePerFeed = 2000;
  645. uint96 initialFeePerGas = 3000;
  646. uint96 newFeePerFeed = 4000;
  647. uint96 newBaseFee = 5000;
  648. uint96 newFeePerGas = 6000;
  649. vm.prank(provider);
  650. pulse.registerProvider(
  651. initialBaseFee,
  652. initialFeePerFeed,
  653. initialFeePerGas
  654. );
  655. vm.prank(provider);
  656. pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas);
  657. PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
  658. assertEq(info.baseFeeInWei, newBaseFee);
  659. assertEq(info.feePerFeedInWei, newFeePerFeed);
  660. assertEq(info.feePerGasInWei, newFeePerGas);
  661. }
  662. function testDefaultProvider() public {
  663. address provider = address(0x123);
  664. uint96 providerFee = 1000;
  665. vm.prank(provider);
  666. pulse.registerProvider(providerFee, providerFee, providerFee);
  667. vm.prank(admin);
  668. pulse.setDefaultProvider(provider);
  669. assertEq(pulse.getDefaultProvider(), provider);
  670. }
  671. function testRequestWithProvider() public {
  672. address provider = address(0x123);
  673. uint96 providerFee = 1000;
  674. vm.prank(provider);
  675. pulse.registerProvider(providerFee, providerFee, providerFee);
  676. bytes32[] memory priceIds = new bytes32[](1);
  677. priceIds[0] = bytes32(uint256(1));
  678. uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds);
  679. vm.deal(address(consumer), totalFee);
  680. vm.prank(address(consumer));
  681. uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
  682. value: totalFee
  683. }(
  684. provider,
  685. SafeCast.toUint64(block.timestamp),
  686. priceIds,
  687. CALLBACK_GAS_LIMIT
  688. );
  689. PulseState.Request memory req = pulse.getRequest(sequenceNumber);
  690. assertEq(req.provider, provider);
  691. }
  692. function testExclusivityPeriod() public {
  693. // Test initial value
  694. assertEq(
  695. pulse.getExclusivityPeriod(),
  696. 15,
  697. "Initial exclusivity period should be 15 seconds"
  698. );
  699. // Test setting new value
  700. vm.prank(admin);
  701. vm.expectEmit();
  702. emit ExclusivityPeriodUpdated(15, 30);
  703. pulse.setExclusivityPeriod(30);
  704. assertEq(
  705. pulse.getExclusivityPeriod(),
  706. 30,
  707. "Exclusivity period should be updated"
  708. );
  709. }
  710. function testSetExclusivityPeriodUnauthorized() public {
  711. vm.prank(address(0xdead));
  712. vm.expectRevert("Only admin can set exclusivity period");
  713. pulse.setExclusivityPeriod(30);
  714. }
  715. function testExecuteCallbackDuringExclusivity() public {
  716. // Register a second provider
  717. address secondProvider = address(0x456);
  718. vm.prank(secondProvider);
  719. pulse.registerProvider(
  720. DEFAULT_PROVIDER_BASE_FEE,
  721. DEFAULT_PROVIDER_FEE_PER_FEED,
  722. DEFAULT_PROVIDER_FEE_PER_GAS
  723. );
  724. // Setup request
  725. (
  726. uint64 sequenceNumber,
  727. bytes32[] memory priceIds,
  728. uint256 publishTime
  729. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  730. // Setup mock data
  731. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  732. publishTime
  733. );
  734. mockParsePriceFeedUpdates(pyth, priceFeeds);
  735. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  736. // Try to execute with second provider during exclusivity period
  737. vm.expectRevert("Only assigned provider during exclusivity period");
  738. pulse.executeCallback(
  739. secondProvider,
  740. sequenceNumber,
  741. updateData,
  742. priceIds
  743. );
  744. // Original provider should succeed
  745. pulse.executeCallback(
  746. defaultProvider,
  747. sequenceNumber,
  748. updateData,
  749. priceIds
  750. );
  751. }
  752. function testExecuteCallbackAfterExclusivity() public {
  753. // Register a second provider
  754. address secondProvider = address(0x456);
  755. vm.prank(secondProvider);
  756. pulse.registerProvider(
  757. DEFAULT_PROVIDER_BASE_FEE,
  758. DEFAULT_PROVIDER_FEE_PER_FEED,
  759. DEFAULT_PROVIDER_FEE_PER_GAS
  760. );
  761. // Setup request
  762. (
  763. uint64 sequenceNumber,
  764. bytes32[] memory priceIds,
  765. uint256 publishTime
  766. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  767. // Setup mock data
  768. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  769. publishTime
  770. );
  771. mockParsePriceFeedUpdates(pyth, priceFeeds);
  772. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  773. // Wait for exclusivity period to end
  774. vm.warp(block.timestamp + pulse.getExclusivityPeriod() + 1);
  775. // Second provider should now succeed
  776. vm.prank(secondProvider);
  777. pulse.executeCallback(
  778. defaultProvider,
  779. sequenceNumber,
  780. updateData,
  781. priceIds
  782. );
  783. }
  784. function testExecuteCallbackWithCustomExclusivityPeriod() public {
  785. // Register a second provider
  786. address secondProvider = address(0x456);
  787. vm.prank(secondProvider);
  788. pulse.registerProvider(
  789. DEFAULT_PROVIDER_BASE_FEE,
  790. DEFAULT_PROVIDER_FEE_PER_FEED,
  791. DEFAULT_PROVIDER_FEE_PER_GAS
  792. );
  793. // Set custom exclusivity period
  794. vm.prank(admin);
  795. pulse.setExclusivityPeriod(30);
  796. // Setup request
  797. (
  798. uint64 sequenceNumber,
  799. bytes32[] memory priceIds,
  800. uint256 publishTime
  801. ) = setupConsumerRequest(pulse, defaultProvider, address(consumer));
  802. // Setup mock data
  803. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  804. publishTime
  805. );
  806. mockParsePriceFeedUpdates(pyth, priceFeeds);
  807. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  808. // Try at 29 seconds (should fail for second provider)
  809. vm.warp(block.timestamp + 29);
  810. vm.expectRevert("Only assigned provider during exclusivity period");
  811. pulse.executeCallback(
  812. secondProvider,
  813. sequenceNumber,
  814. updateData,
  815. priceIds
  816. );
  817. // Try at 31 seconds (should succeed for second provider)
  818. vm.warp(block.timestamp + 2);
  819. pulse.executeCallback(
  820. secondProvider,
  821. sequenceNumber,
  822. updateData,
  823. priceIds
  824. );
  825. }
  826. function testGetFirstActiveRequests() public {
  827. // Setup test data
  828. (
  829. bytes32[] memory priceIds,
  830. bytes[] memory updateData
  831. ) = setupTestData();
  832. createTestRequests(priceIds);
  833. completeRequests(updateData, priceIds);
  834. testRequestScenarios(priceIds, updateData);
  835. }
  836. function setupTestData()
  837. private
  838. pure
  839. returns (bytes32[] memory, bytes[] memory)
  840. {
  841. bytes32[] memory priceIds = new bytes32[](1);
  842. priceIds[0] = bytes32(uint256(1));
  843. bytes[] memory updateData = new bytes[](1);
  844. return (priceIds, updateData);
  845. }
  846. function createTestRequests(bytes32[] memory priceIds) private {
  847. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  848. for (uint i = 0; i < 5; i++) {
  849. vm.deal(address(this), 1 ether);
  850. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  851. defaultProvider,
  852. publishTime,
  853. priceIds,
  854. 1000000
  855. );
  856. }
  857. }
  858. function completeRequests(
  859. bytes[] memory updateData,
  860. bytes32[] memory priceIds
  861. ) private {
  862. // Create mock price feeds and setup Pyth response
  863. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  864. SafeCast.toUint64(block.timestamp)
  865. );
  866. mockParsePriceFeedUpdates(pyth, priceFeeds);
  867. updateData = createMockUpdateData(priceFeeds);
  868. vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
  869. vm.startPrank(defaultProvider);
  870. pulse.executeCallback{value: 1 ether}(
  871. defaultProvider,
  872. 2,
  873. updateData,
  874. priceIds
  875. );
  876. pulse.executeCallback{value: 1 ether}(
  877. defaultProvider,
  878. 4,
  879. updateData,
  880. priceIds
  881. );
  882. vm.stopPrank();
  883. }
  884. function testRequestScenarios(
  885. bytes32[] memory priceIds,
  886. bytes[] memory updateData
  887. ) private {
  888. // Test 1: Request more than available
  889. checkMoreThanAvailable();
  890. // Test 2: Request exact number
  891. checkExactNumber();
  892. // Test 3: Request fewer than available
  893. checkFewerThanAvailable();
  894. // Test 4: Request zero
  895. checkZeroRequest();
  896. // Test 5: Clear all and check empty
  897. clearAllRequests(updateData, priceIds);
  898. checkEmptyState();
  899. }
  900. // Split test scenarios into separate functions
  901. function checkMoreThanAvailable() private {
  902. (PulseState.Request[] memory requests, uint256 count) = pulse
  903. .getFirstActiveRequests(10);
  904. assertEq(count, 3, "Should find 3 active requests");
  905. assertEq(requests.length, 3, "Array should be resized to 3");
  906. assertEq(
  907. requests[0].sequenceNumber,
  908. 1,
  909. "First request should be oldest"
  910. );
  911. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  912. assertEq(requests[2].sequenceNumber, 5, "Third request should be #5");
  913. }
  914. function checkExactNumber() private {
  915. (PulseState.Request[] memory requests, uint256 count) = pulse
  916. .getFirstActiveRequests(3);
  917. assertEq(count, 3, "Should find 3 active requests");
  918. assertEq(requests.length, 3, "Array should match requested size");
  919. }
  920. function checkFewerThanAvailable() private {
  921. (PulseState.Request[] memory requests, uint256 count) = pulse
  922. .getFirstActiveRequests(2);
  923. assertEq(count, 2, "Should find 2 active requests");
  924. assertEq(requests.length, 2, "Array should match requested size");
  925. assertEq(
  926. requests[0].sequenceNumber,
  927. 1,
  928. "First request should be oldest"
  929. );
  930. assertEq(requests[1].sequenceNumber, 3, "Second request should be #3");
  931. }
  932. function checkZeroRequest() private {
  933. (PulseState.Request[] memory requests, uint256 count) = pulse
  934. .getFirstActiveRequests(0);
  935. assertEq(count, 0, "Should find 0 active requests");
  936. assertEq(requests.length, 0, "Array should be empty");
  937. }
  938. function clearAllRequests(
  939. bytes[] memory updateData,
  940. bytes32[] memory priceIds
  941. ) private {
  942. vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
  943. vm.startPrank(defaultProvider);
  944. pulse.executeCallback{value: 1 ether}(
  945. defaultProvider,
  946. 1,
  947. updateData,
  948. priceIds
  949. );
  950. pulse.executeCallback{value: 1 ether}(
  951. defaultProvider,
  952. 3,
  953. updateData,
  954. priceIds
  955. );
  956. pulse.executeCallback{value: 1 ether}(
  957. defaultProvider,
  958. 5,
  959. updateData,
  960. priceIds
  961. );
  962. vm.stopPrank();
  963. }
  964. function checkEmptyState() private {
  965. (PulseState.Request[] memory requests, uint256 count) = pulse
  966. .getFirstActiveRequests(10);
  967. assertEq(count, 0, "Should find 0 active requests");
  968. assertEq(requests.length, 0, "Array should be empty");
  969. }
  970. function testGetFirstActiveRequestsGasUsage() public {
  971. // Setup test data
  972. bytes32[] memory priceIds = new bytes32[](1);
  973. priceIds[0] = bytes32(uint256(1));
  974. uint64 publishTime = SafeCast.toUint64(block.timestamp);
  975. uint256 callbackGasLimit = 1000000;
  976. // Create mock price feeds and setup Pyth response
  977. PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
  978. publishTime
  979. );
  980. mockParsePriceFeedUpdates(pyth, priceFeeds);
  981. bytes[] memory updateData = createMockUpdateData(priceFeeds);
  982. // Create 20 requests with some gaps
  983. for (uint i = 0; i < 20; i++) {
  984. vm.deal(address(this), 1 ether);
  985. pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
  986. defaultProvider,
  987. publishTime,
  988. priceIds,
  989. uint32(callbackGasLimit)
  990. );
  991. // Complete every third request to create gaps
  992. if (i % 3 == 0) {
  993. vm.deal(defaultProvider, 1 ether);
  994. vm.prank(defaultProvider);
  995. pulse.executeCallback{value: 1 ether}(
  996. defaultProvider,
  997. uint64(i + 1),
  998. updateData,
  999. priceIds
  1000. );
  1001. }
  1002. }
  1003. // Measure gas for different request counts
  1004. uint256 gas1 = gasleft();
  1005. pulse.getFirstActiveRequests(5);
  1006. uint256 gas1Used = gas1 - gasleft();
  1007. uint256 gas2 = gasleft();
  1008. pulse.getFirstActiveRequests(10);
  1009. uint256 gas2Used = gas2 - gasleft();
  1010. // Log gas usage for analysis
  1011. emit log_named_uint("Gas used for 5 requests", gas1Used);
  1012. emit log_named_uint("Gas used for 10 requests", gas2Used);
  1013. // Verify gas usage scales roughly linearly
  1014. // Allow 10% margin for other factors
  1015. assertApproxEqRel(
  1016. gas2Used,
  1017. gas1Used * 2,
  1018. 0.1e18, // 10% tolerance
  1019. "Gas usage should scale roughly linearly"
  1020. );
  1021. }
  1022. function getPulse() internal view override returns (address) {
  1023. return address(pulse);
  1024. }
  1025. // Mock implementation of pulseCallback
  1026. function pulseCallback(
  1027. uint64 sequenceNumber,
  1028. PythStructs.PriceFeed[] memory priceFeeds
  1029. ) internal override {
  1030. // Just accept the callback, no need to do anything with the data
  1031. // This prevents the revert we're seeing
  1032. }
  1033. }