Pyth.t.sol 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
  4. import "forge-std/Test.sol";
  5. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  6. import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
  7. import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
  8. import "./utils/WormholeTestUtils.t.sol";
  9. import "./utils/PythTestUtils.t.sol";
  10. import "./utils/RandTestUtils.t.sol";
  11. contract PythTest is Test, WormholeTestUtils, PythTestUtils {
  12. IPyth public pyth;
  13. // -1 is equal to 0xffffff which is the biggest uint if converted back
  14. uint64 constant MAX_UINT64 = uint64(int64(-1));
  15. // 2/3 of the guardians should sign a message for a VAA which is 13 out of 19 guardians.
  16. // It is possible to have more signers but the median seems to be 13.
  17. uint8 constant NUM_GUARDIAN_SIGNERS = 13;
  18. // We will have less than 512 price for a foreseeable future.
  19. uint8 constant MERKLE_TREE_DEPTH = 9;
  20. function setUp() public {
  21. pyth = IPyth(setUpPyth(setUpWormholeReceiver(NUM_GUARDIAN_SIGNERS)));
  22. }
  23. function generateRandomPriceMessages(
  24. uint length
  25. )
  26. internal
  27. returns (bytes32[] memory priceIds, PriceFeedMessage[] memory messages)
  28. {
  29. messages = new PriceFeedMessage[](length);
  30. priceIds = new bytes32[](length);
  31. for (uint i = 0; i < length; i++) {
  32. messages[i].priceId = bytes32(i + 1); // price ids should be non-zero and unique
  33. messages[i].price = getRandInt64();
  34. messages[i].conf = getRandUint64();
  35. messages[i].expo = getRandInt32();
  36. messages[i].emaPrice = getRandInt64();
  37. messages[i].emaConf = getRandUint64();
  38. messages[i].publishTime = getRandUint64();
  39. messages[i].prevPublishTime = getRandUint64();
  40. priceIds[i] = messages[i].priceId;
  41. }
  42. }
  43. // This method divides messages into a couple of batches and creates
  44. // updateData for them. It returns the updateData and the updateFee
  45. function createBatchedUpdateDataFromMessagesWithConfig(
  46. PriceFeedMessage[] memory messages,
  47. MerkleUpdateConfig memory config
  48. ) internal returns (bytes[] memory updateData, uint updateFee) {
  49. uint batchSize = 1 + (getRandUint() % messages.length);
  50. uint numBatches = (messages.length + batchSize - 1) / batchSize;
  51. updateData = new bytes[](numBatches);
  52. for (uint i = 0; i < messages.length; i += batchSize) {
  53. uint len = batchSize;
  54. if (messages.length - i < len) {
  55. len = messages.length - i;
  56. }
  57. PriceFeedMessage[] memory batchMessages = new PriceFeedMessage[](
  58. len
  59. );
  60. for (uint j = i; j < i + len; j++) {
  61. batchMessages[j - i] = messages[j];
  62. }
  63. updateData[i / batchSize] = generateWhMerkleUpdateWithSource(
  64. batchMessages,
  65. config
  66. );
  67. }
  68. updateFee = pyth.getUpdateFee(updateData);
  69. }
  70. function createBatchedUpdateDataFromMessages(
  71. PriceFeedMessage[] memory messages
  72. ) internal returns (bytes[] memory updateData, uint updateFee) {
  73. (updateData, updateFee) = createBatchedUpdateDataFromMessagesWithConfig(
  74. messages,
  75. MerkleUpdateConfig(
  76. MERKLE_TREE_DEPTH,
  77. NUM_GUARDIAN_SIGNERS,
  78. SOURCE_EMITTER_CHAIN_ID,
  79. SOURCE_EMITTER_ADDRESS,
  80. false
  81. )
  82. );
  83. }
  84. /// Testing parsePriceFeedUpdates method.
  85. function testParsePriceFeedUpdatesWorks(uint seed) public {
  86. setRandSeed(seed);
  87. uint numMessages = 1 + (getRandUint() % 10);
  88. (
  89. bytes32[] memory priceIds,
  90. PriceFeedMessage[] memory messages
  91. ) = generateRandomPriceMessages(numMessages);
  92. (
  93. bytes[] memory updateData,
  94. uint updateFee
  95. ) = createBatchedUpdateDataFromMessages(messages);
  96. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  97. value: updateFee
  98. }(updateData, priceIds, 0, MAX_UINT64);
  99. for (uint i = 0; i < numMessages; i++) {
  100. assertEq(priceFeeds[i].id, priceIds[i]);
  101. assertEq(priceFeeds[i].price.price, messages[i].price);
  102. assertEq(priceFeeds[i].price.conf, messages[i].conf);
  103. assertEq(priceFeeds[i].price.expo, messages[i].expo);
  104. assertEq(priceFeeds[i].price.publishTime, messages[i].publishTime);
  105. assertEq(priceFeeds[i].emaPrice.price, messages[i].emaPrice);
  106. assertEq(priceFeeds[i].emaPrice.conf, messages[i].emaConf);
  107. assertEq(priceFeeds[i].emaPrice.expo, messages[i].expo);
  108. assertEq(
  109. priceFeeds[i].emaPrice.publishTime,
  110. messages[i].publishTime
  111. );
  112. }
  113. }
  114. function testParsePriceFeedUpdatesWorksWithOverlappingWithinTimeRangeUpdates()
  115. public
  116. {
  117. PriceFeedMessage[] memory messages = new PriceFeedMessage[](2);
  118. messages[0].priceId = bytes32(uint(1));
  119. messages[0].price = 1000;
  120. messages[0].publishTime = 10;
  121. messages[1].priceId = bytes32(uint(1));
  122. messages[1].price = 2000;
  123. messages[1].publishTime = 20;
  124. (
  125. bytes[] memory updateData,
  126. uint updateFee
  127. ) = createBatchedUpdateDataFromMessages(messages);
  128. bytes32[] memory priceIds = new bytes32[](1);
  129. priceIds[0] = bytes32(uint(1));
  130. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  131. value: updateFee
  132. }(updateData, priceIds, 0, 20);
  133. assertEq(priceFeeds.length, 1);
  134. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  135. assertTrue(
  136. (priceFeeds[0].price.price == 1000 &&
  137. priceFeeds[0].price.publishTime == 10) ||
  138. (priceFeeds[0].price.price == 2000 &&
  139. priceFeeds[0].price.publishTime == 20)
  140. );
  141. }
  142. function testParsePriceFeedUpdatesWorksWithOverlappingMixedTimeRangeUpdates()
  143. public
  144. {
  145. PriceFeedMessage[] memory messages = new PriceFeedMessage[](2);
  146. messages[0].priceId = bytes32(uint(1));
  147. messages[0].price = 1000;
  148. messages[0].publishTime = 10;
  149. messages[1].priceId = bytes32(uint(1));
  150. messages[1].price = 2000;
  151. messages[1].publishTime = 20;
  152. (
  153. bytes[] memory updateData,
  154. uint updateFee
  155. ) = createBatchedUpdateDataFromMessages(messages);
  156. bytes32[] memory priceIds = new bytes32[](1);
  157. priceIds[0] = bytes32(uint(1));
  158. PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
  159. value: updateFee
  160. }(updateData, priceIds, 5, 15);
  161. assertEq(priceFeeds.length, 1);
  162. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  163. assertEq(priceFeeds[0].price.price, 1000);
  164. assertEq(priceFeeds[0].price.publishTime, 10);
  165. priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}(
  166. updateData,
  167. priceIds,
  168. 15,
  169. 25
  170. );
  171. assertEq(priceFeeds.length, 1);
  172. assertEq(priceFeeds[0].id, bytes32(uint(1)));
  173. assertEq(priceFeeds[0].price.price, 2000);
  174. assertEq(priceFeeds[0].price.publishTime, 20);
  175. }
  176. function testParsePriceFeedUpdatesRevertsIfUpdateVAAIsInvalid(
  177. uint seed
  178. ) public {
  179. setRandSeed(seed);
  180. uint numMessages = 1 + (getRandUint() % 10);
  181. (
  182. bytes32[] memory priceIds,
  183. PriceFeedMessage[] memory messages
  184. ) = generateRandomPriceMessages(numMessages);
  185. (
  186. bytes[] memory updateData,
  187. uint updateFee
  188. ) = createBatchedUpdateDataFromMessagesWithConfig(
  189. messages,
  190. MerkleUpdateConfig(
  191. MERKLE_TREE_DEPTH,
  192. NUM_GUARDIAN_SIGNERS,
  193. SOURCE_EMITTER_CHAIN_ID,
  194. SOURCE_EMITTER_ADDRESS,
  195. true
  196. )
  197. );
  198. // It might revert due to different wormhole errors
  199. vm.expectRevert();
  200. pyth.parsePriceFeedUpdates{value: updateFee}(
  201. updateData,
  202. priceIds,
  203. 0,
  204. MAX_UINT64
  205. );
  206. }
  207. function testParsePriceFeedUpdatesRevertsIfUpdateSourceChainIsInvalid()
  208. public
  209. {
  210. uint numMessages = 10;
  211. (
  212. bytes32[] memory priceIds,
  213. PriceFeedMessage[] memory messages
  214. ) = generateRandomPriceMessages(numMessages);
  215. (
  216. bytes[] memory updateData,
  217. uint updateFee
  218. ) = createBatchedUpdateDataFromMessagesWithConfig(
  219. messages,
  220. MerkleUpdateConfig(
  221. MERKLE_TREE_DEPTH,
  222. NUM_GUARDIAN_SIGNERS,
  223. SOURCE_EMITTER_CHAIN_ID + 1,
  224. SOURCE_EMITTER_ADDRESS,
  225. false
  226. )
  227. );
  228. vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
  229. pyth.parsePriceFeedUpdates{value: updateFee}(
  230. updateData,
  231. priceIds,
  232. 0,
  233. MAX_UINT64
  234. );
  235. }
  236. function testParsePriceFeedUpdatesRevertsIfUpdateSourceAddressIsInvalid()
  237. public
  238. {
  239. uint numMessages = 10;
  240. (
  241. bytes32[] memory priceIds,
  242. PriceFeedMessage[] memory messages
  243. ) = generateRandomPriceMessages(numMessages);
  244. (bytes[] memory updateData, uint updateFee) = createBatchedUpdateDataFromMessagesWithConfig(
  245. messages,
  246. MerkleUpdateConfig(
  247. MERKLE_TREE_DEPTH,
  248. NUM_GUARDIAN_SIGNERS,
  249. SOURCE_EMITTER_CHAIN_ID,
  250. 0x00000000000000000000000000000000000000000000000000000000000000aa, // Random wrong source address
  251. false
  252. )
  253. );
  254. vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
  255. pyth.parsePriceFeedUpdates{value: updateFee}(
  256. updateData,
  257. priceIds,
  258. 0,
  259. MAX_UINT64
  260. );
  261. }
  262. }