Pyth.t.sol 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
  4. import "forge-std/Test.sol";
  5. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  6. import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
  7. import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
  8. import "./utils/WormholeTestUtils.t.sol";
  9. import "./utils/PythTestUtils.t.sol";
  10. import "./utils/RandTestUtils.t.sol";
  11. contract PythTest is Test, WormholeTestUtils, PythTestUtils {
  12. IPyth public pyth;
  13. // -1 is equal to 0xffffff which is the biggest uint if converted back
  14. uint64 constant MAX_UINT64 = uint64(int64(-1));
  15. // 2/3 of the guardians should sign a message for a VAA which is 13 out of 19 guardians.
  16. // It is possible to have more signers but the median seems to be 13.
  17. uint8 constant NUM_GUARDIAN_SIGNERS = 13;
  18. // We will have less than 512 price for a foreseeable future.
  19. uint8 constant MERKLE_TREE_DEPTH = 9;
  20. function setUp() public {
  21. pyth = IPyth(setUpPyth(setUpWormholeReceiver(NUM_GUARDIAN_SIGNERS)));
  22. }
  23. function generateRandomPriceMessages(
  24. uint length
  25. )
  26. internal
  27. returns (bytes32[] memory priceIds, PriceFeedMessage[] memory messages)
  28. {
  29. messages = new PriceFeedMessage[](length);
  30. priceIds = new bytes32[](length);
  31. for (uint i = 0; i < length; i++) {
  32. messages[i].priceId = bytes32(i + 1); // price ids should be non-zero and unique
  33. messages[i].price = getRandInt64();
  34. messages[i].conf = getRandUint64();
  35. messages[i].expo = getRandInt32();
  36. messages[i].emaPrice = getRandInt64();
  37. messages[i].emaConf = getRandUint64();
  38. messages[i].publishTime = getRandUint64();
  39. messages[i].prevPublishTime = getRandUint64();
  40. priceIds[i] = messages[i].priceId;
  41. }
  42. }
  43. // This method divides messages into a couple of batches and creates
  44. // updateData for them. It returns the updateData and the updateFee
  45. function createBatchedUpdateDataFromMessagesWithConfig(
  46. PriceFeedMessage[] memory messages,
  47. MerkleUpdateConfig memory config
  48. ) internal returns (bytes[] memory updateData, uint updateFee) {
  49. uint batchSize = 1 + (getRandUint() % messages.length);
  50. uint numBatches = (messages.length + batchSize - 1) / batchSize;
  51. updateData = new bytes[](numBatches);
  52. for (uint i = 0; i < messages.length; i += batchSize) {
  53. uint len = batchSize;
  54. if (messages.length - i < len) {
  55. len = messages.length - i;
  56. }
  57. PriceFeedMessage[] memory batchMessages = new PriceFeedMessage[](
  58. len
  59. );
  60. for (uint j = i; j < i + len; j++) {
  61. batchMessages[j - i] = messages[j];
  62. }
  63. updateData[i / batchSize] = generateWhMerkleUpdateWithSource(
  64. batchMessages,
  65. config
  66. );
  67. }
  68. updateFee = pyth.getUpdateFee(updateData);
  69. }
  70. function createBatchedUpdateDataFromMessages(
  71. PriceFeedMessage[] memory messages
  72. ) internal returns (bytes[] memory updateData, uint updateFee) {
  73. (updateData, updateFee) = createBatchedUpdateDataFromMessagesWithConfig(
  74. messages,
  75. MerkleUpdateConfig(
  76. MERKLE_TREE_DEPTH,
  77. NUM_GUARDIAN_SIGNERS,
  78. SOURCE_EMITTER_CHAIN_ID,
  79. SOURCE_EMITTER_ADDRESS,
  80. false
  81. )
  82. );
  83. }
  84. /// Testing parsePriceFeedUpdates method.
  85. function testParsePriceFeedUpdatesWorks(uint seed) public {
  86. setRandSeed(seed);
  87. uint numMessages = 1 + (getRandUint() % 10);
  88. (
  89. bytes32[] memory priceIds,
  90. PriceFeedMessage[] memory messages
  91. ) = generateRandomPriceMessages(numMessages);
  92. (
  93. bytes[] memory updateData,
  94. uint updateFee
  95. ) = createBatchedUpdateDataFromMessages(messages);
  96. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  97. value: updateFee
  98. }(updateData, priceIds, 0, MAX_UINT64);
  99. for (uint i = 0; i < numMessages; i++) {
  100. assertEq(priceFeeds[i].id, priceIds[i]);
  101. assertEq(priceFeeds[i].price.price, messages[i].price);
  102. assertEq(priceFeeds[i].price.conf, messages[i].conf);
  103. assertEq(priceFeeds[i].price.expo, messages[i].expo);
  104. assertEq(priceFeeds[i].price.publishTime, messages[i].publishTime);
  105. assertEq(priceFeeds[i].emaPrice.price, messages[i].emaPrice);
  106. assertEq(priceFeeds[i].emaPrice.conf, messages[i].emaConf);
  107. assertEq(priceFeeds[i].emaPrice.expo, messages[i].expo);
  108. assertEq(
  109. priceFeeds[i].emaPrice.publishTime,
  110. messages[i].publishTime
  111. );
  112. }
  113. }
  114. function testParsePriceFeedUpdatesWorksWithRandomDistinctUpdatesInput(
  115. uint seed
  116. ) public {
  117. setRandSeed(seed);
  118. uint numMessages = 1 + (getRandUint() % 30);
  119. (
  120. bytes32[] memory priceIds,
  121. PriceFeedMessage[] memory messages
  122. ) = generateRandomPriceMessages(numMessages);
  123. (
  124. bytes[] memory updateData,
  125. uint updateFee
  126. ) = createBatchedUpdateDataFromMessages(messages);
  127. // Shuffle the messages
  128. for (uint i = 1; i < numMessages; i++) {
  129. uint swapWith = getRandUint() % (i + 1);
  130. (messages[i], messages[swapWith]) = (
  131. messages[swapWith],
  132. messages[i]
  133. );
  134. (priceIds[i], priceIds[swapWith]) = (
  135. priceIds[swapWith],
  136. priceIds[i]
  137. );
  138. }
  139. // Select only first numSelectedMessages. numSelectedMessages will be in [0, numMessages]
  140. uint numSelectedMessages = getRandUint() % (numMessages + 1);
  141. PriceFeedMessage[] memory selectedMessages = new PriceFeedMessage[](
  142. numSelectedMessages
  143. );
  144. bytes32[] memory selectedPriceIds = new bytes32[](numSelectedMessages);
  145. for (uint i = 0; i < numSelectedMessages; i++) {
  146. selectedMessages[i] = messages[i];
  147. selectedPriceIds[i] = priceIds[i];
  148. }
  149. // Only parse selected messages
  150. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  151. value: updateFee
  152. }(updateData, selectedPriceIds, 0, MAX_UINT64);
  153. for (uint i = 0; i < numSelectedMessages; i++) {
  154. assertEq(priceFeeds[i].id, selectedPriceIds[i]);
  155. assertEq(priceFeeds[i].price.expo, selectedMessages[i].expo);
  156. assertEq(
  157. priceFeeds[i].emaPrice.price,
  158. selectedMessages[i].emaPrice
  159. );
  160. assertEq(priceFeeds[i].emaPrice.conf, selectedMessages[i].emaConf);
  161. assertEq(priceFeeds[i].emaPrice.expo, selectedMessages[i].expo);
  162. assertEq(priceFeeds[i].price.price, selectedMessages[i].price);
  163. assertEq(priceFeeds[i].price.conf, selectedMessages[i].conf);
  164. assertEq(
  165. priceFeeds[i].price.publishTime,
  166. selectedMessages[i].publishTime
  167. );
  168. assertEq(
  169. priceFeeds[i].emaPrice.publishTime,
  170. selectedMessages[i].publishTime
  171. );
  172. }
  173. }
  174. function testParsePriceFeedUpdatesWorksWithOverlappingWithinTimeRangeUpdates()
  175. public
  176. {
  177. PriceFeedMessage[] memory messages = new PriceFeedMessage[](2);
  178. messages[0].priceId = bytes32(uint(1));
  179. messages[0].price = 1000;
  180. messages[0].publishTime = 10;
  181. messages[1].priceId = bytes32(uint(1));
  182. messages[1].price = 2000;
  183. messages[1].publishTime = 20;
  184. (
  185. bytes[] memory updateData,
  186. uint updateFee
  187. ) = createBatchedUpdateDataFromMessages(messages);
  188. bytes32[] memory priceIds = new bytes32[](1);
  189. priceIds[0] = bytes32(uint(1));
  190. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  191. value: updateFee
  192. }(updateData, priceIds, 0, 20);
  193. assertEq(priceFeeds.length, 1);
  194. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  195. assertTrue(
  196. (priceFeeds[0].price.price == 1000 &&
  197. priceFeeds[0].price.publishTime == 10) ||
  198. (priceFeeds[0].price.price == 2000 &&
  199. priceFeeds[0].price.publishTime == 20)
  200. );
  201. }
  202. function testParsePriceFeedUpdatesWorksWithOverlappingMixedTimeRangeUpdates()
  203. public
  204. {
  205. PriceFeedMessage[] memory messages = new PriceFeedMessage[](2);
  206. messages[0].priceId = bytes32(uint(1));
  207. messages[0].price = 1000;
  208. messages[0].publishTime = 10;
  209. messages[1].priceId = bytes32(uint(1));
  210. messages[1].price = 2000;
  211. messages[1].publishTime = 20;
  212. (
  213. bytes[] memory updateData,
  214. uint updateFee
  215. ) = createBatchedUpdateDataFromMessages(messages);
  216. bytes32[] memory priceIds = new bytes32[](1);
  217. priceIds[0] = bytes32(uint(1));
  218. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  219. value: updateFee
  220. }(updateData, priceIds, 5, 15);
  221. assertEq(priceFeeds.length, 1);
  222. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  223. assertEq(priceFeeds[0].price.price, 1000);
  224. assertEq(priceFeeds[0].price.publishTime, 10);
  225. priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}(
  226. updateData,
  227. priceIds,
  228. 15,
  229. 25
  230. );
  231. assertEq(priceFeeds.length, 1);
  232. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  233. assertEq(priceFeeds[0].price.price, 2000);
  234. assertEq(priceFeeds[0].price.publishTime, 20);
  235. }
  236. function testParsePriceFeedUpdatesRevertsIfUpdateFeeIsNotPaid() public {
  237. uint numMessages = 10;
  238. (
  239. bytes32[] memory priceIds,
  240. PriceFeedMessage[] memory messages
  241. ) = generateRandomPriceMessages(numMessages);
  242. (
  243. bytes[] memory updateData,
  244. uint updateFee
  245. ) = createBatchedUpdateDataFromMessages(messages);
  246. // Since messages are not empty the fee should be at least 1
  247. assertGe(updateFee, 1);
  248. vm.expectRevert(PythErrors.InsufficientFee.selector);
  249. pyth.parsePriceFeedUpdates{value: updateFee - 1}(
  250. updateData,
  251. priceIds,
  252. 0,
  253. MAX_UINT64
  254. );
  255. }
  256. function testParsePriceFeedUpdatesRevertsIfUpdateVAAIsInvalid(
  257. uint seed
  258. ) public {
  259. setRandSeed(seed);
  260. uint numMessages = 1 + (getRandUint() % 10);
  261. (
  262. bytes32[] memory priceIds,
  263. PriceFeedMessage[] memory messages
  264. ) = generateRandomPriceMessages(numMessages);
  265. (
  266. bytes[] memory updateData,
  267. uint updateFee
  268. ) = createBatchedUpdateDataFromMessagesWithConfig(
  269. messages,
  270. MerkleUpdateConfig(
  271. MERKLE_TREE_DEPTH,
  272. NUM_GUARDIAN_SIGNERS,
  273. SOURCE_EMITTER_CHAIN_ID,
  274. SOURCE_EMITTER_ADDRESS,
  275. true
  276. )
  277. );
  278. // It might revert due to different wormhole errors
  279. vm.expectRevert();
  280. pyth.parsePriceFeedUpdates{value: updateFee}(
  281. updateData,
  282. priceIds,
  283. 0,
  284. MAX_UINT64
  285. );
  286. }
  287. function testParsePriceFeedUpdatesRevertsIfUpdateSourceChainIsInvalid()
  288. public
  289. {
  290. uint numMessages = 10;
  291. (
  292. bytes32[] memory priceIds,
  293. PriceFeedMessage[] memory messages
  294. ) = generateRandomPriceMessages(numMessages);
  295. (
  296. bytes[] memory updateData,
  297. uint updateFee
  298. ) = createBatchedUpdateDataFromMessagesWithConfig(
  299. messages,
  300. MerkleUpdateConfig(
  301. MERKLE_TREE_DEPTH,
  302. NUM_GUARDIAN_SIGNERS,
  303. SOURCE_EMITTER_CHAIN_ID + 1,
  304. SOURCE_EMITTER_ADDRESS,
  305. false
  306. )
  307. );
  308. vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
  309. pyth.parsePriceFeedUpdates{value: updateFee}(
  310. updateData,
  311. priceIds,
  312. 0,
  313. MAX_UINT64
  314. );
  315. }
  316. function testParsePriceFeedUpdatesRevertsIfUpdateSourceAddressIsInvalid()
  317. public
  318. {
  319. uint numMessages = 10;
  320. (
  321. bytes32[] memory priceIds,
  322. PriceFeedMessage[] memory messages
  323. ) = generateRandomPriceMessages(numMessages);
  324. (bytes[] memory updateData, uint updateFee) = createBatchedUpdateDataFromMessagesWithConfig(
  325. messages,
  326. MerkleUpdateConfig(
  327. MERKLE_TREE_DEPTH,
  328. NUM_GUARDIAN_SIGNERS,
  329. SOURCE_EMITTER_CHAIN_ID,
  330. 0x00000000000000000000000000000000000000000000000000000000000000aa, // Random wrong source address
  331. false
  332. )
  333. );
  334. vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
  335. pyth.parsePriceFeedUpdates{value: updateFee}(
  336. updateData,
  337. priceIds,
  338. 0,
  339. MAX_UINT64
  340. );
  341. }
  342. function testParsePriceFeedUpdatesRevertsIfPriceIdNotIncluded() public {
  343. PriceFeedMessage[] memory messages = new PriceFeedMessage[](1);
  344. messages[0].priceId = bytes32(uint(1));
  345. messages[0].price = 1000;
  346. messages[0].publishTime = 10;
  347. (
  348. bytes[] memory updateData,
  349. uint updateFee
  350. ) = createBatchedUpdateDataFromMessages(messages);
  351. bytes32[] memory priceIds = new bytes32[](1);
  352. priceIds[0] = bytes32(uint(2));
  353. vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
  354. pyth.parsePriceFeedUpdates{value: updateFee}(
  355. updateData,
  356. priceIds,
  357. 0,
  358. MAX_UINT64
  359. );
  360. }
  361. function testParsePriceFeedUpdateRevertsIfPricesOutOfTimeRange() public {
  362. uint numMessages = 10;
  363. (
  364. bytes32[] memory priceIds,
  365. PriceFeedMessage[] memory messages
  366. ) = generateRandomPriceMessages(numMessages);
  367. for (uint i = 0; i < numMessages; i++) {
  368. messages[i].publishTime = uint64(100 + (getRandUint() % 101)); // All between [100, 200]
  369. }
  370. (
  371. bytes[] memory updateData,
  372. uint updateFee
  373. ) = createBatchedUpdateDataFromMessages(messages);
  374. // Request for parse within the given time range should work
  375. pyth.parsePriceFeedUpdates{value: updateFee}(
  376. updateData,
  377. priceIds,
  378. 100,
  379. 200
  380. );
  381. // Request for parse after the time range should revert.
  382. vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
  383. pyth.parsePriceFeedUpdates{value: updateFee}(
  384. updateData,
  385. priceIds,
  386. 300,
  387. MAX_UINT64
  388. );
  389. }
  390. function testParsePriceFeedUpdatesLatestPriceIfNecessary() public {
  391. uint numMessages = 10;
  392. (
  393. bytes32[] memory priceIds,
  394. PriceFeedMessage[] memory messages
  395. ) = generateRandomPriceMessages(numMessages);
  396. for (uint i = 0; i < numMessages; i++) {
  397. messages[i].publishTime = uint64((getRandUint() % 101)); // All between [0, 100]
  398. }
  399. (
  400. bytes[] memory updateData,
  401. uint updateFee
  402. ) = createBatchedUpdateDataFromMessages(messages);
  403. // Request for parse within the given time range should work and update the latest price
  404. pyth.parsePriceFeedUpdates{value: updateFee}(
  405. updateData,
  406. priceIds,
  407. 0,
  408. 100
  409. );
  410. // Check if the latest price is updated
  411. for (uint i = 0; i < numMessages; i++) {
  412. assertEq(
  413. pyth.getPriceUnsafe(priceIds[i]).publishTime,
  414. messages[i].publishTime
  415. );
  416. }
  417. for (uint i = 0; i < numMessages; i++) {
  418. messages[i].publishTime = uint64(100 + (getRandUint() % 101)); // All between [100, 200]
  419. }
  420. (updateData, updateFee) = createBatchedUpdateDataFromMessages(messages);
  421. // Request for parse after the time range should revert.
  422. vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
  423. pyth.parsePriceFeedUpdates{value: updateFee}(
  424. updateData,
  425. priceIds,
  426. 300,
  427. 400
  428. );
  429. // parse function reverted so publishTimes should remain less than or equal to 100
  430. for (uint i = 0; i < numMessages; i++) {
  431. assertGe(100, pyth.getPriceUnsafe(priceIds[i]).publishTime);
  432. }
  433. // Time range is now fixed, so parse should work and update the latest price
  434. pyth.parsePriceFeedUpdates{value: updateFee}(
  435. updateData,
  436. priceIds,
  437. 100,
  438. 200
  439. );
  440. // Check if the latest price is updated
  441. for (uint i = 0; i < numMessages; i++) {
  442. assertEq(
  443. pyth.getPriceUnsafe(priceIds[i]).publishTime,
  444. messages[i].publishTime
  445. );
  446. }
  447. }
  448. }