Forráskód Böngészése

Add saturating (unsigned) math operations and optimize try operations (#5527)

Hadrien Croubois 7 hónapja
szülő
commit
a9b1f58b00

+ 5 - 0
.changeset/fair-pumpkins-compete.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.

+ 48 - 16
contracts/utils/math/Math.sol

@@ -51,8 +51,8 @@ library Math {
     function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
             uint256 c = a + b;
-            if (c < a) return (false, 0);
-            return (true, c);
+            success = c >= a;
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -61,8 +61,9 @@ library Math {
      */
     function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b > a) return (false, 0);
-            return (true, a - b);
+            uint256 c = a - b;
+            success = c <= a;
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -71,13 +72,14 @@ library Math {
      */
     function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            // Gas optimization: this is cheaper than requiring 'a' not being zero, but the
-            // benefit is lost if 'b' is also tested.
-            // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
-            if (a == 0) return (true, 0);
             uint256 c = a * b;
-            if (c / a != b) return (false, 0);
-            return (true, c);
+            assembly ("memory-safe") {
+                // Only true when the multiplication doesn't overflow
+                // (c / a == b) || (a == 0)
+                success := or(eq(div(c, a), b), iszero(a))
+            }
+            // equivalent to: success ? c : 0
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -86,8 +88,11 @@ library Math {
      */
     function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b == 0) return (false, 0);
-            return (true, a / b);
+            success = b > 0;
+            assembly ("memory-safe") {
+                // The `DIV` opcode returns zero when the denominator is 0.
+                result := div(a, b)
+            }
         }
     }
 
@@ -96,11 +101,38 @@ library Math {
      */
     function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b == 0) return (false, 0);
-            return (true, a % b);
+            success = b > 0;
+            assembly ("memory-safe") {
+                // The `MOD` opcode returns zero when the denominator is 0.
+                result := mod(a, b)
+            }
         }
     }
 
+    /**
+     * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
+     */
+    function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
+        (bool success, uint256 result) = tryAdd(a, b);
+        return ternary(success, result, type(uint256).max);
+    }
+
+    /**
+     * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
+     */
+    function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
+        (, uint256 result) = trySub(a, b);
+        return result;
+    }
+
+    /**
+     * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
+     */
+    function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
+        (bool success, uint256 result) = tryMul(a, b);
+        return ternary(success, result, type(uint256).max);
+    }
+
     /**
      * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
      *
@@ -192,7 +224,7 @@ library Math {
 
             // Make division exact by subtracting the remainder from [high low].
             uint256 remainder;
-            assembly {
+            assembly ("memory-safe") {
                 // Compute remainder using mulmod.
                 remainder := mulmod(x, y, denominator)
 
@@ -205,7 +237,7 @@ library Math {
             // Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
 
             uint256 twos = denominator & (0 - denominator);
-            assembly {
+            assembly ("memory-safe") {
                 // Divide denominator by twos.
                 denominator := div(denominator, twos)
 

+ 56 - 0
test/utils/math/Math.test.js

@@ -168,6 +168,62 @@ describe('Math', function () {
     });
   });
 
+  describe('saturatingAdd', function () {
+    it('adds correctly', async function () {
+      const a = 5678n;
+      const b = 1234n;
+      await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
+      await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
+      await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
+    });
+
+    it('bounds on addition overflow', async function () {
+      await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
+      await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
+        ethers.MaxUint256,
+      );
+    });
+  });
+
+  describe('saturatingSub', function () {
+    it('subtracts correctly', async function () {
+      const a = 5678n;
+      const b = 1234n;
+      await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
+      await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
+      await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
+    });
+
+    it('bounds on subtraction overflow', async function () {
+      await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
+    });
+  });
+
+  describe('saturatingMul', function () {
+    it('multiplies correctly', async function () {
+      const a = 1234n;
+      const b = 5678n;
+      await testCommutative(this.mock.$saturatingMul, a, b, a * b);
+    });
+
+    it('multiplies by zero correctly', async function () {
+      const a = 0n;
+      const b = 5678n;
+      await testCommutative(this.mock.$saturatingMul, a, b, 0n);
+    });
+
+    it('bounds on multiplication overflow', async function () {
+      const a = ethers.MaxUint256;
+      const b = 2n;
+      await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
+    });
+  });
+
   describe('max', function () {
     it('is correctly detected in both position', async function () {
       await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));