فهرست منبع

Add a Math.inv function that inverse a number in Z/nZ (#4839)

Co-authored-by: ernestognw <ernestognw@gmail.com>
Hadrien Croubois 1 سال پیش
والد
کامیت
e86bb45477
4فایلهای تغییر یافته به همراه139 افزوده شده و 4 حذف شده
  1. 5 0
      .changeset/cool-mangos-compare.md
  2. 61 4
      contracts/utils/math/Math.sol
  3. 35 0
      test/utils/math/Math.t.sol
  4. 38 0
      test/utils/math/Math.test.js

+ 5 - 0
.changeset/cool-mangos-compare.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: add an `invMod` function to get the modular multiplicative inverse of a number in Z/nZ.

+ 61 - 4
contracts/utils/math/Math.sol

@@ -121,9 +121,10 @@ library Math {
     }
 
     /**
-     * @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
+     * @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
      * denominator == 0.
-     * @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
+     *
+     * Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
      * Uniswap Labs also under MIT license.
      */
     function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
@@ -208,7 +209,7 @@ library Math {
     }
 
     /**
-     * @notice Calculates x * y / denominator with full precision, following the selected rounding direction.
+     * @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
      */
     function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
         uint256 result = mulDiv(x, y, denominator);
@@ -218,6 +219,62 @@ library Math {
         return result;
     }
 
+    /**
+     * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
+     *
+     * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0.
+     * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
+     *
+     * If the input value is not inversible, 0 is returned.
+     */
+    function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
+        unchecked {
+            if (n == 0) return 0;
+
+            // The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
+            // Used to compute integers x and y such that: ax + ny = gcd(a, n).
+            // When the gcd is 1, then the inverse of a modulo n exists and it's x.
+            // ax + ny = 1
+            // ax = 1 + (-y)n
+            // ax ≡ 1 (mod n) # x is the inverse of a modulo n
+
+            // If the remainder is 0 the gcd is n right away.
+            uint256 remainder = a % n;
+            uint256 gcd = n;
+
+            // Therefore the initial coefficients are:
+            // ax + ny = gcd(a, n) = n
+            // 0a + 1n = n
+            int256 x = 0;
+            int256 y = 1;
+
+            while (remainder != 0) {
+                uint256 quotient = gcd / remainder;
+
+                (gcd, remainder) = (
+                    // The old remainder is the next gcd to try.
+                    remainder,
+                    // Compute the next remainder.
+                    // Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
+                    // where gcd is at most n (capped to type(uint256).max)
+                    gcd - remainder * quotient
+                );
+
+                (x, y) = (
+                    // Increment the coefficient of a.
+                    y,
+                    // Decrement the coefficient of n.
+                    // Can overflow, but the result is casted to uint256 so that the
+                    // next value of y is "wrapped around" to a value between 0 and n - 1.
+                    x - y * int256(quotient)
+                );
+            }
+
+            if (gcd != 1) return 0; // No inverse exists.
+            return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
+        }
+    }
+
     /**
      * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
      * towards zero.
@@ -258,7 +315,7 @@ library Math {
     }
 
     /**
-     * @notice Calculates sqrt(a), following the selected rounding direction.
+     * @dev Calculates sqrt(a), following the selected rounding direction.
      */
     function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
         unchecked {

+ 35 - 0
test/utils/math/Math.t.sol

@@ -55,6 +55,41 @@ contract MathTest is Test {
         return value * value < ref;
     }
 
+    // INV
+    function testInvMod(uint256 value, uint256 p) public {
+        _testInvMod(value, p, true);
+    }
+
+    function testInvMod2(uint256 seed) public {
+        uint256 p = 2; // prime
+        _testInvMod(bound(seed, 1, p - 1), p, false);
+    }
+
+    function testInvMod17(uint256 seed) public {
+        uint256 p = 17; // prime
+        _testInvMod(bound(seed, 1, p - 1), p, false);
+    }
+
+    function testInvMod65537(uint256 seed) public {
+        uint256 p = 65537; // prime
+        _testInvMod(bound(seed, 1, p - 1), p, false);
+    }
+
+    function testInvModP256(uint256 seed) public {
+        uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime
+        _testInvMod(bound(seed, 1, p - 1), p, false);
+    }
+
+    function _testInvMod(uint256 value, uint256 p, bool allowZero) private {
+        uint256 inverse = Math.invMod(value, p);
+        if (inverse != 0) {
+            assertEq(mulmod(value, inverse, p), 1);
+            assertLt(inverse, p);
+        } else {
+            assertTrue(allowZero);
+        }
+    }
+
     // LOG2
     function testLog2(uint256 input, uint8 r) public {
         Math.Rounding rounding = _asRounding(r);

+ 38 - 0
test/utils/math/Math.test.js

@@ -5,6 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
 
 const { Rounding } = require('../../helpers/enums');
 const { min, max } = require('../../helpers/math');
+const { randomArray, generators } = require('../../helpers/random');
 
 const RoundingDown = [Rounding.Floor, Rounding.Trunc];
 const RoundingUp = [Rounding.Ceil, Rounding.Expand];
@@ -298,6 +299,43 @@ describe('Math', function () {
     });
   });
 
+  describe('invMod', function () {
+    for (const factors of [
+      [0n],
+      [1n],
+      [2n],
+      [17n],
+      [65537n],
+      [0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn],
+      [3n, 5n],
+      [3n, 7n],
+      [47n, 53n],
+    ]) {
+      const p = factors.reduce((acc, f) => acc * f, 1n);
+
+      describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () {
+        it('trying to inverse 0 returns 0', async function () {
+          expect(await this.mock.$invMod(0, p)).to.equal(0n);
+          expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p
+        });
+
+        if (p != 0) {
+          for (const value of randomArray(generators.uint256, 16)) {
+            const isInversible = factors.every(f => value % f);
+            it(`trying to inverse ${value}`, async function () {
+              const result = await this.mock.$invMod(value, p);
+              if (isInversible) {
+                expect((value * result) % p).to.equal(1n);
+              } else {
+                expect(result).to.equal(0n);
+              }
+            });
+          }
+        }
+      });
+    }
+  });
+
   describe('sqrt', function () {
     it('rounds down', async function () {
       for (const rounding of RoundingDown) {