Explorar el Código

Add Bytes32x2Set (#5442)

Co-authored-by: Ernesto García <ernestognw@gmail.com>
Hadrien Croubois hace 8 meses
padre
commit
441dc141ac

+ 5 - 0
.changeset/lucky-teachers-sip.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`EnumerableSet`: Add `Bytes32x2Set` that handles (ordered) pairs of bytes32.

+ 5 - 0
.changeset/ten-peas-mix.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Hashes`: Expose `efficientKeccak256` for hashing non-commutative pairs of bytes32 without allocating extra memory.

+ 2 - 2
contracts/utils/cryptography/Hashes.sol

@@ -15,13 +15,13 @@ library Hashes {
      * NOTE: Equivalent to the `standardNodeHash` in our https://github.com/OpenZeppelin/merkle-tree[JavaScript library].
      */
     function commutativeKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32) {
-        return a < b ? _efficientKeccak256(a, b) : _efficientKeccak256(b, a);
+        return a < b ? efficientKeccak256(a, b) : efficientKeccak256(b, a);
     }
 
     /**
      * @dev Implementation of keccak256(abi.encode(a, b)) that doesn't allocate or expand memory.
      */
-    function _efficientKeccak256(bytes32 a, bytes32 b) private pure returns (bytes32 value) {
+    function efficientKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32 value) {
         assembly ("memory-safe") {
             mstore(0x00, a)
             mstore(0x20, b)

+ 113 - 0
contracts/utils/structs/EnumerableSet.sol

@@ -4,6 +4,8 @@
 
 pragma solidity ^0.8.20;
 
+import {Hashes} from "../cryptography/Hashes.sol";
+
 /**
  * @dev Library for managing
  * https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive
@@ -372,4 +374,115 @@ library EnumerableSet {
 
         return result;
     }
+
+    struct Bytes32x2Set {
+        // Storage of set values
+        bytes32[2][] _values;
+        // Position is the index of the value in the `values` array plus 1.
+        // Position 0 is used to mean a value is not in the self.
+        mapping(bytes32 valueHash => uint256) _positions;
+    }
+
+    /**
+     * @dev Add a value to a self. O(1).
+     *
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
+     */
+    function add(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) {
+        if (!contains(self, value)) {
+            self._values.push(value);
+            // The value is stored at length-1, but we add 1 to all indexes
+            // and use 0 as a sentinel value
+            self._positions[_hash(value)] = self._values.length;
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes a value from a self. O(1).
+     *
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
+     */
+    function remove(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) {
+        // We cache the value's position to prevent multiple reads from the same storage slot
+        bytes32 valueHash = _hash(value);
+        uint256 position = self._positions[valueHash];
+
+        if (position != 0) {
+            // Equivalent to contains(self, value)
+            // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+            // the array, and then remove the last element (sometimes called as 'swap and pop').
+            // This modifies the order of the array, as noted in {at}.
+
+            uint256 valueIndex = position - 1;
+            uint256 lastIndex = self._values.length - 1;
+
+            if (valueIndex != lastIndex) {
+                bytes32[2] memory lastValue = self._values[lastIndex];
+
+                // Move the lastValue to the index where the value to delete is
+                self._values[valueIndex] = lastValue;
+                // Update the tracked position of the lastValue (that was just moved)
+                self._positions[_hash(lastValue)] = position;
+            }
+
+            // Delete the slot where the moved value was stored
+            self._values.pop();
+
+            // Delete the tracked position for the deleted slot
+            delete self._positions[valueHash];
+
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Returns true if the value is in the self. O(1).
+     */
+    function contains(Bytes32x2Set storage self, bytes32[2] memory value) internal view returns (bool) {
+        return self._positions[_hash(value)] != 0;
+    }
+
+    /**
+     * @dev Returns the number of values on the self. O(1).
+     */
+    function length(Bytes32x2Set storage self) internal view returns (uint256) {
+        return self._values.length;
+    }
+
+    /**
+     * @dev Returns the value stored at position `index` in the self. O(1).
+     *
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(Bytes32x2Set storage self, uint256 index) internal view returns (bytes32[2] memory) {
+        return self._values[index];
+    }
+
+    /**
+     * @dev Return the entire set in an array
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function values(Bytes32x2Set storage self) internal view returns (bytes32[2][] memory) {
+        return self._values;
+    }
+
+    function _hash(bytes32[2] memory value) private pure returns (bytes32) {
+        return Hashes.efficientKeccak256(value[0], value[1]);
+    }
 }

+ 120 - 1
scripts/generate/templates/EnumerableSet.js

@@ -5,6 +5,8 @@ const { TYPES } = require('./EnumerableSet.opts');
 const header = `\
 pragma solidity ^0.8.20;
 
+import {Hashes} from "../cryptography/Hashes.sol";
+
 /**
  * @dev Library for managing
  * https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive
@@ -233,6 +235,121 @@ function values(${name} storage set) internal view returns (${type}[] memory) {
 }
 `;
 
+const memorySet = ({ name, type }) => `\
+struct ${name} {
+    // Storage of set values
+    ${type}[] _values;
+    // Position is the index of the value in the \`values\` array plus 1.
+    // Position 0 is used to mean a value is not in the self.
+    mapping(bytes32 valueHash => uint256) _positions;
+}
+
+/**
+ * @dev Add a value to a self. O(1).
+ *
+ * Returns true if the value was added to the set, that is if it was not
+ * already present.
+ */
+function add(${name} storage self, ${type} memory value) internal returns (bool) {
+    if (!contains(self, value)) {
+        self._values.push(value);
+        // The value is stored at length-1, but we add 1 to all indexes
+        // and use 0 as a sentinel value
+        self._positions[_hash(value)] = self._values.length;
+        return true;
+    } else {
+        return false;
+    }
+}
+
+/**
+ * @dev Removes a value from a self. O(1).
+ *
+ * Returns true if the value was removed from the set, that is if it was
+ * present.
+ */
+function remove(${name} storage self, ${type} memory value) internal returns (bool) {
+    // We cache the value's position to prevent multiple reads from the same storage slot
+    bytes32 valueHash = _hash(value);
+    uint256 position = self._positions[valueHash];
+
+    if (position != 0) {
+        // Equivalent to contains(self, value)
+        // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+        // the array, and then remove the last element (sometimes called as 'swap and pop').
+        // This modifies the order of the array, as noted in {at}.
+
+        uint256 valueIndex = position - 1;
+        uint256 lastIndex = self._values.length - 1;
+
+        if (valueIndex != lastIndex) {
+            ${type} memory lastValue = self._values[lastIndex];
+
+            // Move the lastValue to the index where the value to delete is
+            self._values[valueIndex] = lastValue;
+            // Update the tracked position of the lastValue (that was just moved)
+            self._positions[_hash(lastValue)] = position;
+        }
+
+        // Delete the slot where the moved value was stored
+        self._values.pop();
+
+        // Delete the tracked position for the deleted slot
+        delete self._positions[valueHash];
+
+        return true;
+    } else {
+        return false;
+    }
+}
+
+/**
+ * @dev Returns true if the value is in the self. O(1).
+ */
+function contains(${name} storage self, ${type} memory value) internal view returns (bool) {
+    return self._positions[_hash(value)] != 0;
+}
+
+/**
+ * @dev Returns the number of values on the self. O(1).
+ */
+function length(${name} storage self) internal view returns (uint256) {
+    return self._values.length;
+}
+
+/**
+ * @dev Returns the value stored at position \`index\` in the self. O(1).
+ *
+ * Note that there are no guarantees on the ordering of values inside the
+ * array, and it may change when more values are added or removed.
+ *
+ * Requirements:
+ *
+ * - \`index\` must be strictly less than {length}.
+ */
+function at(${name} storage self, uint256 index) internal view returns (${type} memory) {
+    return self._values[index];
+}
+
+/**
+ * @dev Return the entire set in an array
+ *
+ * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+ * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+ * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+ * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
+ */
+function values(${name} storage self) internal view returns (${type}[] memory) {
+    return self._values;
+}
+`;
+
+const hashes = `\
+function _hash(bytes32[2] memory value) private pure returns (bytes32) {
+    return Hashes.efficientKeccak256(value[0], value[1]);
+}
+`;
+
 // GENERATE
 module.exports = format(
   header.trimEnd(),
@@ -240,7 +357,9 @@ module.exports = format(
   format(
     [].concat(
       defaultSet,
-      TYPES.map(details => customSet(details)),
+      TYPES.filter(({ size }) => size == undefined).map(details => customSet(details)),
+      TYPES.filter(({ size }) => size != undefined).map(details => memorySet(details)),
+      hashes,
     ),
   ).trimEnd(),
   '}',

+ 9 - 5
scripts/generate/templates/EnumerableSet.opts.js

@@ -1,12 +1,16 @@
 const { capitalize } = require('../../helpers');
 
-const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
+const mapType = ({ type, size }) => [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x');
 
-const formatType = type => ({
-  name: `${mapType(type)}Set`,
-  type,
+const formatType = ({ type, size = undefined }) => ({
+  name: `${mapType({ type, size })}Set`,
+  type: size != undefined ? `${type}[${size}]` : type,
+  base: size != undefined ? type : undefined,
+  size,
 });
 
-const TYPES = ['bytes32', 'address', 'uint256'].map(formatType);
+const TYPES = [{ type: 'bytes32' }, { type: 'bytes32', size: 2 }, { type: 'address' }, { type: 'uint256' }].map(
+  formatType,
+);
 
 module.exports = { TYPES, formatType };

+ 7 - 4
test/utils/structs/EnumerableSet.test.js

@@ -20,10 +20,13 @@ async function fixture() {
   const mock = await ethers.deployContract('$EnumerableSet');
 
   const env = Object.fromEntries(
-    TYPES.map(({ name, type }) => [
+    TYPES.map(({ name, type, base, size }) => [
       type,
       {
-        values: Array.from({ length: 3 }, generators[type]),
+        values: Array.from(
+          { length: 3 },
+          size ? () => Array.from({ length: size }, generators[base]) : generators[type],
+        ),
         methods: getMethods(mock, {
           add: `$add(uint256,${type})`,
           remove: `$remove(uint256,${type})`,
@@ -33,8 +36,8 @@ async function fixture() {
           values: `$values_EnumerableSet_${name}(uint256)`,
         }),
         events: {
-          addReturn: `return$add_EnumerableSet_${name}_${type}`,
-          removeReturn: `return$remove_EnumerableSet_${name}_${type}`,
+          addReturn: `return$add_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`,
+          removeReturn: `return$remove_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`,
         },
       },
     ]),