Browse Source

Improved integer square root (#4403)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
Chris Gorman 1 year ago
parent
commit
140d66fad8
1 changed files with 103 additions and 29 deletions
  1. 103 29
      contracts/utils/math/Math.sol

+ 103 - 29
contracts/utils/math/Math.sol

@@ -385,38 +385,112 @@ library Math {
      * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
      * towards zero.
      *
-     * Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
+     * This method is based on Newton's method for computing square roots; the algorithm is restricted to only
+     * using integer operations.
      */
     function sqrt(uint256 a) internal pure returns (uint256) {
-        if (a == 0) {
-            return 0;
-        }
-
-        // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
-        //
-        // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
-        // `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
-        //
-        // This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
-        // → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
-        // → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
-        //
-        // Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
-        uint256 result = 1 << (log2(a) >> 1);
-
-        // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
-        // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
-        // every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
-        // into the expected uint128 result.
         unchecked {
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            result = (result + a / result) >> 1;
-            return min(result, a / result);
+            // Take care of easy edge cases when a == 0 or a == 1
+            if (a <= 1) {
+                return a;
+            }
+
+            // In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
+            // sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
+            // the current value as `ε_n = | x_n - sqrt(a) |`.
+            //
+            // For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
+            // of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
+            // bigger than any uint256.
+            //
+            // By noticing that
+            // `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
+            // we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
+            // to the msb function.
+            uint256 aa = a;
+            uint256 xn = 1;
+
+            if (aa >= (1 << 128)) {
+                aa >>= 128;
+                xn <<= 64;
+            }
+            if (aa >= (1 << 64)) {
+                aa >>= 64;
+                xn <<= 32;
+            }
+            if (aa >= (1 << 32)) {
+                aa >>= 32;
+                xn <<= 16;
+            }
+            if (aa >= (1 << 16)) {
+                aa >>= 16;
+                xn <<= 8;
+            }
+            if (aa >= (1 << 8)) {
+                aa >>= 8;
+                xn <<= 4;
+            }
+            if (aa >= (1 << 4)) {
+                aa >>= 4;
+                xn <<= 2;
+            }
+            if (aa >= (1 << 2)) {
+                xn <<= 1;
+            }
+
+            // We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
+            //
+            // We can refine our estimation by noticing that the the middle of that interval minimizes the error.
+            // If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
+            // This is going to be our x_0 (and ε_0)
+            xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)
+
+            // From here, Newton's method give us:
+            // x_{n+1} = (x_n + a / x_n) / 2
+            //
+            // One should note that:
+            // x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
+            //              = ((x_n² + a) / (2 * x_n))² - a
+            //              = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
+            //              = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
+            //              = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
+            //              = (x_n² - a)² / (2 * x_n)²
+            //              = ((x_n² - a) / (2 * x_n))²
+            //              ≥ 0
+            // Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
+            //
+            // This gives us the proof of quadratic convergence of the sequence:
+            // ε_{n+1} = | x_{n+1} - sqrt(a) |
+            //         = | (x_n + a / x_n) / 2 - sqrt(a) |
+            //         = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
+            //         = | (x_n - sqrt(a))² / (2 * x_n) |
+            //         = | ε_n² / (2 * x_n) |
+            //         = ε_n² / | (2 * x_n) |
+            //
+            // For the first iteration, we have a special case where x_0 is known:
+            // ε_1 = ε_0² / | (2 * x_0) |
+            //     ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
+            //     ≤ 2**(2*e-4) / (3 * 2**(e-1))
+            //     ≤ 2**(e-3) / 3
+            //     ≤ 2**(e-3-log2(3))
+            //     ≤ 2**(e-4.5)
+            //
+            // For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
+            // ε_{n+1} = ε_n² / | (2 * x_n) |
+            //         ≤ (2**(e-k))² / (2 * 2**(e-1))
+            //         ≤ 2**(2*e-2*k) / 2**e
+            //         ≤ 2**(e-2*k)
+            xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5)  -- special case, see above
+            xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9)    -- general case with k = 4.5
+            xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18)   -- general case with k = 9
+            xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36)   -- general case with k = 18
+            xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72)   -- general case with k = 36
+            xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144)  -- general case with k = 72
+
+            // Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
+            // ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
+            // sqrt(a) or sqrt(a) + 1.
+            return xn - SafeCast.toUint(xn > a / xn);
         }
     }