Эх сурвалжийг харах

Add `bytes memory` version of `Math.modExp` (#4893)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Ernesto García 1 жил өмнө
parent
commit
4e7e6e54da

+ 1 - 1
.changeset/shiny-poets-whisper.md

@@ -2,4 +2,4 @@
 'openzeppelin-solidity': minor
 ---
 
-`Math`: Add `modExp` function that exposes the `EIP-198` precompile.
+`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions.

+ 52 - 6
contracts/utils/math/Math.sol

@@ -3,7 +3,6 @@
 
 pragma solidity ^0.8.20;
 
-import {Address} from "../Address.sol";
 import {Panic} from "../Panic.sol";
 import {SafeCast} from "./SafeCast.sol";
 
@@ -289,11 +288,7 @@ library Math {
     function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
         (bool success, uint256 result) = tryModExp(b, e, m);
         if (!success) {
-            if (m == 0) {
-                Panic.panic(Panic.DIVISION_BY_ZERO);
-            } else {
-                revert Address.FailedInnerCall();
-            }
+            Panic.panic(Panic.DIVISION_BY_ZERO);
         }
         return result;
     }
@@ -335,6 +330,57 @@ library Math {
         }
     }
 
+    /**
+     * @dev Variant of {modExp} that supports inputs of arbitrary length.
+     */
+    function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
+        (bool success, bytes memory result) = tryModExp(b, e, m);
+        if (!success) {
+            Panic.panic(Panic.DIVISION_BY_ZERO);
+        }
+        return result;
+    }
+
+    /**
+     * @dev Variant of {tryModExp} that supports inputs of arbitrary length.
+     */
+    function tryModExp(
+        bytes memory b,
+        bytes memory e,
+        bytes memory m
+    ) internal view returns (bool success, bytes memory result) {
+        if (_zeroBytes(m)) return (false, new bytes(0));
+
+        uint256 mLen = m.length;
+
+        // Encode call args in result and move the free memory pointer
+        result = abi.encodePacked(b.length, e.length, mLen, b, e, m);
+
+        /// @solidity memory-safe-assembly
+        assembly {
+            let dataPtr := add(result, 0x20)
+            // Write result on top of args to avoid allocating extra memory.
+            success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
+            // Overwrite the length.
+            // result.length > returndatasize() is guaranteed because returndatasize() == m.length
+            mstore(result, mLen)
+            // Set the memory pointer after the returned data.
+            mstore(0x40, add(dataPtr, mLen))
+        }
+    }
+
+    /**
+     * @dev Returns whether the provided byte array is zero.
+     */
+    function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
+        for (uint256 i = 0; i < byteArray.length; ++i) {
+            if (byteArray[i] != 0) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     /**
      * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
      * towards zero.

+ 23 - 0
test/helpers/math.js

@@ -3,8 +3,31 @@ const max = (...values) => values.slice(1).reduce((x, y) => (x > y ? x : y), val
 const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0));
 const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0));
 
+// Computes modexp without BigInt overflow for large numbers
+function modExp(b, e, m) {
+  let result = 1n;
+
+  // If e is a power of two, modexp can be calculated as:
+  // for (let result = b, i = 0; i < log2(e); i++) result = modexp(result, 2, m)
+  //
+  // Given any natural number can be written in terms of powers of 2 (i.e. binary)
+  // then modexp can be calculated for any e, by multiplying b**i for all i where
+  // binary(e)[i] is 1 (i.e. a power of two).
+  for (let base = b % m; e > 0n; base = base ** 2n % m) {
+    // Least significant bit is 1
+    if (e % 2n == 1n) {
+      result = (result * base) % m;
+    }
+
+    e /= 2n; // Binary pop
+  }
+
+  return result;
+}
+
 module.exports = {
   min,
   max,
   sum,
+  modExp,
 };

+ 27 - 0
test/utils/math/Math.t.sol

@@ -226,6 +226,33 @@ contract MathTest is Test {
         }
     }
 
+    function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
+        if (m == 0) {
+            vm.expectRevert(stdError.divisionError);
+        }
+        bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
+        assertEq(result.length, 0x20);
+        uint256 res = abi.decode(result, (uint256));
+        assertLt(res, m);
+        assertEq(res, _nativeModExp(b, e, m));
+    }
+
+    function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
+        (bool success, bytes memory result) = Math.tryModExp(
+            abi.encodePacked(b),
+            abi.encodePacked(e),
+            abi.encodePacked(m)
+        );
+        if (success) {
+            assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20
+            uint256 res = abi.decode(result, (uint256));
+            assertLt(res, m);
+            assertEq(res, _nativeModExp(b, e, m));
+        } else {
+            assertEq(result.length, 0);
+        }
+    }
+
     function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
         if (m == 1) return 0;
         uint256 r = 1;

+ 77 - 29
test/utils/math/Math.test.js

@@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
 
 const { Rounding } = require('../../helpers/enums');
-const { min, max } = require('../../helpers/math');
+const { min, max, modExp } = require('../../helpers/math');
 const { generators } = require('../../helpers/random');
+const { range } = require('../../../scripts/helpers');
+const { product } = require('../../helpers/iterate');
 
 const RoundingDown = [Rounding.Floor, Rounding.Trunc];
 const RoundingUp = [Rounding.Ceil, Rounding.Expand];
 
+const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width));
+const uint256 = value => ethers.Typed.uint256(value);
+bytes.zero = '0x';
+uint256.zero = 0n;
+
 async function testCommutative(fn, lhs, rhs, expected, ...extra) {
   expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
   expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
@@ -141,24 +148,6 @@ describe('Math', function () {
     });
   });
 
-  describe('tryModExp', function () {
-    it('is correctly returning true and calculating modulus', async function () {
-      const base = 3n;
-      const exponent = 200n;
-      const modulus = 50n;
-
-      expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
-    });
-
-    it('is correctly returning false when modulus is 0', async function () {
-      const base = 3n;
-      const exponent = 200n;
-      const modulus = 0n;
-
-      expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
-    });
-  });
-
   describe('max', function () {
     it('is correctly detected in both position', async function () {
       await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
@@ -354,20 +343,79 @@ describe('Math', function () {
   });
 
   describe('modExp', function () {
-    it('is correctly calculating modulus', async function () {
-      const base = 3n;
-      const exponent = 200n;
-      const modulus = 50n;
+    for (const [name, type] of Object.entries({ uint256, bytes })) {
+      describe(`with ${name} inputs`, function () {
+        it('is correctly calculating modulus', async function () {
+          const b = 3n;
+          const e = 200n;
+          const m = 50n;
+
+          expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
+        });
 
-      expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
+        it('is correctly reverting when modulus is zero', async function () {
+          const b = 3n;
+          const e = 200n;
+          const m = 0n;
+
+          await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic(
+            PANIC_CODES.DIVISION_BY_ZERO,
+          );
+        });
+      });
+    }
+
+    describe('with large bytes inputs', function () {
+      for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
+        range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+        range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+        range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+      )) {
+        it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
+          const mLength = ethers.dataLength(ethers.toBeHex(m));
+
+          expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
+        });
+      }
     });
+  });
+
+  describe('tryModExp', function () {
+    for (const [name, type] of Object.entries({ uint256, bytes })) {
+      describe(`with ${name} inputs`, function () {
+        it('is correctly calculating modulus', async function () {
+          const b = 3n;
+          const e = 200n;
+          const m = 50n;
+
+          expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
+        });
 
-    it('is correctly reverting when modulus is zero', async function () {
-      const base = 3n;
-      const exponent = 200n;
-      const modulus = 0n;
+        it('is correctly reverting when modulus is zero', async function () {
+          const b = 3n;
+          const e = 200n;
+          const m = 0n;
 
-      await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
+          expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
+        });
+      });
+    }
+
+    describe('with large bytes inputs', function () {
+      for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
+        range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+        range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+        range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
+      )) {
+        it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
+          const mLength = ethers.dataLength(ethers.toBeHex(m));
+
+          expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
+            true,
+            bytes(modExp(b, e, m), mLength).value,
+          ]);
+        });
+      }
     });
   });