Преглед на файлове

Add function to update a leaf in a MerkleTree structure (#5453)

Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com>
Hadrien Croubois преди 7 месеца
родител
ревизия
71bc0f7774
променени са 4 файла, в които са добавени 213 реда и са изтрити 28 реда
  1. 5 0
      .changeset/good-zebras-ring.md
  2. 8 0
      contracts/mocks/MerkleTreeMock.sol
  3. 92 0
      contracts/utils/structs/MerkleTree.sol
  4. 108 28
      test/utils/structs/MerkleTree.test.js

+ 5 - 0
.changeset/good-zebras-ring.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way.

+ 8 - 0
contracts/mocks/MerkleTreeMock.sol

@@ -14,6 +14,7 @@ contract MerkleTreeMock {
     bytes32 public root;
 
     event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
+    event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root);
 
     function setup(uint8 _depth, bytes32 _zero) public {
         root = _tree.setup(_depth, _zero);
@@ -25,6 +26,13 @@ contract MerkleTreeMock {
         root = currentRoot;
     }
 
+    function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public {
+        (bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof);
+        if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof();
+        emit LeafUpdated(oldValue, newValue, index, newRoot);
+        root = newRoot;
+    }
+
     function depth() public view returns (uint256) {
         return _tree.depth();
     }

+ 92 - 0
contracts/utils/structs/MerkleTree.sol

@@ -6,6 +6,7 @@ pragma solidity ^0.8.20;
 import {Hashes} from "../cryptography/Hashes.sol";
 import {Arrays} from "../Arrays.sol";
 import {Panic} from "../Panic.sol";
+import {StorageSlot} from "../StorageSlot.sol";
 
 /**
  * @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
@@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol";
  * _Available since v5.1._
  */
 library MerkleTree {
+    /// @dev Error emitted when trying to update a leaf that was not previously pushed.
+    error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length);
+
+    /// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side).
+    error MerkleTreeUpdateInvalidProof();
+
     /**
      * @dev A complete `bytes32` Merkle tree.
      *
@@ -166,6 +173,91 @@ library MerkleTree {
         return (index, currentLevelHash);
     }
 
+    /**
+     * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
+     * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
+     * root is the last known one.
+     *
+     * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
+     * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
+     * all "in flight" updates invalid.
+     *
+     * This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees
+     * that were setup using the same (default) hashing function (i.e. by calling
+     * {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function).
+     */
+    function update(
+        Bytes32PushTree storage self,
+        uint256 index,
+        bytes32 oldValue,
+        bytes32 newValue,
+        bytes32[] memory proof
+    ) internal returns (bytes32 oldRoot, bytes32 newRoot) {
+        return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256);
+    }
+
+    /**
+     * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
+     * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
+     * root is the last known one.
+     *
+     * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
+     * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
+     * all "in flight" updates invalid.
+     *
+     * This variant uses a custom hashing function to hash internal nodes. It should only be called with the same
+     * function as the one used during the initial setup of the merkle tree.
+     */
+    function update(
+        Bytes32PushTree storage self,
+        uint256 index,
+        bytes32 oldValue,
+        bytes32 newValue,
+        bytes32[] memory proof,
+        function(bytes32, bytes32) view returns (bytes32) fnHash
+    ) internal returns (bytes32 oldRoot, bytes32 newRoot) {
+        unchecked {
+            // Check index range
+            uint256 length = self._nextLeafIndex;
+            if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length);
+
+            // Cache read
+            uint256 treeDepth = depth(self);
+
+            // Workaround stack too deep
+            bytes32[] storage sides = self._sides;
+
+            // This cannot overflow because: 0 <= index < length
+            uint256 lastIndex = length - 1;
+            uint256 currentIndex = index;
+            bytes32 currentLevelHashOld = oldValue;
+            bytes32 currentLevelHashNew = newValue;
+            for (uint32 i = 0; i < treeDepth; i++) {
+                bool isLeft = currentIndex % 2 == 0;
+
+                lastIndex >>= 1;
+                currentIndex >>= 1;
+
+                if (isLeft && currentIndex == lastIndex) {
+                    StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i);
+                    if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof();
+                    side.value = currentLevelHashNew;
+                }
+
+                bytes32 sibling = proof[i];
+                currentLevelHashOld = fnHash(
+                    isLeft ? currentLevelHashOld : sibling,
+                    isLeft ? sibling : currentLevelHashOld
+                );
+                currentLevelHashNew = fnHash(
+                    isLeft ? currentLevelHashNew : sibling,
+                    isLeft ? sibling : currentLevelHashNew
+                );
+            }
+            return (currentLevelHashOld, currentLevelHashNew);
+        }
+    }
+
     /**
      * @dev Tree's depth (set at initialization)
      */

+ 108 - 28
test/utils/structs/MerkleTree.test.js

@@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
 const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');
 
 const { generators } = require('../../helpers/random');
+const { range } = require('../../helpers/iterate');
 
-const makeTree = (leaves = [ethers.ZeroHash]) =>
+const DEPTH = 4; // 16 slots
+
+const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) =>
   StandardMerkleTree.of(
-    leaves.map(leaf => [leaf]),
+    []
+      .concat(
+        leaves,
+        Array.from({ length: length - leaves.length }, () => zero),
+      )
+      .map(leaf => [leaf]),
     ['bytes32'],
     { sortLeaves: false },
   );
 
-const hashLeaf = leaf => makeTree().leafHash([leaf]);
-
-const DEPTH = 4n; // 16 slots
-const ZERO = hashLeaf(ethers.ZeroHash);
+const ZERO = makeTree().leafHash([ethers.ZeroHash]);
 
 async function fixture() {
   const mock = await ethers.deployContract('MerkleTreeMock');
@@ -30,57 +35,132 @@ describe('MerkleTree', function () {
   });
 
   it('sets initial values at setup', async function () {
-    const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
+    const merkleTree = makeTree();
 
-    expect(await this.mock.root()).to.equal(merkleTree.root);
-    expect(await this.mock.depth()).to.equal(DEPTH);
-    expect(await this.mock.nextLeafIndex()).to.equal(0n);
+    await expect(this.mock.root()).to.eventually.equal(merkleTree.root);
+    await expect(this.mock.depth()).to.eventually.equal(DEPTH);
+    await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n);
   });
 
   describe('push', function () {
-    it('tree is correctly updated', async function () {
-      const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
+    it('pushing correctly updates the tree', async function () {
+      const leaves = [];
 
       // for each leaf slot
-      for (const i in leaves) {
-        // generate random leaf and hash it
-        const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32()));
+      for (const i in range(2 ** DEPTH)) {
+        // generate random leaf
+        leaves.push(generators.bytes32());
 
-        // update leaf list and rebuild tree.
+        // rebuild tree.
         const tree = makeTree(leaves);
+        const hash = tree.leafHash(tree.at(i));
 
         // push value to tree
-        await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
+        await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root);
 
         // check tree
-        expect(await this.mock.root()).to.equal(tree.root);
-        expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
+        await expect(this.mock.root()).to.eventually.equal(tree.root);
+        await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n);
       }
     });
 
-    it('revert when tree is full', async function () {
+    it('pushing to a full tree reverts', async function () {
       await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));
 
       await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
     });
   });
 
+  describe('update', function () {
+    for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount =>
+      range(leafCount).map(leafIndex => ({ leafCount, leafIndex })),
+    ))
+      it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () {
+        // initial tree
+        const leaves = Array.from({ length: leafCount }, generators.bytes32);
+        const oldTree = makeTree(leaves);
+
+        // fill tree and verify root
+        for (const i in leaves) {
+          await this.mock.push(oldTree.leafHash(oldTree.at(i)));
+        }
+        await expect(this.mock.root()).to.eventually.equal(oldTree.root);
+
+        // create updated tree
+        leaves[leafIndex] = generators.bytes32();
+        const newTree = makeTree(leaves);
+
+        const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex));
+        const newLeafHash = newTree.leafHash(newTree.at(leafIndex));
+
+        // perform update
+        await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex)))
+          .to.emit(this.mock, 'LeafUpdated')
+          .withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root);
+
+        // verify updated root
+        await expect(this.mock.root()).to.eventually.equal(newTree.root);
+
+        // if there is still room in the tree, fill it
+        for (const i of range(leafCount, 2 ** DEPTH)) {
+          // push new value and rebuild tree
+          leaves.push(generators.bytes32());
+          const nextTree = makeTree(leaves);
+
+          // push and verify root
+          await this.mock.push(nextTree.leafHash(nextTree.at(i)));
+          await expect(this.mock.root()).to.eventually.equal(nextTree.root);
+        }
+      });
+
+    it('replacing a leaf that was not previously pushed reverts', async function () {
+      // changing leaf 0 on an empty tree
+      await expect(this.mock.update(1, ZERO, ZERO, []))
+        .to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex')
+        .withArgs(1, 0);
+    });
+
+    it('replacing a leaf using an invalid proof reverts', async function () {
+      const leafCount = 4;
+      const leafIndex = 2;
+
+      const leaves = Array.from({ length: leafCount }, generators.bytes32);
+      const tree = makeTree(leaves);
+
+      // fill tree and verify root
+      for (const i in leaves) {
+        await this.mock.push(tree.leafHash(tree.at(i)));
+      }
+      await expect(this.mock.root()).to.eventually.equal(tree.root);
+
+      const oldLeafHash = tree.leafHash(tree.at(leafIndex));
+      const newLeafHash = generators.bytes32();
+      const proof = tree.getProof(leafIndex);
+      // invalid proof (tamper)
+      proof[1] = generators.bytes32();
+
+      await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError(
+        this.mock,
+        'MerkleTreeUpdateInvalidProof',
+      );
+    });
+  });
+
   it('reset', async function () {
     // empty tree
-    const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
-    const zeroTree = makeTree(zeroLeaves);
+    const emptyTree = makeTree();
 
     // tree with one element
-    const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
-    const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it
+    const leaves = [generators.bytes32()];
     const tree = makeTree(leaves);
+    const hash = tree.leafHash(tree.at(0));
 
     // root should be that of a zero tree
-    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.root()).to.equal(emptyTree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(0n);
 
     // push leaf and check root
-    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+    await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
 
     expect(await this.mock.root()).to.equal(tree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(1n);
@@ -88,11 +168,11 @@ describe('MerkleTree', function () {
     // reset tree
     await this.mock.setup(DEPTH, ZERO);
 
-    expect(await this.mock.root()).to.equal(zeroTree.root);
+    expect(await this.mock.root()).to.equal(emptyTree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(0n);
 
     // re-push leaf and check root
-    await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
+    await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
 
     expect(await this.mock.root()).to.equal(tree.root);
     expect(await this.mock.nextLeafIndex()).to.equal(1n);