Kaynağa Gözat

Add a MerkleTree builder (#3617)

Co-authored-by: Ernesto García <ernestognw@gmail.com>
Hadrien Croubois 1 yıl önce
ebeveyn
işleme
92ff025622

+ 5 - 0
.changeset/odd-files-protect.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Hashes`: A library with commonly used hash functions.

+ 5 - 0
.changeset/warm-sheep-cover.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`MerkleTree`: A data structure that allows inserting elements into a merkle tree and updating its root hash.

+ 24 - 0
contracts/mocks/ArraysMock.sol

@@ -48,6 +48,14 @@ contract Uint256ArraysMock {
     function _reverse(uint256 a, uint256 b) private pure returns (bool) {
         return a > b;
     }
+
+    function unsafeSetLength(uint256 newLength) external {
+        _array.unsafeSetLength(newLength);
+    }
+
+    function length() external view returns (uint256) {
+        return _array.length;
+    }
 }
 
 contract AddressArraysMock {
@@ -74,6 +82,14 @@ contract AddressArraysMock {
     function _reverse(address a, address b) private pure returns (bool) {
         return uint160(a) > uint160(b);
     }
+
+    function unsafeSetLength(uint256 newLength) external {
+        _array.unsafeSetLength(newLength);
+    }
+
+    function length() external view returns (uint256) {
+        return _array.length;
+    }
 }
 
 contract Bytes32ArraysMock {
@@ -100,4 +116,12 @@ contract Bytes32ArraysMock {
     function _reverse(bytes32 a, bytes32 b) private pure returns (bool) {
         return uint256(a) > uint256(b);
     }
+
+    function unsafeSetLength(uint256 newLength) external {
+        _array.unsafeSetLength(newLength);
+    }
+
+    function length() external view returns (uint256) {
+        return _array.length;
+    }
 }

+ 43 - 0
contracts/mocks/MerkleTreeMock.sol

@@ -0,0 +1,43 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.0;
+
+import {MerkleTree} from "../utils/structs/MerkleTree.sol";
+
+contract MerkleTreeMock {
+    using MerkleTree for MerkleTree.Bytes32PushTree;
+
+    MerkleTree.Bytes32PushTree private _tree;
+
+    event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
+
+    function setup(uint8 _depth, bytes32 _zero) public {
+        _tree.setup(_depth, _zero);
+    }
+
+    function push(bytes32 leaf) public {
+        (uint256 leafIndex, bytes32 currentRoot) = _tree.push(leaf);
+        emit LeafInserted(leaf, leafIndex, currentRoot);
+    }
+
+    function root() public view returns (bytes32) {
+        return _tree.root();
+    }
+
+    function depth() public view returns (uint256) {
+        return _tree.depth();
+    }
+
+    // internal state
+    function nextLeafIndex() public view returns (uint256) {
+        return _tree._nextLeafIndex;
+    }
+
+    function sides(uint256 i) public view returns (bytes32) {
+        return _tree._sides[i];
+    }
+
+    function zeros(uint256 i) public view returns (bytes32) {
+        return _tree._zeros[i];
+    }
+}

+ 33 - 0
contracts/utils/Arrays.sol

@@ -440,4 +440,37 @@ library Arrays {
             res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
         }
     }
+
+    /**
+     * @dev Helper to set the length of an dynamic array. Directly writing to `.length` is forbidden.
+     *
+     * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
+     */
+    function unsafeSetLength(address[] storage array, uint256 len) internal {
+        assembly {
+            sstore(array.slot, len)
+        }
+    }
+
+    /**
+     * @dev Helper to set the length of an dynamic array. Directly writing to `.length` is forbidden.
+     *
+     * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
+     */
+    function unsafeSetLength(bytes32[] storage array, uint256 len) internal {
+        assembly {
+            sstore(array.slot, len)
+        }
+    }
+
+    /**
+     * @dev Helper to set the length of an dynamic array. Directly writing to `.length` is forbidden.
+     *
+     * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
+     */
+    function unsafeSetLength(uint256[] storage array, uint256 len) internal {
+        assembly {
+            sstore(array.slot, len)
+        }
+    }
 }

+ 8 - 2
contracts/utils/README.adoc

@@ -9,6 +9,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t
  * {SafeCast}: Checked downcasting functions to avoid silent truncation.
  * {ECDSA}, {MessageHashUtils}: Libraries for interacting with ECDSA signatures.
  * {SignatureChecker}: A library helper to support regular ECDSA from EOAs as well as ERC-1271 signatures for smart contracts.
+ * {Hashes}: Commonly used hash functions.
  * {MerkleProof}: Functions for verifying https://en.wikipedia.org/wiki/Merkle_tree[Merkle Tree] proofs.
  * {EIP712}: Contract with functions to allow processing signed typed structure data according to https://eips.ethereum.org/EIPS/eip-712[EIP-712].
  * {ReentrancyGuard}: A modifier that can prevent reentrancy during certain functions.
@@ -20,6 +21,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t
  * {EnumerableSet}: Like {EnumerableMap}, but for https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets]. Can be used to store privileged accounts, issued IDs, etc.
  * {DoubleEndedQueue}: An implementation of a https://en.wikipedia.org/wiki/Double-ended_queue[double ended queue] whose values can be removed added or remove from both sides. Useful for FIFO and LIFO structures.
  * {Checkpoints}: A data structure to store values mapped to an strictly increasing key. Can be used for storing and accessing values over time.
+ * {MerkleTree}: A library with https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures and helper functions.
  * {Create2}: Wrapper around the https://blog.openzeppelin.com/getting-the-most-out-of-create2/[`CREATE2` EVM opcode] for safe use without having to deal with low-level assembly.
  * {Address}: Collection of functions for overloading Solidity's https://docs.soliditylang.org/en/latest/types.html#address[`address`] type.
  * {Arrays}: Collection of functions that operate on https://docs.soliditylang.org/en/latest/types.html#arrays[`arrays`].
@@ -48,13 +50,15 @@ Because Solidity does not support generic types, {EnumerableMap} and {Enumerable
 
 {{ECDSA}}
 
+{{EIP712}}
+
 {{MessageHashUtils}}
 
 {{SignatureChecker}}
 
-{{MerkleProof}}
+{{Hashes}}
 
-{{EIP712}}
+{{MerkleProof}}
 
 == Security
 
@@ -88,6 +92,8 @@ Ethereum contracts have no native concept of an interface, so applications must
 
 {{Checkpoints}}
 
+{{MerkleTree}}
+
 == Libraries
 
 {{Create2}}

+ 29 - 0
contracts/utils/cryptography/Hashes.sol

@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.0;
+
+/**
+ * @dev Library of standard hash functions.
+ */
+library Hashes {
+    /**
+     * @dev Commutative Keccak256 hash of a sorted pair of bytes32. Frequently used when working with merkle proofs.
+     *
+     * NOTE: Equivalent to the `standardNodeHash` in our https://github.com/OpenZeppelin/merkle-tree[JavaScript library].
+     */
+    function commutativeKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32) {
+        return a < b ? _efficientKeccak256(a, b) : _efficientKeccak256(b, a);
+    }
+
+    /**
+     * @dev Implementation of keccak256(abi.encode(a, b)) that doesn't allocate or expand memory.
+     */
+    function _efficientKeccak256(bytes32 a, bytes32 b) private pure returns (bytes32 value) {
+        /// @solidity memory-safe-assembly
+        assembly {
+            mstore(0x00, a)
+            mstore(0x20, b)
+            value := keccak256(0x00, 0x40)
+        }
+    }
+}

+ 6 - 23
contracts/utils/cryptography/MerkleProof.sol

@@ -3,6 +3,8 @@
 
 pragma solidity ^0.8.20;
 
+import {Hashes} from "./Hashes.sol";
+
 /**
  * @dev These functions deal with verification of Merkle Tree proofs.
  *
@@ -49,7 +51,7 @@ library MerkleProof {
     function processProof(bytes32[] memory proof, bytes32 leaf) internal pure returns (bytes32) {
         bytes32 computedHash = leaf;
         for (uint256 i = 0; i < proof.length; i++) {
-            computedHash = _hashPair(computedHash, proof[i]);
+            computedHash = Hashes.commutativeKeccak256(computedHash, proof[i]);
         }
         return computedHash;
     }
@@ -60,7 +62,7 @@ library MerkleProof {
     function processProofCalldata(bytes32[] calldata proof, bytes32 leaf) internal pure returns (bytes32) {
         bytes32 computedHash = leaf;
         for (uint256 i = 0; i < proof.length; i++) {
-            computedHash = _hashPair(computedHash, proof[i]);
+            computedHash = Hashes.commutativeKeccak256(computedHash, proof[i]);
         }
         return computedHash;
     }
@@ -138,7 +140,7 @@ library MerkleProof {
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
                 : proof[proofPos++];
-            hashes[i] = _hashPair(a, b);
+            hashes[i] = Hashes.commutativeKeccak256(a, b);
         }
 
         if (totalHashes > 0) {
@@ -194,7 +196,7 @@ library MerkleProof {
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
                 : proof[proofPos++];
-            hashes[i] = _hashPair(a, b);
+            hashes[i] = Hashes.commutativeKeccak256(a, b);
         }
 
         if (totalHashes > 0) {
@@ -210,23 +212,4 @@ library MerkleProof {
             return proof[0];
         }
     }
-
-    /**
-     * @dev Sorts the pair (a, b) and hashes the result.
-     */
-    function _hashPair(bytes32 a, bytes32 b) private pure returns (bytes32) {
-        return a < b ? _efficientHash(a, b) : _efficientHash(b, a);
-    }
-
-    /**
-     * @dev Implementation of keccak256(abi.encode(a, b)) that doesn't allocate or expand memory.
-     */
-    function _efficientHash(bytes32 a, bytes32 b) private pure returns (bytes32 value) {
-        /// @solidity memory-safe-assembly
-        assembly {
-            mstore(0x00, a)
-            mstore(0x20, b)
-            value := keccak256(0x00, 0x40)
-        }
-    }
 }

+ 154 - 0
contracts/utils/structs/MerkleTree.sol

@@ -0,0 +1,154 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.0;
+
+import {Hashes} from "../cryptography/Hashes.sol";
+import {Arrays} from "../Arrays.sol";
+import {Panic} from "../Panic.sol";
+
+/**
+ * @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
+ *
+ * Each tree is a complete binary tree with the ability to sequentially insert leaves, changing them from a zero to a
+ * non-zero value and updating its root. This structure allows inserting commitments (or other entries) that are not
+ * stored, but can be proven to be part of the tree at a later time. See {MerkleProof}.
+ *
+ * A tree is defined by the following parameters:
+ *
+ * * Depth: The number of levels in the tree, it also defines the maximum number of leaves as 2**depth.
+ * * Zero value: The value that represents an empty leaf. Used to avoid regular zero values to be part of the tree.
+ * * Hashing function: A cryptographic hash function used to produce internal nodes.
+ *
+ * _Available since v5.1._
+ */
+library MerkleTree {
+    /**
+     * @dev A complete `bytes32` Merkle tree.
+     *
+     * The `sides` and `zero` arrays are set to have a length equal to the depth of the tree during setup.
+     *
+     * The hashing function used during initialization to compute the `zeros` values (value of a node at a given depth
+     * for which the subtree is full of zero leaves). This function is kept in the structure for handling insertions.
+     *
+     * Struct members have an underscore prefix indicating that they are "private" and should not be read or written to
+     * directly. Use the functions provided below instead. Modifying the struct manually may violate assumptions and
+     * lead to unexpected behavior.
+     *
+     * NOTE: The `root` is kept up to date after each insertion without keeping track of its history. Consider
+     * using a secondary structure to store a list of historical roots (e.g. a mapping, {BitMaps} or {Checkpoints}).
+     *
+     * WARNING: Updating any of the tree's parameters after the first insertion will result in a corrupted tree.
+     */
+    struct Bytes32PushTree {
+        bytes32 _root;
+        uint256 _nextLeafIndex;
+        bytes32[] _sides;
+        bytes32[] _zeros;
+        function(bytes32, bytes32) view returns (bytes32) _fnHash;
+    }
+
+    /**
+     * @dev Initialize a {Bytes32PushTree} using {Hashes-commutativeKeccak256} to hash internal nodes.
+     * The capacity of the tree (i.e. number of leaves) is set to `2**levels`.
+     *
+     * Calling this function on MerkleTree that was already setup and used will reset it to a blank state.
+     *
+     * IMPORTANT: The zero value should be carefully chosen since it will be stored in the tree representing
+     * empty leaves. It should be a value that is not expected to be part of the tree.
+     */
+    function setup(Bytes32PushTree storage self, uint8 levels, bytes32 zero) internal {
+        return setup(self, levels, zero, Hashes.commutativeKeccak256);
+    }
+
+    /**
+     * @dev Same as {setup}, but allows to specify a custom hashing function.
+     *
+     * IMPORTANT: Providing a custom hashing function is a security-sensitive operation since it may
+     * compromise the soundness of the tree. Consider using functions from {Hashes}.
+     */
+    function setup(
+        Bytes32PushTree storage self,
+        uint8 levels,
+        bytes32 zero,
+        function(bytes32, bytes32) view returns (bytes32) fnHash
+    ) internal {
+        // Store depth in the dynamic array
+        Arrays.unsafeSetLength(self._sides, levels);
+        Arrays.unsafeSetLength(self._zeros, levels);
+
+        // Build each root of zero-filled subtrees
+        bytes32 currentZero = zero;
+        for (uint32 i = 0; i < levels; ++i) {
+            Arrays.unsafeAccess(self._zeros, i).value = currentZero;
+            currentZero = fnHash(currentZero, currentZero);
+        }
+
+        // Set the first root
+        self._root = currentZero;
+        self._nextLeafIndex = 0;
+        self._fnHash = fnHash;
+    }
+
+    /**
+     * @dev Insert a new leaf in the tree, and compute the new root. Returns the position of the inserted leaf in the
+     * tree, and the resulting root.
+     *
+     * Hashing the leaf before calling this function is recommended as a protection against
+     * second pre-image attacks.
+     */
+    function push(Bytes32PushTree storage self, bytes32 leaf) internal returns (uint256 index, bytes32 newRoot) {
+        // Cache read
+        uint256 levels = self._zeros.length;
+        function(bytes32, bytes32) view returns (bytes32) fnHash = self._fnHash;
+
+        // Get leaf index
+        uint256 leafIndex = self._nextLeafIndex++;
+
+        // Check if tree is full.
+        if (leafIndex >= 1 << levels) {
+            Panic.panic(Panic.RESOURCE_ERROR);
+        }
+
+        // Rebuild branch from leaf to root
+        uint256 currentIndex = leafIndex;
+        bytes32 currentLevelHash = leaf;
+        for (uint32 i = 0; i < levels; i++) {
+            // Reaching the parent node, is currentLevelHash the left child?
+            bool isLeft = currentIndex % 2 == 0;
+
+            // If so, next time we will come from the right, so we need to save it
+            if (isLeft) {
+                Arrays.unsafeAccess(self._sides, i).value = currentLevelHash;
+            }
+
+            // Compute the current node hash by using the hash function
+            // with either the its sibling (side) or the zero value for that level.
+            currentLevelHash = fnHash(
+                isLeft ? currentLevelHash : Arrays.unsafeAccess(self._sides, i).value,
+                isLeft ? Arrays.unsafeAccess(self._zeros, i).value : currentLevelHash
+            );
+
+            // Update node index
+            currentIndex >>= 1;
+        }
+
+        // Record new root
+        self._root = currentLevelHash;
+
+        return (leafIndex, currentLevelHash);
+    }
+
+    /**
+     * @dev Tree's current root
+     */
+    function root(Bytes32PushTree storage self) internal view returns (bytes32) {
+        return self._root;
+    }
+
+    /**
+     * @dev Tree's depth (set at initialization)
+     */
+    function depth(Bytes32PushTree storage self) internal view returns (uint256) {
+        return self._zeros.length;
+    }
+}

+ 4 - 4
package-lock.json

@@ -17,7 +17,7 @@
         "@nomicfoundation/hardhat-ethers": "^3.0.4",
         "@nomicfoundation/hardhat-network-helpers": "^1.0.3",
         "@openzeppelin/docs-utils": "^0.1.5",
-        "@openzeppelin/merkle-tree": "^1.0.5",
+        "@openzeppelin/merkle-tree": "^1.0.6",
         "@openzeppelin/upgrade-safe-transpiler": "^0.3.32",
         "@openzeppelin/upgrades-core": "^1.20.6",
         "chai": "^4.2.0",
@@ -2392,9 +2392,9 @@
       }
     },
     "node_modules/@openzeppelin/merkle-tree": {
-      "version": "1.0.5",
-      "resolved": "https://registry.npmjs.org/@openzeppelin/merkle-tree/-/merkle-tree-1.0.5.tgz",
-      "integrity": "sha512-JkwG2ysdHeIphrScNxYagPy6jZeNONgDRyqU6lbFgE8HKCZFSkcP8r6AjZs+3HZk4uRNV0kNBBzuWhKQ3YV7Kw==",
+      "version": "1.0.6",
+      "resolved": "https://registry.npmjs.org/@openzeppelin/merkle-tree/-/merkle-tree-1.0.6.tgz",
+      "integrity": "sha512-cGWOb2WBWbJhqvupzxjnKAwGLxxAEYPg51sk76yZ5nVe5D03mw7Vx5yo8llaIEqYhP5O39M8QlrNWclgLfKVrA==",
       "dev": true,
       "dependencies": {
         "@ethersproject/abi": "^5.7.0",

+ 1 - 1
package.json

@@ -57,7 +57,7 @@
     "@nomicfoundation/hardhat-ethers": "^3.0.4",
     "@nomicfoundation/hardhat-network-helpers": "^1.0.3",
     "@openzeppelin/docs-utils": "^0.1.5",
-    "@openzeppelin/merkle-tree": "^1.0.5",
+    "@openzeppelin/merkle-tree": "^1.0.6",
     "@openzeppelin/upgrade-safe-transpiler": "^0.3.32",
     "@openzeppelin/upgrades-core": "^1.20.6",
     "chai": "^4.2.0",

+ 32 - 16
test/utils/Arrays.test.js

@@ -17,6 +17,7 @@ const upperBound = (array, value) => {
 };
 
 const bigintSign = x => (x > 0n ? 1 : x < 0n ? -1 : 0);
+const comparator = (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b));
 const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
 
 describe('Arrays', function () {
@@ -116,23 +117,22 @@ describe('Arrays', function () {
     }
   });
 
-  for (const [type, { artifact, elements, comp }] of Object.entries({
+  for (const [type, { artifact, format }] of Object.entries({
     address: {
       artifact: 'AddressArraysMock',
-      elements: Array.from({ length: 10 }, generators.address),
-      comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)),
+      format: x => ethers.getAddress(ethers.toBeHex(x, 20)),
     },
     bytes32: {
       artifact: 'Bytes32ArraysMock',
-      elements: Array.from({ length: 10 }, generators.bytes32),
-      comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)),
+      format: x => ethers.toBeHex(x, 32),
     },
     uint256: {
       artifact: 'Uint256ArraysMock',
-      elements: Array.from({ length: 10 }, generators.uint256),
-      comp: (a, b) => bigintSign(a - b),
+      format: x => ethers.toBigInt(x),
     },
   })) {
+    const elements = Array.from({ length: 10 }, generators[type]);
+
     describe(type, function () {
       const fixture = async () => {
         return { instance: await ethers.deployContract(artifact, [elements]) };
@@ -146,14 +146,14 @@ describe('Arrays', function () {
         for (const length of [0, 1, 2, 8, 32, 128]) {
           describe(`${type}[] of length ${length}`, function () {
             beforeEach(async function () {
-              this.elements = Array.from({ length }, generators[type]);
+              this.array = Array.from({ length }, generators[type]);
             });
 
             afterEach(async function () {
-              const expected = Array.from(this.elements).sort(comp);
+              const expected = Array.from(this.array).sort(comparator);
               const reversed = Array.from(expected).reverse();
-              expect(await this.instance.sort(this.elements)).to.deep.equal(expected);
-              expect(await this.instance.sortReverse(this.elements)).to.deep.equal(reversed);
+              expect(await this.instance.sort(this.array)).to.deep.equal(expected);
+              expect(await this.instance.sortReverse(this.array)).to.deep.equal(reversed);
             });
 
             it('sort array', async function () {
@@ -163,23 +163,23 @@ describe('Arrays', function () {
             if (length > 1) {
               it('sort array for identical elements', async function () {
                 // duplicate the first value to all elements
-                this.elements.fill(this.elements.at(0));
+                this.array.fill(this.array.at(0));
               });
 
               it('sort already sorted array', async function () {
                 // pre-sort the elements
-                this.elements.sort(comp);
+                this.array.sort(comparator);
               });
 
               it('sort reversed array', async function () {
                 // pre-sort in reverse order
-                this.elements.sort(comp).reverse();
+                this.array.sort(comparator).reverse();
               });
 
               it('sort almost sorted array', async function () {
                 // pre-sort + rotate (move the last element to the front) for an almost sorted effect
-                this.elements.sort(comp);
-                this.elements.unshift(this.elements.pop());
+                this.array.sort(comparator);
+                this.array.unshift(this.array.pop());
               });
             }
           });
@@ -197,6 +197,14 @@ describe('Arrays', function () {
           it('unsafeAccess outside bounds', async function () {
             await expect(this.instance.unsafeAccess(elements.length)).to.not.be.rejected;
           });
+
+          it('unsafeSetLength changes the length or the array', async function () {
+            const newLength = generators.uint256();
+
+            expect(await this.instance.length()).to.equal(elements.length);
+            await expect(this.instance.unsafeSetLength(newLength)).to.not.be.rejected;
+            expect(await this.instance.length()).to.equal(newLength);
+          });
         });
 
         describe('memory', function () {
@@ -211,6 +219,14 @@ describe('Arrays', function () {
           it('unsafeMemoryAccess outside bounds', async function () {
             await expect(this.mock[fragment](elements, elements.length)).to.not.be.rejected;
           });
+
+          it('unsafeMemoryAccess loop around', async function () {
+            for (let i = 251n; i < 256n; ++i) {
+              expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(format(elements.length));
+              expect(await this.mock[fragment](elements, 2n ** i + 0n)).to.equal(elements[0]);
+              expect(await this.mock[fragment](elements, 2n ** i + 1n)).to.equal(elements[1]);
+            }
+          });
         });
       });
     });

+ 100 - 0
test/utils/structs/MerkleTree.test.js

@@ -0,0 +1,100 @@
+const { ethers } = require('hardhat');
+const { expect } = require('chai');
+const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
+const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
+const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');
+
+const { generators } = require('../../helpers/random');
+
+const makeTree = (leafs = [ethers.ZeroHash]) =>
+  StandardMerkleTree.of(
+    leafs.map(leaf => [leaf]),
+    ['bytes32'],
+    { sortLeaves: false },
+  );
+
+const hashLeaf = leaf => makeTree().leafHash([leaf]);
+
+const DEPTH = 4n; // 16 slots
+const ZERO = hashLeaf(ethers.ZeroHash);
+
+async function fixture() {
+  const mock = await ethers.deployContract('MerkleTreeMock');
+  await mock.setup(DEPTH, ZERO);
+  return { mock };
+}
+
+describe('MerkleTree', function () {
+  beforeEach(async function () {
+    Object.assign(this, await loadFixture(fixture));
+  });
+
+  it('sets initial values at setup', async function () {
+    const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
+
+    expect(await this.mock.root()).to.equal(merkleTree.root);
+    expect(await this.mock.depth()).to.equal(DEPTH);
+    expect(await this.mock.nextLeafIndex()).to.equal(0n);
+  });
+
+  describe('push', function () {
+    it('tree is correctly updated', async function () {
+      const leafs = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
+
+      // for each leaf slot
+      for (const i in leafs) {
+        // generate random leaf and hash it
+        const hashedLeaf = hashLeaf((leafs[i] = generators.bytes32()));
+
+        // update leaf list and rebuild tree.
+        const tree = makeTree(leafs);
+
+        // push value to tree
+        await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
+
+        // check tree
+        expect(await this.mock.root()).to.equal(tree.root);
+        expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
+      }
+    });
+
+    it('revert when tree is full', async function () {
+      await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));
+
+      await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
+    });
+  });
+
+  it('reset', async function () {
+    // empty tree
+    const zeroLeafs = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
+    const zeroTree = makeTree(zeroLeafs);
+
+    // tree with one element
+    const leafs = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
+    const hashedLeaf = hashLeaf((leafs[0] = generators.bytes32())); // fill first leaf and hash it
+    const tree = makeTree(leafs);
+
+    // root should be that of a zero tree
+    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.nextLeafIndex()).to.equal(0n);
+
+    // push leaf and check root
+    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+
+    expect(await this.mock.root()).to.equal(tree.root);
+    expect(await this.mock.nextLeafIndex()).to.equal(1n);
+
+    // reset tree
+    await this.mock.setup(DEPTH, ZERO);
+
+    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.nextLeafIndex()).to.equal(0n);
+
+    // re-push leaf and check root
+    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+
+    expect(await this.mock.root()).to.equal(tree.root);
+    expect(await this.mock.nextLeafIndex()).to.equal(1n);
+  });
+});