Pyth.Aave.t.sol 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "forge-std/Test.sol";
  4. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  5. import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
  6. import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
  7. import "./utils/WormholeTestUtils.t.sol";
  8. import "./utils/PythTestUtils.t.sol";
  9. import "./utils/RandTestUtils.t.sol";
  10. import "../contracts/aave/interfaces/IPriceOracleGetter.sol";
  11. import "../contracts/aave/PythPriceOracleGetter.sol";
  12. import "./Pyth.WormholeMerkleAccumulator.t.sol";
  13. contract PythAaveTest is PythWormholeMerkleAccumulatorTest {
  14. IPriceOracleGetter public pythOracleGetter;
  15. address[] assets;
  16. bytes32[] priceIds;
  17. uint constant NUM_PRICE_FEEDS = 5;
  18. uint256 constant BASE_CURRENCY_UNIT = 1e8;
  19. uint constant VALID_TIME_PERIOD_SECS = 60;
  20. function setUp() public override {
  21. pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
  22. assets = new address[](NUM_PRICE_FEEDS);
  23. PriceFeedMessage[]
  24. memory priceFeedMessages = generateRandomBoundedPriceFeedMessage(
  25. NUM_PRICE_FEEDS
  26. );
  27. priceIds = new bytes32[](NUM_PRICE_FEEDS);
  28. for (uint i = 0; i < NUM_PRICE_FEEDS; i++) {
  29. assets[i] = address(
  30. uint160(uint(keccak256(abi.encodePacked(i + NUM_PRICE_FEEDS))))
  31. );
  32. priceIds[i] = priceFeedMessages[i].priceId;
  33. }
  34. (
  35. bytes[] memory updateData,
  36. uint updateFee
  37. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  38. pyth.updatePriceFeeds{value: updateFee}(updateData);
  39. pythOracleGetter = new PythPriceOracleGetter(
  40. address(pyth),
  41. assets,
  42. priceIds,
  43. address(0x0),
  44. BASE_CURRENCY_UNIT,
  45. VALID_TIME_PERIOD_SECS
  46. );
  47. }
  48. function testConversion(
  49. int64 pythPrice,
  50. int32 pythExpo,
  51. uint256 aavePrice,
  52. uint256 baseCurrencyUnit
  53. ) private {
  54. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  55. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  56. priceId: getRandBytes32(),
  57. price: pythPrice,
  58. conf: getRandUint64(),
  59. expo: pythExpo,
  60. publishTime: uint64(1),
  61. prevPublishTime: getRandUint64(),
  62. emaPrice: getRandInt64(),
  63. emaConf: getRandUint64()
  64. });
  65. priceFeedMessages[0] = priceFeedMessage;
  66. (
  67. bytes[] memory updateData,
  68. uint updateFee
  69. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  70. pyth.updatePriceFeeds{value: updateFee}(updateData);
  71. priceIds = new bytes32[](1);
  72. priceIds[0] = priceFeedMessage.priceId;
  73. assets = new address[](1);
  74. assets[0] = address(
  75. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  76. );
  77. pythOracleGetter = new PythPriceOracleGetter(
  78. address(pyth),
  79. assets,
  80. priceIds,
  81. address(0x0),
  82. baseCurrencyUnit,
  83. VALID_TIME_PERIOD_SECS
  84. );
  85. assertEq(pythOracleGetter.getAssetPrice(assets[0]), aavePrice);
  86. }
  87. function testGetAssetPriceWorks() public {
  88. // "display" price is 529.30903
  89. testConversion(52_930_903, -5, 52_930_903_000, BASE_CURRENCY_UNIT);
  90. }
  91. function testGetAssetPriceWorksWithPositiveExponent() public {
  92. // "display" price is 5_293_000
  93. testConversion(5_293, 3, 529_300_000_000_000, BASE_CURRENCY_UNIT);
  94. }
  95. function testGetAssetPriceWorksWithZeroExponent() public {
  96. // "display" price is 5_293
  97. testConversion(5_293, 0, 529_300_000_000, BASE_CURRENCY_UNIT);
  98. }
  99. function testGetAssetPriceWorksWithNegativeNormalizerExponent() public {
  100. // "display" price is 5_293
  101. testConversion(
  102. 5_293_000_000_000_000,
  103. -12,
  104. 529_300_000_000,
  105. BASE_CURRENCY_UNIT
  106. );
  107. }
  108. function testGetAssetPriceWorksWithBaseCurrencyUnitOfOne() public {
  109. // "display" price is 529.30903
  110. testConversion(52_930_903, -5, 529, 1);
  111. }
  112. function testGetAssetPriceWorksWithBoundedRandomValues(uint seed) public {
  113. setRandSeed(seed);
  114. for (uint i = 0; i < assets.length; i++) {
  115. address asset = assets[i];
  116. uint256 assetPrice = pythOracleGetter.getAssetPrice(asset);
  117. uint256 aavePrice = assetPrice / BASE_CURRENCY_UNIT;
  118. bytes32 priceId = priceIds[i];
  119. PythStructs.Price memory price = pyth.getPrice(priceId);
  120. int64 pythRawPrice = price.price;
  121. uint pythNormalizer;
  122. uint pythPrice;
  123. if (price.expo < 0) {
  124. pythNormalizer = 10 ** uint32(-price.expo);
  125. pythPrice = uint64(pythRawPrice) / pythNormalizer;
  126. } else {
  127. pythNormalizer = 10 ** uint32(price.expo);
  128. pythPrice = uint64(pythRawPrice) * pythNormalizer;
  129. }
  130. assertEq(aavePrice, pythPrice);
  131. }
  132. }
  133. function testGetAssetPriceWorksIfGivenBaseCurrencyAddress() public {
  134. address usdAddress = address(0x0);
  135. uint256 assetPrice = pythOracleGetter.getAssetPrice(usdAddress);
  136. assertEq(assetPrice, BASE_CURRENCY_UNIT);
  137. }
  138. function testGetAssetRevertsIfPriceNotRecentEnough() public {
  139. uint timestamp = block.timestamp;
  140. vm.warp(timestamp + VALID_TIME_PERIOD_SECS);
  141. for (uint i = 0; i < assets.length; i++) {
  142. pythOracleGetter.getAssetPrice(assets[i]);
  143. }
  144. vm.warp(timestamp + VALID_TIME_PERIOD_SECS + 1);
  145. for (uint i = 0; i < assets.length; i++) {
  146. vm.expectRevert(PythErrors.StalePrice.selector);
  147. pythOracleGetter.getAssetPrice(assets[i]);
  148. }
  149. }
  150. function testGetAssetRevertsIfPriceFeedNotFound() public {
  151. address addr = address(
  152. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  153. );
  154. vm.expectRevert(PythErrors.PriceFeedNotFound.selector);
  155. pythOracleGetter.getAssetPrice(addr);
  156. }
  157. function testGetAssetPriceRevertsIfPriceIsNegative() public {
  158. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  159. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  160. priceId: getRandBytes32(),
  161. price: int64(-5),
  162. conf: getRandUint64(),
  163. expo: getRandInt32(),
  164. publishTime: uint64(1),
  165. prevPublishTime: getRandUint64(),
  166. emaPrice: getRandInt64(),
  167. emaConf: getRandUint64()
  168. });
  169. priceFeedMessages[0] = priceFeedMessage;
  170. (
  171. bytes[] memory updateData,
  172. uint updateFee
  173. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  174. pyth.updatePriceFeeds{value: updateFee}(updateData);
  175. priceIds = new bytes32[](1);
  176. priceIds[0] = priceFeedMessage.priceId;
  177. assets = new address[](1);
  178. assets[0] = address(
  179. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  180. );
  181. pythOracleGetter = new PythPriceOracleGetter(
  182. address(pyth),
  183. assets,
  184. priceIds,
  185. address(0x0),
  186. BASE_CURRENCY_UNIT,
  187. VALID_TIME_PERIOD_SECS
  188. );
  189. vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
  190. pythOracleGetter.getAssetPrice(assets[0]);
  191. }
  192. function testGetAssetPriceRevertsIfNormalizerOverflows() public {
  193. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  194. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  195. priceId: getRandBytes32(),
  196. price: int64(1),
  197. conf: getRandUint64(),
  198. expo: int32(59), // type(uint192).max = ~6.27e58
  199. publishTime: uint64(1),
  200. prevPublishTime: getRandUint64(),
  201. emaPrice: getRandInt64(),
  202. emaConf: getRandUint64()
  203. });
  204. priceFeedMessages[0] = priceFeedMessage;
  205. (
  206. bytes[] memory updateData,
  207. uint updateFee
  208. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  209. pyth.updatePriceFeeds{value: updateFee}(updateData);
  210. priceIds = new bytes32[](1);
  211. priceIds[0] = priceFeedMessage.priceId;
  212. assets = new address[](1);
  213. assets[0] = address(
  214. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  215. );
  216. pythOracleGetter = new PythPriceOracleGetter(
  217. address(pyth),
  218. assets,
  219. priceIds,
  220. address(0x0),
  221. BASE_CURRENCY_UNIT,
  222. VALID_TIME_PERIOD_SECS
  223. );
  224. vm.expectRevert(abi.encodeWithSignature("NormalizationOverflow()"));
  225. pythOracleGetter.getAssetPrice(assets[0]);
  226. }
  227. function testGetAssetPriceRevertsIfNormalizedToZero() public {
  228. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  229. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  230. priceId: getRandBytes32(),
  231. price: int64(1),
  232. conf: getRandUint64(),
  233. expo: int32(-75),
  234. publishTime: uint64(1),
  235. prevPublishTime: getRandUint64(),
  236. emaPrice: getRandInt64(),
  237. emaConf: getRandUint64()
  238. });
  239. priceFeedMessages[0] = priceFeedMessage;
  240. (
  241. bytes[] memory updateData,
  242. uint updateFee
  243. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  244. pyth.updatePriceFeeds{value: updateFee}(updateData);
  245. priceIds = new bytes32[](1);
  246. priceIds[0] = priceFeedMessage.priceId;
  247. assets = new address[](1);
  248. assets[0] = address(
  249. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  250. );
  251. pythOracleGetter = new PythPriceOracleGetter(
  252. address(pyth),
  253. assets,
  254. priceIds,
  255. address(0x0),
  256. BASE_CURRENCY_UNIT,
  257. VALID_TIME_PERIOD_SECS
  258. );
  259. vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
  260. pythOracleGetter.getAssetPrice(assets[0]);
  261. }
  262. function testPythPriceOracleGetterConstructorRevertsIfAssetsAndPriceIdsLengthAreDifferent()
  263. public
  264. {
  265. priceIds = new bytes32[](2);
  266. priceIds[0] = getRandBytes32();
  267. priceIds[1] = getRandBytes32();
  268. assets = new address[](1);
  269. assets[0] = address(
  270. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  271. );
  272. vm.expectRevert(abi.encodeWithSignature("InconsistentParamsLength()"));
  273. pythOracleGetter = new PythPriceOracleGetter(
  274. address(pyth),
  275. assets,
  276. priceIds,
  277. address(0x0),
  278. BASE_CURRENCY_UNIT,
  279. VALID_TIME_PERIOD_SECS
  280. );
  281. }
  282. function testPythPriceOracleGetterConstructorRevertsIfInvalidBaseCurrencyUnit()
  283. public
  284. {
  285. priceIds = new bytes32[](1);
  286. priceIds[0] = getRandBytes32();
  287. assets = new address[](1);
  288. assets[0] = address(
  289. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  290. );
  291. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  292. pythOracleGetter = new PythPriceOracleGetter(
  293. address(pyth),
  294. assets,
  295. priceIds,
  296. address(0x0),
  297. 0,
  298. VALID_TIME_PERIOD_SECS
  299. );
  300. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  301. pythOracleGetter = new PythPriceOracleGetter(
  302. address(pyth),
  303. assets,
  304. priceIds,
  305. address(0x0),
  306. 11,
  307. VALID_TIME_PERIOD_SECS
  308. );
  309. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  310. pythOracleGetter = new PythPriceOracleGetter(
  311. address(pyth),
  312. assets,
  313. priceIds,
  314. address(0x0),
  315. 20,
  316. VALID_TIME_PERIOD_SECS
  317. );
  318. }
  319. }