Ver Fonte

Get leaves from memory in processMultiProofCalldata (#5140)

Hadrien Croubois há 1 ano atrás
pai
commit
24a641d9c9

+ 32 - 28
contracts/utils/cryptography/MerkleProof.sol

@@ -105,7 +105,7 @@ library MerkleProof {
      * This version handles proofs in calldata with the default hashing function.
      */
     function verifyCalldata(bytes32[] calldata proof, bytes32 root, bytes32 leaf) internal pure returns (bool) {
-        return processProof(proof, leaf) == root;
+        return processProofCalldata(proof, leaf) == root;
     }
 
     /**
@@ -138,7 +138,7 @@ library MerkleProof {
         bytes32 leaf,
         function(bytes32, bytes32) view returns (bytes32) hasher
     ) internal view returns (bool) {
-        return processProof(proof, leaf, hasher) == root;
+        return processProofCalldata(proof, leaf, hasher) == root;
     }
 
     /**
@@ -200,15 +200,16 @@ library MerkleProof {
         // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
         // the Merkle tree.
         uint256 leavesLen = leaves.length;
+        uint256 proofFlagsLen = proofFlags.length;
 
         // Check proof validity.
-        if (leavesLen + proof.length != proofFlags.length + 1) {
+        if (leavesLen + proof.length != proofFlagsLen + 1) {
             revert MerkleProofInvalidMultiproof();
         }
 
         // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
         // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
-        bytes32[] memory hashes = new bytes32[](proofFlags.length);
+        bytes32[] memory hashes = new bytes32[](proofFlagsLen);
         uint256 leafPos = 0;
         uint256 hashPos = 0;
         uint256 proofPos = 0;
@@ -217,7 +218,7 @@ library MerkleProof {
         //   get the next hash.
         // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
         //   `proof` array.
-        for (uint256 i = 0; i < proofFlags.length; i++) {
+        for (uint256 i = 0; i < proofFlagsLen; i++) {
             bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@@ -225,12 +226,12 @@ library MerkleProof {
             hashes[i] = Hashes.commutativeKeccak256(a, b);
         }
 
-        if (proofFlags.length > 0) {
+        if (proofFlagsLen > 0) {
             if (proofPos != proof.length) {
                 revert MerkleProofInvalidMultiproof();
             }
             unchecked {
-                return hashes[proofFlags.length - 1];
+                return hashes[proofFlagsLen - 1];
             }
         } else if (leavesLen > 0) {
             return leaves[0];
@@ -280,15 +281,16 @@ library MerkleProof {
         // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
         // the Merkle tree.
         uint256 leavesLen = leaves.length;
+        uint256 proofFlagsLen = proofFlags.length;
 
         // Check proof validity.
-        if (leavesLen + proof.length != proofFlags.length + 1) {
+        if (leavesLen + proof.length != proofFlagsLen + 1) {
             revert MerkleProofInvalidMultiproof();
         }
 
         // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
         // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
-        bytes32[] memory hashes = new bytes32[](proofFlags.length);
+        bytes32[] memory hashes = new bytes32[](proofFlagsLen);
         uint256 leafPos = 0;
         uint256 hashPos = 0;
         uint256 proofPos = 0;
@@ -297,7 +299,7 @@ library MerkleProof {
         //   get the next hash.
         // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
         //   `proof` array.
-        for (uint256 i = 0; i < proofFlags.length; i++) {
+        for (uint256 i = 0; i < proofFlagsLen; i++) {
             bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@@ -305,12 +307,12 @@ library MerkleProof {
             hashes[i] = hasher(a, b);
         }
 
-        if (proofFlags.length > 0) {
+        if (proofFlagsLen > 0) {
             if (proofPos != proof.length) {
                 revert MerkleProofInvalidMultiproof();
             }
             unchecked {
-                return hashes[proofFlags.length - 1];
+                return hashes[proofFlagsLen - 1];
             }
         } else if (leavesLen > 0) {
             return leaves[0];
@@ -331,9 +333,9 @@ library MerkleProof {
         bytes32[] calldata proof,
         bool[] calldata proofFlags,
         bytes32 root,
-        bytes32[] calldata leaves
+        bytes32[] memory leaves
     ) internal pure returns (bool) {
-        return processMultiProof(proof, proofFlags, leaves) == root;
+        return processMultiProofCalldata(proof, proofFlags, leaves) == root;
     }
 
     /**
@@ -351,22 +353,23 @@ library MerkleProof {
     function processMultiProofCalldata(
         bytes32[] calldata proof,
         bool[] calldata proofFlags,
-        bytes32[] calldata leaves
+        bytes32[] memory leaves
     ) internal pure returns (bytes32 merkleRoot) {
         // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
         // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the
         // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
         // the Merkle tree.
         uint256 leavesLen = leaves.length;
+        uint256 proofFlagsLen = proofFlags.length;
 
         // Check proof validity.
-        if (leavesLen + proof.length != proofFlags.length + 1) {
+        if (leavesLen + proof.length != proofFlagsLen + 1) {
             revert MerkleProofInvalidMultiproof();
         }
 
         // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
         // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
-        bytes32[] memory hashes = new bytes32[](proofFlags.length);
+        bytes32[] memory hashes = new bytes32[](proofFlagsLen);
         uint256 leafPos = 0;
         uint256 hashPos = 0;
         uint256 proofPos = 0;
@@ -375,7 +378,7 @@ library MerkleProof {
         //   get the next hash.
         // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
         //   `proof` array.
-        for (uint256 i = 0; i < proofFlags.length; i++) {
+        for (uint256 i = 0; i < proofFlagsLen; i++) {
             bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@@ -383,12 +386,12 @@ library MerkleProof {
             hashes[i] = Hashes.commutativeKeccak256(a, b);
         }
 
-        if (proofFlags.length > 0) {
+        if (proofFlagsLen > 0) {
             if (proofPos != proof.length) {
                 revert MerkleProofInvalidMultiproof();
             }
             unchecked {
-                return hashes[proofFlags.length - 1];
+                return hashes[proofFlagsLen - 1];
             }
         } else if (leavesLen > 0) {
             return leaves[0];
@@ -409,10 +412,10 @@ library MerkleProof {
         bytes32[] calldata proof,
         bool[] calldata proofFlags,
         bytes32 root,
-        bytes32[] calldata leaves,
+        bytes32[] memory leaves,
         function(bytes32, bytes32) view returns (bytes32) hasher
     ) internal view returns (bool) {
-        return processMultiProof(proof, proofFlags, leaves, hasher) == root;
+        return processMultiProofCalldata(proof, proofFlags, leaves, hasher) == root;
     }
 
     /**
@@ -430,7 +433,7 @@ library MerkleProof {
     function processMultiProofCalldata(
         bytes32[] calldata proof,
         bool[] calldata proofFlags,
-        bytes32[] calldata leaves,
+        bytes32[] memory leaves,
         function(bytes32, bytes32) view returns (bytes32) hasher
     ) internal view returns (bytes32 merkleRoot) {
         // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
@@ -438,15 +441,16 @@ library MerkleProof {
         // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
         // the Merkle tree.
         uint256 leavesLen = leaves.length;
+        uint256 proofFlagsLen = proofFlags.length;
 
         // Check proof validity.
-        if (leavesLen + proof.length != proofFlags.length + 1) {
+        if (leavesLen + proof.length != proofFlagsLen + 1) {
             revert MerkleProofInvalidMultiproof();
         }
 
         // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
         // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
-        bytes32[] memory hashes = new bytes32[](proofFlags.length);
+        bytes32[] memory hashes = new bytes32[](proofFlagsLen);
         uint256 leafPos = 0;
         uint256 hashPos = 0;
         uint256 proofPos = 0;
@@ -455,7 +459,7 @@ library MerkleProof {
         //   get the next hash.
         // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
         //   `proof` array.
-        for (uint256 i = 0; i < proofFlags.length; i++) {
+        for (uint256 i = 0; i < proofFlagsLen; i++) {
             bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
             bytes32 b = proofFlags[i]
                 ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@@ -463,12 +467,12 @@ library MerkleProof {
             hashes[i] = hasher(a, b);
         }
 
-        if (proofFlags.length > 0) {
+        if (proofFlagsLen > 0) {
             if (proofPos != proof.length) {
                 revert MerkleProofInvalidMultiproof();
             }
             unchecked {
-                return hashes[proofFlags.length - 1];
+                return hashes[proofFlagsLen - 1];
             }
         } else if (leavesLen > 0) {
             return leaves[0];

+ 10 - 9
scripts/generate/templates/MerkleProof.js

@@ -56,7 +56,7 @@ function verify${suffix}(${(hash ? formatArgsMultiline : formatArgsSingleLine)(
   'bytes32 leaf',
   hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
 )}) internal ${visibility} returns (bool) {
-    return processProof(proof, leaf${hash ? `, ${hash}` : ''}) == root;
+    return processProof${suffix}(proof, leaf${hash ? `, ${hash}` : ''}) == root;
 }
 
 /**
@@ -93,10 +93,10 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
   `bytes32[] ${location} proof`,
   `bool[] ${location} proofFlags`,
   'bytes32 root',
-  `bytes32[] ${location} leaves`,
+  `bytes32[] memory leaves`,
   hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
 )}) internal ${visibility} returns (bool) {
-    return processMultiProof(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root;
+    return processMultiProof${suffix}(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root;
 }
 
 /**
@@ -114,7 +114,7 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
 function processMultiProof${suffix}(${formatArgsMultiline(
   `bytes32[] ${location} proof`,
   `bool[] ${location} proofFlags`,
-  `bytes32[] ${location} leaves`,
+  `bytes32[] memory leaves`,
   hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
 )}) internal ${visibility} returns (bytes32 merkleRoot) {
     // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
@@ -122,15 +122,16 @@ function processMultiProof${suffix}(${formatArgsMultiline(
     // \`hashes\` array. At the end of the process, the last hash in the \`hashes\` array should contain the root of
     // the Merkle tree.
     uint256 leavesLen = leaves.length;
+    uint256 proofFlagsLen = proofFlags.length;
 
     // Check proof validity.
-    if (leavesLen + proof.length != proofFlags.length + 1) {
+    if (leavesLen + proof.length != proofFlagsLen + 1) {
         revert MerkleProofInvalidMultiproof();
     }
 
     // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
     // \`xxx[xxxPos++]\`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
-    bytes32[] memory hashes = new bytes32[](proofFlags.length);
+    bytes32[] memory hashes = new bytes32[](proofFlagsLen);
     uint256 leafPos = 0;
     uint256 hashPos = 0;
     uint256 proofPos = 0;
@@ -139,7 +140,7 @@ function processMultiProof${suffix}(${formatArgsMultiline(
     //   get the next hash.
     // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
     //   \`proof\` array.
-    for (uint256 i = 0; i < proofFlags.length; i++) {
+    for (uint256 i = 0; i < proofFlagsLen; i++) {
         bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
         bytes32 b = proofFlags[i]
             ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@@ -147,12 +148,12 @@ function processMultiProof${suffix}(${formatArgsMultiline(
         hashes[i] = ${hash ?? DEFAULT_HASH}(a, b);
     }
 
-    if (proofFlags.length > 0) {
+    if (proofFlagsLen > 0) {
         if (proofPos != proof.length) {
             revert MerkleProofInvalidMultiproof();
         }
         unchecked {
-            return hashes[proofFlags.length - 1];
+            return hashes[proofFlagsLen - 1];
         }
     } else if (leavesLen > 0) {
         return leaves[0];