Browse Source

Add `Math.modExp` and a `Panic` library (#3298)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
Mihir Wadekar 1 year ago
parent
commit
192e873fcb

+ 5 - 0
.changeset/shiny-poets-whisper.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: Add `modExp` function that exposes the `EIP-198` precompile.

+ 5 - 0
.changeset/silver-swans-promise.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Panic`: Add a library for reverting with panic codes.

+ 5 - 0
.changeset/smart-bugs-switch.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: MathOverflowedMulDiv error was replaced with native panic codes.

+ 40 - 0
contracts/mocks/_import.sol

@@ -0,0 +1,40 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Address} from "../utils/Address.sol";
+import {Arrays} from "../utils/Arrays.sol";
+import {Base64} from "../utils/Base64.sol";
+import {BitMaps} from "../utils/structs/BitMaps.sol";
+import {Checkpoints} from "../utils/structs/Checkpoints.sol";
+import {Context} from "../utils/Context.sol";
+import {Create2} from "../utils/Create2.sol";
+import {DoubleEndedQueue} from "../utils/structs/DoubleEndedQueue.sol";
+import {ECDSA} from "../utils/cryptography/ECDSA.sol";
+import {EIP712} from "../utils/cryptography/EIP712.sol";
+import {EnumerableMap} from "../utils/structs/EnumerableMap.sol";
+import {EnumerableSet} from "../utils/structs/EnumerableSet.sol";
+import {ERC165} from "../utils/introspection/ERC165.sol";
+import {ERC165Checker} from "../utils/introspection/ERC165Checker.sol";
+import {IERC165} from "../utils/introspection/IERC165.sol";
+import {Math} from "../utils/math/Math.sol";
+import {MerkleProof} from "../utils/cryptography/MerkleProof.sol";
+import {MessageHashUtils} from "../utils/cryptography/MessageHashUtils.sol";
+import {Multicall} from "../utils/Multicall.sol";
+import {Nonces} from "../utils/Nonces.sol";
+import {Panic} from "../utils/Panic.sol";
+import {Pausable} from "../utils/Pausable.sol";
+import {ReentrancyGuard} from "../utils/ReentrancyGuard.sol";
+import {SafeCast} from "../utils/math/SafeCast.sol";
+import {ShortStrings} from "../utils/ShortStrings.sol";
+import {SignatureChecker} from "../utils/cryptography/SignatureChecker.sol";
+import {SignedMath} from "../utils/math/SignedMath.sol";
+import {StorageSlot} from "../utils/StorageSlot.sol";
+import {Strings} from "../utils/Strings.sol";
+import {Time} from "../utils/types/Time.sol";
+
+abstract contract ExposeImports {
+    // This will be transpiled, causing all the imports above to be transpiled when running the upgradeable tests.
+    // This trick is necessary for testing libraries such as Panic.sol (which are not imported by any other transpiled
+    // contracts and would otherwise not be exposed).
+}

+ 55 - 0
contracts/utils/Panic.sol

@@ -0,0 +1,55 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+/**
+ * @dev Helper library for emitting standardized panic codes.
+ *
+ * ```solidity
+ * contract Example {
+ *      using Panic for uint256;
+ *
+ *      // Use any of the declared internal constants
+ *      function foo() { Panic.GENERIC.panic(); }
+ *
+ *      // Alternatively
+ *      function foo() { Panic.panic(Panic.GENERIC); }
+ * }
+ * ```
+ *
+ * Follows the list from libsolutil: https://github.com/ethereum/solidity/blob/v0.8.24/libsolutil/ErrorCodes.h
+ */
+// slither-disable-next-line unused-state
+library Panic {
+    /// @dev generic / unspecified error
+    uint256 internal constant GENERIC = 0x00;
+    /// @dev used by the assert() builtin
+    uint256 internal constant ASSERT = 0x01;
+    /// @dev arithmetic underflow or overflow
+    uint256 internal constant UNDER_OVERFLOW = 0x11;
+    /// @dev division or modulo by zero
+    uint256 internal constant DIVISION_BY_ZERO = 0x12;
+    /// @dev enum conversion error
+    uint256 internal constant ENUM_CONVERSION_ERROR = 0x21;
+    /// @dev invalid encoding in storage
+    uint256 internal constant STORAGE_ENCODING_ERROR = 0x22;
+    /// @dev empty array pop
+    uint256 internal constant EMPTY_ARRAY_POP = 0x31;
+    /// @dev array out of bounds access
+    uint256 internal constant ARRAY_OUT_OF_BOUNDS = 0x32;
+    /// @dev resource error (too large allocation or too large array)
+    uint256 internal constant RESOURCE_ERROR = 0x41;
+    /// @dev calling invalid internal function
+    uint256 internal constant INVALID_INTERNAL_FUNCTION = 0x51;
+
+    /// @dev Reverts with a panic code. Recommended to use with
+    /// the internal constants with predefined codes.
+    function panic(uint256 code) internal pure {
+        /// @solidity memory-safe-assembly
+        assembly {
+            mstore(0x00, shl(0xe0, 0x4e487b71))
+            mstore(0x04, code)
+            revert(0x00, 0x24)
+        }
+    }
+}

+ 70 - 7
contracts/utils/math/Math.sol

@@ -3,15 +3,13 @@
 
 pragma solidity ^0.8.20;
 
+import {Address} from "../Address.sol";
+import {Panic} from "../Panic.sol";
+
 /**
  * @dev Standard math utilities missing in the Solidity language.
  */
 library Math {
-    /**
-     * @dev Muldiv operation overflow.
-     */
-    error MathOverflowedMulDiv();
-
     enum Rounding {
         Floor, // Toward negative infinity
         Ceil, // Toward positive infinity
@@ -107,7 +105,7 @@ library Math {
     function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) {
         if (b == 0) {
             // Guarantee the same behavior as in a regular Solidity division.
-            return a / b;
+            Panic.panic(Panic.DIVISION_BY_ZERO);
         }
 
         // The following calculation ensures accurate ceiling division without overflow.
@@ -149,7 +147,7 @@ library Math {
 
             // Make sure the result is less than 2^256. Also prevents denominator == 0.
             if (denominator <= prod1) {
-                revert MathOverflowedMulDiv();
+                Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW);
             }
 
             ///////////////////////////////////////////////
@@ -226,6 +224,9 @@ library Math {
      * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
      *
      * If the input value is not inversible, 0 is returned.
+     *
+     * NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Ferma's little theorem and get the
+     * inverse using `Math.modExp(a, n - 2, n)`.
      */
     function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
         unchecked {
@@ -275,6 +276,68 @@ library Math {
         }
     }
 
+    /**
+     * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m)
+     *
+     * Requirements:
+     * - modulus can't be zero
+     * - underlying staticcall to precompile must succeed
+     *
+     * IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make
+     * sure the chain you're using it on supports the precompiled contract for modular exponentiation
+     * at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise,
+     * the underlying function will succeed given the lack of a revert, but the result may be incorrectly
+     * interpreted as 0.
+     */
+    function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
+        (bool success, uint256 result) = tryModExp(b, e, m);
+        if (!success) {
+            if (m == 0) {
+                Panic.panic(Panic.DIVISION_BY_ZERO);
+            } else {
+                revert Address.FailedInnerCall();
+            }
+        }
+        return result;
+    }
+
+    /**
+     * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m).
+     * It includes a success flag indicating if the operation succeeded. Operation will be marked has failed if trying
+     * to operate modulo 0 or if the underlying precompile reverted.
+     *
+     * IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain
+     * you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in
+     * https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack
+     * of a revert, but the result may be incorrectly interpreted as 0.
+     */
+    function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) {
+        if (m == 0) return (false, 0);
+        /// @solidity memory-safe-assembly
+        assembly {
+            let ptr := mload(0x40)
+            // | Offset    | Content    | Content (Hex)                                                      |
+            // |-----------|------------|--------------------------------------------------------------------|
+            // | 0x00:0x1f | size of b  | 0x0000000000000000000000000000000000000000000000000000000000000020 |
+            // | 0x20:0x3f | size of e  | 0x0000000000000000000000000000000000000000000000000000000000000020 |
+            // | 0x40:0x5f | size of m  | 0x0000000000000000000000000000000000000000000000000000000000000020 |
+            // | 0x60:0x7f | value of b | 0x<.............................................................b> |
+            // | 0x80:0x9f | value of e | 0x<.............................................................e> |
+            // | 0xa0:0xbf | value of m | 0x<.............................................................m> |
+            mstore(ptr, 0x20)
+            mstore(add(ptr, 0x20), 0x20)
+            mstore(add(ptr, 0x40), 0x20)
+            mstore(add(ptr, 0x60), b)
+            mstore(add(ptr, 0x80), e)
+            mstore(add(ptr, 0xa0), m)
+
+            // Given the result < m, it's guaranteed to fit in 32 bytes,
+            // so we can use the memory scratch space located at offset 0.
+            success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20)
+            result := mload(0x00)
+        }
+    }
+
     /**
      * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
      * towards zero.

+ 2 - 2
scripts/generate/templates/Checkpoints.t.js

@@ -26,14 +26,14 @@ function _bound${capitalize(opts.keyTypeName)}(
     ${opts.keyTypeName} x,
     ${opts.keyTypeName} min,
     ${opts.keyTypeName} max
-) internal view returns (${opts.keyTypeName}) {
+) internal pure returns (${opts.keyTypeName}) {
     return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max)));
 }
 
 function _prepareKeys(
     ${opts.keyTypeName}[] memory keys,
     ${opts.keyTypeName} maxSpread
-) internal view {
+) internal pure {
     ${opts.keyTypeName} lastKey = 0;
     for (uint256 i = 0; i < keys.length; ++i) {
         ${opts.keyTypeName} key = _bound${capitalize(opts.keyTypeName)}(keys[i], lastKey, lastKey + maxSpread);

+ 37 - 0
test/utils/Panic.test.js

@@ -0,0 +1,37 @@
+const { ethers } = require('hardhat');
+const { expect } = require('chai');
+const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
+const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
+
+async function fixture() {
+  return { mock: await ethers.deployContract('$Panic') };
+}
+
+describe('Panic', function () {
+  beforeEach(async function () {
+    Object.assign(this, await loadFixture(fixture));
+  });
+
+  for (const [name, code] of Object.entries({
+    GENERIC: 0x0,
+    ASSERT: PANIC_CODES.ASSERTION_ERROR,
+    UNDER_OVERFLOW: PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
+    DIVISION_BY_ZERO: PANIC_CODES.DIVISION_BY_ZERO,
+    ENUM_CONVERSION_ERROR: PANIC_CODES.ENUM_CONVERSION_OUT_OF_BOUNDS,
+    STORAGE_ENCODING_ERROR: PANIC_CODES.INCORRECTLY_ENCODED_STORAGE_BYTE_ARRAY,
+    EMPTY_ARRAY_POP: PANIC_CODES.POP_ON_EMPTY_ARRAY,
+    ARRAY_OUT_OF_BOUNDS: PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS,
+    RESOURCE_ERROR: PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED,
+    INVALID_INTERNAL_FUNCTION: PANIC_CODES.ZERO_INITIALIZED_VARIABLE,
+  })) {
+    describe(`${name} (${ethers.toBeHex(code)})`, function () {
+      it('exposes panic code as constant', async function () {
+        expect(await this.mock.getFunction(`$${name}`)()).to.equal(code);
+      });
+
+      it('reverts with panic when called', async function () {
+        await expect(this.mock.$panic(code)).to.be.revertedWithPanic(code);
+      });
+    });
+  }
+});

+ 36 - 14
test/utils/math/Math.t.sol

@@ -2,7 +2,7 @@
 
 pragma solidity ^0.8.20;
 
-import {Test} from "forge-std/Test.sol";
+import {Test, stdError} from "forge-std/Test.sol";
 
 import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
 
@@ -186,7 +186,7 @@ contract MathTest is Test {
         // Full precision for q * d
         (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
         // Add remainder of x * y / d (computed as rem = (x * y % d))
-        (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d));
+        (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
         uint256 qdRemHi = qdHi + c;
 
         // Full precision check that x * y = q * d + rem
@@ -201,14 +201,42 @@ contract MathTest is Test {
         vm.assume(xyHi >= d);
 
         // we are outside the scope of {testMulDiv}, we expect muldiv to revert
-        try this.muldiv(x, y, d) returns (uint256) {
-            fail();
-        } catch {}
+        vm.expectRevert(d == 0 ? stdError.divisionError : stdError.arithmeticError);
+        Math.mulDiv(x, y, d);
     }
 
-    // External call
-    function muldiv(uint256 x, uint256 y, uint256 d) external pure returns (uint256) {
-        return Math.mulDiv(x, y, d);
+    // MOD EXP
+    function testModExp(uint256 b, uint256 e, uint256 m) public {
+        if (m == 0) {
+            vm.expectRevert(stdError.divisionError);
+        }
+        uint256 result = Math.modExp(b, e, m);
+        assertLt(result, m);
+        assertEq(result, _nativeModExp(b, e, m));
+    }
+
+    function testTryModExp(uint256 b, uint256 e, uint256 m) public {
+        (bool success, uint256 result) = Math.tryModExp(b, e, m);
+        assertEq(success, m != 0);
+        if (success) {
+            assertLt(result, m);
+            assertEq(result, _nativeModExp(b, e, m));
+        } else {
+            assertEq(result, 0);
+        }
+    }
+
+    function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
+        if (m == 1) return 0;
+        uint256 r = 1;
+        while (e > 0) {
+            if (e % 2 > 0) {
+                r = mulmod(r, b, m);
+            }
+            b = mulmod(b, b, m);
+            e >>= 1;
+        }
+        return r;
     }
 
     // Helpers
@@ -217,12 +245,6 @@ contract MathTest is Test {
         return Math.Rounding(r);
     }
 
-    function _mulmod(uint256 x, uint256 y, uint256 z) private pure returns (uint256 r) {
-        assembly {
-            r := mulmod(x, y, z)
-        }
-    }
-
     function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
         (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
         (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);

+ 39 - 4
test/utils/math/Math.test.js

@@ -141,6 +141,24 @@ describe('Math', function () {
     });
   });
 
+  describe('tryModExp', function () {
+    it('is correctly returning true and calculating modulus', async function () {
+      const base = 3n;
+      const exponent = 200n;
+      const modulus = 50n;
+
+      expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
+    });
+
+    it('is correctly returning false when modulus is 0', async function () {
+      const base = 3n;
+      const exponent = 200n;
+      const modulus = 0n;
+
+      expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
+    });
+  });
+
   describe('max', function () {
     it('is correctly detected in both position', async function () {
       await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
@@ -222,7 +240,7 @@ describe('Math', function () {
     });
   });
 
-  describe('muldiv', function () {
+  describe('mulDiv', function () {
     it('divide by 0', async function () {
       const a = 1n;
       const b = 1n;
@@ -234,9 +252,8 @@ describe('Math', function () {
       const a = 5n;
       const b = ethers.MaxUint256;
       const c = 2n;
-      await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithCustomError(
-        this.mock,
-        'MathOverflowedMulDiv',
+      await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithPanic(
+        PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
       );
     });
 
@@ -336,6 +353,24 @@ describe('Math', function () {
     }
   });
 
+  describe('modExp', function () {
+    it('is correctly calculating modulus', async function () {
+      const base = 3n;
+      const exponent = 200n;
+      const modulus = 50n;
+
+      expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
+    });
+
+    it('is correctly reverting when modulus is zero', async function () {
+      const base = 3n;
+      const exponent = 200n;
+      const modulus = 0n;
+
+      await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
+    });
+  });
+
   describe('sqrt', function () {
     it('rounds down', async function () {
       for (const rounding of RoundingDown) {

+ 6 - 6
test/utils/structs/Checkpoints.t.sol

@@ -17,11 +17,11 @@ contract CheckpointsTrace224Test is Test {
     Checkpoints.Trace224 internal _ckpts;
 
     // helpers
-    function _boundUint32(uint32 x, uint32 min, uint32 max) internal view returns (uint32) {
+    function _boundUint32(uint32 x, uint32 min, uint32 max) internal pure returns (uint32) {
         return SafeCast.toUint32(bound(uint256(x), uint256(min), uint256(max)));
     }
 
-    function _prepareKeys(uint32[] memory keys, uint32 maxSpread) internal view {
+    function _prepareKeys(uint32[] memory keys, uint32 maxSpread) internal pure {
         uint32 lastKey = 0;
         for (uint256 i = 0; i < keys.length; ++i) {
             uint32 key = _boundUint32(keys[i], lastKey, lastKey + maxSpread);
@@ -125,11 +125,11 @@ contract CheckpointsTrace208Test is Test {
     Checkpoints.Trace208 internal _ckpts;
 
     // helpers
-    function _boundUint48(uint48 x, uint48 min, uint48 max) internal view returns (uint48) {
+    function _boundUint48(uint48 x, uint48 min, uint48 max) internal pure returns (uint48) {
         return SafeCast.toUint48(bound(uint256(x), uint256(min), uint256(max)));
     }
 
-    function _prepareKeys(uint48[] memory keys, uint48 maxSpread) internal view {
+    function _prepareKeys(uint48[] memory keys, uint48 maxSpread) internal pure {
         uint48 lastKey = 0;
         for (uint256 i = 0; i < keys.length; ++i) {
             uint48 key = _boundUint48(keys[i], lastKey, lastKey + maxSpread);
@@ -233,11 +233,11 @@ contract CheckpointsTrace160Test is Test {
     Checkpoints.Trace160 internal _ckpts;
 
     // helpers
-    function _boundUint96(uint96 x, uint96 min, uint96 max) internal view returns (uint96) {
+    function _boundUint96(uint96 x, uint96 min, uint96 max) internal pure returns (uint96) {
         return SafeCast.toUint96(bound(uint256(x), uint256(min), uint256(max)));
     }
 
-    function _prepareKeys(uint96[] memory keys, uint96 maxSpread) internal view {
+    function _prepareKeys(uint96[] memory keys, uint96 maxSpread) internal pure {
         uint96 lastKey = 0;
         for (uint256 i = 0; i < keys.length; ++i) {
             uint96 key = _boundUint96(keys[i], lastKey, lastKey + maxSpread);