Преглед изворни кода

Optimize Math operations using branchless bool to uint translation. (#4878)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
Igor Żuk пре 1 година
родитељ
комит
17a8955cd8

+ 5 - 0
.changeset/nervous-pans-grow.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': patch
+---
+
+`SafeCast`: Add `toUint(bool)` for operating on `bool` values as `uint256`.

+ 54 - 59
contracts/utils/math/Math.sol

@@ -5,6 +5,7 @@ pragma solidity ^0.8.20;
 
 import {Address} from "../Address.sol";
 import {Panic} from "../Panic.sol";
+import {SafeCast} from "./SafeCast.sol";
 
 /**
  * @dev Standard math utilities missing in the Solidity language.
@@ -210,11 +211,7 @@ library Math {
      * @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
      */
     function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
-        uint256 result = mulDiv(x, y, denominator);
-        if (unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0) {
-            result += 1;
-        }
-        return result;
+        return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
     }
 
     /**
@@ -383,7 +380,7 @@ library Math {
     function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
         unchecked {
             uint256 result = sqrt(a);
-            return result + (unsignedRoundsUp(rounding) && result * result < a ? 1 : 0);
+            return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a);
         }
     }
 
@@ -393,38 +390,37 @@ library Math {
      */
     function log2(uint256 value) internal pure returns (uint256) {
         uint256 result = 0;
+        uint256 exp;
         unchecked {
-            if (value >> 128 > 0) {
-                value >>= 128;
-                result += 128;
-            }
-            if (value >> 64 > 0) {
-                value >>= 64;
-                result += 64;
-            }
-            if (value >> 32 > 0) {
-                value >>= 32;
-                result += 32;
-            }
-            if (value >> 16 > 0) {
-                value >>= 16;
-                result += 16;
-            }
-            if (value >> 8 > 0) {
-                value >>= 8;
-                result += 8;
-            }
-            if (value >> 4 > 0) {
-                value >>= 4;
-                result += 4;
-            }
-            if (value >> 2 > 0) {
-                value >>= 2;
-                result += 2;
-            }
-            if (value >> 1 > 0) {
-                result += 1;
-            }
+            exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
+            value >>= exp;
+            result += exp;
+
+            exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
+            value >>= exp;
+            result += exp;
+
+            result += SafeCast.toUint(value > 1);
         }
         return result;
     }
@@ -436,7 +432,7 @@ library Math {
     function log2(uint256 value, Rounding rounding) internal pure returns (uint256) {
         unchecked {
             uint256 result = log2(value);
-            return result + (unsignedRoundsUp(rounding) && 1 << result < value ? 1 : 0);
+            return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value);
         }
     }
 
@@ -485,7 +481,7 @@ library Math {
     function log10(uint256 value, Rounding rounding) internal pure returns (uint256) {
         unchecked {
             uint256 result = log10(value);
-            return result + (unsignedRoundsUp(rounding) && 10 ** result < value ? 1 : 0);
+            return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value);
         }
     }
 
@@ -497,26 +493,25 @@ library Math {
      */
     function log256(uint256 value) internal pure returns (uint256) {
         uint256 result = 0;
+        uint256 isGt;
         unchecked {
-            if (value >> 128 > 0) {
-                value >>= 128;
-                result += 16;
-            }
-            if (value >> 64 > 0) {
-                value >>= 64;
-                result += 8;
-            }
-            if (value >> 32 > 0) {
-                value >>= 32;
-                result += 4;
-            }
-            if (value >> 16 > 0) {
-                value >>= 16;
-                result += 2;
-            }
-            if (value >> 8 > 0) {
-                result += 1;
-            }
+            isGt = SafeCast.toUint(value > (1 << 128) - 1);
+            value >>= isGt * 128;
+            result += isGt * 16;
+
+            isGt = SafeCast.toUint(value > (1 << 64) - 1);
+            value >>= isGt * 64;
+            result += isGt * 8;
+
+            isGt = SafeCast.toUint(value > (1 << 32) - 1);
+            value >>= isGt * 32;
+            result += isGt * 4;
+
+            isGt = SafeCast.toUint(value > (1 << 16) - 1);
+            value >>= isGt * 16;
+            result += isGt * 2;
+
+            result += SafeCast.toUint(value > (1 << 8) - 1);
         }
         return result;
     }
@@ -528,7 +523,7 @@ library Math {
     function log256(uint256 value, Rounding rounding) internal pure returns (uint256) {
         unchecked {
             uint256 result = log256(value);
-            return result + (unsignedRoundsUp(rounding) && 1 << (result << 3) < value ? 1 : 0);
+            return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value);
         }
     }
 

+ 11 - 1
contracts/utils/math/SafeCast.sol

@@ -5,7 +5,7 @@
 pragma solidity ^0.8.20;
 
 /**
- * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
+ * @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow
  * checks.
  *
  * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
@@ -1150,4 +1150,14 @@ library SafeCast {
         }
         return int256(value);
     }
+
+    /**
+     * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
+     */
+    function toUint(bool b) internal pure returns (uint256 u) {
+        /// @solidity memory-safe-assembly
+        assembly {
+            u := iszero(iszero(b))
+        }
+    }
 }

+ 14 - 2
scripts/generate/templates/SafeCast.js

@@ -7,7 +7,7 @@ const header = `\
 pragma solidity ^0.8.20;
 
 /**
- * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
+ * @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow
  * checks.
  *
  * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
@@ -116,11 +116,23 @@ function toUint${length}(int${length} value) internal pure returns (uint${length
 }
 `;
 
+const boolToUint = `
+  /**
+   * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
+   */
+  function toUint(bool b) internal pure returns (uint256 u) {
+      /// @solidity memory-safe-assembly
+      assembly {
+          u := iszero(iszero(b))
+      }
+  }
+`;
+
 // GENERATE
 module.exports = format(
   header.trimEnd(),
   'library SafeCast {',
   errors,
-  [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256)],
+  [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint],
   '}',
 );

+ 10 - 0
test/utils/math/SafeCast.test.js

@@ -146,4 +146,14 @@ describe('SafeCast', function () {
         .withArgs(ethers.MaxUint256);
     });
   });
+
+  describe('toUint (bool)', function () {
+    it('toUint(false) should be 0', async function () {
+      expect(await this.mock.$toUint(false)).to.equal(0n);
+    });
+
+    it('toUint(true) should be 1', async function () {
+      expect(await this.mock.$toUint(true)).to.equal(1n);
+    });
+  });
 });