Browse Source

Refactor abs without logical branching (#4497)

Co-authored-by: Francisco Giordano <fg@frang.io>
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
Vladislav Volosnikov 1 year ago
parent
commit
dfae50fa5b
2 changed files with 79 additions and 2 deletions
  1. 9 2
      contracts/utils/math/SignedMath.sol
  2. 70 0
      test/utils/math/SignedMath.t.sol

+ 9 - 2
contracts/utils/math/SignedMath.sol

@@ -36,8 +36,15 @@ library SignedMath {
      */
     function abs(int256 n) internal pure returns (uint256) {
         unchecked {
-            // must be unchecked in order to support `n = type(int256).min`
-            return uint256(n >= 0 ? n : -n);
+            // Formula from the "Bit Twiddling Hacks" by Sean Eron Anderson.
+            // Since `n` is a signed integer, the generated bytecode will use the SAR opcode to perform the right shift,
+            // taking advantage of the most significant (or "sign" bit) in two's complement representation.
+            // This opcode adds new most significant bits set to the value of the previous most significant bit. As a result,
+            // the mask will either be `bytes(0)` (if n is positive) or `~bytes32(0)` (if n is negative).
+            int256 mask = n >> 255;
+
+            // A `bytes(0)` mask leaves the input unchanged, while a `~bytes32(0)` mask complements it.
+            return uint256((n + mask) ^ mask);
         }
     }
 }

+ 70 - 0
test/utils/math/SignedMath.t.sol

@@ -0,0 +1,70 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Test} from "forge-std/Test.sol";
+
+import {Math} from "../../../contracts/utils/math/Math.sol";
+import {SignedMath} from "../../../contracts/utils/math/SignedMath.sol";
+
+contract SignedMathTest is Test {
+    // MIN
+    function testMin(int256 a, int256 b) public {
+        int256 result = SignedMath.min(a, b);
+
+        assertLe(result, a);
+        assertLe(result, b);
+        assertTrue(result == a || result == b);
+    }
+
+    // MAX
+    function testMax(int256 a, int256 b) public {
+        int256 result = SignedMath.max(a, b);
+
+        assertGe(result, a);
+        assertGe(result, b);
+        assertTrue(result == a || result == b);
+    }
+
+    // AVERAGE
+    // 1. simple test, not full int256 range
+    function testAverage1(int256 a, int256 b) public {
+        a = bound(a, type(int256).min / 2, type(int256).max / 2);
+        b = bound(b, type(int256).min / 2, type(int256).max / 2);
+
+        int256 result = SignedMath.average(a, b);
+
+        assertEq(result, (a + b) / 2);
+    }
+
+    // 2. more complex test, full int256 range
+    function testAverage2(int256 a, int256 b) public {
+        (int256 result, int256 min, int256 max) = (
+            SignedMath.average(a, b),
+            SignedMath.min(a, b),
+            SignedMath.max(a, b)
+        );
+
+        // average must be between `a` and `b`
+        assertGe(result, min);
+        assertLe(result, max);
+
+        unchecked {
+            // must be unchecked in order to support `a = type(int256).min, b = type(int256).max`
+            uint256 deltaLower = uint256(result - min);
+            uint256 deltaUpper = uint256(max - result);
+            uint256 remainder = uint256((a & 1) ^ (b & 1));
+            assertEq(remainder, Math.max(deltaLower, deltaUpper) - Math.min(deltaLower, deltaUpper));
+        }
+    }
+
+    // ABS
+    function testAbs(int256 a) public {
+        uint256 result = SignedMath.abs(a);
+
+        unchecked {
+            // must be unchecked in order to support `n = type(int256).min`
+            assertEq(result, a < 0 ? uint256(-a) : uint256(a));
+        }
+    }
+}