Pyth.Aave.t.sol 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. // SPDX-License-Identifier: Apache 2
  2. pragma solidity ^0.8.0;
  3. import "forge-std/Test.sol";
  4. import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
  5. import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
  6. import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
  7. import "./utils/WormholeTestUtils.t.sol";
  8. import "./utils/PythTestUtils.t.sol";
  9. import "./utils/RandTestUtils.t.sol";
  10. import "../contracts/aave/interfaces/IPriceOracleGetter.sol";
  11. import "../contracts/aave/PythPriceOracleGetter.sol";
  12. import "./Pyth.WormholeMerkleAccumulator.t.sol";
  13. contract PythAaveTest is PythWormholeMerkleAccumulatorTest {
  14. IPriceOracleGetter public pythOracleGetter;
  15. address[] assets;
  16. bytes32[] priceIds;
  17. uint constant NUM_PRICE_FEEDS = 5;
  18. uint256 constant BASE_CURRENCY_UNIT = 1e8;
  19. uint constant VALID_TIME_PERIOD_SECS = 60;
  20. function setUp() public override {
  21. pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
  22. assets = new address[](NUM_PRICE_FEEDS);
  23. PriceFeedMessage[]
  24. memory priceFeedMessages = generateRandomBoundedPriceFeedMessage(
  25. NUM_PRICE_FEEDS
  26. );
  27. priceIds = new bytes32[](NUM_PRICE_FEEDS);
  28. for (uint i = 0; i < NUM_PRICE_FEEDS; i++) {
  29. assets[i] = address(
  30. uint160(uint(keccak256(abi.encodePacked(i + NUM_PRICE_FEEDS))))
  31. );
  32. priceIds[i] = priceFeedMessages[i].priceId;
  33. }
  34. (
  35. bytes[] memory updateData,
  36. uint updateFee
  37. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  38. pyth.updatePriceFeeds{value: updateFee}(updateData);
  39. pythOracleGetter = new PythPriceOracleGetter(
  40. address(pyth),
  41. assets,
  42. priceIds,
  43. address(0x0),
  44. BASE_CURRENCY_UNIT,
  45. VALID_TIME_PERIOD_SECS
  46. );
  47. }
  48. function testConversion(
  49. int64 pythPrice,
  50. int32 pythExpo,
  51. uint256 aavePrice,
  52. uint256 baseCurrencyUnit
  53. ) private {
  54. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  55. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  56. priceId: getRandBytes32(),
  57. price: pythPrice,
  58. conf: getRandUint64(),
  59. expo: pythExpo,
  60. publishTime: uint64(1),
  61. prevPublishTime: getRandUint64(),
  62. emaPrice: getRandInt64(),
  63. emaConf: getRandUint64()
  64. });
  65. priceFeedMessages[0] = priceFeedMessage;
  66. (
  67. bytes[] memory updateData,
  68. uint updateFee
  69. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  70. pyth.updatePriceFeeds{value: updateFee}(updateData);
  71. priceIds = new bytes32[](1);
  72. priceIds[0] = priceFeedMessage.priceId;
  73. assets = new address[](1);
  74. assets[0] = address(
  75. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  76. );
  77. pythOracleGetter = new PythPriceOracleGetter(
  78. address(pyth),
  79. assets,
  80. priceIds,
  81. address(0x0),
  82. baseCurrencyUnit,
  83. VALID_TIME_PERIOD_SECS
  84. );
  85. assertEq(pythOracleGetter.getAssetPrice(assets[0]), aavePrice);
  86. }
  87. function testGetAssetPriceWorks() public {
  88. // "display" price is 529.30903
  89. testConversion(52_930_903, -5, 52_930_903_000, BASE_CURRENCY_UNIT);
  90. }
  91. function testGetAssetPriceWorksWithPositiveExponent() public {
  92. // "display" price is 5_293_000
  93. testConversion(5_293, 3, 529_300_000_000_000, BASE_CURRENCY_UNIT);
  94. }
  95. function testGetAssetPriceWorksWithZeroExponent() public {
  96. // "display" price is 5_293
  97. testConversion(5_293, 0, 529_300_000_000, BASE_CURRENCY_UNIT);
  98. }
  99. function testGetAssetPriceWorksWithNegativeNormalizerExponent() public {
  100. // "display" price is 5_293
  101. testConversion(
  102. 5_293_000_000_000_000,
  103. -12,
  104. 529_300_000_000,
  105. BASE_CURRENCY_UNIT
  106. );
  107. }
  108. function testGetAssetPriceWorksWithBaseCurrencyUnitOfOne() public {
  109. // "display" price is 529.30903
  110. testConversion(52_930_903, -5, 529, 1);
  111. }
  112. function testGetAssetPriceWorksWithBoundedRandomValues(uint seed) public {
  113. setRandSeed(seed);
  114. for (uint i = 0; i < assets.length; i++) {
  115. address asset = assets[i];
  116. uint256 assetPrice = pythOracleGetter.getAssetPrice(asset);
  117. uint256 aavePrice = assetPrice / BASE_CURRENCY_UNIT;
  118. bytes32 priceId = priceIds[i];
  119. PythStructs.Price memory price = pyth.getPriceNoOlderThan(
  120. priceId,
  121. 60
  122. );
  123. int64 pythRawPrice = price.price;
  124. uint pythNormalizer;
  125. uint pythPrice;
  126. if (price.expo < 0) {
  127. pythNormalizer = 10 ** uint32(-price.expo);
  128. pythPrice = uint64(pythRawPrice) / pythNormalizer;
  129. } else {
  130. pythNormalizer = 10 ** uint32(price.expo);
  131. pythPrice = uint64(pythRawPrice) * pythNormalizer;
  132. }
  133. assertEq(aavePrice, pythPrice);
  134. }
  135. }
  136. function testGetAssetPriceWorksIfGivenBaseCurrencyAddress() public {
  137. address usdAddress = address(0x0);
  138. uint256 assetPrice = pythOracleGetter.getAssetPrice(usdAddress);
  139. assertEq(assetPrice, BASE_CURRENCY_UNIT);
  140. }
  141. function testGetAssetRevertsIfPriceNotRecentEnough() public {
  142. uint timestamp = block.timestamp;
  143. vm.warp(timestamp + VALID_TIME_PERIOD_SECS);
  144. for (uint i = 0; i < assets.length; i++) {
  145. pythOracleGetter.getAssetPrice(assets[i]);
  146. }
  147. vm.warp(timestamp + VALID_TIME_PERIOD_SECS + 1);
  148. for (uint i = 0; i < assets.length; i++) {
  149. vm.expectRevert(PythErrors.StalePrice.selector);
  150. pythOracleGetter.getAssetPrice(assets[i]);
  151. }
  152. }
  153. function testGetAssetRevertsIfPriceFeedNotFound() public {
  154. address addr = address(
  155. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  156. );
  157. vm.expectRevert(PythErrors.PriceFeedNotFound.selector);
  158. pythOracleGetter.getAssetPrice(addr);
  159. }
  160. function testGetAssetPriceRevertsIfPriceIsNegative() public {
  161. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  162. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  163. priceId: getRandBytes32(),
  164. price: int64(-5),
  165. conf: getRandUint64(),
  166. expo: getRandInt32(),
  167. publishTime: uint64(1),
  168. prevPublishTime: getRandUint64(),
  169. emaPrice: getRandInt64(),
  170. emaConf: getRandUint64()
  171. });
  172. priceFeedMessages[0] = priceFeedMessage;
  173. (
  174. bytes[] memory updateData,
  175. uint updateFee
  176. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  177. pyth.updatePriceFeeds{value: updateFee}(updateData);
  178. priceIds = new bytes32[](1);
  179. priceIds[0] = priceFeedMessage.priceId;
  180. assets = new address[](1);
  181. assets[0] = address(
  182. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  183. );
  184. pythOracleGetter = new PythPriceOracleGetter(
  185. address(pyth),
  186. assets,
  187. priceIds,
  188. address(0x0),
  189. BASE_CURRENCY_UNIT,
  190. VALID_TIME_PERIOD_SECS
  191. );
  192. vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
  193. pythOracleGetter.getAssetPrice(assets[0]);
  194. }
  195. function testGetAssetPriceRevertsIfNormalizerOverflows() public {
  196. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  197. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  198. priceId: getRandBytes32(),
  199. price: int64(1),
  200. conf: getRandUint64(),
  201. expo: int32(59), // type(uint192).max = ~6.27e58
  202. publishTime: uint64(1),
  203. prevPublishTime: getRandUint64(),
  204. emaPrice: getRandInt64(),
  205. emaConf: getRandUint64()
  206. });
  207. priceFeedMessages[0] = priceFeedMessage;
  208. (
  209. bytes[] memory updateData,
  210. uint updateFee
  211. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  212. pyth.updatePriceFeeds{value: updateFee}(updateData);
  213. priceIds = new bytes32[](1);
  214. priceIds[0] = priceFeedMessage.priceId;
  215. assets = new address[](1);
  216. assets[0] = address(
  217. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  218. );
  219. pythOracleGetter = new PythPriceOracleGetter(
  220. address(pyth),
  221. assets,
  222. priceIds,
  223. address(0x0),
  224. BASE_CURRENCY_UNIT,
  225. VALID_TIME_PERIOD_SECS
  226. );
  227. vm.expectRevert(abi.encodeWithSignature("NormalizationOverflow()"));
  228. pythOracleGetter.getAssetPrice(assets[0]);
  229. }
  230. function testGetAssetPriceRevertsIfNormalizedToZero() public {
  231. PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
  232. PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
  233. priceId: getRandBytes32(),
  234. price: int64(1),
  235. conf: getRandUint64(),
  236. expo: int32(-75),
  237. publishTime: uint64(1),
  238. prevPublishTime: getRandUint64(),
  239. emaPrice: getRandInt64(),
  240. emaConf: getRandUint64()
  241. });
  242. priceFeedMessages[0] = priceFeedMessage;
  243. (
  244. bytes[] memory updateData,
  245. uint updateFee
  246. ) = createWormholeMerkleUpdateData(priceFeedMessages);
  247. pyth.updatePriceFeeds{value: updateFee}(updateData);
  248. priceIds = new bytes32[](1);
  249. priceIds[0] = priceFeedMessage.priceId;
  250. assets = new address[](1);
  251. assets[0] = address(
  252. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  253. );
  254. pythOracleGetter = new PythPriceOracleGetter(
  255. address(pyth),
  256. assets,
  257. priceIds,
  258. address(0x0),
  259. BASE_CURRENCY_UNIT,
  260. VALID_TIME_PERIOD_SECS
  261. );
  262. vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
  263. pythOracleGetter.getAssetPrice(assets[0]);
  264. }
  265. function testPythPriceOracleGetterConstructorRevertsIfAssetsAndPriceIdsLengthAreDifferent()
  266. public
  267. {
  268. priceIds = new bytes32[](2);
  269. priceIds[0] = getRandBytes32();
  270. priceIds[1] = getRandBytes32();
  271. assets = new address[](1);
  272. assets[0] = address(
  273. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  274. );
  275. vm.expectRevert(abi.encodeWithSignature("InconsistentParamsLength()"));
  276. pythOracleGetter = new PythPriceOracleGetter(
  277. address(pyth),
  278. assets,
  279. priceIds,
  280. address(0x0),
  281. BASE_CURRENCY_UNIT,
  282. VALID_TIME_PERIOD_SECS
  283. );
  284. }
  285. function testPythPriceOracleGetterConstructorRevertsIfInvalidBaseCurrencyUnit()
  286. public
  287. {
  288. priceIds = new bytes32[](1);
  289. priceIds[0] = getRandBytes32();
  290. assets = new address[](1);
  291. assets[0] = address(
  292. uint160(uint(keccak256(abi.encodePacked(uint(100)))))
  293. );
  294. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  295. pythOracleGetter = new PythPriceOracleGetter(
  296. address(pyth),
  297. assets,
  298. priceIds,
  299. address(0x0),
  300. 0,
  301. VALID_TIME_PERIOD_SECS
  302. );
  303. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  304. pythOracleGetter = new PythPriceOracleGetter(
  305. address(pyth),
  306. assets,
  307. priceIds,
  308. address(0x0),
  309. 11,
  310. VALID_TIME_PERIOD_SECS
  311. );
  312. vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
  313. pythOracleGetter = new PythPriceOracleGetter(
  314. address(pyth),
  315. assets,
  316. priceIds,
  317. address(0x0),
  318. 20,
  319. VALID_TIME_PERIOD_SECS
  320. );
  321. }
  322. }