Эх сурвалжийг харах

Release v5.3 cherrypick #2 (#5526)

Signed-off-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com>
Co-authored-by: Voronor <129545215+voronor@users.noreply.github.com>
Co-authored-by: StackOverflowExcept1on <109800286+StackOverflowExcept1on@users.noreply.github.com>
Co-authored-by: Michalis Kargakis <kargakis@protonmail.com>
Hadrien Croubois 7 сар өмнө
parent
commit
887697148d
34 өөрчлөгдсөн 1641 нэмэгдсэн , 262 устгасан
  1. 5 0
      .changeset/blue-nails-give.md
  2. 5 0
      .changeset/fair-pumpkins-compete.md
  3. 5 0
      .changeset/fast-coats-try.md
  4. 5 0
      .changeset/fuzzy-crews-poke.md
  5. 5 0
      .changeset/good-zebras-ring.md
  6. 5 0
      .changeset/nice-cherries-reply.md
  7. 5 0
      .changeset/ninety-rings-suffer.md
  8. 5 0
      .changeset/quiet-shrimps-kiss.md
  9. 2 2
      contracts/governance/Governor.sol
  10. 8 0
      contracts/governance/README.adoc
  11. 1 1
      contracts/governance/extensions/GovernorCountingFractional.sol
  12. 59 0
      contracts/governance/extensions/GovernorSuperQuorum.sol
  13. 16 13
      contracts/governance/extensions/GovernorVotesQuorumFraction.sol
  14. 132 0
      contracts/governance/extensions/GovernorVotesSuperQuorumFraction.sol
  15. 8 0
      contracts/mocks/MerkleTreeMock.sol
  16. 95 0
      contracts/mocks/governance/GovernorSuperQuorumMock.sol
  17. 37 0
      contracts/mocks/governance/GovernorVotesSuperQuorumFractionMock.sol
  18. 11 1
      contracts/proxy/utils/Initializable.sol
  19. 45 0
      contracts/utils/Strings.sol
  20. 15 0
      contracts/utils/cryptography/MessageHashUtils.sol
  21. 109 37
      contracts/utils/math/Math.sol
  22. 92 0
      contracts/utils/structs/MerkleTree.sol
  23. 15 0
      hardhat/common-contracts.js
  24. 5 1
      scripts/checks/coverage.sh
  25. 168 0
      test/governance/extensions/GovernorSuperQuorum.test.js
  26. 79 0
      test/governance/extensions/GovernorSuperQuorumGreaterThanQuorum.t.sol
  27. 160 0
      test/governance/extensions/GovernorVotesSuperQuorumFraction.test.js
  28. 7 5
      test/helpers/enums.js
  29. 7 0
      test/utils/Strings.test.js
  30. 33 0
      test/utils/cryptography/MessageHashUtils.t.sol
  31. 34 5
      test/utils/cryptography/MessageHashUtils.test.js
  32. 57 22
      test/utils/math/Math.t.sol
  33. 298 147
      test/utils/math/Math.test.js
  34. 108 28
      test/utils/structs/MerkleTree.test.js

+ 5 - 0
.changeset/blue-nails-give.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: Add `add512`, `mul512` and `mulShr`.

+ 5 - 0
.changeset/fair-pumpkins-compete.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.

+ 5 - 0
.changeset/fast-coats-try.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Initializable`: Add `_initializableStorageSlot` function that returns a pointer to the storage struct. The function allows customizing with a custom storage slot with an `override`.

+ 5 - 0
.changeset/fuzzy-crews-poke.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`GovernorSuperQuorum`: Add a governance extension to support a super quorum. Proposals that meet the super quorum (and have a majority of for votes) advance to the `Succeeded` state before the proposal deadline.

+ 5 - 0
.changeset/good-zebras-ring.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way.

+ 5 - 0
.changeset/nice-cherries-reply.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Strings`: Add `espaceJSON` that escapes special characters in JSON strings.

+ 5 - 0
.changeset/ninety-rings-suffer.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`GovernorVotesSuperQuorumFraction`: Add a variant of the `GovernorSuperQuorum` extensions where the super quorum is expressed as a fraction of the total supply.

+ 5 - 0
.changeset/quiet-shrimps-kiss.md

@@ -0,0 +1,5 @@
+---
+"openzeppelin-solidity": patch
+---
+
+`MessageHashUtils`: Add `toDataWithIntendedValidatorHash(address, bytes32)`.

+ 2 - 2
contracts/governance/Governor.sol

@@ -21,9 +21,9 @@ import {IGovernor, IERC6372} from "./IGovernor.sol";
  *
  * This contract is abstract and requires several functions to be implemented in various modules:
  *
- * - A counting module must implement {quorum}, {_quorumReached}, {_voteSucceeded} and {_countVote}
+ * - A counting module must implement {_quorumReached}, {_voteSucceeded} and {_countVote}
  * - A voting module must implement {_getVotes}
- * - Additionally, {votingPeriod} must also be implemented
+ * - Additionally, {votingPeriod}, {votingDelay}, and {quorum} must also be implemented
  */
 abstract contract Governor is Context, ERC165, EIP712, Nonces, IGovernor, IERC721Receiver, IERC1155Receiver {
     using DoubleEndedQueue for DoubleEndedQueue.Bytes32Deque;

+ 8 - 0
contracts/governance/README.adoc

@@ -24,6 +24,8 @@ Votes modules determine the source of voting power, and sometimes quorum number.
 
 * {GovernorVotesQuorumFraction}: Combines with `GovernorVotes` to set the quorum as a fraction of the total token supply.
 
+* {GovernorVotesSuperQuorumFraction}: Combines `GovernorSuperQuorum` with `GovernorVotesQuorumFraction` to set the super quorum as a fraction of the total token supply.
+
 Counting modules determine valid voting options.
 
 * {GovernorCountingSimple}: Simple voting mechanism with 3 voting options: Against, For and Abstain.
@@ -50,6 +52,8 @@ Other extensions can customize the behavior or interface in multiple ways.
 
 * {GovernorProposalGuardian}: Adds a proposal guardian that can cancel proposals at any stage in their lifecycle--this permission is passed on to the proposers if the guardian is not set.
 
+* {GovernorSuperQuorum}: Extension of {Governor} with a super quorum. Proposals that meet the super quorum (and have a majority of for votes) advance to the `Succeeded` state before the proposal deadline.
+
 In addition to modules and extensions, the core contract requires a few virtual functions to be implemented to your particular specifications:
 
 * <<Governor-votingDelay-,`votingDelay()`>>: Delay (in ERC-6372 clock) since the proposal is submitted until voting power is fixed and voting starts. This can be used to enforce a delay after a proposal is published for users to buy tokens, or delegate their votes.
@@ -76,6 +80,8 @@ NOTE: Functions of the `Governor` contract do not include access control. If you
 
 {{GovernorVotesQuorumFraction}}
 
+{{GovernorVotesSuperQuorumFraction}}
+
 === Extensions
 
 {{GovernorTimelockAccess}}
@@ -92,6 +98,8 @@ NOTE: Functions of the `Governor` contract do not include access control. If you
 
 {{GovernorProposalGuardian}}
 
+{{GovernorSuperQuorum}}
+
 == Utils
 
 {{Votes}}

+ 1 - 1
contracts/governance/extensions/GovernorCountingFractional.sol

@@ -27,7 +27,7 @@ import {Math} from "../../utils/math/Math.sol";
  * * Voting from an L2 with tokens held by a bridge
  * * Voting privately from a shielded pool using zero knowledge proofs.
  *
- * Based on ScopeLift's GovernorCountingFractional[https://github.com/ScopeLift/flexible-voting/blob/e5de2efd1368387b840931f19f3c184c85842761/src/GovernorCountingFractional.sol]
+ * Based on ScopeLift's https://github.com/ScopeLift/flexible-voting/blob/e5de2efd1368387b840931f19f3c184c85842761/src/GovernorCountingFractional.sol[`GovernorCountingFractional`]
  *
  * _Available since v5.1._
  */

+ 59 - 0
contracts/governance/extensions/GovernorSuperQuorum.sol

@@ -0,0 +1,59 @@
+// SPDX-License-Identifier: MIT
+pragma solidity ^0.8.20;
+
+import {Governor} from "../Governor.sol";
+import {SafeCast} from "../../utils/math/SafeCast.sol";
+import {Checkpoints} from "../../utils/structs/Checkpoints.sol";
+
+/**
+ * @dev Extension of {Governor} with a super quorum. Proposals that meet the super quorum (and have a majority of for
+ * votes) advance to the `Succeeded` state before the proposal deadline. Counting modules that want to use this
+ * extension must implement {proposalVotes}.
+ */
+abstract contract GovernorSuperQuorum is Governor {
+    /**
+     * @dev Minimum number of cast votes required for a proposal to reach super quorum. Only FOR votes are counted
+     * towards the super quorum. Once the super quorum is reached, an active proposal can proceed to the next state
+     * without waiting for the proposal deadline.
+     *
+     * NOTE: The `timepoint` parameter corresponds to the snapshot used for counting the vote. This enables scaling of the
+     * quorum depending on values such as the `totalSupply` of a token at this timepoint (see {ERC20Votes}).
+     *
+     * NOTE: Make sure the value specified for the super quorum is greater than {quorum}, otherwise, it may be
+     * possible to pass a proposal with less votes than the default quorum.
+     */
+    function superQuorum(uint256 timepoint) public view virtual returns (uint256);
+
+    /**
+     * @dev Accessor to the internal vote counts. This must be implemented by the counting module. Counting modules
+     * that don't implement this function are incompatible with this module
+     */
+    function proposalVotes(
+        uint256 proposalId
+    ) public view virtual returns (uint256 againstVotes, uint256 forVotes, uint256 abstainVotes);
+
+    /**
+     * @dev Overridden version of the {Governor-state} function that checks if the proposal has reached the super
+     * quorum.
+     *
+     * NOTE: If the proposal reaches super quorum but {_voteSucceeded} returns false, eg, assuming the super quorum
+     * has been set low enough that both FOR and AGAINST votes have exceeded it and AGAINST votes exceed FOR votes,
+     * the proposal continues to be active until {_voteSucceeded} returns true or the proposal deadline is reached.
+     * This means that with a low super quorum it is also possible that a vote can succeed prematurely before enough
+     * AGAINST voters have a chance to vote. Hence, it is recommended to set a high enough super quorum to avoid these
+     * types of scenarios.
+     */
+    function state(uint256 proposalId) public view virtual override returns (ProposalState) {
+        ProposalState currentState = super.state(proposalId);
+        if (currentState != ProposalState.Active) return currentState;
+
+        (, uint256 forVotes, ) = proposalVotes(proposalId);
+        if (forVotes < superQuorum(proposalSnapshot(proposalId)) || !_voteSucceeded(proposalId)) {
+            return ProposalState.Active;
+        } else if (proposalEta(proposalId) == 0) {
+            return ProposalState.Succeeded;
+        } else {
+            return ProposalState.Queued;
+        }
+    }
+}

+ 16 - 13
contracts/governance/extensions/GovernorVotesQuorumFraction.sol

@@ -4,6 +4,7 @@
 pragma solidity ^0.8.20;
 
 import {GovernorVotes} from "./GovernorVotes.sol";
+import {Math} from "../../utils/math/Math.sol";
 import {SafeCast} from "../../utils/math/SafeCast.sol";
 import {Checkpoints} from "../../utils/structs/Checkpoints.sol";
 
@@ -45,18 +46,7 @@ abstract contract GovernorVotesQuorumFraction is GovernorVotes {
      * @dev Returns the quorum numerator at a specific timepoint. See {quorumDenominator}.
      */
     function quorumNumerator(uint256 timepoint) public view virtual returns (uint256) {
-        uint256 length = _quorumNumeratorHistory._checkpoints.length;
-
-        // Optimistic search, check the latest checkpoint
-        Checkpoints.Checkpoint208 storage latest = _quorumNumeratorHistory._checkpoints[length - 1];
-        uint48 latestKey = latest._key;
-        uint208 latestValue = latest._value;
-        if (latestKey <= timepoint) {
-            return latestValue;
-        }
-
-        // Otherwise, do the binary search
-        return _quorumNumeratorHistory.upperLookupRecent(SafeCast.toUint48(timepoint));
+        return _optimisticUpperLookupRecent(_quorumNumeratorHistory, timepoint);
     }
 
     /**
@@ -70,7 +60,7 @@ abstract contract GovernorVotesQuorumFraction is GovernorVotes {
      * @dev Returns the quorum for a timepoint, in terms of number of votes: `supply * numerator / denominator`.
      */
     function quorum(uint256 timepoint) public view virtual override returns (uint256) {
-        return (token().getPastTotalSupply(timepoint) * quorumNumerator(timepoint)) / quorumDenominator();
+        return Math.mulDiv(token().getPastTotalSupply(timepoint), quorumNumerator(timepoint), quorumDenominator());
     }
 
     /**
@@ -107,4 +97,17 @@ abstract contract GovernorVotesQuorumFraction is GovernorVotes {
 
         emit QuorumNumeratorUpdated(oldQuorumNumerator, newQuorumNumerator);
     }
+
+    /**
+     * @dev Returns the numerator at a specific timepoint.
+     */
+    function _optimisticUpperLookupRecent(
+        Checkpoints.Trace208 storage ckpts,
+        uint256 timepoint
+    ) internal view returns (uint256) {
+        // If trace is empty, key and value are both equal to 0.
+        // In that case `key <= timepoint` is true, and it is ok to return 0.
+        (, uint48 key, uint208 value) = ckpts.latestCheckpoint();
+        return key <= timepoint ? value : ckpts.upperLookupRecent(SafeCast.toUint48(timepoint));
+    }
 }

+ 132 - 0
contracts/governance/extensions/GovernorVotesSuperQuorumFraction.sol

@@ -0,0 +1,132 @@
+// SPDX-License-Identifier: MIT
+pragma solidity ^0.8.20;
+
+import {Governor} from "../Governor.sol";
+import {GovernorSuperQuorum} from "./GovernorSuperQuorum.sol";
+import {GovernorVotesQuorumFraction} from "./GovernorVotesQuorumFraction.sol";
+import {Math} from "../../utils/math/Math.sol";
+import {SafeCast} from "../../utils/math/SafeCast.sol";
+import {Checkpoints} from "../../utils/structs/Checkpoints.sol";
+
+/**
+ * @dev Extension of {GovernorVotesQuorumFraction} with a super quorum expressed as a
+ * fraction of the total supply. Proposals that meet the super quorum (and have a majority of for votes) advance to
+ * the `Succeeded` state before the proposal deadline.
+ */
+abstract contract GovernorVotesSuperQuorumFraction is GovernorVotesQuorumFraction, GovernorSuperQuorum {
+    using Checkpoints for Checkpoints.Trace208;
+
+    Checkpoints.Trace208 private _superQuorumNumeratorHistory;
+
+    event SuperQuorumNumeratorUpdated(uint256 oldSuperQuorumNumerator, uint256 newSuperQuorumNumerator);
+
+    /**
+     * @dev The super quorum set is not valid as it exceeds the quorum denominator.
+     */
+    error GovernorInvalidSuperQuorumFraction(uint256 superQuorumNumerator, uint256 denominator);
+
+    /**
+     * @dev The super quorum set is not valid as it is smaller or equal to the quorum.
+     */
+    error GovernorInvalidSuperQuorumTooSmall(uint256 superQuorumNumerator, uint256 quorumNumerator);
+
+    /**
+     * @dev The quorum set is not valid as it exceeds the super quorum.
+     */
+    error GovernorInvalidQuorumTooLarge(uint256 quorumNumerator, uint256 superQuorumNumerator);
+
+    /**
+     * @dev Initialize super quorum as a fraction of the token's total supply.
+     *
+     * The super quorum is specified as a fraction of the token's total supply and has to
+     * be greater than the quorum.
+     */
+    constructor(uint256 superQuorumNumeratorValue) {
+        _updateSuperQuorumNumerator(superQuorumNumeratorValue);
+    }
+
+    /**
+     * @dev Returns the current super quorum numerator.
+     */
+    function superQuorumNumerator() public view virtual returns (uint256) {
+        return _superQuorumNumeratorHistory.latest();
+    }
+
+    /**
+     * @dev Returns the super quorum numerator at a specific `timepoint`.
+     */
+    function superQuorumNumerator(uint256 timepoint) public view virtual returns (uint256) {
+        return _optimisticUpperLookupRecent(_superQuorumNumeratorHistory, timepoint);
+    }
+
+    /**
+     * @dev Returns the super quorum for a `timepoint`, in terms of number of votes: `supply * numerator / denominator`.
+     */
+    function superQuorum(uint256 timepoint) public view virtual override returns (uint256) {
+        return Math.mulDiv(token().getPastTotalSupply(timepoint), superQuorumNumerator(timepoint), quorumDenominator());
+    }
+
+    /**
+     * @dev Changes the super quorum numerator.
+     *
+     * Emits a {SuperQuorumNumeratorUpdated} event.
+     *
+     * Requirements:
+     *
+     * - Must be called through a governance proposal.
+     * - New super quorum numerator must be smaller or equal to the denominator.
+     * - New super quorum numerator must be greater than or equal to the quorum numerator.
+     */
+    function updateSuperQuorumNumerator(uint256 newSuperQuorumNumerator) public virtual onlyGovernance {
+        _updateSuperQuorumNumerator(newSuperQuorumNumerator);
+    }
+
+    /**
+     * @dev Changes the super quorum numerator.
+     *
+     * Emits a {SuperQuorumNumeratorUpdated} event.
+     *
+     * Requirements:
+     *
+     * - New super quorum numerator must be smaller or equal to the denominator.
+     * - New super quorum numerator must be greater than or equal to the quorum numerator.
+     */
+    function _updateSuperQuorumNumerator(uint256 newSuperQuorumNumerator) internal virtual {
+        uint256 denominator = quorumDenominator();
+        if (newSuperQuorumNumerator > denominator) {
+            revert GovernorInvalidSuperQuorumFraction(newSuperQuorumNumerator, denominator);
+        }
+
+        uint256 quorumNumerator = quorumNumerator();
+        if (newSuperQuorumNumerator < quorumNumerator) {
+            revert GovernorInvalidSuperQuorumTooSmall(newSuperQuorumNumerator, quorumNumerator);
+        }
+
+        uint256 oldSuperQuorumNumerator = _superQuorumNumeratorHistory.latest();
+        _superQuorumNumeratorHistory.push(clock(), SafeCast.toUint208(newSuperQuorumNumerator));
+
+        emit SuperQuorumNumeratorUpdated(oldSuperQuorumNumerator, newSuperQuorumNumerator);
+    }
+
+    /**
+     * @dev Overrides {GovernorVotesQuorumFraction-_updateQuorumNumerator} to ensure the super
+     * quorum numerator is greater than or equal to the quorum numerator.
+     */
+    function _updateQuorumNumerator(uint256 newQuorumNumerator) internal virtual override {
+        // Ignoring check when the superQuorum was never set (construction sets quorum before superQuorum)
+        if (_superQuorumNumeratorHistory.length() > 0) {
+            uint256 superQuorumNumerator_ = superQuorumNumerator();
+            if (newQuorumNumerator > superQuorumNumerator_) {
+                revert GovernorInvalidQuorumTooLarge(newQuorumNumerator, superQuorumNumerator_);
+            }
+        }
+        super._updateQuorumNumerator(newQuorumNumerator);
+    }
+
+    /// @inheritdoc GovernorSuperQuorum
+    function state(
+        uint256 proposalId
+    ) public view virtual override(Governor, GovernorSuperQuorum) returns (ProposalState) {
+        return super.state(proposalId);
+    }
+}

+ 8 - 0
contracts/mocks/MerkleTreeMock.sol

@@ -14,6 +14,7 @@ contract MerkleTreeMock {
     bytes32 public root;
 
     event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
+    event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root);
 
     function setup(uint8 _depth, bytes32 _zero) public {
         root = _tree.setup(_depth, _zero);
@@ -25,6 +26,13 @@ contract MerkleTreeMock {
         root = currentRoot;
     }
 
+    function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public {
+        (bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof);
+        if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof();
+        emit LeafUpdated(oldValue, newValue, index, newRoot);
+        root = newRoot;
+    }
+
     function depth() public view returns (uint256) {
         return _tree.depth();
     }

+ 95 - 0
contracts/mocks/governance/GovernorSuperQuorumMock.sol

@@ -0,0 +1,95 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Governor} from "../../governance/Governor.sol";
+import {GovernorSettings} from "../../governance/extensions/GovernorSettings.sol";
+import {GovernorVotes} from "../../governance/extensions/GovernorVotes.sol";
+import {GovernorSuperQuorum} from "../../governance/extensions/GovernorSuperQuorum.sol";
+import {GovernorCountingSimple} from "../../governance/extensions/GovernorCountingSimple.sol";
+import {GovernorTimelockControl} from "../../governance/extensions/GovernorTimelockControl.sol";
+
+abstract contract GovernorSuperQuorumMock is
+    GovernorSettings,
+    GovernorVotes,
+    GovernorTimelockControl,
+    GovernorSuperQuorum,
+    GovernorCountingSimple
+{
+    uint256 private _quorum;
+    uint256 private _superQuorum;
+
+    constructor(uint256 quorum_, uint256 superQuorum_) {
+        _quorum = quorum_;
+        _superQuorum = superQuorum_;
+    }
+
+    function quorum(uint256) public view override returns (uint256) {
+        return _quorum;
+    }
+
+    function superQuorum(uint256) public view override returns (uint256) {
+        return _superQuorum;
+    }
+
+    function state(
+        uint256 proposalId
+    ) public view override(Governor, GovernorSuperQuorum, GovernorTimelockControl) returns (ProposalState) {
+        return super.state(proposalId);
+    }
+
+    function proposalThreshold() public view override(Governor, GovernorSettings) returns (uint256) {
+        return super.proposalThreshold();
+    }
+
+    function proposalVotes(
+        uint256 proposalId
+    )
+        public
+        view
+        virtual
+        override(GovernorCountingSimple, GovernorSuperQuorum)
+        returns (uint256 againstVotes, uint256 forVotes, uint256 abstainVotes)
+    {
+        return super.proposalVotes(proposalId);
+    }
+
+    function _cancel(
+        address[] memory targets,
+        uint256[] memory values,
+        bytes[] memory calldatas,
+        bytes32 descriptionHash
+    ) internal override(Governor, GovernorTimelockControl) returns (uint256) {
+        return super._cancel(targets, values, calldatas, descriptionHash);
+    }
+
+    function _executeOperations(
+        uint256 proposalId,
+        address[] memory targets,
+        uint256[] memory values,
+        bytes[] memory calldatas,
+        bytes32 descriptionHash
+    ) internal override(Governor, GovernorTimelockControl) {
+        super._executeOperations(proposalId, targets, values, calldatas, descriptionHash);
+    }
+
+    function _executor() internal view override(Governor, GovernorTimelockControl) returns (address) {
+        return super._executor();
+    }
+
+    function _queueOperations(
+        uint256 proposalId,
+        address[] memory targets,
+        uint256[] memory values,
+        bytes[] memory calldatas,
+        bytes32 descriptionHash
+    ) internal override(Governor, GovernorTimelockControl) returns (uint48) {
+        return super._queueOperations(proposalId, targets, values, calldatas, descriptionHash);
+    }
+
+    function proposalNeedsQueuing(
+        uint256 proposalId
+    ) public view override(Governor, GovernorTimelockControl) returns (bool) {
+        return super.proposalNeedsQueuing(proposalId);
+    }
+}

+ 37 - 0
contracts/mocks/governance/GovernorVotesSuperQuorumFractionMock.sol

@@ -0,0 +1,37 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Governor} from "../../governance/Governor.sol";
+import {GovernorSettings} from "../../governance/extensions/GovernorSettings.sol";
+import {GovernorSuperQuorum} from "../../governance/extensions/GovernorSuperQuorum.sol";
+import {GovernorCountingSimple} from "../../governance/extensions/GovernorCountingSimple.sol";
+import {GovernorVotesSuperQuorumFraction} from "../../governance/extensions/GovernorVotesSuperQuorumFraction.sol";
+
+abstract contract GovernorVotesSuperQuorumFractionMock is
+    GovernorSettings,
+    GovernorVotesSuperQuorumFraction,
+    GovernorCountingSimple
+{
+    function proposalThreshold() public view override(Governor, GovernorSettings) returns (uint256) {
+        return super.proposalThreshold();
+    }
+
+    function proposalVotes(
+        uint256 proposalId
+    )
+        public
+        view
+        virtual
+        override(GovernorCountingSimple, GovernorSuperQuorum)
+        returns (uint256 againstVotes, uint256 forVotes, uint256 abstainVotes)
+    {
+        return super.proposalVotes(proposalId);
+    }
+
+    function state(
+        uint256 proposalId
+    ) public view override(Governor, GovernorVotesSuperQuorumFraction) returns (ProposalState) {
+        return super.state(proposalId);
+    }
+}

+ 11 - 1
contracts/proxy/utils/Initializable.sol

@@ -216,13 +216,23 @@ abstract contract Initializable {
         return _getInitializableStorage()._initializing;
     }
 
+    /**
+     * @dev Pointer to storage slot. Allows integrators to override it with a custom storage location.
+     *
+     * NOTE: Consider following the ERC-7201 formula to derive storage locations.
+     */
+    function _initializableStorageSlot() internal pure virtual returns (bytes32) {
+        return INITIALIZABLE_STORAGE;
+    }
+
     /**
      * @dev Returns a pointer to the storage namespace.
      */
     // solhint-disable-next-line var-name-mixedcase
     function _getInitializableStorage() private pure returns (InitializableStorage storage $) {
+        bytes32 slot = _initializableStorageSlot();
         assembly {
-            $.slot := INITIALIZABLE_STORAGE
+            $.slot := slot
         }
     }
 }

+ 45 - 0
contracts/utils/Strings.sol

@@ -15,6 +15,14 @@ library Strings {
 
     bytes16 private constant HEX_DIGITS = "0123456789abcdef";
     uint8 private constant ADDRESS_LENGTH = 20;
+    uint256 private constant SPECIAL_CHARS_LOOKUP =
+        (1 << 0x08) | // backspace
+            (1 << 0x09) | // tab
+            (1 << 0x0a) | // newline
+            (1 << 0x0c) | // form feed
+            (1 << 0x0d) | // carriage return
+            (1 << 0x22) | // double quote
+            (1 << 0x5c); // backslash
 
     /**
      * @dev The `value` string doesn't fit in the specified `length`.
@@ -426,6 +434,43 @@ library Strings {
         return value;
     }
 
+    /**
+     * @dev Escape special characters in JSON strings. This can be useful to prevent JSON injection in NFT metadata.
+     *
+     * WARNING: This function should only be used in double quoted JSON strings. Single quotes are not escaped.
+     */
+    function escapeJSON(string memory input) internal pure returns (string memory) {
+        bytes memory buffer = bytes(input);
+        bytes memory output = new bytes(2 * buffer.length); // worst case scenario
+        uint256 outputLength = 0;
+
+        for (uint256 i; i < buffer.length; ++i) {
+            bytes1 char = bytes1(_unsafeReadBytesOffset(buffer, i));
+            if (((SPECIAL_CHARS_LOOKUP & (1 << uint8(char))) != 0)) {
+                output[outputLength++] = "\\";
+                if (char == 0x08) output[outputLength++] = "b";
+                else if (char == 0x09) output[outputLength++] = "t";
+                else if (char == 0x0a) output[outputLength++] = "n";
+                else if (char == 0x0c) output[outputLength++] = "f";
+                else if (char == 0x0d) output[outputLength++] = "r";
+                else if (char == 0x5c) output[outputLength++] = "\\";
+                else if (char == 0x22) {
+                    // solhint-disable-next-line quotes
+                    output[outputLength++] = '"';
+                }
+            } else {
+                output[outputLength++] = char;
+            }
+        }
+        // write the actual length and deallocate unused memory
+        assembly ("memory-safe") {
+            mstore(output, outputLength)
+            mstore(0x40, add(output, shl(5, shr(5, add(outputLength, 63)))))
+        }
+
+        return string(output);
+    }
+
     /**
      * @dev Reads a bytes32 from a bytes array without bounds checking.
      *

+ 15 - 0
contracts/utils/cryptography/MessageHashUtils.sol

@@ -63,6 +63,21 @@ library MessageHashUtils {
         return keccak256(abi.encodePacked(hex"19_00", validator, data));
     }
 
+    /**
+     * @dev Variant of {toDataWithIntendedValidatorHash-address-bytes} optimized for cases where `data` is a bytes32.
+     */
+    function toDataWithIntendedValidatorHash(
+        address validator,
+        bytes32 messageHash
+    ) internal pure returns (bytes32 digest) {
+        assembly ("memory-safe") {
+            mstore(0x00, hex"19_00")
+            mstore(0x02, shl(96, validator))
+            mstore(0x16, messageHash)
+            digest := keccak256(0x00, 0x36)
+        }
+    }
+
     /**
      * @dev Returns the keccak256 digest of an EIP-712 typed data (ERC-191 version `0x01`).
      *

+ 109 - 37
contracts/utils/math/Math.sol

@@ -17,14 +17,42 @@ library Math {
         Expand // Away from zero
     }
 
+    /**
+     * @dev Return the 512-bit addition of two uint256.
+     *
+     * The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
+     */
+    function add512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
+        assembly ("memory-safe") {
+            low := add(a, b)
+            high := lt(low, a)
+        }
+    }
+
+    /**
+     * @dev Return the 512-bit multiplication of two uint256.
+     *
+     * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
+     */
+    function mul512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
+        // 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
+        // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
+        // variables such that product = high * 2²⁵⁶ + low.
+        assembly ("memory-safe") {
+            let mm := mulmod(a, b, not(0))
+            low := mul(a, b)
+            high := sub(sub(mm, low), lt(mm, low))
+        }
+    }
+
     /**
      * @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
      */
     function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
             uint256 c = a + b;
-            if (c < a) return (false, 0);
-            return (true, c);
+            success = c >= a;
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -33,8 +61,9 @@ library Math {
      */
     function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b > a) return (false, 0);
-            return (true, a - b);
+            uint256 c = a - b;
+            success = c <= a;
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -43,13 +72,14 @@ library Math {
      */
     function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            // Gas optimization: this is cheaper than requiring 'a' not being zero, but the
-            // benefit is lost if 'b' is also tested.
-            // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
-            if (a == 0) return (true, 0);
             uint256 c = a * b;
-            if (c / a != b) return (false, 0);
-            return (true, c);
+            assembly ("memory-safe") {
+                // Only true when the multiplication doesn't overflow
+                // (c / a == b) || (a == 0)
+                success := or(eq(div(c, a), b), iszero(a))
+            }
+            // equivalent to: success ? c : 0
+            result = c * SafeCast.toUint(success);
         }
     }
 
@@ -58,8 +88,11 @@ library Math {
      */
     function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b == 0) return (false, 0);
-            return (true, a / b);
+            success = b > 0;
+            assembly ("memory-safe") {
+                // The `DIV` opcode returns zero when the denominator is 0.
+                result := div(a, b)
+            }
         }
     }
 
@@ -68,11 +101,38 @@ library Math {
      */
     function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
         unchecked {
-            if (b == 0) return (false, 0);
-            return (true, a % b);
+            success = b > 0;
+            assembly ("memory-safe") {
+                // The `MOD` opcode returns zero when the denominator is 0.
+                result := mod(a, b)
+            }
         }
     }
 
+    /**
+     * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
+     */
+    function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
+        (bool success, uint256 result) = tryAdd(a, b);
+        return ternary(success, result, type(uint256).max);
+    }
+
+    /**
+     * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
+     */
+    function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
+        (, uint256 result) = trySub(a, b);
+        return result;
+    }
+
+    /**
+     * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
+     */
+    function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
+        (bool success, uint256 result) = tryMul(a, b);
+        return ternary(success, result, type(uint256).max);
+    }
+
     /**
      * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
      *
@@ -143,26 +203,18 @@ library Math {
      */
     function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
         unchecked {
-            // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
-            // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
-            // variables such that product = prod1 * 2²⁵⁶ + prod0.
-            uint256 prod0 = x * y; // Least significant 256 bits of the product
-            uint256 prod1; // Most significant 256 bits of the product
-            assembly {
-                let mm := mulmod(x, y, not(0))
-                prod1 := sub(sub(mm, prod0), lt(mm, prod0))
-            }
+            (uint256 high, uint256 low) = mul512(x, y);
 
             // Handle non-overflow cases, 256 by 256 division.
-            if (prod1 == 0) {
+            if (high == 0) {
                 // Solidity will revert if denominator == 0, unlike the div opcode on its own.
                 // The surrounding unchecked block does not change this fact.
                 // See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
-                return prod0 / denominator;
+                return low / denominator;
             }
 
             // Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
-            if (denominator <= prod1) {
+            if (denominator <= high) {
                 Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
             }
 
@@ -170,34 +222,34 @@ library Math {
             // 512 by 256 division.
             ///////////////////////////////////////////////
 
-            // Make division exact by subtracting the remainder from [prod1 prod0].
+            // Make division exact by subtracting the remainder from [high low].
             uint256 remainder;
-            assembly {
+            assembly ("memory-safe") {
                 // Compute remainder using mulmod.
                 remainder := mulmod(x, y, denominator)
 
                 // Subtract 256 bit number from 512 bit number.
-                prod1 := sub(prod1, gt(remainder, prod0))
-                prod0 := sub(prod0, remainder)
+                high := sub(high, gt(remainder, low))
+                low := sub(low, remainder)
             }
 
             // Factor powers of two out of denominator and compute largest power of two divisor of denominator.
             // Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
 
             uint256 twos = denominator & (0 - denominator);
-            assembly {
+            assembly ("memory-safe") {
                 // Divide denominator by twos.
                 denominator := div(denominator, twos)
 
-                // Divide [prod1 prod0] by twos.
-                prod0 := div(prod0, twos)
+                // Divide [high low] by twos.
+                low := div(low, twos)
 
                 // Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
                 twos := add(div(sub(0, twos), twos), 1)
             }
 
-            // Shift in bits from prod1 into prod0.
-            prod0 |= prod1 * twos;
+            // Shift in bits from high into low.
+            low |= high * twos;
 
             // Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
             // that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
@@ -215,9 +267,9 @@ library Math {
 
             // Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
             // This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
-            // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1
+            // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
             // is no longer required.
-            result = prod0 * inverse;
+            result = low * inverse;
             return result;
         }
     }
@@ -229,6 +281,26 @@ library Math {
         return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
     }
 
+    /**
+     * @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
+     */
+    function mulShr(uint256 x, uint256 y, uint8 n) internal pure returns (uint256 result) {
+        unchecked {
+            (uint256 high, uint256 low) = mul512(x, y);
+            if (high >= 1 << n) {
+                Panic.panic(Panic.UNDER_OVERFLOW);
+            }
+            return (high << (256 - n)) | (low >> n);
+        }
+    }
+
+    /**
+     * @dev Calculates x * y >> n with full precision, following the selected rounding direction.
+     */
+    function mulShr(uint256 x, uint256 y, uint8 n, Rounding rounding) internal pure returns (uint256) {
+        return mulShr(x, y, n) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, 1 << n) > 0);
+    }
+
     /**
      * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
      *

+ 92 - 0
contracts/utils/structs/MerkleTree.sol

@@ -6,6 +6,7 @@ pragma solidity ^0.8.20;
 import {Hashes} from "../cryptography/Hashes.sol";
 import {Arrays} from "../Arrays.sol";
 import {Panic} from "../Panic.sol";
+import {StorageSlot} from "../StorageSlot.sol";
 
 /**
  * @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
@@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol";
  * _Available since v5.1._
  */
 library MerkleTree {
+    /// @dev Error emitted when trying to update a leaf that was not previously pushed.
+    error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length);
+
+    /// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side).
+    error MerkleTreeUpdateInvalidProof();
+
     /**
      * @dev A complete `bytes32` Merkle tree.
      *
@@ -166,6 +173,91 @@ library MerkleTree {
         return (index, currentLevelHash);
     }
 
+    /**
+     * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
+     * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
+     * root is the last known one.
+     *
+     * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
+     * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
+     * all "in flight" updates invalid.
+     *
+     * This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees
+     * that were setup using the same (default) hashing function (i.e. by calling
+     * {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function).
+     */
+    function update(
+        Bytes32PushTree storage self,
+        uint256 index,
+        bytes32 oldValue,
+        bytes32 newValue,
+        bytes32[] memory proof
+    ) internal returns (bytes32 oldRoot, bytes32 newRoot) {
+        return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256);
+    }
+
+    /**
+     * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
+     * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
+     * root is the last known one.
+     *
+     * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
+     * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
+     * all "in flight" updates invalid.
+     *
+     * This variant uses a custom hashing function to hash internal nodes. It should only be called with the same
+     * function as the one used during the initial setup of the merkle tree.
+     */
+    function update(
+        Bytes32PushTree storage self,
+        uint256 index,
+        bytes32 oldValue,
+        bytes32 newValue,
+        bytes32[] memory proof,
+        function(bytes32, bytes32) view returns (bytes32) fnHash
+    ) internal returns (bytes32 oldRoot, bytes32 newRoot) {
+        unchecked {
+            // Check index range
+            uint256 length = self._nextLeafIndex;
+            if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length);
+
+            // Cache read
+            uint256 treeDepth = depth(self);
+
+            // Workaround stack too deep
+            bytes32[] storage sides = self._sides;
+
+            // This cannot overflow because: 0 <= index < length
+            uint256 lastIndex = length - 1;
+            uint256 currentIndex = index;
+            bytes32 currentLevelHashOld = oldValue;
+            bytes32 currentLevelHashNew = newValue;
+            for (uint32 i = 0; i < treeDepth; i++) {
+                bool isLeft = currentIndex % 2 == 0;
+
+                lastIndex >>= 1;
+                currentIndex >>= 1;
+
+                if (isLeft && currentIndex == lastIndex) {
+                    StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i);
+                    if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof();
+                    side.value = currentLevelHashNew;
+                }
+
+                bytes32 sibling = proof[i];
+                currentLevelHashOld = fnHash(
+                    isLeft ? currentLevelHashOld : sibling,
+                    isLeft ? sibling : currentLevelHashOld
+                );
+                currentLevelHashNew = fnHash(
+                    isLeft ? currentLevelHashNew : sibling,
+                    isLeft ? sibling : currentLevelHashNew
+                );
+            }
+            return (currentLevelHashOld, currentLevelHashNew);
+        }
+    }
+
     /**
      * @dev Tree's depth (set at initialization)
      */

+ 15 - 0
hardhat/common-contracts.js

@@ -6,6 +6,7 @@ const fs = require('fs');
 const path = require('path');
 
 const INSTANCES = {
+  // Entrypoint v0.7.0
   entrypoint: {
     address: '0x0000000071727De22E5E9d8BAf0edAc6f37da032',
     abi: JSON.parse(fs.readFileSync(path.resolve(__dirname, '../test/bin/EntryPoint070.abi'), 'utf-8')),
@@ -16,6 +17,20 @@ const INSTANCES = {
     abi: JSON.parse(fs.readFileSync(path.resolve(__dirname, '../test/bin/SenderCreator070.abi'), 'utf-8')),
     bytecode: fs.readFileSync(path.resolve(__dirname, '../test/bin/SenderCreator070.bytecode'), 'hex'),
   },
+  // Arachnid's deterministic deployment proxy
+  // See: https://github.com/Arachnid/deterministic-deployment-proxy/tree/master
+  arachnidDeployer: {
+    address: '0x4e59b44847b379578588920cA78FbF26c0B4956C',
+    abi: [],
+    bytecode:
+      '0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe03601600081602082378035828234f58015156039578182fd5b8082525050506014600cf3',
+  },
+  // Micah's deployer
+  micahDeployer: {
+    address: '0x7A0D94F55792C434d74a40883C6ed8545E406D12',
+    abi: [],
+    bytecode: '0x60003681823780368234f58015156014578182fd5b80825250506014600cf3',
+  },
 };
 
 task(TASK_TEST_SETUP_TEST_ENVIRONMENT).setAction((_, env, runSuper) =>

+ 5 - 1
scripts/checks/coverage.sh

@@ -14,7 +14,11 @@ if [ "${CI:-"false"}" == "true" ]; then
   # Foundry coverage
   forge coverage --report lcov --ir-minimum
   # Remove zero hits
-  sed -i '/,0/d' lcov.info
+  if [[ "$OSTYPE" == "darwin"* ]]; then
+    sed -i '' '/,0/d' lcov.info
+  else
+    sed -i '/,0/d' lcov.info
+  fi
 fi
 
 # Reports are then uploaded to Codecov automatically by workflow, and merged.

+ 168 - 0
test/governance/extensions/GovernorSuperQuorum.test.js

@@ -0,0 +1,168 @@
+const { ethers } = require('hardhat');
+const { expect } = require('chai');
+const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
+
+const { GovernorHelper } = require('../../helpers/governance');
+const { ProposalState, VoteType } = require('../../helpers/enums');
+const time = require('../../helpers/time');
+
+const TOKENS = [
+  { Token: '$ERC20Votes', mode: 'blocknumber' },
+  { Token: '$ERC20VotesTimestampMock', mode: 'timestamp' },
+];
+
+const DEFAULT_ADMIN_ROLE = ethers.ZeroHash;
+const PROPOSER_ROLE = ethers.id('PROPOSER_ROLE');
+const EXECUTOR_ROLE = ethers.id('EXECUTOR_ROLE');
+const CANCELLER_ROLE = ethers.id('CANCELLER_ROLE');
+
+const name = 'OZ-Governor';
+const version = '1';
+const tokenName = 'MockToken';
+const tokenSymbol = 'MTKN';
+const tokenSupply = ethers.parseEther('100');
+const votingDelay = 4n;
+const votingPeriod = 16n;
+const quorum = 10n;
+const superQuorum = 40n;
+const value = ethers.parseEther('1');
+const delay = time.duration.hours(1n);
+
+describe('GovernorSuperQuorum', function () {
+  for (const { Token, mode } of TOKENS) {
+    const fixture = async () => {
+      const [proposer, voter1, voter2, voter3, voter4, voter5] = await ethers.getSigners();
+      const receiver = await ethers.deployContract('CallReceiverMock');
+
+      const timelock = await ethers.deployContract('TimelockController', [delay, [], [], proposer]);
+      const token = await ethers.deployContract(Token, [tokenName, tokenSymbol, tokenName, version]);
+      const mock = await ethers.deployContract('$GovernorSuperQuorumMock', [
+        name,
+        votingDelay, // initialVotingDelay
+        votingPeriod, // initialVotingPeriod
+        0n, // initialProposalThreshold
+        token,
+        timelock,
+        quorum,
+        superQuorum,
+      ]);
+
+      await proposer.sendTransaction({ to: timelock, value });
+      await token.$_mint(proposer, tokenSupply);
+      await timelock.grantRole(PROPOSER_ROLE, mock);
+      await timelock.grantRole(PROPOSER_ROLE, proposer);
+      await timelock.grantRole(CANCELLER_ROLE, mock);
+      await timelock.grantRole(CANCELLER_ROLE, proposer);
+      await timelock.grantRole(EXECUTOR_ROLE, ethers.ZeroAddress);
+      await timelock.revokeRole(DEFAULT_ADMIN_ROLE, proposer);
+
+      const helper = new GovernorHelper(mock, mode);
+      await helper.connect(proposer).delegate({ token, to: voter1, value: 40 });
+      await helper.connect(proposer).delegate({ token, to: voter2, value: 30 });
+      await helper.connect(proposer).delegate({ token, to: voter3, value: 20 });
+      await helper.connect(proposer).delegate({ token, to: voter4, value: 15 });
+      await helper.connect(proposer).delegate({ token, to: voter5, value: 5 });
+
+      return { proposer, voter1, voter2, voter3, voter4, voter5, receiver, token, mock, timelock, helper };
+    };
+
+    describe(`using ${Token}`, function () {
+      beforeEach(async function () {
+        Object.assign(this, await loadFixture(fixture));
+
+        // default proposal
+        this.proposal = this.helper.setProposal(
+          [
+            {
+              target: this.receiver.target,
+              value,
+              data: this.receiver.interface.encodeFunctionData('mockFunction'),
+            },
+          ],
+          '<proposal description>',
+        );
+      });
+
+      it('deployment check', async function () {
+        await expect(this.mock.name()).to.eventually.equal(name);
+        await expect(this.mock.token()).to.eventually.equal(this.token);
+        await expect(this.mock.quorum(0)).to.eventually.equal(quorum);
+        await expect(this.mock.superQuorum(0)).to.eventually.equal(superQuorum);
+      });
+
+      it('proposal succeeds early when super quorum is reached', async function () {
+        await this.helper.connect(this.proposer).propose();
+        await this.helper.waitForSnapshot();
+
+        // Vote with voter2 (30) - above quorum (10) but below super quorum (40)
+        await this.helper.connect(this.voter2).vote({ support: VoteType.For });
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Active);
+
+        // Vote with voter3 (20) to reach super quorum (50 total > 40)
+        await this.helper.connect(this.voter3).vote({ support: VoteType.For });
+
+        await expect(this.mock.proposalEta(this.proposal.id)).to.eventually.equal(0);
+
+        // Should be succeeded since we reached super quorum and no eta is set
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Succeeded);
+      });
+
+      it('proposal remains active if super quorum is not reached', async function () {
+        await this.helper.connect(this.proposer).propose();
+        await this.helper.waitForSnapshot();
+
+        // Vote with voter4 (15) - below super quorum (40) but above quorum (10)
+        await this.helper.connect(this.voter4).vote({ support: VoteType.For });
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Active);
+
+        // Vote with voter5 (5) - still below super quorum (total 20 < 40)
+        await this.helper.connect(this.voter5).vote({ support: VoteType.For });
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Active);
+
+        // Wait for deadline
+        await this.helper.waitForDeadline(1n);
+
+        // Should succeed since deadline passed and we have enough support (20 > 10 quorum)
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Succeeded);
+      });
+
+      it('proposal remains active if super quorum is reached but vote fails', async function () {
+        await this.helper.connect(this.proposer).propose();
+        await this.helper.waitForSnapshot();
+
+        // Vote against with voter2 and voter3 (50)
+        await this.helper.connect(this.voter2).vote({ support: VoteType.Against });
+        await this.helper.connect(this.voter3).vote({ support: VoteType.Against });
+
+        // Vote for with voter1 (40) (reaching super quorum)
+        await this.helper.connect(this.voter1).vote({ support: VoteType.For });
+
+        // should be active since super quorum is reached but vote fails
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Active);
+
+        // wait for deadline
+        await this.helper.waitForDeadline(1n);
+
+        // should be defeated since against votes are higher
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Defeated);
+      });
+
+      it('proposal is queued if super quorum is reached and eta is set', async function () {
+        await this.helper.connect(this.proposer).propose();
+
+        await this.helper.waitForSnapshot();
+
+        // Vote with voter1 (40) - reaching super quorum
+        await this.helper.connect(this.voter1).vote({ support: VoteType.For });
+
+        await this.helper.queue();
+
+        // Queueing should set eta
+        await expect(this.mock.proposalEta(this.proposal.id)).to.eventually.not.equal(0);
+
+        // Should be queued since we reached super quorum and eta is set
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Queued);
+      });
+    });
+  }
+});

+ 79 - 0
test/governance/extensions/GovernorSuperQuorumGreaterThanQuorum.t.sol

@@ -0,0 +1,79 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Test} from "forge-std/Test.sol";
+import {GovernorVotesSuperQuorumFractionMock} from "../../../contracts/mocks/governance/GovernorVotesSuperQuorumFractionMock.sol";
+import {GovernorVotesQuorumFraction} from "../../../contracts/governance/extensions/GovernorVotesQuorumFraction.sol";
+import {GovernorVotesSuperQuorumFraction} from "../../../contracts/governance/extensions/GovernorVotesSuperQuorumFraction.sol";
+import {GovernorSettings} from "../../../contracts/governance/extensions/GovernorSettings.sol";
+import {GovernorVotes} from "../../../contracts/governance/extensions/GovernorVotes.sol";
+import {Governor} from "../../../contracts/governance/Governor.sol";
+import {IVotes} from "../../../contracts/governance/utils/IVotes.sol";
+import {ERC20VotesExtendedTimestampMock} from "../../../contracts/mocks/token/ERC20VotesAdditionalCheckpointsMock.sol";
+import {EIP712} from "../../../contracts/utils/cryptography/EIP712.sol";
+import {ERC20} from "../../../contracts/token/ERC20/ERC20.sol";
+
+contract TokenMock is ERC20VotesExtendedTimestampMock {
+    constructor() ERC20("Mock Token", "MTK") EIP712("Mock Token", "1") {}
+}
+
+/**
+ * Main responsibility: expose the functions that are relevant to the simulation
+ */
+contract GovernorHandler is GovernorVotesSuperQuorumFractionMock {
+    constructor(
+        string memory name_,
+        uint48 votingDelay_,
+        uint32 votingPeriod_,
+        uint256 proposalThreshold_,
+        IVotes token_,
+        uint256 quorumNumerator_,
+        uint256 superQuorumNumerator_
+    )
+        Governor(name_)
+        GovernorSettings(votingDelay_, votingPeriod_, proposalThreshold_)
+        GovernorVotes(token_)
+        GovernorVotesQuorumFraction(quorumNumerator_)
+        GovernorVotesSuperQuorumFraction(superQuorumNumerator_)
+    {}
+
+    // solhint-disable-next-line func-name-mixedcase
+    function $_updateSuperQuorumNumerator(uint256 newSuperQuorumNumerator) public {
+        _updateSuperQuorumNumerator(newSuperQuorumNumerator);
+    }
+
+    // solhint-disable-next-line func-name-mixedcase
+    function $_updateQuorumNumerator(uint256 newQuorumNumerator) public {
+        _updateQuorumNumerator(newQuorumNumerator);
+    }
+}
+
+contract GovernorSuperQuorumGreaterThanQuorum is Test {
+    GovernorHandler private _governorHandler;
+
+    function setUp() external {
+        _governorHandler = new GovernorHandler(
+            "GovernorName",
+            0, // votingDelay
+            1e4, // votingPeriod
+            0, // proposalThreshold
+            new TokenMock(), // token
+            10, // quorumNumerator
+            50 // superQuorumNumerator
+        );
+
+        // limit the fuzzer scope
+        bytes4[] memory selectors = new bytes4[](2);
+        selectors[0] = GovernorHandler.$_updateSuperQuorumNumerator.selector;
+        selectors[1] = GovernorHandler.$_updateQuorumNumerator.selector;
+
+        targetContract(address(_governorHandler));
+        targetSelector(FuzzSelector(address(_governorHandler), selectors));
+    }
+
+    // solhint-disable-next-line func-name-mixedcase
+    function invariant_superQuorumGreaterThanQuorum() external view {
+        assertGe(_governorHandler.superQuorumNumerator(), _governorHandler.quorumNumerator());
+    }
+}

+ 160 - 0
test/governance/extensions/GovernorVotesSuperQuorumFraction.test.js

@@ -0,0 +1,160 @@
+const { ethers } = require('hardhat');
+const { expect } = require('chai');
+const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
+
+const { GovernorHelper } = require('../../helpers/governance');
+const { ProposalState, VoteType } = require('../../helpers/enums');
+const time = require('../../helpers/time');
+
+const TOKENS = [
+  { Token: '$ERC20Votes', mode: 'blocknumber' },
+  { Token: '$ERC20VotesTimestampMock', mode: 'timestamp' },
+];
+
+const name = 'OZ-Governor';
+const version = '1';
+const tokenName = 'MockToken';
+const tokenSymbol = 'MTKN';
+const tokenSupply = ethers.parseEther('100');
+const quorumRatio = 8n; // percents
+const superQuorumRatio = 50n; // percents
+const newSuperQuorumRatio = 15n; // percents
+const votingDelay = 4n;
+const votingPeriod = 16n;
+const value = ethers.parseEther('1');
+
+describe('GovernorVotesSuperQuorumFraction', function () {
+  for (const { Token, mode } of TOKENS) {
+    const fixture = async () => {
+      const [owner, voter1, voter2, voter3, voter4] = await ethers.getSigners();
+      const receiver = await ethers.deployContract('CallReceiverMock');
+
+      const token = await ethers.deployContract(Token, [tokenName, tokenSymbol, tokenName, version]);
+      const mock = await ethers.deployContract('$GovernorVotesSuperQuorumFractionMock', [
+        name,
+        votingDelay,
+        votingPeriod,
+        0n,
+        token,
+        quorumRatio,
+        superQuorumRatio,
+      ]);
+
+      await owner.sendTransaction({ to: mock, value });
+      await token.$_mint(owner, tokenSupply);
+
+      const helper = new GovernorHelper(mock, mode);
+      await helper.connect(owner).delegate({ token, to: voter1, value: ethers.parseEther('30') });
+      await helper.connect(owner).delegate({ token, to: voter2, value: ethers.parseEther('20') });
+      await helper.connect(owner).delegate({ token, to: voter3, value: ethers.parseEther('15') });
+      await helper.connect(owner).delegate({ token, to: voter4, value: ethers.parseEther('5') });
+
+      return { owner, voter1, voter2, voter3, voter4, receiver, token, mock, helper };
+    };
+
+    describe(`using ${Token}`, function () {
+      beforeEach(async function () {
+        Object.assign(this, await loadFixture(fixture));
+
+        // default proposal
+        this.proposal = this.helper.setProposal(
+          [
+            {
+              target: this.receiver.target,
+              value,
+              data: this.receiver.interface.encodeFunctionData('mockFunction'),
+            },
+          ],
+          '<proposal description>',
+        );
+      });
+
+      it('deployment check', async function () {
+        await expect(this.mock.name()).to.eventually.eventually.equal(name);
+        await expect(this.mock.token()).to.eventually.equal(this.token);
+        await expect(this.mock.votingDelay()).to.eventually.equal(votingDelay);
+        await expect(this.mock.votingPeriod()).to.eventually.equal(votingPeriod);
+        await expect(this.mock.quorumNumerator()).to.eventually.equal(quorumRatio);
+        await expect(this.mock.superQuorumNumerator()).to.eventually.equal(superQuorumRatio);
+        await expect(this.mock.quorumDenominator()).to.eventually.equal(100n);
+        await expect(time.clock[mode]().then(clock => this.mock.superQuorum(clock - 1n))).to.eventually.equal(
+          (tokenSupply * superQuorumRatio) / 100n,
+        );
+      });
+
+      it('proposal remains active until super quorum is reached', async function () {
+        await this.helper.propose();
+        await this.helper.waitForSnapshot();
+
+        // Vote with voter1 (30%) - above quorum (8%) but below super quorum (50%)
+        await this.helper.connect(this.voter1).vote({ support: VoteType.For });
+
+        // Check proposal is still active
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Active);
+
+        // Vote with voter2 (20%) - now matches super quorum
+        await this.helper.connect(this.voter2).vote({ support: VoteType.For });
+
+        // Proposal should no longer be active
+        await expect(this.mock.state(this.proposal.id)).to.eventually.equal(ProposalState.Succeeded);
+      });
+
+      describe('super quorum updates', function () {
+        it('updateSuperQuorumNumerator is protected', async function () {
+          await expect(this.mock.connect(this.owner).updateSuperQuorumNumerator(newSuperQuorumRatio))
+            .to.be.revertedWithCustomError(this.mock, 'GovernorOnlyExecutor')
+            .withArgs(this.owner);
+        });
+
+        it('can update super quorum through governance', async function () {
+          this.helper.setProposal(
+            [
+              {
+                target: this.mock.target,
+                data: this.mock.interface.encodeFunctionData('updateSuperQuorumNumerator', [newSuperQuorumRatio]),
+              },
+            ],
+            '<proposal description>',
+          );
+
+          await this.helper.propose();
+          await this.helper.waitForSnapshot();
+          await this.helper.connect(this.voter1).vote({ support: VoteType.For });
+          await this.helper.connect(this.voter2).vote({ support: VoteType.For });
+          await this.helper.waitForDeadline();
+
+          await expect(this.helper.execute())
+            .to.emit(this.mock, 'SuperQuorumNumeratorUpdated')
+            .withArgs(superQuorumRatio, newSuperQuorumRatio);
+
+          await expect(this.mock.superQuorumNumerator()).to.eventually.equal(newSuperQuorumRatio);
+        });
+
+        it('cannot set super quorum below quorum', async function () {
+          const invalidSuperQuorum = quorumRatio - 1n;
+
+          await expect(this.mock.$_updateSuperQuorumNumerator(invalidSuperQuorum))
+            .to.be.revertedWithCustomError(this.mock, 'GovernorInvalidSuperQuorumTooSmall')
+            .withArgs(invalidSuperQuorum, quorumRatio);
+        });
+
+        it('cannot set super quorum above denominator', async function () {
+          const denominator = await this.mock.quorumDenominator();
+          const invalidSuperQuorum = BigInt(denominator) + 1n;
+
+          await expect(this.mock.$_updateSuperQuorumNumerator(invalidSuperQuorum))
+            .to.be.revertedWithCustomError(this.mock, 'GovernorInvalidSuperQuorumFraction')
+            .withArgs(invalidSuperQuorum, denominator);
+        });
+
+        it('cannot set quorum above super quorum', async function () {
+          const invalidQuorum = superQuorumRatio + 1n;
+
+          await expect(this.mock.$_updateQuorumNumerator(invalidQuorum))
+            .to.be.revertedWithCustomError(this.mock, 'GovernorInvalidQuorumTooLarge')
+            .withArgs(invalidQuorum, superQuorumRatio);
+        });
+      });
+    });
+  }
+});

+ 7 - 5
test/helpers/enums.js

@@ -1,12 +1,14 @@
-function Enum(...options) {
-  return Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
-}
+const { ethers } = require('ethers');
+
+const Enum = (...options) => Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
+const EnumTyped = (...options) => Object.fromEntries(options.map((key, i) => [key, ethers.Typed.uint8(i)]));
 
 module.exports = {
   Enum,
+  EnumTyped,
   ProposalState: Enum('Pending', 'Active', 'Canceled', 'Defeated', 'Succeeded', 'Queued', 'Expired', 'Executed'),
   VoteType: Object.assign(Enum('Against', 'For', 'Abstain'), { Parameters: 255n }),
-  Rounding: Enum('Floor', 'Ceil', 'Trunc', 'Expand'),
+  Rounding: EnumTyped('Floor', 'Ceil', 'Trunc', 'Expand'),
   OperationState: Enum('Unset', 'Waiting', 'Ready', 'Done'),
-  RevertType: Enum('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
+  RevertType: EnumTyped('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
 };

+ 7 - 0
test/utils/Strings.test.js

@@ -339,4 +339,11 @@ describe('Strings', function () {
       }
     });
   });
+
+  describe('Escape JSON string', function () {
+    for (const input of ['', 'a', '{"a":"b/c"}', 'a\tb\nc\\d"e\rf/g\fh\bi'])
+      it(`escape ${JSON.stringify(input)}`, async function () {
+        await expect(this.mock.$escapeJSON(input)).to.eventually.equal(JSON.stringify(input).slice(1, -1));
+      });
+  });
 });

+ 33 - 0
test/utils/cryptography/MessageHashUtils.t.sol

@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Test} from "forge-std/Test.sol";
+import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
+
+contract MessageHashUtilsTest is Test {
+    function testToDataWithIntendedValidatorHash(address validator, bytes memory data) external pure {
+        assertEq(
+            MessageHashUtils.toDataWithIntendedValidatorHash(validator, data),
+            MessageHashUtils.toDataWithIntendedValidatorHash(_dirty(validator), data)
+        );
+    }
+
+    function testToDataWithIntendedValidatorHash(address validator, bytes32 messageHash) external pure {
+        assertEq(
+            MessageHashUtils.toDataWithIntendedValidatorHash(validator, messageHash),
+            MessageHashUtils.toDataWithIntendedValidatorHash(_dirty(validator), messageHash)
+        );
+
+        assertEq(
+            MessageHashUtils.toDataWithIntendedValidatorHash(validator, messageHash),
+            MessageHashUtils.toDataWithIntendedValidatorHash(validator, abi.encodePacked(messageHash))
+        );
+    }
+
+    function _dirty(address input) private pure returns (address output) {
+        assembly ("memory-safe") {
+            output := or(input, shl(160, not(0)))
+        }
+    }
+}

+ 34 - 5
test/utils/cryptography/MessageHashUtils.test.js

@@ -19,14 +19,16 @@ describe('MessageHashUtils', function () {
       const message = ethers.randomBytes(32);
       const expectedHash = ethers.hashMessage(message);
 
-      expect(await this.mock.getFunction('$toEthSignedMessageHash(bytes32)')(message)).to.equal(expectedHash);
+      await expect(this.mock.getFunction('$toEthSignedMessageHash(bytes32)')(message)).to.eventually.equal(
+        expectedHash,
+      );
     });
 
     it('prefixes dynamic length data correctly', async function () {
       const message = ethers.randomBytes(128);
       const expectedHash = ethers.hashMessage(message);
 
-      expect(await this.mock.getFunction('$toEthSignedMessageHash(bytes)')(message)).to.equal(expectedHash);
+      await expect(this.mock.getFunction('$toEthSignedMessageHash(bytes)')(message)).to.eventually.equal(expectedHash);
     });
 
     it('version match for bytes32', async function () {
@@ -39,7 +41,20 @@ describe('MessageHashUtils', function () {
   });
 
   describe('toDataWithIntendedValidatorHash', function () {
-    it('returns the digest correctly', async function () {
+    it('returns the digest of `bytes32 messageHash` correctly', async function () {
+      const verifier = ethers.Wallet.createRandom().address;
+      const message = ethers.randomBytes(32);
+      const expectedHash = ethers.solidityPackedKeccak256(
+        ['string', 'address', 'bytes32'],
+        ['\x19\x00', verifier, message],
+      );
+
+      await expect(
+        this.mock.getFunction('$toDataWithIntendedValidatorHash(address,bytes32)')(verifier, message),
+      ).to.eventually.equal(expectedHash);
+    });
+
+    it('returns the digest of `bytes memory message` correctly', async function () {
       const verifier = ethers.Wallet.createRandom().address;
       const message = ethers.randomBytes(128);
       const expectedHash = ethers.solidityPackedKeccak256(
@@ -47,7 +62,21 @@ describe('MessageHashUtils', function () {
         ['\x19\x00', verifier, message],
       );
 
-      expect(await this.mock.$toDataWithIntendedValidatorHash(verifier, message)).to.equal(expectedHash);
+      await expect(
+        this.mock.getFunction('$toDataWithIntendedValidatorHash(address,bytes)')(verifier, message),
+      ).to.eventually.equal(expectedHash);
+    });
+
+    it('version match for bytes32', async function () {
+      const verifier = ethers.Wallet.createRandom().address;
+      const message = ethers.randomBytes(32);
+      const fixed = await this.mock.getFunction('$toDataWithIntendedValidatorHash(address,bytes)')(verifier, message);
+      const dynamic = await this.mock.getFunction('$toDataWithIntendedValidatorHash(address,bytes32)')(
+        verifier,
+        message,
+      );
+
+      expect(fixed).to.equal(dynamic);
     });
   });
 
@@ -62,7 +91,7 @@ describe('MessageHashUtils', function () {
       const structhash = ethers.randomBytes(32);
       const expectedHash = hashTypedData(domain, structhash);
 
-      expect(await this.mock.$toTypedDataHash(domainSeparator(domain), structhash)).to.equal(expectedHash);
+      await expect(this.mock.$toTypedDataHash(domainSeparator(domain), structhash)).to.eventually.equal(expectedHash);
     });
   });
 });

+ 57 - 22
test/utils/math/Math.t.sol

@@ -11,6 +11,48 @@ contract MathTest is Test {
         assertEq(Math.ternary(f, a, b), f ? a : b);
     }
 
+    // ADD512 & MUL512
+    function testAdd512(uint256 a, uint256 b) public pure {
+        (uint256 high, uint256 low) = Math.add512(a, b);
+
+        // test against tryAdd
+        (bool success, uint256 result) = Math.tryAdd(a, b);
+        if (success) {
+            assertEq(high, 0);
+            assertEq(low, result);
+        } else {
+            assertEq(high, 1);
+        }
+
+        // test against unchecked
+        unchecked {
+            assertEq(low, a + b); // unchecked allow overflow
+        }
+    }
+
+    function testMul512(uint256 a, uint256 b) public pure {
+        (uint256 high, uint256 low) = Math.mul512(a, b);
+
+        // test against tryMul
+        (bool success, uint256 result) = Math.tryMul(a, b);
+        if (success) {
+            assertEq(high, 0);
+            assertEq(low, result);
+        } else {
+            assertGt(high, 0);
+        }
+
+        // test against unchecked
+        unchecked {
+            assertEq(low, a * b); // unchecked allow overflow
+        }
+
+        // test against alternative method
+        (uint256 _high, uint256 _low) = _mulKaratsuba(a, b);
+        assertEq(high, _high);
+        assertEq(low, _low);
+    }
+
     // MIN & MAX
     function testSymbolicMinMax(uint256 a, uint256 b) public pure {
         assertEq(Math.min(a, b), a < b ? a : b);
@@ -184,7 +226,7 @@ contract MathTest is Test {
     // MULDIV
     function testMulDiv(uint256 x, uint256 y, uint256 d) public pure {
         // Full precision for x * y
-        (uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y);
+        (uint256 xyHi, uint256 xyLo) = Math.mul512(x, y);
 
         // Assume result won't overflow (see {testMulDivDomain})
         // This also checks that `d` is positive
@@ -194,9 +236,9 @@ contract MathTest is Test {
         uint256 q = Math.mulDiv(x, y, d);
 
         // Full precision for q * d
-        (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
+        (uint256 qdHi, uint256 qdLo) = Math.mul512(q, d);
         // Add remainder of x * y / d (computed as rem = (x * y % d))
-        (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
+        (uint256 c, uint256 qdRemLo) = Math.add512(qdLo, mulmod(x, y, d));
         uint256 qdRemHi = qdHi + c;
 
         // Full precision check that x * y = q * d + rem
@@ -206,7 +248,7 @@ contract MathTest is Test {
 
     /// forge-config: default.allow_internal_expect_revert = true
     function testMulDivDomain(uint256 x, uint256 y, uint256 d) public {
-        (uint256 xyHi, ) = _mulHighLow(x, y);
+        (uint256 xyHi, ) = Math.mul512(x, y);
 
         // Violate {testMulDiv} assumption (covers d is 0 and result overflow)
         vm.assume(xyHi >= d);
@@ -266,26 +308,13 @@ contract MathTest is Test {
         }
     }
 
-    function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
-        if (m == 1) return 0;
-        uint256 r = 1;
-        while (e > 0) {
-            if (e % 2 > 0) {
-                r = mulmod(r, b, m);
-            }
-            b = mulmod(b, b, m);
-            e >>= 1;
-        }
-        return r;
-    }
-
     // Helpers
     function _asRounding(uint8 r) private pure returns (Math.Rounding) {
         vm.assume(r < uint8(type(Math.Rounding).max));
         return Math.Rounding(r);
     }
 
-    function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
+    function _mulKaratsuba(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
         (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
         (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
 
@@ -305,10 +334,16 @@ contract MathTest is Test {
         }
     }
 
-    function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) {
-        unchecked {
-            res = x + y;
+    function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
+        if (m == 1) return 0;
+        uint256 r = 1;
+        while (e > 0) {
+            if (e % 2 > 0) {
+                r = mulmod(r, b, m);
+            }
+            b = mulmod(b, b, m);
+            e >>= 1;
         }
-        carry = res < x ? 1 : 0;
+        return r;
     }
 }

+ 298 - 147
test/utils/math/Math.test.js

@@ -16,10 +16,13 @@ const uint256 = value => ethers.Typed.uint256(value);
 bytes.zero = '0x';
 uint256.zero = 0n;
 
-async function testCommutative(fn, lhs, rhs, expected, ...extra) {
-  expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
-  expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
-}
+const testCommutative = (fn, lhs, rhs, expected, ...extra) =>
+  Promise.all([
+    expect(fn(lhs, rhs, ...extra)).to.eventually.deep.equal(expected),
+    expect(fn(rhs, lhs, ...extra)).to.eventually.deep.equal(expected),
+  ]);
+
+const splitHighLow = n => [n / (1n << 256n), n % (1n << 256n)];
 
 async function fixture() {
   const mock = await ethers.deployContract('$Math');
@@ -39,6 +42,24 @@ describe('Math', function () {
     Object.assign(this, await loadFixture(fixture));
   });
 
+  describe('add512', function () {
+    it('adds correctly without reverting', async function () {
+      const values = [0n, 1n, 17n, 42n, ethers.MaxUint256 - 1n, ethers.MaxUint256];
+      for (const [a, b] of product(values, values)) {
+        await expect(this.mock.$add512(a, b)).to.eventually.deep.equal(splitHighLow(a + b));
+      }
+    });
+  });
+
+  describe('mul512', function () {
+    it('multiplies correctly without reverting', async function () {
+      const values = [0n, 1n, 17n, 42n, ethers.MaxUint256 - 1n, ethers.MaxUint256];
+      for (const [a, b] of product(values, values)) {
+        await expect(this.mock.$mul512(a, b)).to.eventually.deep.equal(splitHighLow(a * b));
+      }
+    });
+  });
+
   describe('tryAdd', function () {
     it('adds correctly', async function () {
       const a = 5678n;
@@ -57,13 +78,13 @@ describe('Math', function () {
     it('subtracts correctly', async function () {
       const a = 5678n;
       const b = 1234n;
-      expect(await this.mock.$trySub(a, b)).to.deep.equal([true, a - b]);
+      await expect(this.mock.$trySub(a, b)).to.eventually.deep.equal([true, a - b]);
     });
 
     it('reverts if subtraction result would be negative', async function () {
       const a = 1234n;
       const b = 5678n;
-      expect(await this.mock.$trySub(a, b)).to.deep.equal([false, 0n]);
+      await expect(this.mock.$trySub(a, b)).to.eventually.deep.equal([false, 0n]);
     });
   });
 
@@ -91,25 +112,25 @@ describe('Math', function () {
     it('divides correctly', async function () {
       const a = 5678n;
       const b = 5678n;
-      expect(await this.mock.$tryDiv(a, b)).to.deep.equal([true, a / b]);
+      await expect(this.mock.$tryDiv(a, b)).to.eventually.deep.equal([true, a / b]);
     });
 
     it('divides zero correctly', async function () {
       const a = 0n;
       const b = 5678n;
-      expect(await this.mock.$tryDiv(a, b)).to.deep.equal([true, a / b]);
+      await expect(this.mock.$tryDiv(a, b)).to.eventually.deep.equal([true, a / b]);
     });
 
     it('returns complete number result on non-even division', async function () {
       const a = 7000n;
       const b = 5678n;
-      expect(await this.mock.$tryDiv(a, b)).to.deep.equal([true, a / b]);
+      await expect(this.mock.$tryDiv(a, b)).to.eventually.deep.equal([true, a / b]);
     });
 
     it('reverts on division by zero', async function () {
       const a = 5678n;
       const b = 0n;
-      expect(await this.mock.$tryDiv(a, b)).to.deep.equal([false, 0n]);
+      await expect(this.mock.$tryDiv(a, b)).to.eventually.deep.equal([false, 0n]);
     });
   });
 
@@ -118,32 +139,88 @@ describe('Math', function () {
       it('when the dividend is smaller than the divisor', async function () {
         const a = 284n;
         const b = 5678n;
-        expect(await this.mock.$tryMod(a, b)).to.deep.equal([true, a % b]);
+        await expect(this.mock.$tryMod(a, b)).to.eventually.deep.equal([true, a % b]);
       });
 
       it('when the dividend is equal to the divisor', async function () {
         const a = 5678n;
         const b = 5678n;
-        expect(await this.mock.$tryMod(a, b)).to.deep.equal([true, a % b]);
+        await expect(this.mock.$tryMod(a, b)).to.eventually.deep.equal([true, a % b]);
       });
 
       it('when the dividend is larger than the divisor', async function () {
         const a = 7000n;
         const b = 5678n;
-        expect(await this.mock.$tryMod(a, b)).to.deep.equal([true, a % b]);
+        await expect(this.mock.$tryMod(a, b)).to.eventually.deep.equal([true, a % b]);
       });
 
       it('when the dividend is a multiple of the divisor', async function () {
         const a = 17034n; // 17034 == 5678 * 3
         const b = 5678n;
-        expect(await this.mock.$tryMod(a, b)).to.deep.equal([true, a % b]);
+        await expect(this.mock.$tryMod(a, b)).to.eventually.deep.equal([true, a % b]);
       });
     });
 
     it('reverts with a 0 divisor', async function () {
       const a = 5678n;
       const b = 0n;
-      expect(await this.mock.$tryMod(a, b)).to.deep.equal([false, 0n]);
+      await expect(this.mock.$tryMod(a, b)).to.eventually.deep.equal([false, 0n]);
+    });
+  });
+
+  describe('saturatingAdd', function () {
+    it('adds correctly', async function () {
+      const a = 5678n;
+      const b = 1234n;
+      await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
+      await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
+      await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
+    });
+
+    it('bounds on addition overflow', async function () {
+      await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
+      await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
+        ethers.MaxUint256,
+      );
+    });
+  });
+
+  describe('saturatingSub', function () {
+    it('subtracts correctly', async function () {
+      const a = 5678n;
+      const b = 1234n;
+      await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
+      await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
+      await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
+    });
+
+    it('bounds on subtraction overflow', async function () {
+      await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
+      await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
+    });
+  });
+
+  describe('saturatingMul', function () {
+    it('multiplies correctly', async function () {
+      const a = 1234n;
+      const b = 5678n;
+      await testCommutative(this.mock.$saturatingMul, a, b, a * b);
+    });
+
+    it('multiplies by zero correctly', async function () {
+      const a = 0n;
+      const b = 5678n;
+      await testCommutative(this.mock.$saturatingMul, a, b, 0n);
+    });
+
+    it('bounds on multiplication overflow', async function () {
+      const a = ethers.MaxUint256;
+      const b = 2n;
+      await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
     });
   });
 
@@ -163,24 +240,24 @@ describe('Math', function () {
     it('is correctly calculated with two odd numbers', async function () {
       const a = 57417n;
       const b = 95431n;
-      expect(await this.mock.$average(a, b)).to.equal((a + b) / 2n);
+      await expect(this.mock.$average(a, b)).to.eventually.equal((a + b) / 2n);
     });
 
     it('is correctly calculated with two even numbers', async function () {
       const a = 42304n;
       const b = 84346n;
-      expect(await this.mock.$average(a, b)).to.equal((a + b) / 2n);
+      await expect(this.mock.$average(a, b)).to.eventually.equal((a + b) / 2n);
     });
 
     it('is correctly calculated with one even and one odd number', async function () {
       const a = 57417n;
       const b = 84346n;
-      expect(await this.mock.$average(a, b)).to.equal((a + b) / 2n);
+      await expect(this.mock.$average(a, b)).to.eventually.equal((a + b) / 2n);
     });
 
     it('is correctly calculated with two max uint256 numbers', async function () {
       const a = ethers.MaxUint256;
-      expect(await this.mock.$average(a, a)).to.equal(a);
+      await expect(this.mock.$average(a, a)).to.eventually.equal(a);
     });
   });
 
@@ -196,35 +273,35 @@ describe('Math', function () {
       const a = 0n;
       const b = 2n;
       const r = 0n;
-      expect(await this.mock.$ceilDiv(a, b)).to.equal(r);
+      await expect(this.mock.$ceilDiv(a, b)).to.eventually.equal(r);
     });
 
     it('does not round up on exact division', async function () {
       const a = 10n;
       const b = 5n;
       const r = 2n;
-      expect(await this.mock.$ceilDiv(a, b)).to.equal(r);
+      await expect(this.mock.$ceilDiv(a, b)).to.eventually.equal(r);
     });
 
     it('rounds up on division with remainders', async function () {
       const a = 42n;
       const b = 13n;
       const r = 4n;
-      expect(await this.mock.$ceilDiv(a, b)).to.equal(r);
+      await expect(this.mock.$ceilDiv(a, b)).to.eventually.equal(r);
     });
 
     it('does not overflow', async function () {
       const a = ethers.MaxUint256;
       const b = 2n;
       const r = 1n << 255n;
-      expect(await this.mock.$ceilDiv(a, b)).to.equal(r);
+      await expect(this.mock.$ceilDiv(a, b)).to.eventually.equal(r);
     });
 
     it('correctly computes max uint256 divided by 1', async function () {
       const a = ethers.MaxUint256;
       const b = 1n;
       const r = ethers.MaxUint256;
-      expect(await this.mock.$ceilDiv(a, b)).to.equal(r);
+      await expect(this.mock.$ceilDiv(a, b)).to.eventually.equal(r);
     });
   });
 
@@ -248,28 +325,30 @@ describe('Math', function () {
     describe('does round down', function () {
       it('small values', async function () {
         for (const rounding of RoundingDown) {
-          expect(await this.mock.$mulDiv(3n, 4n, 5n, rounding)).to.equal(2n);
-          expect(await this.mock.$mulDiv(3n, 5n, 5n, rounding)).to.equal(3n);
+          await expect(this.mock.$mulDiv(3n, 4n, 5n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$mulDiv(3n, 5n, 5n, rounding)).to.eventually.equal(3n);
         }
       });
 
       it('large values', async function () {
         for (const rounding of RoundingDown) {
-          expect(await this.mock.$mulDiv(42n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding)).to.equal(41n);
+          await expect(this.mock.$mulDiv(42n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding)).to.eventually.equal(
+            41n,
+          );
 
-          expect(await this.mock.$mulDiv(17n, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.equal(17n);
+          await expect(this.mock.$mulDiv(17n, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.eventually.equal(17n);
 
-          expect(
-            await this.mock.$mulDiv(ethers.MaxUint256 - 1n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
-          ).to.equal(ethers.MaxUint256 - 2n);
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256 - 1n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256 - 2n);
 
-          expect(
-            await this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
-          ).to.equal(ethers.MaxUint256 - 1n);
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256 - 1n);
 
-          expect(await this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.equal(
-            ethers.MaxUint256,
-          );
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256);
         }
       });
     });
@@ -277,28 +356,91 @@ describe('Math', function () {
     describe('does round up', function () {
       it('small values', async function () {
         for (const rounding of RoundingUp) {
-          expect(await this.mock.$mulDiv(3n, 4n, 5n, rounding)).to.equal(3n);
-          expect(await this.mock.$mulDiv(3n, 5n, 5n, rounding)).to.equal(3n);
+          await expect(this.mock.$mulDiv(3n, 4n, 5n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$mulDiv(3n, 5n, 5n, rounding)).to.eventually.equal(3n);
         }
       });
 
       it('large values', async function () {
         for (const rounding of RoundingUp) {
-          expect(await this.mock.$mulDiv(42n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding)).to.equal(42n);
+          await expect(this.mock.$mulDiv(42n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding)).to.eventually.equal(
+            42n,
+          );
 
-          expect(await this.mock.$mulDiv(17n, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.equal(17n);
+          await expect(this.mock.$mulDiv(17n, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.eventually.equal(17n);
 
-          expect(
-            await this.mock.$mulDiv(ethers.MaxUint256 - 1n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
-          ).to.equal(ethers.MaxUint256 - 1n);
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256 - 1n, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256 - 1n);
 
-          expect(
-            await this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
-          ).to.equal(ethers.MaxUint256 - 1n);
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256 - 1n, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256 - 1n);
 
-          expect(await this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256, ethers.MaxUint256, rounding)).to.equal(
+          await expect(
+            this.mock.$mulDiv(ethers.MaxUint256, ethers.MaxUint256, ethers.MaxUint256, rounding),
+          ).to.eventually.equal(ethers.MaxUint256);
+        }
+      });
+    });
+  });
+
+  describe('mulShr', function () {
+    it('reverts with result higher than 2 ^ 256', async function () {
+      const a = 5n;
+      const b = ethers.MaxUint256;
+      const c = 1n;
+      await expect(this.mock.$mulShr(a, b, c, Rounding.Floor)).to.be.revertedWithPanic(
+        PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
+      );
+    });
+
+    describe('does round down', function () {
+      it('small values', async function () {
+        for (const rounding of RoundingDown) {
+          await expect(this.mock.$mulShr(3n, 5n, 1n, rounding)).to.eventually.equal(7n);
+          await expect(this.mock.$mulShr(3n, 5n, 2n, rounding)).to.eventually.equal(3n);
+        }
+      });
+
+      it('large values', async function () {
+        for (const rounding of RoundingDown) {
+          await expect(this.mock.$mulShr(42n, ethers.MaxUint256, 255n, rounding)).to.eventually.equal(83n);
+
+          await expect(this.mock.$mulShr(17n, ethers.MaxUint256, 255n, rounding)).to.eventually.equal(33n);
+
+          await expect(this.mock.$mulShr(ethers.MaxUint256, ethers.MaxInt256 + 1n, 255n, rounding)).to.eventually.equal(
+            ethers.MaxUint256,
+          );
+
+          await expect(this.mock.$mulShr(ethers.MaxUint256, ethers.MaxInt256, 255n, rounding)).to.eventually.equal(
+            ethers.MaxUint256 - 2n,
+          );
+        }
+      });
+    });
+
+    describe('does round up', function () {
+      it('small values', async function () {
+        for (const rounding of RoundingUp) {
+          await expect(this.mock.$mulShr(3n, 5n, 1n, rounding)).to.eventually.equal(8n);
+          await expect(this.mock.$mulShr(3n, 5n, 2n, rounding)).to.eventually.equal(4n);
+        }
+      });
+
+      it('large values', async function () {
+        for (const rounding of RoundingUp) {
+          await expect(this.mock.$mulShr(42n, ethers.MaxUint256, 255n, rounding)).to.eventually.equal(84n);
+
+          await expect(this.mock.$mulShr(17n, ethers.MaxUint256, 255n, rounding)).to.eventually.equal(34n);
+
+          await expect(this.mock.$mulShr(ethers.MaxUint256, ethers.MaxInt256 + 1n, 255n, rounding)).to.eventually.equal(
             ethers.MaxUint256,
           );
+
+          await expect(this.mock.$mulShr(ethers.MaxUint256, ethers.MaxInt256, 255n, rounding)).to.eventually.equal(
+            ethers.MaxUint256 - 1n,
+          );
         }
       });
     });
@@ -320,8 +462,8 @@ describe('Math', function () {
 
       describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () {
         it('trying to inverse 0 returns 0', async function () {
-          expect(await this.mock.$invMod(0, p)).to.equal(0n);
-          expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p
+          await expect(this.mock.$invMod(0, p)).to.eventually.equal(0n);
+          await expect(this.mock.$invMod(p, p)).to.eventually.equal(0n); // p is 0 mod p
         });
 
         if (p != 0) {
@@ -349,7 +491,7 @@ describe('Math', function () {
           const e = 200n;
           const m = 50n;
 
-          expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
+          await expect(this.mock.$modExp(type(b), type(e), type(m))).to.eventually.equal(type(b ** e % m).value);
         });
 
         it('is correctly reverting when modulus is zero', async function () {
@@ -373,7 +515,9 @@ describe('Math', function () {
         it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
           const mLength = ethers.dataLength(ethers.toBeHex(m));
 
-          expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
+          await expect(this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.eventually.equal(
+            bytes(modExp(b, e, m), mLength).value,
+          );
         });
       }
     });
@@ -387,7 +531,10 @@ describe('Math', function () {
           const e = 200n;
           const m = 50n;
 
-          expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
+          await expect(this.mock.$tryModExp(type(b), type(e), type(m))).to.eventually.deep.equal([
+            true,
+            type(b ** e % m).value,
+          ]);
         });
 
         it('is correctly reverting when modulus is zero', async function () {
@@ -395,7 +542,7 @@ describe('Math', function () {
           const e = 200n;
           const m = 0n;
 
-          expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
+          await expect(this.mock.$tryModExp(type(b), type(e), type(m))).to.eventually.deep.equal([false, type.zero]);
         });
       });
     }
@@ -409,7 +556,7 @@ describe('Math', function () {
         it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
           const mLength = ethers.dataLength(ethers.toBeHex(m));
 
-          expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
+          await expect(this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.eventually.deep.equal([
             true,
             bytes(modExp(b, e, m), mLength).value,
           ]);
@@ -421,35 +568,39 @@ describe('Math', function () {
   describe('sqrt', function () {
     it('rounds down', async function () {
       for (const rounding of RoundingDown) {
-        expect(await this.mock.$sqrt(0n, rounding)).to.equal(0n);
-        expect(await this.mock.$sqrt(1n, rounding)).to.equal(1n);
-        expect(await this.mock.$sqrt(2n, rounding)).to.equal(1n);
-        expect(await this.mock.$sqrt(3n, rounding)).to.equal(1n);
-        expect(await this.mock.$sqrt(4n, rounding)).to.equal(2n);
-        expect(await this.mock.$sqrt(144n, rounding)).to.equal(12n);
-        expect(await this.mock.$sqrt(999999n, rounding)).to.equal(999n);
-        expect(await this.mock.$sqrt(1000000n, rounding)).to.equal(1000n);
-        expect(await this.mock.$sqrt(1000001n, rounding)).to.equal(1000n);
-        expect(await this.mock.$sqrt(1002000n, rounding)).to.equal(1000n);
-        expect(await this.mock.$sqrt(1002001n, rounding)).to.equal(1001n);
-        expect(await this.mock.$sqrt(ethers.MaxUint256, rounding)).to.equal(340282366920938463463374607431768211455n);
+        await expect(this.mock.$sqrt(0n, rounding)).to.eventually.equal(0n);
+        await expect(this.mock.$sqrt(1n, rounding)).to.eventually.equal(1n);
+        await expect(this.mock.$sqrt(2n, rounding)).to.eventually.equal(1n);
+        await expect(this.mock.$sqrt(3n, rounding)).to.eventually.equal(1n);
+        await expect(this.mock.$sqrt(4n, rounding)).to.eventually.equal(2n);
+        await expect(this.mock.$sqrt(144n, rounding)).to.eventually.equal(12n);
+        await expect(this.mock.$sqrt(999999n, rounding)).to.eventually.equal(999n);
+        await expect(this.mock.$sqrt(1000000n, rounding)).to.eventually.equal(1000n);
+        await expect(this.mock.$sqrt(1000001n, rounding)).to.eventually.equal(1000n);
+        await expect(this.mock.$sqrt(1002000n, rounding)).to.eventually.equal(1000n);
+        await expect(this.mock.$sqrt(1002001n, rounding)).to.eventually.equal(1001n);
+        await expect(this.mock.$sqrt(ethers.MaxUint256, rounding)).to.eventually.equal(
+          340282366920938463463374607431768211455n,
+        );
       }
     });
 
     it('rounds up', async function () {
       for (const rounding of RoundingUp) {
-        expect(await this.mock.$sqrt(0n, rounding)).to.equal(0n);
-        expect(await this.mock.$sqrt(1n, rounding)).to.equal(1n);
-        expect(await this.mock.$sqrt(2n, rounding)).to.equal(2n);
-        expect(await this.mock.$sqrt(3n, rounding)).to.equal(2n);
-        expect(await this.mock.$sqrt(4n, rounding)).to.equal(2n);
-        expect(await this.mock.$sqrt(144n, rounding)).to.equal(12n);
-        expect(await this.mock.$sqrt(999999n, rounding)).to.equal(1000n);
-        expect(await this.mock.$sqrt(1000000n, rounding)).to.equal(1000n);
-        expect(await this.mock.$sqrt(1000001n, rounding)).to.equal(1001n);
-        expect(await this.mock.$sqrt(1002000n, rounding)).to.equal(1001n);
-        expect(await this.mock.$sqrt(1002001n, rounding)).to.equal(1001n);
-        expect(await this.mock.$sqrt(ethers.MaxUint256, rounding)).to.equal(340282366920938463463374607431768211456n);
+        await expect(this.mock.$sqrt(0n, rounding)).to.eventually.equal(0n);
+        await expect(this.mock.$sqrt(1n, rounding)).to.eventually.equal(1n);
+        await expect(this.mock.$sqrt(2n, rounding)).to.eventually.equal(2n);
+        await expect(this.mock.$sqrt(3n, rounding)).to.eventually.equal(2n);
+        await expect(this.mock.$sqrt(4n, rounding)).to.eventually.equal(2n);
+        await expect(this.mock.$sqrt(144n, rounding)).to.eventually.equal(12n);
+        await expect(this.mock.$sqrt(999999n, rounding)).to.eventually.equal(1000n);
+        await expect(this.mock.$sqrt(1000000n, rounding)).to.eventually.equal(1000n);
+        await expect(this.mock.$sqrt(1000001n, rounding)).to.eventually.equal(1001n);
+        await expect(this.mock.$sqrt(1002000n, rounding)).to.eventually.equal(1001n);
+        await expect(this.mock.$sqrt(1002001n, rounding)).to.eventually.equal(1001n);
+        await expect(this.mock.$sqrt(ethers.MaxUint256, rounding)).to.eventually.equal(
+          340282366920938463463374607431768211456n,
+        );
       }
     });
   });
@@ -458,33 +609,33 @@ describe('Math', function () {
     describe('log2', function () {
       it('rounds down', async function () {
         for (const rounding of RoundingDown) {
-          expect(await this.mock.$log2(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log2(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log2(2n, rounding)).to.equal(1n);
-          expect(await this.mock.$log2(3n, rounding)).to.equal(1n);
-          expect(await this.mock.$log2(4n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(5n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(6n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(7n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(8n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(9n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(ethers.MaxUint256, rounding)).to.equal(255n);
+          await expect(this.mock.$log2(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log2(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log2(2n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log2(3n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log2(4n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(5n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(6n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(7n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(8n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(9n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(ethers.MaxUint256, rounding)).to.eventually.equal(255n);
         }
       });
 
       it('rounds up', async function () {
         for (const rounding of RoundingUp) {
-          expect(await this.mock.$log2(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log2(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log2(2n, rounding)).to.equal(1n);
-          expect(await this.mock.$log2(3n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(4n, rounding)).to.equal(2n);
-          expect(await this.mock.$log2(5n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(6n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(7n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(8n, rounding)).to.equal(3n);
-          expect(await this.mock.$log2(9n, rounding)).to.equal(4n);
-          expect(await this.mock.$log2(ethers.MaxUint256, rounding)).to.equal(256n);
+          await expect(this.mock.$log2(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log2(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log2(2n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log2(3n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(4n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log2(5n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(6n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(7n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(8n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log2(9n, rounding)).to.eventually.equal(4n);
+          await expect(this.mock.$log2(ethers.MaxUint256, rounding)).to.eventually.equal(256n);
         }
       });
     });
@@ -492,37 +643,37 @@ describe('Math', function () {
     describe('log10', function () {
       it('rounds down', async function () {
         for (const rounding of RoundingDown) {
-          expect(await this.mock.$log10(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(2n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(9n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(10n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(11n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(99n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(100n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(101n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(999n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(1000n, rounding)).to.equal(3n);
-          expect(await this.mock.$log10(1001n, rounding)).to.equal(3n);
-          expect(await this.mock.$log10(ethers.MaxUint256, rounding)).to.equal(77n);
+          await expect(this.mock.$log10(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(2n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(9n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(10n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(11n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(99n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(100n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(101n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(999n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(1000n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log10(1001n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log10(ethers.MaxUint256, rounding)).to.eventually.equal(77n);
         }
       });
 
       it('rounds up', async function () {
         for (const rounding of RoundingUp) {
-          expect(await this.mock.$log10(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log10(2n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(9n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(10n, rounding)).to.equal(1n);
-          expect(await this.mock.$log10(11n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(99n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(100n, rounding)).to.equal(2n);
-          expect(await this.mock.$log10(101n, rounding)).to.equal(3n);
-          expect(await this.mock.$log10(999n, rounding)).to.equal(3n);
-          expect(await this.mock.$log10(1000n, rounding)).to.equal(3n);
-          expect(await this.mock.$log10(1001n, rounding)).to.equal(4n);
-          expect(await this.mock.$log10(ethers.MaxUint256, rounding)).to.equal(78n);
+          await expect(this.mock.$log10(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log10(2n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(9n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(10n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log10(11n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(99n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(100n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log10(101n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log10(999n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log10(1000n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log10(1001n, rounding)).to.eventually.equal(4n);
+          await expect(this.mock.$log10(ethers.MaxUint256, rounding)).to.eventually.equal(78n);
         }
       });
     });
@@ -530,31 +681,31 @@ describe('Math', function () {
     describe('log256', function () {
       it('rounds down', async function () {
         for (const rounding of RoundingDown) {
-          expect(await this.mock.$log256(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(2n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(255n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(256n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(257n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(65535n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(65536n, rounding)).to.equal(2n);
-          expect(await this.mock.$log256(65537n, rounding)).to.equal(2n);
-          expect(await this.mock.$log256(ethers.MaxUint256, rounding)).to.equal(31n);
+          await expect(this.mock.$log256(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(2n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(255n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(256n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(257n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(65535n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(65536n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log256(65537n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log256(ethers.MaxUint256, rounding)).to.eventually.equal(31n);
         }
       });
 
       it('rounds up', async function () {
         for (const rounding of RoundingUp) {
-          expect(await this.mock.$log256(0n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(1n, rounding)).to.equal(0n);
-          expect(await this.mock.$log256(2n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(255n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(256n, rounding)).to.equal(1n);
-          expect(await this.mock.$log256(257n, rounding)).to.equal(2n);
-          expect(await this.mock.$log256(65535n, rounding)).to.equal(2n);
-          expect(await this.mock.$log256(65536n, rounding)).to.equal(2n);
-          expect(await this.mock.$log256(65537n, rounding)).to.equal(3n);
-          expect(await this.mock.$log256(ethers.MaxUint256, rounding)).to.equal(32n);
+          await expect(this.mock.$log256(0n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(1n, rounding)).to.eventually.equal(0n);
+          await expect(this.mock.$log256(2n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(255n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(256n, rounding)).to.eventually.equal(1n);
+          await expect(this.mock.$log256(257n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log256(65535n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log256(65536n, rounding)).to.eventually.equal(2n);
+          await expect(this.mock.$log256(65537n, rounding)).to.eventually.equal(3n);
+          await expect(this.mock.$log256(ethers.MaxUint256, rounding)).to.eventually.equal(32n);
         }
       });
     });

+ 108 - 28
test/utils/structs/MerkleTree.test.js

@@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
 const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');
 
 const { generators } = require('../../helpers/random');
+const { range } = require('../../helpers/iterate');
 
-const makeTree = (leaves = [ethers.ZeroHash]) =>
+const DEPTH = 4; // 16 slots
+
+const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) =>
   StandardMerkleTree.of(
-    leaves.map(leaf => [leaf]),
+    []
+      .concat(
+        leaves,
+        Array.from({ length: length - leaves.length }, () => zero),
+      )
+      .map(leaf => [leaf]),
     ['bytes32'],
     { sortLeaves: false },
   );
 
-const hashLeaf = leaf => makeTree().leafHash([leaf]);
-
-const DEPTH = 4n; // 16 slots
-const ZERO = hashLeaf(ethers.ZeroHash);
+const ZERO = makeTree().leafHash([ethers.ZeroHash]);
 
 async function fixture() {
   const mock = await ethers.deployContract('MerkleTreeMock');
@@ -30,57 +35,132 @@ describe('MerkleTree', function () {
   });
 
   it('sets initial values at setup', async function () {
-    const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
+    const merkleTree = makeTree();
 
-    expect(await this.mock.root()).to.equal(merkleTree.root);
-    expect(await this.mock.depth()).to.equal(DEPTH);
-    expect(await this.mock.nextLeafIndex()).to.equal(0n);
+    await expect(this.mock.root()).to.eventually.equal(merkleTree.root);
+    await expect(this.mock.depth()).to.eventually.equal(DEPTH);
+    await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n);
   });
 
   describe('push', function () {
-    it('tree is correctly updated', async function () {
-      const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
+    it('pushing correctly updates the tree', async function () {
+      const leaves = [];
 
       // for each leaf slot
-      for (const i in leaves) {
-        // generate random leaf and hash it
-        const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32()));
+      for (const i in range(2 ** DEPTH)) {
+        // generate random leaf
+        leaves.push(generators.bytes32());
 
-        // update leaf list and rebuild tree.
+        // rebuild tree.
         const tree = makeTree(leaves);
+        const hash = tree.leafHash(tree.at(i));
 
         // push value to tree
-        await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
+        await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root);
 
         // check tree
-        expect(await this.mock.root()).to.equal(tree.root);
-        expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
+        await expect(this.mock.root()).to.eventually.equal(tree.root);
+        await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n);
       }
     });
 
-    it('revert when tree is full', async function () {
+    it('pushing to a full tree reverts', async function () {
       await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));
 
       await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
     });
   });
 
+  describe('update', function () {
+    for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount =>
+      range(leafCount).map(leafIndex => ({ leafCount, leafIndex })),
+    ))
+      it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () {
+        // initial tree
+        const leaves = Array.from({ length: leafCount }, generators.bytes32);
+        const oldTree = makeTree(leaves);
+
+        // fill tree and verify root
+        for (const i in leaves) {
+          await this.mock.push(oldTree.leafHash(oldTree.at(i)));
+        }
+        await expect(this.mock.root()).to.eventually.equal(oldTree.root);
+
+        // create updated tree
+        leaves[leafIndex] = generators.bytes32();
+        const newTree = makeTree(leaves);
+
+        const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex));
+        const newLeafHash = newTree.leafHash(newTree.at(leafIndex));
+
+        // perform update
+        await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex)))
+          .to.emit(this.mock, 'LeafUpdated')
+          .withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root);
+
+        // verify updated root
+        await expect(this.mock.root()).to.eventually.equal(newTree.root);
+
+        // if there is still room in the tree, fill it
+        for (const i of range(leafCount, 2 ** DEPTH)) {
+          // push new value and rebuild tree
+          leaves.push(generators.bytes32());
+          const nextTree = makeTree(leaves);
+
+          // push and verify root
+          await this.mock.push(nextTree.leafHash(nextTree.at(i)));
+          await expect(this.mock.root()).to.eventually.equal(nextTree.root);
+        }
+      });
+
+    it('replacing a leaf that was not previously pushed reverts', async function () {
+      // changing leaf 0 on an empty tree
+      await expect(this.mock.update(1, ZERO, ZERO, []))
+        .to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex')
+        .withArgs(1, 0);
+    });
+
+    it('replacing a leaf using an invalid proof reverts', async function () {
+      const leafCount = 4;
+      const leafIndex = 2;
+
+      const leaves = Array.from({ length: leafCount }, generators.bytes32);
+      const tree = makeTree(leaves);
+
+      // fill tree and verify root
+      for (const i in leaves) {
+        await this.mock.push(tree.leafHash(tree.at(i)));
+      }
+      await expect(this.mock.root()).to.eventually.equal(tree.root);
+
+      const oldLeafHash = tree.leafHash(tree.at(leafIndex));
+      const newLeafHash = generators.bytes32();
+      const proof = tree.getProof(leafIndex);
+      // invalid proof (tamper)
+      proof[1] = generators.bytes32();
+
+      await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError(
+        this.mock,
+        'MerkleTreeUpdateInvalidProof',
+      );
+    });
+  });
+
   it('reset', async function () {
     // empty tree
-    const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
-    const zeroTree = makeTree(zeroLeaves);
+    const emptyTree = makeTree();
 
     // tree with one element
-    const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
-    const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it
+    const leaves = [generators.bytes32()];
     const tree = makeTree(leaves);
+    const hash = tree.leafHash(tree.at(0));
 
     // root should be that of a zero tree
-    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.root()).to.equal(emptyTree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(0n);
 
     // push leaf and check root
-    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+    await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
 
     expect(await this.mock.root()).to.equal(tree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(1n);
@@ -88,11 +168,11 @@ describe('MerkleTree', function () {
     // reset tree
     await this.mock.setup(DEPTH, ZERO);
 
-    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.root()).to.equal(emptyTree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(0n);
 
     // re-push leaf and check root
-    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+    await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
 
     expect(await this.mock.root()).to.equal(tree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(1n);