Browse Source

Add sqrt for math (#3242)

jjz 3 years ago
parent
commit
3ac4add548
4 changed files with 113 additions and 0 deletions
  1. 1 0
      CHANGELOG.md
  2. 4 0
      contracts/mocks/MathMock.sol
  3. 74 0
      contracts/utils/math/Math.sol
  4. 34 0
      test/utils/math/Math.test.js

+ 1 - 0
CHANGELOG.md

@@ -7,6 +7,7 @@
  * `ERC20FlashMint`: Add customizable flash fee receiver. ([#3327](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3327))
  * `ERC20TokenizedVault`: add an extension of `ERC20` that implements the ERC4626 Tokenized Vault Standard. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171))
  * `Math`: add a `mulDiv` function that can round the result either up or down. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171))
+ * `Math`: Add a `sqrt` function to compute square roots of integers, rounding either up or down. ([#3242](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3242))
  * `Strings`: add a new overloaded function `toHexString` that converts an `address` with fixed length of 20 bytes to its not checksummed ASCII `string` hexadecimal representation. ([#3403](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3403))
  * `EnumerableMap`: add new `UintToUintMap` map type. ([#3338](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3338))
  * `EnumerableMap`: add new `Bytes32ToUintMap` map type. ([#3416](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3416))

+ 4 - 0
contracts/mocks/MathMock.sol

@@ -29,4 +29,8 @@ contract MathMock {
     ) public pure returns (uint256) {
         return Math.mulDiv(a, b, denominator, direction);
     }
+
+    function sqrt(uint256 a, Math.Rounding direction) public pure returns (uint256) {
+        return Math.sqrt(a, direction);
+    }
 }

+ 74 - 0
contracts/utils/math/Math.sol

@@ -149,4 +149,78 @@ library Math {
         }
         return result;
     }
+
+    /**
+     * @dev Returns the square root of a number. It the number is not a perfect square, the value is rounded down.
+     *
+     * Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
+     */
+    function sqrt(uint256 a) internal pure returns (uint256) {
+        if (a == 0) {
+            return 0;
+        }
+
+        // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
+        // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
+        // `msb(a) <= a < 2*msb(a)`.
+        // We also know that `k`, the position of the most significant bit, is such that `msb(a) = 2**k`.
+        // This gives `2**k < a <= 2**(k+1)` → `2**(k/2) <= sqrt(a) < 2 ** (k/2+1)`.
+        // Using an algorithm similar to the msb conmputation, we are able to compute `result = 2**(k/2)` which is a
+        // good first aproximation of `sqrt(a)` with at least 1 correct bit.
+        uint256 result = 1;
+        uint256 x = a;
+        if (x >> 128 > 0) {
+            x >>= 128;
+            result <<= 64;
+        }
+        if (x >> 64 > 0) {
+            x >>= 64;
+            result <<= 32;
+        }
+        if (x >> 32 > 0) {
+            x >>= 32;
+            result <<= 16;
+        }
+        if (x >> 16 > 0) {
+            x >>= 16;
+            result <<= 8;
+        }
+        if (x >> 8 > 0) {
+            x >>= 8;
+            result <<= 4;
+        }
+        if (x >> 4 > 0) {
+            x >>= 4;
+            result <<= 2;
+        }
+        if (x >> 2 > 0) {
+            result <<= 1;
+        }
+
+        // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
+        // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
+        // every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
+        // into the expected uint128 result.
+        unchecked {
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            result = (result + a / result) >> 1;
+            return min(result, a / result);
+        }
+    }
+
+    /**
+     * @notice Calculates sqrt(a), following the selected rounding direction.
+     */
+    function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
+        uint256 result = sqrt(a);
+        if (rounding == Rounding.Up && result * result < a) {
+            result += 1;
+        }
+        return result;
+    }
 }

+ 34 - 0
test/utils/math/Math.test.js

@@ -182,4 +182,38 @@ contract('Math', function (accounts) {
       });
     });
   });
+
+  describe('sqrt', function () {
+    it('rounds down', async function () {
+      expect(await this.math.sqrt(new BN('0'), Rounding.Down)).to.be.bignumber.equal('0');
+      expect(await this.math.sqrt(new BN('1'), Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt(new BN('2'), Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt(new BN('3'), Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt(new BN('4'), Rounding.Down)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt(new BN('144'), Rounding.Down)).to.be.bignumber.equal('12');
+      expect(await this.math.sqrt(new BN('999999'), Rounding.Down)).to.be.bignumber.equal('999');
+      expect(await this.math.sqrt(new BN('1000000'), Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt(new BN('1000001'), Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt(new BN('1002000'), Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt(new BN('1002001'), Rounding.Down)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt(MAX_UINT256, Rounding.Down))
+        .to.be.bignumber.equal('340282366920938463463374607431768211455');
+    });
+
+    it('rounds up', async function () {
+      expect(await this.math.sqrt(new BN('0'), Rounding.Up)).to.be.bignumber.equal('0');
+      expect(await this.math.sqrt(new BN('1'), Rounding.Up)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt(new BN('2'), Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt(new BN('3'), Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt(new BN('4'), Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt(new BN('144'), Rounding.Up)).to.be.bignumber.equal('12');
+      expect(await this.math.sqrt(new BN('999999'), Rounding.Up)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt(new BN('1000000'), Rounding.Up)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt(new BN('1000001'), Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt(new BN('1002000'), Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt(new BN('1002001'), Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt(MAX_UINT256, Rounding.Up))
+        .to.be.bignumber.equal('340282366920938463463374607431768211456');
+    });
+  });
 });