Browse Source

Add log2, log10 and log256 functions (#3670)

Hadrien Croubois 3 years ago
parent
commit
c1d6e39aab
5 changed files with 288 additions and 119 deletions
  1. 1 0
      CHANGELOG.md
  2. 12 0
      contracts/mocks/MathMock.sol
  3. 4 60
      contracts/utils/Strings.sol
  4. 156 37
      contracts/utils/math/Math.sol
  5. 115 22
      test/utils/math/Math.test.js

+ 1 - 0
CHANGELOG.md

@@ -31,6 +31,7 @@
  * `Strings`: optimize `toString`. ([#3573](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3573))
  * `Ownable2Step`: extension of `Ownable` that makes the ownership transfers a two step process. ([#3620](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3620))
  * `Math` and `SignedMath`: optimize function `max` by using `>` instead of `>=`. ([#3679](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3679))
+ * `Math`: Add `log2`, `log10` and `log256`. ([#3670](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3670))
 
 ### Breaking changes
 

+ 12 - 0
contracts/mocks/MathMock.sol

@@ -33,4 +33,16 @@ contract MathMock {
     function sqrt(uint256 a, Math.Rounding direction) public pure returns (uint256) {
         return Math.sqrt(a, direction);
     }
+
+    function log2(uint256 a, Math.Rounding direction) public pure returns (uint256) {
+        return Math.log2(a, direction);
+    }
+
+    function log10(uint256 a, Math.Rounding direction) public pure returns (uint256) {
+        return Math.log10(a, direction);
+    }
+
+    function log256(uint256 a, Math.Rounding direction) public pure returns (uint256) {
+        return Math.log256(a, direction);
+    }
 }

+ 4 - 60
contracts/utils/Strings.sol

@@ -3,6 +3,8 @@
 
 pragma solidity ^0.8.0;
 
+import "./math/Math.sol";
+
 /**
  * @dev String operations.
  */
@@ -15,39 +17,7 @@ library Strings {
      */
     function toString(uint256 value) internal pure returns (string memory) {
         unchecked {
-            uint256 length = 1;
-
-            // compute log10(value), and add it to length
-            uint256 valueCopy = value;
-            if (valueCopy >= 10**64) {
-                valueCopy /= 10**64;
-                length += 64;
-            }
-            if (valueCopy >= 10**32) {
-                valueCopy /= 10**32;
-                length += 32;
-            }
-            if (valueCopy >= 10**16) {
-                valueCopy /= 10**16;
-                length += 16;
-            }
-            if (valueCopy >= 10**8) {
-                valueCopy /= 10**8;
-                length += 8;
-            }
-            if (valueCopy >= 10**4) {
-                valueCopy /= 10**4;
-                length += 4;
-            }
-            if (valueCopy >= 10**2) {
-                valueCopy /= 10**2;
-                length += 2;
-            }
-            if (valueCopy >= 10**1) {
-                length += 1;
-            }
-            // now, length is log10(value) + 1
-
+            uint256 length = Math.log10(value) + 1;
             string memory buffer = new string(length);
             uint256 ptr;
             /// @solidity memory-safe-assembly
@@ -72,33 +42,7 @@ library Strings {
      */
     function toHexString(uint256 value) internal pure returns (string memory) {
         unchecked {
-            uint256 length = 1;
-
-            // compute log256(value), and add it to length
-            uint256 valueCopy = value;
-            if (valueCopy >= 1 << 128) {
-                valueCopy >>= 128;
-                length += 16;
-            }
-            if (valueCopy >= 1 << 64) {
-                valueCopy >>= 64;
-                length += 8;
-            }
-            if (valueCopy >= 1 << 32) {
-                valueCopy >>= 32;
-                length += 4;
-            }
-            if (valueCopy >= 1 << 16) {
-                valueCopy >>= 16;
-                length += 2;
-            }
-            if (valueCopy >= 1 << 8) {
-                valueCopy >>= 8;
-                length += 1;
-            }
-            // now, length is log256(value) + 1
-
-            return toHexString(value, length);
+            return toHexString(value, Math.log256(value) + 1);
         }
     }
 

+ 156 - 37
contracts/utils/math/Math.sol

@@ -161,41 +161,16 @@ library Math {
         }
 
         // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
+        //
         // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
-        // `msb(a) <= a < 2*msb(a)`.
-        // We also know that `k`, the position of the most significant bit, is such that `msb(a) = 2**k`.
-        // This gives `2**k < a <= 2**(k+1)` → `2**(k/2) <= sqrt(a) < 2 ** (k/2+1)`.
-        // Using an algorithm similar to the msb computation, we are able to compute `result = 2**(k/2)` which is a
-        // good first approximation of `sqrt(a)` with at least 1 correct bit.
-        uint256 result = 1;
-        uint256 x = a;
-        if (x >> 128 > 0) {
-            x >>= 128;
-            result <<= 64;
-        }
-        if (x >> 64 > 0) {
-            x >>= 64;
-            result <<= 32;
-        }
-        if (x >> 32 > 0) {
-            x >>= 32;
-            result <<= 16;
-        }
-        if (x >> 16 > 0) {
-            x >>= 16;
-            result <<= 8;
-        }
-        if (x >> 8 > 0) {
-            x >>= 8;
-            result <<= 4;
-        }
-        if (x >> 4 > 0) {
-            x >>= 4;
-            result <<= 2;
-        }
-        if (x >> 2 > 0) {
-            result <<= 1;
-        }
+        // `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
+        //
+        // This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
+        // → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
+        // → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
+        //
+        // Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
+        uint256 result = 1 << (log2(a) >> 1);
 
         // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
         // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
@@ -217,10 +192,154 @@ library Math {
      * @notice Calculates sqrt(a), following the selected rounding direction.
      */
     function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
-        uint256 result = sqrt(a);
-        if (rounding == Rounding.Up && result * result < a) {
-            result += 1;
+        unchecked {
+            uint256 result = sqrt(a);
+            return result + (rounding == Rounding.Up && result * result < a ? 1 : 0);
+        }
+    }
+
+    /**
+     * @dev Return the log in base 2, rounded down, of a positive value.
+     * Returns 0 if given 0.
+     */
+    function log2(uint256 value) internal pure returns (uint256) {
+        uint256 result = 0;
+        unchecked {
+            if (value >> 128 > 0) {
+                value >>= 128;
+                result += 128;
+            }
+            if (value >> 64 > 0) {
+                value >>= 64;
+                result += 64;
+            }
+            if (value >> 32 > 0) {
+                value >>= 32;
+                result += 32;
+            }
+            if (value >> 16 > 0) {
+                value >>= 16;
+                result += 16;
+            }
+            if (value >> 8 > 0) {
+                value >>= 8;
+                result += 8;
+            }
+            if (value >> 4 > 0) {
+                value >>= 4;
+                result += 4;
+            }
+            if (value >> 2 > 0) {
+                value >>= 2;
+                result += 2;
+            }
+            if (value >> 1 > 0) {
+                result += 1;
+            }
+        }
+        return result;
+    }
+
+    /**
+     * @dev Return the log in base 2, following the selected rounding direction, of a positive value.
+     * Returns 0 if given 0.
+     */
+    function log2(uint256 value, Rounding rounding) internal pure returns (uint256) {
+        unchecked {
+            uint256 result = log2(value);
+            return result + (rounding == Rounding.Up && 1 << result < value ? 1 : 0);
+        }
+    }
+
+    /**
+     * @dev Return the log in base 10, rounded down, of a positive value.
+     * Returns 0 if given 0.
+     */
+    function log10(uint256 value) internal pure returns (uint256) {
+        uint256 result = 0;
+        unchecked {
+            if (value >= 10**64) {
+                value /= 10**64;
+                result += 64;
+            }
+            if (value >= 10**32) {
+                value /= 10**32;
+                result += 32;
+            }
+            if (value >= 10**16) {
+                value /= 10**16;
+                result += 16;
+            }
+            if (value >= 10**8) {
+                value /= 10**8;
+                result += 8;
+            }
+            if (value >= 10**4) {
+                value /= 10**4;
+                result += 4;
+            }
+            if (value >= 10**2) {
+                value /= 10**2;
+                result += 2;
+            }
+            if (value >= 10**1) {
+                result += 1;
+            }
+        }
+        return result;
+    }
+
+    /**
+     * @dev Return the log in base 10, following the selected rounding direction, of a positive value.
+     * Returns 0 if given 0.
+     */
+    function log10(uint256 value, Rounding rounding) internal pure returns (uint256) {
+        unchecked {
+            uint256 result = log10(value);
+            return result + (rounding == Rounding.Up && 10**result < value ? 1 : 0);
+        }
+    }
+
+    /**
+     * @dev Return the log in base 256, rounded down, of a positive value.
+     * Returns 0 if given 0.
+     *
+     * Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string.
+     */
+    function log256(uint256 value) internal pure returns (uint256) {
+        uint256 result = 0;
+        unchecked {
+            if (value >> 128 > 0) {
+                value >>= 128;
+                result += 16;
+            }
+            if (value >> 64 > 0) {
+                value >>= 64;
+                result += 8;
+            }
+            if (value >> 32 > 0) {
+                value >>= 32;
+                result += 4;
+            }
+            if (value >> 16 > 0) {
+                value >>= 16;
+                result += 2;
+            }
+            if (value >> 8 > 0) {
+                result += 1;
+            }
         }
         return result;
     }
+
+    /**
+     * @dev Return the log in base 10, following the selected rounding direction, of a positive value.
+     * Returns 0 if given 0.
+     */
+    function log256(uint256 value, Rounding rounding) internal pure returns (uint256) {
+        unchecked {
+            uint256 result = log256(value);
+            return result + (rounding == Rounding.Up && 1 << (result * 8) < value ? 1 : 0);
+        }
+    }
 }

+ 115 - 22
test/utils/math/Math.test.js

@@ -185,35 +185,128 @@ contract('Math', function (accounts) {
 
   describe('sqrt', function () {
     it('rounds down', async function () {
-      expect(await this.math.sqrt(new BN('0'), Rounding.Down)).to.be.bignumber.equal('0');
-      expect(await this.math.sqrt(new BN('1'), Rounding.Down)).to.be.bignumber.equal('1');
-      expect(await this.math.sqrt(new BN('2'), Rounding.Down)).to.be.bignumber.equal('1');
-      expect(await this.math.sqrt(new BN('3'), Rounding.Down)).to.be.bignumber.equal('1');
-      expect(await this.math.sqrt(new BN('4'), Rounding.Down)).to.be.bignumber.equal('2');
-      expect(await this.math.sqrt(new BN('144'), Rounding.Down)).to.be.bignumber.equal('12');
-      expect(await this.math.sqrt(new BN('999999'), Rounding.Down)).to.be.bignumber.equal('999');
-      expect(await this.math.sqrt(new BN('1000000'), Rounding.Down)).to.be.bignumber.equal('1000');
-      expect(await this.math.sqrt(new BN('1000001'), Rounding.Down)).to.be.bignumber.equal('1000');
-      expect(await this.math.sqrt(new BN('1002000'), Rounding.Down)).to.be.bignumber.equal('1000');
-      expect(await this.math.sqrt(new BN('1002001'), Rounding.Down)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt('0', Rounding.Down)).to.be.bignumber.equal('0');
+      expect(await this.math.sqrt('1', Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt('2', Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt('3', Rounding.Down)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt('4', Rounding.Down)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt('144', Rounding.Down)).to.be.bignumber.equal('12');
+      expect(await this.math.sqrt('999999', Rounding.Down)).to.be.bignumber.equal('999');
+      expect(await this.math.sqrt('1000000', Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt('1000001', Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt('1002000', Rounding.Down)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt('1002001', Rounding.Down)).to.be.bignumber.equal('1001');
       expect(await this.math.sqrt(MAX_UINT256, Rounding.Down))
         .to.be.bignumber.equal('340282366920938463463374607431768211455');
     });
 
     it('rounds up', async function () {
-      expect(await this.math.sqrt(new BN('0'), Rounding.Up)).to.be.bignumber.equal('0');
-      expect(await this.math.sqrt(new BN('1'), Rounding.Up)).to.be.bignumber.equal('1');
-      expect(await this.math.sqrt(new BN('2'), Rounding.Up)).to.be.bignumber.equal('2');
-      expect(await this.math.sqrt(new BN('3'), Rounding.Up)).to.be.bignumber.equal('2');
-      expect(await this.math.sqrt(new BN('4'), Rounding.Up)).to.be.bignumber.equal('2');
-      expect(await this.math.sqrt(new BN('144'), Rounding.Up)).to.be.bignumber.equal('12');
-      expect(await this.math.sqrt(new BN('999999'), Rounding.Up)).to.be.bignumber.equal('1000');
-      expect(await this.math.sqrt(new BN('1000000'), Rounding.Up)).to.be.bignumber.equal('1000');
-      expect(await this.math.sqrt(new BN('1000001'), Rounding.Up)).to.be.bignumber.equal('1001');
-      expect(await this.math.sqrt(new BN('1002000'), Rounding.Up)).to.be.bignumber.equal('1001');
-      expect(await this.math.sqrt(new BN('1002001'), Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt('0', Rounding.Up)).to.be.bignumber.equal('0');
+      expect(await this.math.sqrt('1', Rounding.Up)).to.be.bignumber.equal('1');
+      expect(await this.math.sqrt('2', Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt('3', Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt('4', Rounding.Up)).to.be.bignumber.equal('2');
+      expect(await this.math.sqrt('144', Rounding.Up)).to.be.bignumber.equal('12');
+      expect(await this.math.sqrt('999999', Rounding.Up)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt('1000000', Rounding.Up)).to.be.bignumber.equal('1000');
+      expect(await this.math.sqrt('1000001', Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt('1002000', Rounding.Up)).to.be.bignumber.equal('1001');
+      expect(await this.math.sqrt('1002001', Rounding.Up)).to.be.bignumber.equal('1001');
       expect(await this.math.sqrt(MAX_UINT256, Rounding.Up))
         .to.be.bignumber.equal('340282366920938463463374607431768211456');
     });
   });
+
+  describe('log', function () {
+    describe('log2', function () {
+      it('rounds down', async function () {
+        expect(await this.math.log2('0', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log2('1', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log2('2', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log2('3', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log2('4', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('5', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('6', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('7', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('8', Rounding.Down)).to.be.bignumber.equal('3');
+        expect(await this.math.log2('9', Rounding.Down)).to.be.bignumber.equal('3');
+        expect(await this.math.log2(MAX_UINT256, Rounding.Down)).to.be.bignumber.equal('255');
+      });
+
+      it('rounds up', async function () {
+        expect(await this.math.log2('0', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log2('1', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log2('2', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log2('3', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('4', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log2('5', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log2('6', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log2('7', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log2('8', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log2(MAX_UINT256, Rounding.Up)).to.be.bignumber.equal('256');
+      });
+    });
+
+    describe('log10', function () {
+      it('rounds down', async function () {
+        expect(await this.math.log10('0', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('1', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('2', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('9', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('10', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('11', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('99', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('100', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('101', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('999', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('1000', Rounding.Down)).to.be.bignumber.equal('3');
+        expect(await this.math.log10('1001', Rounding.Down)).to.be.bignumber.equal('3');
+        expect(await this.math.log10(MAX_UINT256, Rounding.Down)).to.be.bignumber.equal('77');
+      });
+
+      it('rounds up', async function () {
+        expect(await this.math.log10('0', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('1', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log10('2', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('9', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('10', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log10('11', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('99', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('100', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log10('101', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log10('999', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log10('1000', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log10('1001', Rounding.Up)).to.be.bignumber.equal('4');
+        expect(await this.math.log10(MAX_UINT256, Rounding.Up)).to.be.bignumber.equal('78');
+      });
+    });
+
+    describe('log256', function () {
+      it('rounds down', async function () {
+        expect(await this.math.log256('0', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('1', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('2', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('255', Rounding.Down)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('256', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('257', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('65535', Rounding.Down)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('65536', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log256('65537', Rounding.Down)).to.be.bignumber.equal('2');
+        expect(await this.math.log256(MAX_UINT256, Rounding.Down)).to.be.bignumber.equal('31');
+      });
+
+      it('rounds up', async function () {
+        expect(await this.math.log256('0', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('1', Rounding.Up)).to.be.bignumber.equal('0');
+        expect(await this.math.log256('2', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('255', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('256', Rounding.Up)).to.be.bignumber.equal('1');
+        expect(await this.math.log256('257', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log256('65535', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log256('65536', Rounding.Up)).to.be.bignumber.equal('2');
+        expect(await this.math.log256('65537', Rounding.Up)).to.be.bignumber.equal('3');
+        expect(await this.math.log256(MAX_UINT256, Rounding.Up)).to.be.bignumber.equal('32');
+      });
+    });
+  });
 });