Browse Source

Optimize `log2` with a lookup table (#5236)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: cairo <cairoeth@protonmail.com>
Lohann Paterno Coutinho Ferreira 11 months ago
parent
commit
448efeea66
1 changed files with 38 additions and 34 deletions
  1. 38 34
      contracts/utils/math/Math.sol

+ 38 - 34
contracts/utils/math/Math.sol

@@ -537,41 +537,45 @@ library Math {
      * @dev Return the log in base 2 of a positive value rounded towards zero.
      * Returns 0 if given 0.
      */
-    function log2(uint256 value) internal pure returns (uint256) {
-        uint256 result = 0;
-        uint256 exp;
-        unchecked {
-            exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
-            value >>= exp;
-            result += exp;
-
-            exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
-            value >>= exp;
-            result += exp;
-
-            result += SafeCast.toUint(value > 1);
+    function log2(uint256 x) internal pure returns (uint256 r) {
+        // If value has upper 128 bits set, log2 result is at least 128
+        r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7;
+        // If upper 64 bits of 128-bit half set, add 64 to result
+        r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6;
+        // If upper 32 bits of 64-bit half set, add 32 to result
+        r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5;
+        // If upper 16 bits of 32-bit half set, add 16 to result
+        r |= SafeCast.toUint((x >> r) > 0xffff) << 4;
+        // If upper 8 bits of 16-bit half set, add 8 to result
+        r |= SafeCast.toUint((x >> r) > 0xff) << 3;
+        // If upper 4 bits of 8-bit half set, add 4 to result
+        r |= SafeCast.toUint((x >> r) > 0xf) << 2;
+
+        // Shifts value right by the current result and use it as an index into this lookup table:
+        //
+        // | x (4 bits) |  index  | table[index] = MSB position |
+        // |------------|---------|-----------------------------|
+        // |    0000    |    0    |        table[0] = 0         |
+        // |    0001    |    1    |        table[1] = 0         |
+        // |    0010    |    2    |        table[2] = 1         |
+        // |    0011    |    3    |        table[3] = 1         |
+        // |    0100    |    4    |        table[4] = 2         |
+        // |    0101    |    5    |        table[5] = 2         |
+        // |    0110    |    6    |        table[6] = 2         |
+        // |    0111    |    7    |        table[7] = 2         |
+        // |    1000    |    8    |        table[8] = 3         |
+        // |    1001    |    9    |        table[9] = 3         |
+        // |    1010    |   10    |        table[10] = 3        |
+        // |    1011    |   11    |        table[11] = 3        |
+        // |    1100    |   12    |        table[12] = 3        |
+        // |    1101    |   13    |        table[13] = 3        |
+        // |    1110    |   14    |        table[14] = 3        |
+        // |    1111    |   15    |        table[15] = 3        |
+        //
+        // The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes.
+        assembly ("memory-safe") {
+            r := or(r, byte(shr(r, x), 0x0000010102020202030303030303030300000000000000000000000000000000))
         }
-        return result;
     }
 
     /**