|
@@ -1,5 +1,5 @@
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
-// OpenZeppelin Contracts (last updated v5.1.0) (utils/math/Math.sol)
|
|
|
+// OpenZeppelin Contracts (last updated v5.3.0) (utils/math/Math.sol)
|
|
|
|
|
|
pragma solidity ^0.8.20;
|
|
|
|
|
@@ -18,38 +18,68 @@ library Math {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
|
|
|
+ * @dev Return the 512-bit addition of two uint256.
|
|
|
+ *
|
|
|
+ * The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
|
|
|
+ */
|
|
|
+ function add512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ low := add(a, b)
|
|
|
+ high := lt(low, a)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @dev Return the 512-bit multiplication of two uint256.
|
|
|
+ *
|
|
|
+ * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
|
|
|
+ */
|
|
|
+ function mul512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
|
|
|
+ // 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
|
|
|
+ // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
|
|
|
+ // variables such that product = high * 2²⁵⁶ + low.
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ let mm := mulmod(a, b, not(0))
|
|
|
+ low := mul(a, b)
|
|
|
+ high := sub(sub(mm, low), lt(mm, low))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @dev Returns the addition of two unsigned integers, with a success flag (no overflow).
|
|
|
*/
|
|
|
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
|
|
unchecked {
|
|
|
uint256 c = a + b;
|
|
|
- if (c < a) return (false, 0);
|
|
|
- return (true, c);
|
|
|
+ success = c >= a;
|
|
|
+ result = c * SafeCast.toUint(success);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * @dev Returns the subtraction of two unsigned integers, with an success flag (no overflow).
|
|
|
+ * @dev Returns the subtraction of two unsigned integers, with a success flag (no overflow).
|
|
|
*/
|
|
|
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
|
|
unchecked {
|
|
|
- if (b > a) return (false, 0);
|
|
|
- return (true, a - b);
|
|
|
+ uint256 c = a - b;
|
|
|
+ success = c <= a;
|
|
|
+ result = c * SafeCast.toUint(success);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * @dev Returns the multiplication of two unsigned integers, with an success flag (no overflow).
|
|
|
+ * @dev Returns the multiplication of two unsigned integers, with a success flag (no overflow).
|
|
|
*/
|
|
|
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
|
|
unchecked {
|
|
|
- // Gas optimization: this is cheaper than requiring 'a' not being zero, but the
|
|
|
- // benefit is lost if 'b' is also tested.
|
|
|
- // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
|
|
|
- if (a == 0) return (true, 0);
|
|
|
uint256 c = a * b;
|
|
|
- if (c / a != b) return (false, 0);
|
|
|
- return (true, c);
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ // Only true when the multiplication doesn't overflow
|
|
|
+ // (c / a == b) || (a == 0)
|
|
|
+ success := or(eq(div(c, a), b), iszero(a))
|
|
|
+ }
|
|
|
+ // equivalent to: success ? c : 0
|
|
|
+ result = c * SafeCast.toUint(success);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -58,8 +88,11 @@ library Math {
|
|
|
*/
|
|
|
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
|
|
unchecked {
|
|
|
- if (b == 0) return (false, 0);
|
|
|
- return (true, a / b);
|
|
|
+ success = b > 0;
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ // The `DIV` opcode returns zero when the denominator is 0.
|
|
|
+ result := div(a, b)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -68,11 +101,38 @@ library Math {
|
|
|
*/
|
|
|
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
|
|
unchecked {
|
|
|
- if (b == 0) return (false, 0);
|
|
|
- return (true, a % b);
|
|
|
+ success = b > 0;
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ // The `MOD` opcode returns zero when the denominator is 0.
|
|
|
+ result := mod(a, b)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
|
|
|
+ */
|
|
|
+ function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
|
|
|
+ (bool success, uint256 result) = tryAdd(a, b);
|
|
|
+ return ternary(success, result, type(uint256).max);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
|
|
|
+ */
|
|
|
+ function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
|
|
|
+ (, uint256 result) = trySub(a, b);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
|
|
|
+ */
|
|
|
+ function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
|
|
|
+ (bool success, uint256 result) = tryMul(a, b);
|
|
|
+ return ternary(success, result, type(uint256).max);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
|
|
|
*
|
|
@@ -143,26 +203,18 @@ library Math {
|
|
|
*/
|
|
|
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
|
|
|
unchecked {
|
|
|
- // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
|
|
|
- // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
|
|
|
- // variables such that product = prod1 * 2²⁵⁶ + prod0.
|
|
|
- uint256 prod0 = x * y; // Least significant 256 bits of the product
|
|
|
- uint256 prod1; // Most significant 256 bits of the product
|
|
|
- assembly {
|
|
|
- let mm := mulmod(x, y, not(0))
|
|
|
- prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
|
|
- }
|
|
|
+ (uint256 high, uint256 low) = mul512(x, y);
|
|
|
|
|
|
// Handle non-overflow cases, 256 by 256 division.
|
|
|
- if (prod1 == 0) {
|
|
|
+ if (high == 0) {
|
|
|
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
|
|
|
// The surrounding unchecked block does not change this fact.
|
|
|
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
|
|
|
- return prod0 / denominator;
|
|
|
+ return low / denominator;
|
|
|
}
|
|
|
|
|
|
// Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
|
|
|
- if (denominator <= prod1) {
|
|
|
+ if (denominator <= high) {
|
|
|
Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
|
|
|
}
|
|
|
|
|
@@ -170,34 +222,34 @@ library Math {
|
|
|
// 512 by 256 division.
|
|
|
///////////////////////////////////////////////
|
|
|
|
|
|
- // Make division exact by subtracting the remainder from [prod1 prod0].
|
|
|
+ // Make division exact by subtracting the remainder from [high low].
|
|
|
uint256 remainder;
|
|
|
- assembly {
|
|
|
+ assembly ("memory-safe") {
|
|
|
// Compute remainder using mulmod.
|
|
|
remainder := mulmod(x, y, denominator)
|
|
|
|
|
|
// Subtract 256 bit number from 512 bit number.
|
|
|
- prod1 := sub(prod1, gt(remainder, prod0))
|
|
|
- prod0 := sub(prod0, remainder)
|
|
|
+ high := sub(high, gt(remainder, low))
|
|
|
+ low := sub(low, remainder)
|
|
|
}
|
|
|
|
|
|
// Factor powers of two out of denominator and compute largest power of two divisor of denominator.
|
|
|
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
|
|
|
|
|
|
uint256 twos = denominator & (0 - denominator);
|
|
|
- assembly {
|
|
|
+ assembly ("memory-safe") {
|
|
|
// Divide denominator by twos.
|
|
|
denominator := div(denominator, twos)
|
|
|
|
|
|
- // Divide [prod1 prod0] by twos.
|
|
|
- prod0 := div(prod0, twos)
|
|
|
+ // Divide [high low] by twos.
|
|
|
+ low := div(low, twos)
|
|
|
|
|
|
// Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
|
|
|
twos := add(div(sub(0, twos), twos), 1)
|
|
|
}
|
|
|
|
|
|
- // Shift in bits from prod1 into prod0.
|
|
|
- prod0 |= prod1 * twos;
|
|
|
+ // Shift in bits from high into low.
|
|
|
+ low |= high * twos;
|
|
|
|
|
|
// Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
|
|
|
// that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
|
|
@@ -215,9 +267,9 @@ library Math {
|
|
|
|
|
|
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
|
|
|
// This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
|
|
|
- // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1
|
|
|
+ // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
|
|
|
// is no longer required.
|
|
|
- result = prod0 * inverse;
|
|
|
+ result = low * inverse;
|
|
|
return result;
|
|
|
}
|
|
|
}
|
|
@@ -229,6 +281,26 @@ library Math {
|
|
|
return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
|
|
|
+ */
|
|
|
+ function mulShr(uint256 x, uint256 y, uint8 n) internal pure returns (uint256 result) {
|
|
|
+ unchecked {
|
|
|
+ (uint256 high, uint256 low) = mul512(x, y);
|
|
|
+ if (high >= 1 << n) {
|
|
|
+ Panic.panic(Panic.UNDER_OVERFLOW);
|
|
|
+ }
|
|
|
+ return (high << (256 - n)) | (low >> n);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @dev Calculates x * y >> n with full precision, following the selected rounding direction.
|
|
|
+ */
|
|
|
+ function mulShr(uint256 x, uint256 y, uint8 n, Rounding rounding) internal pure returns (uint256) {
|
|
|
+ return mulShr(x, y, n) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, 1 << n) > 0);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
|
|
|
*
|
|
@@ -537,41 +609,45 @@ library Math {
|
|
|
* @dev Return the log in base 2 of a positive value rounded towards zero.
|
|
|
* Returns 0 if given 0.
|
|
|
*/
|
|
|
- function log2(uint256 value) internal pure returns (uint256) {
|
|
|
- uint256 result = 0;
|
|
|
- uint256 exp;
|
|
|
- unchecked {
|
|
|
- exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
|
|
|
- value >>= exp;
|
|
|
- result += exp;
|
|
|
-
|
|
|
- result += SafeCast.toUint(value > 1);
|
|
|
+ function log2(uint256 x) internal pure returns (uint256 r) {
|
|
|
+ // If value has upper 128 bits set, log2 result is at least 128
|
|
|
+ r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7;
|
|
|
+ // If upper 64 bits of 128-bit half set, add 64 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6;
|
|
|
+ // If upper 32 bits of 64-bit half set, add 32 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5;
|
|
|
+ // If upper 16 bits of 32-bit half set, add 16 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffff) << 4;
|
|
|
+ // If upper 8 bits of 16-bit half set, add 8 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xff) << 3;
|
|
|
+ // If upper 4 bits of 8-bit half set, add 4 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xf) << 2;
|
|
|
+
|
|
|
+ // Shifts value right by the current result and use it as an index into this lookup table:
|
|
|
+ //
|
|
|
+ // | x (4 bits) | index | table[index] = MSB position |
|
|
|
+ // |------------|---------|-----------------------------|
|
|
|
+ // | 0000 | 0 | table[0] = 0 |
|
|
|
+ // | 0001 | 1 | table[1] = 0 |
|
|
|
+ // | 0010 | 2 | table[2] = 1 |
|
|
|
+ // | 0011 | 3 | table[3] = 1 |
|
|
|
+ // | 0100 | 4 | table[4] = 2 |
|
|
|
+ // | 0101 | 5 | table[5] = 2 |
|
|
|
+ // | 0110 | 6 | table[6] = 2 |
|
|
|
+ // | 0111 | 7 | table[7] = 2 |
|
|
|
+ // | 1000 | 8 | table[8] = 3 |
|
|
|
+ // | 1001 | 9 | table[9] = 3 |
|
|
|
+ // | 1010 | 10 | table[10] = 3 |
|
|
|
+ // | 1011 | 11 | table[11] = 3 |
|
|
|
+ // | 1100 | 12 | table[12] = 3 |
|
|
|
+ // | 1101 | 13 | table[13] = 3 |
|
|
|
+ // | 1110 | 14 | table[14] = 3 |
|
|
|
+ // | 1111 | 15 | table[15] = 3 |
|
|
|
+ //
|
|
|
+ // The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes.
|
|
|
+ assembly ("memory-safe") {
|
|
|
+ r := or(r, byte(shr(r, x), 0x0000010102020202030303030303030300000000000000000000000000000000))
|
|
|
}
|
|
|
- return result;
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -640,29 +716,17 @@ library Math {
|
|
|
*
|
|
|
* Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string.
|
|
|
*/
|
|
|
- function log256(uint256 value) internal pure returns (uint256) {
|
|
|
- uint256 result = 0;
|
|
|
- uint256 isGt;
|
|
|
- unchecked {
|
|
|
- isGt = SafeCast.toUint(value > (1 << 128) - 1);
|
|
|
- value >>= isGt * 128;
|
|
|
- result += isGt * 16;
|
|
|
-
|
|
|
- isGt = SafeCast.toUint(value > (1 << 64) - 1);
|
|
|
- value >>= isGt * 64;
|
|
|
- result += isGt * 8;
|
|
|
-
|
|
|
- isGt = SafeCast.toUint(value > (1 << 32) - 1);
|
|
|
- value >>= isGt * 32;
|
|
|
- result += isGt * 4;
|
|
|
-
|
|
|
- isGt = SafeCast.toUint(value > (1 << 16) - 1);
|
|
|
- value >>= isGt * 16;
|
|
|
- result += isGt * 2;
|
|
|
-
|
|
|
- result += SafeCast.toUint(value > (1 << 8) - 1);
|
|
|
- }
|
|
|
- return result;
|
|
|
+ function log256(uint256 x) internal pure returns (uint256 r) {
|
|
|
+ // If value has upper 128 bits set, log2 result is at least 128
|
|
|
+ r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7;
|
|
|
+ // If upper 64 bits of 128-bit half set, add 64 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6;
|
|
|
+ // If upper 32 bits of 64-bit half set, add 32 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5;
|
|
|
+ // If upper 16 bits of 32-bit half set, add 16 to result
|
|
|
+ r |= SafeCast.toUint((x >> r) > 0xffff) << 4;
|
|
|
+ // Add 1 if upper 8 bits of 16-bit half set, and divide accumulated result by 8
|
|
|
+ return (r >> 3) | SafeCast.toUint((x >> r) > 0xff);
|
|
|
}
|
|
|
|
|
|
/**
|