Browse Source

Add fuzz testing of mulDiv (#3717)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Francisco 3 years ago
parent
commit
c08c6e1b84
3 changed files with 109 additions and 2 deletions
  1. 2 0
      foundry.toml
  2. 2 2
      package.json
  3. 105 0
      test/utils/math/Math.t.sol

+ 2 - 0
foundry.toml

@@ -0,0 +1,2 @@
+[fuzz]
+runs = 10000

+ 2 - 2
package.json

@@ -20,8 +20,8 @@
     "lint:fix": "npm run lint:js:fix && npm run lint:sol:fix",
     "lint:js": "eslint --ignore-path .gitignore .",
     "lint:js:fix": "eslint --ignore-path .gitignore . --fix",
-    "lint:sol": "solhint 'contracts/**/*.sol' && prettier -c 'contracts/**/*.sol'",
-    "lint:sol:fix": "prettier --write \"contracts/**/*.sol\"",
+    "lint:sol": "solhint '{contracts,test}/**/*.sol' && prettier -c '{contracts,test}/**/*.sol'",
+    "lint:sol:fix": "prettier --write '{contracts,test}/**/*.sol'",
     "clean": "hardhat clean && rimraf build contracts/build",
     "prepare": "scripts/prepare.sh",
     "prepack": "scripts/prepack.sh",

+ 105 - 0
test/utils/math/Math.t.sol

@@ -8,6 +8,22 @@ import "../../../contracts/utils/math/Math.sol";
 import "../../../contracts/utils/math/SafeMath.sol";
 
 contract MathTest is Test {
+    // CEILDIV
+    function testCeilDiv(uint256 a, uint256 b) public {
+        vm.assume(b > 0);
+
+        uint256 result = Math.ceilDiv(a, b);
+
+        if (result == 0) {
+            assertEq(a, 0);
+        } else {
+            uint256 maxdiv = UINT256_MAX / b;
+            bool overflow = maxdiv * b < a;
+            assertTrue(a > b * (result - 1));
+            assertTrue(overflow ? result == maxdiv + 1 : a <= b * result);
+        }
+    }
+
     // SQRT
     function testSqrt(uint256 input, uint8 r) public {
         Math.Rounding rounding = _asRounding(r);
@@ -120,9 +136,98 @@ contract MathTest is Test {
         return 256**value < ref;
     }
 
+    // MULDIV
+    function testMulDiv(
+        uint256 x,
+        uint256 y,
+        uint256 d
+    ) public {
+        // Full precision for x * y
+        (uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y);
+
+        // Assume result won't overflow (see {testMulDivDomain})
+        // This also checks that `d` is positive
+        vm.assume(xyHi < d);
+
+        // Perform muldiv
+        uint256 q = Math.mulDiv(x, y, d);
+
+        // Full precision for q * d
+        (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
+        // Add reminder of x * y / d (computed as rem = (x * y % d))
+        (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d));
+        uint256 qdRemHi = qdHi + c;
+
+        // Full precision check that x * y = q * d + rem
+        assertEq(xyHi, qdRemHi);
+        assertEq(xyLo, qdRemLo);
+    }
+
+    function testMulDivDomain(
+        uint256 x,
+        uint256 y,
+        uint256 d
+    ) public {
+        (uint256 xyHi, ) = _mulHighLow(x, y);
+
+        // Violate {testMulDiv} assumption (covers d is 0 and result overflow)
+        vm.assume(xyHi >= d);
+
+        // we are outside the scope of {testMulDiv}, we expect muldiv to revert
+        try this.muldiv(x, y, d) returns (uint256) {
+            fail();
+        } catch {}
+    }
+
+    // External call
+    function muldiv(
+        uint256 x,
+        uint256 y,
+        uint256 d
+    ) external pure returns (uint256) {
+        return Math.mulDiv(x, y, d);
+    }
+
     // Helpers
     function _asRounding(uint8 r) private returns (Math.Rounding) {
         vm.assume(r < uint8(type(Math.Rounding).max));
         return Math.Rounding(r);
     }
+
+    function _mulmod(
+        uint256 x,
+        uint256 y,
+        uint256 z
+    ) private pure returns (uint256 r) {
+        assembly {
+            r := mulmod(x, y, z)
+        }
+    }
+
+    function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
+        (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
+        (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
+
+        // Karatsuba algorithm
+        // https://en.wikipedia.org/wiki/Karatsuba_algorithm
+        uint256 z2 = x1 * y1;
+        uint256 z1a = x1 * y0;
+        uint256 z1b = x0 * y1;
+        uint256 z0 = x0 * y0;
+
+        uint256 carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128;
+
+        high = z2 + (z1a >> 128) + (z1b >> 128) + carry;
+
+        unchecked {
+            low = x * y;
+        }
+    }
+
+    function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) {
+        unchecked {
+            res = x + y;
+        }
+        carry = res < x ? 1 : 0;
+    }
 }