Browse Source

Branchless ternary, min and max methods (#4976)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
Lohann Paterno Coutinho Ferreira 1 year ago
parent
commit
4032b42694

+ 5 - 0
.changeset/spotty-falcons-explain.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`, `SignedMath`: Add a branchless `ternary` function that computes`cond ? a : b` in constant gas cost.

+ 21 - 5
contracts/utils/math/Math.sol

@@ -73,18 +73,34 @@ library Math {
         }
     }
 
+    /**
+     * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
+     *
+     * IMPORTANT: This function may reduce bytecode size and consume less gas when used standalone.
+     * However, the compiler may optimize Solidity ternary operations (i.e. `a ? b : c`) to only compute
+     * one branch when needed, making this function more expensive.
+     */
+    function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256) {
+        unchecked {
+            // branchless ternary works because:
+            // b ^ (a ^ b) == a
+            // b ^ 0 == b
+            return b ^ ((a ^ b) * SafeCast.toUint(condition));
+        }
+    }
+
     /**
      * @dev Returns the largest of two numbers.
      */
     function max(uint256 a, uint256 b) internal pure returns (uint256) {
-        return a > b ? a : b;
+        return ternary(a > b, a, b);
     }
 
     /**
      * @dev Returns the smallest of two numbers.
      */
     function min(uint256 a, uint256 b) internal pure returns (uint256) {
-        return a < b ? a : b;
+        return ternary(a < b, a, b);
     }
 
     /**
@@ -114,7 +130,7 @@ library Math {
         // but the largest value we can obtain is type(uint256).max - 1, which happens
         // when a = type(uint256).max and b = 1.
         unchecked {
-            return a == 0 ? 0 : (a - 1) / b + 1;
+            return SafeCast.toUint(a > 0) * ((a - 1) / b + 1);
         }
     }
 
@@ -147,7 +163,7 @@ library Math {
 
             // Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
             if (denominator <= prod1) {
-                Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW);
+                Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
             }
 
             ///////////////////////////////////////////////
@@ -268,7 +284,7 @@ library Math {
             }
 
             if (gcd != 1) return 0; // No inverse exists.
-            return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
+            return ternary(x < 0, n - uint256(-x), uint256(x)); // Wrap the result if it's negative.
         }
     }
 

+ 20 - 2
contracts/utils/math/SignedMath.sol

@@ -3,22 +3,40 @@
 
 pragma solidity ^0.8.20;
 
+import {SafeCast} from "./SafeCast.sol";
+
 /**
  * @dev Standard signed math utilities missing in the Solidity language.
  */
 library SignedMath {
+    /**
+     * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
+     *
+     * IMPORTANT: This function may reduce bytecode size and consume less gas when used standalone.
+     * However, the compiler may optimize Solidity ternary operations (i.e. `a ? b : c`) to only compute
+     * one branch when needed, making this function more expensive.
+     */
+    function ternary(bool condition, int256 a, int256 b) internal pure returns (int256) {
+        unchecked {
+            // branchless terinary works because:
+            // b ^ (a ^ b) == a
+            // b ^ 0 == b
+            return b ^ ((a ^ b) * int256(SafeCast.toUint(condition)));
+        }
+    }
+
     /**
      * @dev Returns the largest of two signed numbers.
      */
     function max(int256 a, int256 b) internal pure returns (int256) {
-        return a > b ? a : b;
+        return ternary(a > b, a, b);
     }
 
     /**
      * @dev Returns the smallest of two signed numbers.
      */
     function min(int256 a, int256 b) internal pure returns (int256) {
-        return a < b ? a : b;
+        return ternary(a < b, a, b);
     }
 
     /**

+ 10 - 0
test/utils/math/Math.t.sol

@@ -7,6 +7,16 @@ import {Test, stdError} from "forge-std/Test.sol";
 import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
 
 contract MathTest is Test {
+    function testSelect(bool f, uint256 a, uint256 b) public {
+        assertEq(Math.ternary(f, a, b), f ? a : b);
+    }
+
+    // MIN & MAX
+    function testMinMax(uint256 a, uint256 b) public {
+        assertEq(Math.min(a, b), a < b ? a : b);
+        assertEq(Math.max(a, b), a > b ? a : b);
+    }
+
     // CEILDIV
     function testCeilDiv(uint256 a, uint256 b) public {
         vm.assume(b > 0);

+ 10 - 0
test/utils/math/SignedMath.t.sol

@@ -8,6 +8,16 @@ import {Math} from "../../../contracts/utils/math/Math.sol";
 import {SignedMath} from "../../../contracts/utils/math/SignedMath.sol";
 
 contract SignedMathTest is Test {
+    function testSelect(bool f, int256 a, int256 b) public {
+        assertEq(SignedMath.ternary(f, a, b), f ? a : b);
+    }
+
+    // MIN & MAX
+    function testMinMax(int256 a, int256 b) public {
+        assertEq(SignedMath.min(a, b), a < b ? a : b);
+        assertEq(SignedMath.max(a, b), a > b ? a : b);
+    }
+
     // MIN
     function testMin(int256 a, int256 b) public {
         int256 result = SignedMath.min(a, b);