ソースを参照

Add fuzz tests for Math.sqrt & Math.logX using Foundry (#3676)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Francisco Giordano <frangio.1@gmail.com>
Nicolás Venturo 3 年 前
コミット
80ae402387

+ 13 - 0
.github/workflows/checks.yml

@@ -56,6 +56,19 @@ jobs:
         with:
           token: ${{ github.token }}
 
+  foundry-tests:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v3
+        with:
+          submodules: recursive
+      - name: Install Foundry
+        uses: foundry-rs/foundry-toolchain@v1
+        with:
+          version: nightly
+      - name: Run tests
+        run: forge test -vv
+
   coverage:
     if: github.repository != 'OpenZeppelin/openzeppelin-contracts-upgradeable'
     runs-on: ubuntu-latest

+ 3 - 0
.gitmodules

@@ -0,0 +1,3 @@
+[submodule "lib/forge-std"]
+	path = lib/forge-std
+	url = https://github.com/foundry-rs/forge-std

+ 1 - 1
contracts/mocks/MulticallTest.sol

@@ -5,7 +5,7 @@ pragma solidity ^0.8.0;
 import "./MulticallTokenMock.sol";
 
 contract MulticallTest {
-    function testReturnValues(
+    function checkReturnValues(
         MulticallTokenMock multicallToken,
         address[] calldata recipients,
         uint256[] calldata amounts

+ 7 - 0
hardhat/skip-foundry-tests.js

@@ -0,0 +1,7 @@
+const { subtask } = require('hardhat/config');
+const { TASK_COMPILE_SOLIDITY_GET_SOURCE_PATHS } = require('hardhat/builtin-tasks/task-names');
+
+subtask(TASK_COMPILE_SOLIDITY_GET_SOURCE_PATHS)
+  .setAction(async (_, __, runSuper) =>
+    (await runSuper()).filter((path) => !path.endsWith('.t.sol')),
+  );

+ 1 - 0
lib/forge-std

@@ -0,0 +1 @@
+Subproject commit ca8d6e00ea9cb035f6856ff732203c9a3c48b966

+ 1 - 1
test/utils/Multicall.test.js

@@ -31,7 +31,7 @@ contract('MulticallToken', function (accounts) {
     const recipients = [alice, bob];
     const amounts = [amount / 2, amount / 3].map(n => new BN(n));
 
-    await multicallTest.testReturnValues(this.multicallToken.address, recipients, amounts);
+    await multicallTest.checkReturnValues(this.multicallToken.address, recipients, amounts);
   });
 
   it('reverts previous calls', async function () {

+ 118 - 0
test/utils/math/Math.t.sol

@@ -0,0 +1,118 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.0;
+
+import "forge-std/Test.sol";
+
+import "../../../contracts/utils/math/Math.sol";
+import "../../../contracts/utils/math/SafeMath.sol";
+
+contract MathTest is Test {
+    // SQRT
+    function testSqrt(uint256 input, uint8 r) public {
+        Math.Rounding rounding = _asRounding(r);
+
+        uint256 result = Math.sqrt(input, rounding);
+
+        // square of result is bigger than input
+        if (_squareBigger(result, input)) {
+            assertTrue(rounding == Math.Rounding.Up);
+            assertTrue(_squareSmaller(result - 1, input));
+        }
+        // square of result is smaller than input
+        else if (_squareSmaller(result, input)) {
+            assertFalse(rounding == Math.Rounding.Up);
+            assertTrue(_squareBigger(result + 1, input));
+        }
+    }
+
+    function _squareBigger(uint256 value, uint256 ref) private pure returns (bool) {
+        (bool noOverflow, uint256 square) = SafeMath.tryMul(value, value);
+        return !noOverflow || square > ref;
+    }
+
+    function _squareSmaller(uint256 value, uint256 ref) private pure returns (bool) {
+        return value * value < ref;
+    }
+
+    // LOG2
+    function testLog2(uint256 input, uint8 r) public {
+        Math.Rounding rounding = _asRounding(r);
+
+        uint256 result = Math.log2(input, rounding);
+
+        if (input == 0) {
+            assertEq(result, 0);
+        } else if (_powerOf2Bigger(result, input)) {
+            assertTrue(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf2Smaller(result - 1, input));
+        } else if (_powerOf2Smaller(result, input)) {
+            assertFalse(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf2Bigger(result + 1, input));
+        }
+    }
+
+    function _powerOf2Bigger(uint256 value, uint256 ref) private pure returns (bool) {
+        return value >= 256 || 2**value > ref; // 2**256 overflows uint256
+    }
+
+    function _powerOf2Smaller(uint256 value, uint256 ref) private pure returns (bool) {
+        return 2**value < ref;
+    }
+
+    // LOG10
+    function testLog10(uint256 input, uint8 r) public {
+        Math.Rounding rounding = _asRounding(r);
+
+        uint256 result = Math.log10(input, rounding);
+
+        if (input == 0) {
+            assertEq(result, 0);
+        } else if (_powerOf10Bigger(result, input)) {
+            assertTrue(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf10Smaller(result - 1, input));
+        } else if (_powerOf10Smaller(result, input)) {
+            assertFalse(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf10Bigger(result + 1, input));
+        }
+    }
+
+    function _powerOf10Bigger(uint256 value, uint256 ref) private pure returns (bool) {
+        return value >= 78 || 10**value > ref; // 10**78 overflows uint256
+    }
+
+    function _powerOf10Smaller(uint256 value, uint256 ref) private pure returns (bool) {
+        return 10**value < ref;
+    }
+
+    // LOG256
+    function testLog256(uint256 input, uint8 r) public {
+        Math.Rounding rounding = _asRounding(r);
+
+        uint256 result = Math.log256(input, rounding);
+
+        if (input == 0) {
+            assertEq(result, 0);
+        } else if (_powerOf256Bigger(result, input)) {
+            assertTrue(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf256Smaller(result - 1, input));
+        } else if (_powerOf256Smaller(result, input)) {
+            assertFalse(rounding == Math.Rounding.Up);
+            assertTrue(_powerOf256Bigger(result + 1, input));
+        }
+    }
+
+    function _powerOf256Bigger(uint256 value, uint256 ref) private pure returns (bool) {
+        return value >= 32 || 256**value > ref; // 256**32 overflows uint256
+    }
+
+    function _powerOf256Smaller(uint256 value, uint256 ref) private pure returns (bool) {
+        return 256**value < ref;
+    }
+
+    // Helpers
+    function _asRounding(uint8 r) private returns (Math.Rounding) {
+        vm.assume(r < uint8(type(Math.Rounding).max));
+        return Math.Rounding(r);
+    }
+}