Prechádzať zdrojové kódy

Add non-value types in EnumerableSet and EnumerableMap (#5658)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Ernesto García 4 mesiacov pred
rodič
commit
784d4f71b1

+ 5 - 0
.changeset/long-hornets-mate.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`EnumerableMap`: Add support for `BytesToBytesMap` type.

+ 5 - 0
.changeset/pink-dolls-shop.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`EnumerableSet`: Add support for `StringSet` and `BytesSet` types.

+ 120 - 1
contracts/utils/structs/EnumerableMap.sol

@@ -39,6 +39,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
  * - `address -> address` (`AddressToAddressMap`) since v5.1.0
  * - `address -> bytes32` (`AddressToBytes32Map`) since v5.1.0
  * - `bytes32 -> address` (`Bytes32ToAddressMap`) since v5.1.0
+ * - `bytes -> bytes` (`BytesToBytesMap`) since v5.4.0
  *
  * [WARNING]
  * ====
@@ -51,7 +52,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
  * ====
  */
 library EnumerableMap {
-    using EnumerableSet for EnumerableSet.Bytes32Set;
+    using EnumerableSet for *;
 
     // To implement this library for multiple types with as little code repetition as possible, we write it in
     // terms of a generic Map type with bytes32 keys and values. The Map implementation uses private functions,
@@ -997,4 +998,122 @@ library EnumerableMap {
 
         return result;
     }
+
+    /**
+     * @dev Query for a nonexistent map key.
+     */
+    error EnumerableMapNonexistentBytesKey(bytes key);
+
+    struct BytesToBytesMap {
+        // Storage of keys
+        EnumerableSet.BytesSet _keys;
+        mapping(bytes key => bytes) _values;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(BytesToBytesMap storage map, bytes memory key, bytes memory value) internal returns (bool) {
+        map._values[key] = value;
+        return map._keys.add(key);
+    }
+
+    /**
+     * @dev Removes a key-value pair from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(BytesToBytesMap storage map, bytes memory key) internal returns (bool) {
+        delete map._values[key];
+        return map._keys.remove(key);
+    }
+
+    /**
+     * @dev Removes all the entries from a map. O(n).
+     *
+     * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
+     * function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block.
+     */
+    function clear(BytesToBytesMap storage map) internal {
+        uint256 len = length(map);
+        for (uint256 i = 0; i < len; ++i) {
+            delete map._values[map._keys.at(i)];
+        }
+        map._keys.clear();
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(BytesToBytesMap storage map, bytes memory key) internal view returns (bool) {
+        return map._keys.contains(key);
+    }
+
+    /**
+     * @dev Returns the number of key-value pairs in the map. O(1).
+     */
+    function length(BytesToBytesMap storage map) internal view returns (uint256) {
+        return map._keys.length();
+    }
+
+    /**
+     * @dev Returns the key-value pair stored at position `index` in the map. O(1).
+     *
+     * Note that there are no guarantees on the ordering of entries inside the
+     * array, and it may change when more entries are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(
+        BytesToBytesMap storage map,
+        uint256 index
+    ) internal view returns (bytes memory key, bytes memory value) {
+        key = map._keys.at(index);
+        value = map._values[key];
+    }
+
+    /**
+     * @dev Tries to returns the value associated with `key`. O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(
+        BytesToBytesMap storage map,
+        bytes memory key
+    ) internal view returns (bool exists, bytes memory value) {
+        value = map._values[key];
+        exists = bytes(value).length != 0 || contains(map, key);
+    }
+
+    /**
+     * @dev Returns the value associated with `key`. O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(BytesToBytesMap storage map, bytes memory key) internal view returns (bytes memory value) {
+        bool exists;
+        (exists, value) = tryGet(map, key);
+        if (!exists) {
+            revert EnumerableMapNonexistentBytesKey(key);
+        }
+    }
+
+    /**
+     * @dev Return the an array containing all the keys
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function keys(BytesToBytesMap storage map) internal view returns (bytes[] memory) {
+        return map._keys.values();
+    }
 }

+ 247 - 2
contracts/utils/structs/EnumerableSet.sol

@@ -28,8 +28,13 @@ import {Arrays} from "../Arrays.sol";
  * }
  * ```
  *
- * As of v3.3.0, sets of type `bytes32` (`Bytes32Set`), `address` (`AddressSet`)
- * and `uint256` (`UintSet`) are supported.
+ * The following types are supported:
+ *
+ * - `bytes32` (`Bytes32Set`) since v3.3.0
+ * - `address` (`AddressSet`) since v3.3.0
+ * - `uint256` (`UintSet`) since v3.3.0
+ * - `string` (`StringSet`) since v5.4.0
+ * - `bytes` (`BytesSet`) since v5.4.0
  *
  * [WARNING]
  * ====
@@ -419,4 +424,244 @@ library EnumerableSet {
 
         return result;
     }
+
+    struct StringSet {
+        // Storage of set values
+        string[] _values;
+        // Position is the index of the value in the `values` array plus 1.
+        // Position 0 is used to mean a value is not in the set.
+        mapping(string value => uint256) _positions;
+    }
+
+    /**
+     * @dev Add a value to a set. O(1).
+     *
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
+     */
+    function add(StringSet storage self, string memory value) internal returns (bool) {
+        if (!contains(self, value)) {
+            self._values.push(value);
+            // The value is stored at length-1, but we add 1 to all indexes
+            // and use 0 as a sentinel value
+            self._positions[value] = self._values.length;
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes a value from a set. O(1).
+     *
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
+     */
+    function remove(StringSet storage self, string memory value) internal returns (bool) {
+        // We cache the value's position to prevent multiple reads from the same storage slot
+        uint256 position = self._positions[value];
+
+        if (position != 0) {
+            // Equivalent to contains(self, value)
+            // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+            // the array, and then remove the last element (sometimes called as 'swap and pop').
+            // This modifies the order of the array, as noted in {at}.
+
+            uint256 valueIndex = position - 1;
+            uint256 lastIndex = self._values.length - 1;
+
+            if (valueIndex != lastIndex) {
+                string memory lastValue = self._values[lastIndex];
+
+                // Move the lastValue to the index where the value to delete is
+                self._values[valueIndex] = lastValue;
+                // Update the tracked position of the lastValue (that was just moved)
+                self._positions[lastValue] = position;
+            }
+
+            // Delete the slot where the moved value was stored
+            self._values.pop();
+
+            // Delete the tracked position for the deleted slot
+            delete self._positions[value];
+
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes all the values from a set. O(n).
+     *
+     * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
+     * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
+     */
+    function clear(StringSet storage set) internal {
+        uint256 len = length(set);
+        for (uint256 i = 0; i < len; ++i) {
+            delete set._positions[set._values[i]];
+        }
+        Arrays.unsafeSetLength(set._values, 0);
+    }
+
+    /**
+     * @dev Returns true if the value is in the set. O(1).
+     */
+    function contains(StringSet storage self, string memory value) internal view returns (bool) {
+        return self._positions[value] != 0;
+    }
+
+    /**
+     * @dev Returns the number of values on the set. O(1).
+     */
+    function length(StringSet storage self) internal view returns (uint256) {
+        return self._values.length;
+    }
+
+    /**
+     * @dev Returns the value stored at position `index` in the set. O(1).
+     *
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(StringSet storage self, uint256 index) internal view returns (string memory) {
+        return self._values[index];
+    }
+
+    /**
+     * @dev Return the entire set in an array
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function values(StringSet storage self) internal view returns (string[] memory) {
+        return self._values;
+    }
+
+    struct BytesSet {
+        // Storage of set values
+        bytes[] _values;
+        // Position is the index of the value in the `values` array plus 1.
+        // Position 0 is used to mean a value is not in the set.
+        mapping(bytes value => uint256) _positions;
+    }
+
+    /**
+     * @dev Add a value to a set. O(1).
+     *
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
+     */
+    function add(BytesSet storage self, bytes memory value) internal returns (bool) {
+        if (!contains(self, value)) {
+            self._values.push(value);
+            // The value is stored at length-1, but we add 1 to all indexes
+            // and use 0 as a sentinel value
+            self._positions[value] = self._values.length;
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes a value from a set. O(1).
+     *
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
+     */
+    function remove(BytesSet storage self, bytes memory value) internal returns (bool) {
+        // We cache the value's position to prevent multiple reads from the same storage slot
+        uint256 position = self._positions[value];
+
+        if (position != 0) {
+            // Equivalent to contains(self, value)
+            // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+            // the array, and then remove the last element (sometimes called as 'swap and pop').
+            // This modifies the order of the array, as noted in {at}.
+
+            uint256 valueIndex = position - 1;
+            uint256 lastIndex = self._values.length - 1;
+
+            if (valueIndex != lastIndex) {
+                bytes memory lastValue = self._values[lastIndex];
+
+                // Move the lastValue to the index where the value to delete is
+                self._values[valueIndex] = lastValue;
+                // Update the tracked position of the lastValue (that was just moved)
+                self._positions[lastValue] = position;
+            }
+
+            // Delete the slot where the moved value was stored
+            self._values.pop();
+
+            // Delete the tracked position for the deleted slot
+            delete self._positions[value];
+
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes all the values from a set. O(n).
+     *
+     * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
+     * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
+     */
+    function clear(BytesSet storage set) internal {
+        uint256 len = length(set);
+        for (uint256 i = 0; i < len; ++i) {
+            delete set._positions[set._values[i]];
+        }
+        Arrays.unsafeSetLength(set._values, 0);
+    }
+
+    /**
+     * @dev Returns true if the value is in the set. O(1).
+     */
+    function contains(BytesSet storage self, bytes memory value) internal view returns (bool) {
+        return self._positions[value] != 0;
+    }
+
+    /**
+     * @dev Returns the number of values on the set. O(1).
+     */
+    function length(BytesSet storage self) internal view returns (uint256) {
+        return self._values.length;
+    }
+
+    /**
+     * @dev Returns the value stored at position `index` in the set. O(1).
+     *
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(BytesSet storage self, uint256 index) internal view returns (bytes memory) {
+        return self._values[index];
+    }
+
+    /**
+     * @dev Return the entire set in an array
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function values(BytesSet storage self) internal view returns (bytes[] memory) {
+        return self._values;
+    }
 }

+ 8 - 5
scripts/generate/run.js

@@ -1,6 +1,6 @@
 #!/usr/bin/env node
 
-// const cp = require('child_process');
+const cp = require('child_process');
 const fs = require('fs');
 const path = require('path');
 const format = require('./format-lines');
@@ -13,7 +13,7 @@ function getVersion(path) {
   }
 }
 
-function generateFromTemplate(file, template, outputPrefix = '') {
+function generateFromTemplate(file, template, outputPrefix = '', lint = false) {
   const script = path.relative(path.join(__dirname, '../..'), __filename);
   const input = path.join(path.dirname(script), template);
   const output = path.join(outputPrefix, file);
@@ -27,9 +27,12 @@ function generateFromTemplate(file, template, outputPrefix = '') {
   );
 
   fs.writeFileSync(output, content);
-  // cp.execFileSync('prettier', ['--write', output]);
+  lint && cp.execFileSync('prettier', ['--write', output]);
 }
 
+// Some templates needs to go through the linter after generation
+const needsLinter = ['utils/structs/EnumerableMap.sol'];
+
 // Contracts
 for (const [file, template] of Object.entries({
   'utils/cryptography/MerkleProof.sol': './templates/MerkleProof.js',
@@ -45,7 +48,7 @@ for (const [file, template] of Object.entries({
   'mocks/StorageSlotMock.sol': './templates/StorageSlotMock.js',
   'mocks/TransientSlotMock.sol': './templates/TransientSlotMock.js',
 })) {
-  generateFromTemplate(file, template, './contracts/');
+  generateFromTemplate(file, template, './contracts/', needsLinter.includes(file));
 }
 
 // Tests
@@ -54,5 +57,5 @@ for (const [file, template] of Object.entries({
   'utils/Packing.t.sol': './templates/Packing.t.js',
   'utils/SlotDerivation.t.sol': './templates/SlotDerivation.t.js',
 })) {
-  generateFromTemplate(file, template, './test/');
+  generateFromTemplate(file, template, './test/', needsLinter.includes(file));
 }

+ 53 - 0
scripts/generate/templates/Enumerable.opts.js

@@ -0,0 +1,53 @@
+const { capitalize, mapValues } = require('../../helpers');
+
+const typeDescr = ({ type, size = 0, memory = false }) => {
+  memory |= size > 0;
+
+  const name = [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x');
+  const base = size ? type : undefined;
+  const typeFull = size ? `${type}[${size}]` : type;
+  const typeLoc = memory ? `${typeFull} memory` : typeFull;
+  return { name, type: typeFull, typeLoc, base, size, memory };
+};
+
+const toSetTypeDescr = value => ({
+  name: value.name + 'Set',
+  value,
+});
+
+const toMapTypeDescr = ({ key, value }) => ({
+  name: `${key.name}To${value.name}Map`,
+  keySet: toSetTypeDescr(key),
+  key,
+  value,
+});
+
+const SET_TYPES = [
+  { type: 'bytes32' },
+  { type: 'address' },
+  { type: 'uint256' },
+  { type: 'string', memory: true },
+  { type: 'bytes', memory: true },
+]
+  .map(typeDescr)
+  .map(toSetTypeDescr);
+
+const MAP_TYPES = []
+  .concat(
+    // value type maps
+    ['uint256', 'address', 'bytes32']
+      .flatMap((keyType, _, array) => array.map(valueType => ({ key: { type: keyType }, value: { type: valueType } })))
+      .slice(0, -1), // remove bytes32 → bytes32 (last one) that is already defined
+    // non-value type maps
+    { key: { type: 'bytes', memory: true }, value: { type: 'bytes', memory: true } },
+  )
+  .map(entry => mapValues(entry, typeDescr))
+  .map(toMapTypeDescr);
+
+module.exports = {
+  SET_TYPES,
+  MAP_TYPES,
+  typeDescr,
+  toSetTypeDescr,
+  toMapTypeDescr,
+};

+ 141 - 19
scripts/generate/templates/EnumerableMap.js

@@ -1,6 +1,6 @@
 const format = require('../format-lines');
 const { fromBytes32, toBytes32 } = require('./conversion');
-const { TYPES } = require('./EnumerableMap.opts');
+const { MAP_TYPES } = require('./Enumerable.opts');
 
 const header = `\
 pragma solidity ^0.8.20;
@@ -40,6 +40,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
  * - \`address -> address\` (\`AddressToAddressMap\`) since v5.1.0
  * - \`address -> bytes32\` (\`AddressToBytes32Map\`) since v5.1.0
  * - \`bytes32 -> address\` (\`Bytes32ToAddressMap\`) since v5.1.0
+ * - \`bytes -> bytes\` (\`BytesToBytesMap\`) since v5.4.0
  *
  * [WARNING]
  * ====
@@ -176,7 +177,7 @@ function keys(Bytes32ToBytes32Map storage map) internal view returns (bytes32[]
 }
 `;
 
-const customMap = ({ name, keyType, valueType }) => `\
+const customMap = ({ name, key, value }) => `\
 // ${name}
 
 struct ${name} {
@@ -190,8 +191,8 @@ struct ${name} {
  * Returns true if the key was added to the map, that is if it was not
  * already present.
  */
-function set(${name} storage map, ${keyType} key, ${valueType} value) internal returns (bool) {
-    return set(map._inner, ${toBytes32(keyType, 'key')}, ${toBytes32(valueType, 'value')});
+function set(${name} storage map, ${key.type} key, ${value.type} value) internal returns (bool) {
+    return set(map._inner, ${toBytes32(key.type, 'key')}, ${toBytes32(value.type, 'value')});
 }
 
 /**
@@ -199,8 +200,8 @@ function set(${name} storage map, ${keyType} key, ${valueType} value) internal r
  *
  * Returns true if the key was removed from the map, that is if it was present.
  */
-function remove(${name} storage map, ${keyType} key) internal returns (bool) {
-    return remove(map._inner, ${toBytes32(keyType, 'key')});
+function remove(${name} storage map, ${key.type} key) internal returns (bool) {
+    return remove(map._inner, ${toBytes32(key.type, 'key')});
 }
 
 /**
@@ -216,8 +217,8 @@ function clear(${name} storage map) internal {
 /**
  * @dev Returns true if the key is in the map. O(1).
  */
-function contains(${name} storage map, ${keyType} key) internal view returns (bool) {
-    return contains(map._inner, ${toBytes32(keyType, 'key')});
+function contains(${name} storage map, ${key.type} key) internal view returns (bool) {
+    return contains(map._inner, ${toBytes32(key.type, 'key')});
 }
 
 /**
@@ -236,18 +237,18 @@ function length(${name} storage map) internal view returns (uint256) {
  *
  * - \`index\` must be strictly less than {length}.
  */
-function at(${name} storage map, uint256 index) internal view returns (${keyType} key, ${valueType} value) {
+function at(${name} storage map, uint256 index) internal view returns (${key.type} key, ${value.type} value) {
     (bytes32 atKey, bytes32 val) = at(map._inner, index);
-    return (${fromBytes32(keyType, 'atKey')}, ${fromBytes32(valueType, 'val')});
+    return (${fromBytes32(key.type, 'atKey')}, ${fromBytes32(value.type, 'val')});
 }
 
 /**
  * @dev Tries to returns the value associated with \`key\`. O(1).
  * Does not revert if \`key\` is not in the map.
  */
-function tryGet(${name} storage map, ${keyType} key) internal view returns (bool exists, ${valueType} value) {
-    (bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(keyType, 'key')});
-    return (success, ${fromBytes32(valueType, 'val')});
+function tryGet(${name} storage map, ${key.type} key) internal view returns (bool exists, ${value.type} value) {
+    (bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(key.type, 'key')});
+    return (success, ${fromBytes32(value.type, 'val')});
 }
 
 /**
@@ -257,8 +258,8 @@ function tryGet(${name} storage map, ${keyType} key) internal view returns (bool
  *
  * - \`key\` must be in the map.
  */
-function get(${name} storage map, ${keyType} key) internal view returns (${valueType}) {
-    return ${fromBytes32(valueType, `get(map._inner, ${toBytes32(keyType, 'key')})`)};
+function get(${name} storage map, ${key.type} key) internal view returns (${value.type}) {
+    return ${fromBytes32(value.type, `get(map._inner, ${toBytes32(key.type, 'key')})`)};
 }
 
 /**
@@ -269,9 +270,9 @@ function get(${name} storage map, ${keyType} key) internal view returns (${value
  * this function has an unbounded cost, and using it as part of a state-changing function may render the function
  * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
  */
-function keys(${name} storage map) internal view returns (${keyType}[] memory) {
+function keys(${name} storage map) internal view returns (${key.type}[] memory) {
     bytes32[] memory store = keys(map._inner);
-    ${keyType}[] memory result;
+    ${key.type}[] memory result;
 
     assembly ("memory-safe") {
         result := store
@@ -281,16 +282,137 @@ function keys(${name} storage map) internal view returns (${keyType}[] memory) {
 }
 `;
 
+const memoryMap = ({ name, keySet, key, value }) => `\
+/**
+ * @dev Query for a nonexistent map key.
+ */
+error EnumerableMapNonexistent${key.name}Key(${key.type} key);
+
+struct ${name} {
+    // Storage of keys
+    EnumerableSet.${keySet.name} _keys;
+    mapping(${key.type} key => ${value.type}) _values;
+}
+
+/**
+ * @dev Adds a key-value pair to a map, or updates the value for an existing
+ * key. O(1).
+ *
+ * Returns true if the key was added to the map, that is if it was not
+ * already present.
+ */
+function set(${name} storage map, ${key.typeLoc} key, ${value.typeLoc} value) internal returns (bool) {
+    map._values[key] = value;
+    return map._keys.add(key);
+}
+
+/**
+ * @dev Removes a key-value pair from a map. O(1).
+ *
+ * Returns true if the key was removed from the map, that is if it was present.
+ */
+function remove(${name} storage map, ${key.typeLoc} key) internal returns (bool) {
+    delete map._values[key];
+    return map._keys.remove(key);
+}
+
+/**
+ * @dev Removes all the entries from a map. O(n).
+ *
+ * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
+ * function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block.
+ */
+function clear(${name} storage map) internal {
+    uint256 len = length(map);
+    for (uint256 i = 0; i < len; ++i) {
+        delete map._values[map._keys.at(i)];
+    }
+    map._keys.clear();
+}
+
+/**
+ * @dev Returns true if the key is in the map. O(1).
+ */
+function contains(${name} storage map, ${key.typeLoc} key) internal view returns (bool) {
+    return map._keys.contains(key);
+}
+
+/**
+ * @dev Returns the number of key-value pairs in the map. O(1).
+ */
+function length(${name} storage map) internal view returns (uint256) {
+    return map._keys.length();
+}
+
+/**
+ * @dev Returns the key-value pair stored at position \`index\` in the map. O(1).
+ *
+ * Note that there are no guarantees on the ordering of entries inside the
+ * array, and it may change when more entries are added or removed.
+ *
+ * Requirements:
+ *
+ * - \`index\` must be strictly less than {length}.
+ */
+function at(
+    ${name} storage map,
+    uint256 index
+) internal view returns (${key.typeLoc} key, ${value.typeLoc} value) {
+    key = map._keys.at(index);
+    value = map._values[key];
+}
+
+/**
+ * @dev Tries to returns the value associated with \`key\`. O(1).
+ * Does not revert if \`key\` is not in the map.
+ */
+function tryGet(
+    ${name} storage map,
+    ${key.typeLoc} key
+) internal view returns (bool exists, ${value.typeLoc} value) {
+    value = map._values[key];
+    exists = ${value.memory ? 'bytes(value).length != 0' : `value != ${value.type}(0)`} || contains(map, key);
+}
+
+/**
+ * @dev Returns the value associated with \`key\`. O(1).
+ *
+ * Requirements:
+ *
+ * - \`key\` must be in the map.
+ */
+function get(${name} storage map, ${key.typeLoc} key) internal view returns (${value.typeLoc} value) {
+    bool exists;
+    (exists, value) = tryGet(map, key);
+    if (!exists) {
+        revert EnumerableMapNonexistent${key.name}Key(key);
+    }
+}
+
+/**
+ * @dev Return the an array containing all the keys
+ *
+ * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+ * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+ * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+ * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+ */
+function keys(${name} storage map) internal view returns (${key.type}[] memory) {
+    return map._keys.values();
+}
+`;
+
 // GENERATE
 module.exports = format(
   header.trimEnd(),
   'library EnumerableMap {',
   format(
     [].concat(
-      'using EnumerableSet for EnumerableSet.Bytes32Set;',
+      'using EnumerableSet for *;',
       '',
       defaultMap,
-      TYPES.map(details => customMap(details)),
+      MAP_TYPES.filter(({ key, value }) => !(key.memory || value.memory)).map(customMap),
+      MAP_TYPES.filter(({ key, value }) => key.memory || value.memory).map(memoryMap),
     ),
   ).trimEnd(),
   '}',

+ 0 - 19
scripts/generate/templates/EnumerableMap.opts.js

@@ -1,19 +0,0 @@
-const { capitalize } = require('../../helpers');
-
-const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
-
-const formatType = (keyType, valueType) => ({
-  name: `${mapType(keyType)}To${mapType(valueType)}Map`,
-  keyType,
-  valueType,
-});
-
-const TYPES = ['uint256', 'address', 'bytes32']
-  .flatMap((key, _, array) => array.map(value => [key, value]))
-  .slice(0, -1) // remove bytes32 → byte32 (last one) that is already defined
-  .map(args => formatType(...args));
-
-module.exports = {
-  TYPES,
-  formatType,
-};

+ 135 - 5
scripts/generate/templates/EnumerableSet.js

@@ -1,6 +1,6 @@
 const format = require('../format-lines');
 const { fromBytes32, toBytes32 } = require('./conversion');
-const { TYPES } = require('./EnumerableSet.opts');
+const { SET_TYPES } = require('./Enumerable.opts');
 
 const header = `\
 pragma solidity ^0.8.20;
@@ -29,8 +29,13 @@ import {Arrays} from "../Arrays.sol";
  * }
  * \`\`\`
  *
- * As of v3.3.0, sets of type \`bytes32\` (\`Bytes32Set\`), \`address\` (\`AddressSet\`)
- * and \`uint256\` (\`UintSet\`) are supported.
+ * The following types are supported:
+ *
+ * - \`bytes32\` (\`Bytes32Set\`) since v3.3.0
+ * - \`address\` (\`AddressSet\`) since v3.3.0
+ * - \`uint256\` (\`UintSet\`) since v3.3.0
+ * - \`string\` (\`StringSet\`) since v5.4.0
+ * - \`bytes\` (\`BytesSet\`) since v5.4.0
  *
  * [WARNING]
  * ====
@@ -44,6 +49,7 @@ import {Arrays} from "../Arrays.sol";
  */
 `;
 
+// NOTE: this should be deprecated in favor of a more native construction in v6.0
 const defaultSet = `\
 // To implement this library for multiple types with as little code
 // repetition as possible, we write it in terms of a generic Set type with
@@ -175,7 +181,8 @@ function _values(Set storage set) private view returns (bytes32[] memory) {
 }
 `;
 
-const customSet = ({ name, type }) => `\
+// NOTE: this should be deprecated in favor of a more native construction in v6.0
+const customSet = ({ name, value: { type } }) => `\
 // ${name}
 
 struct ${name} {
@@ -260,6 +267,128 @@ function values(${name} storage set) internal view returns (${type}[] memory) {
 }
 `;
 
+const memorySet = ({ name, value }) => `\
+struct ${name} {
+    // Storage of set values
+    ${value.type}[] _values;
+    // Position is the index of the value in the \`values\` array plus 1.
+    // Position 0 is used to mean a value is not in the set.
+    mapping(${value.type} value => uint256) _positions;
+}
+
+/**
+ * @dev Add a value to a set. O(1).
+ *
+ * Returns true if the value was added to the set, that is if it was not
+ * already present.
+ */
+function add(${name} storage self, ${value.type} memory value) internal returns (bool) {
+    if (!contains(self, value)) {
+        self._values.push(value);
+        // The value is stored at length-1, but we add 1 to all indexes
+        // and use 0 as a sentinel value
+        self._positions[value] = self._values.length;
+        return true;
+    } else {
+        return false;
+    }
+}
+
+/**
+ * @dev Removes a value from a set. O(1).
+ *
+ * Returns true if the value was removed from the set, that is if it was
+ * present.
+ */
+function remove(${name} storage self, ${value.type} memory value) internal returns (bool) {
+    // We cache the value's position to prevent multiple reads from the same storage slot
+    uint256 position = self._positions[value];
+
+    if (position != 0) {
+        // Equivalent to contains(self, value)
+        // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+        // the array, and then remove the last element (sometimes called as 'swap and pop').
+        // This modifies the order of the array, as noted in {at}.
+
+        uint256 valueIndex = position - 1;
+        uint256 lastIndex = self._values.length - 1;
+
+        if (valueIndex != lastIndex) {
+            ${value.type} memory lastValue = self._values[lastIndex];
+
+            // Move the lastValue to the index where the value to delete is
+            self._values[valueIndex] = lastValue;
+            // Update the tracked position of the lastValue (that was just moved)
+            self._positions[lastValue] = position;
+        }
+
+        // Delete the slot where the moved value was stored
+        self._values.pop();
+
+        // Delete the tracked position for the deleted slot
+        delete self._positions[value];
+
+        return true;
+    } else {
+        return false;
+    }
+}
+
+/**
+ * @dev Removes all the values from a set. O(n).
+ *
+ * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
+ * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
+ */
+function clear(${name} storage set) internal {
+    uint256 len = length(set);
+    for (uint256 i = 0; i < len; ++i) {
+        delete set._positions[set._values[i]];
+    }
+    Arrays.unsafeSetLength(set._values, 0);
+}
+
+/**
+ * @dev Returns true if the value is in the set. O(1).
+ */
+function contains(${name} storage self, ${value.type} memory value) internal view returns (bool) {
+    return self._positions[value] != 0;
+}
+
+/**
+ * @dev Returns the number of values on the set. O(1).
+ */
+function length(${name} storage self) internal view returns (uint256) {
+    return self._values.length;
+}
+
+/**
+ * @dev Returns the value stored at position \`index\` in the set. O(1).
+ *
+ * Note that there are no guarantees on the ordering of values inside the
+ * array, and it may change when more values are added or removed.
+ *
+ * Requirements:
+ *
+ * - \`index\` must be strictly less than {length}.
+ */
+function at(${name} storage self, uint256 index) internal view returns (${value.type} memory) {
+    return self._values[index];
+}
+
+/**
+ * @dev Return the entire set in an array
+ *
+ * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+ * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+ * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+ * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
+ */
+function values(${name} storage self) internal view returns (${value.type}[] memory) {
+    return self._values;
+}
+`;
+
 // GENERATE
 module.exports = format(
   header.trimEnd(),
@@ -267,7 +396,8 @@ module.exports = format(
   format(
     [].concat(
       defaultSet,
-      TYPES.map(details => customSet(details)),
+      SET_TYPES.filter(({ value }) => !value.memory).map(customSet),
+      SET_TYPES.filter(({ value }) => value.memory).map(memorySet),
     ),
   ).trimEnd(),
   '}',

+ 0 - 12
scripts/generate/templates/EnumerableSet.opts.js

@@ -1,12 +0,0 @@
-const { capitalize } = require('../../helpers');
-
-const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
-
-const formatType = type => ({
-  name: `${mapType(type)}Set`,
-  type,
-});
-
-const TYPES = ['bytes32', 'address', 'uint256'].map(formatType);
-
-module.exports = { TYPES, formatType };

+ 1 - 1
test/utils/structs/EnumerableMap.behavior.js

@@ -176,7 +176,7 @@ function shouldBehaveLikeMap() {
           .withArgs(
             this.key?.memory || this.value?.memory
               ? this.keyB
-              : ethers.AbiCoder.defaultAbiCoder().encode([this.keyType], [this.keyB]),
+              : ethers.AbiCoder.defaultAbiCoder().encode([this.key.type], [this.keyB]),
           );
       });
     });

+ 37 - 22
test/utils/structs/EnumerableMap.test.js

@@ -3,43 +3,58 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { mapValues } = require('../../helpers/iterate');
 const { generators } = require('../../helpers/random');
-const { TYPES, formatType } = require('../../../scripts/generate/templates/EnumerableMap.opts');
+const { MAP_TYPES, typeDescr, toMapTypeDescr } = require('../../../scripts/generate/templates/Enumerable.opts');
 
 const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior');
 
 // Add Bytes32ToBytes32Map that must be tested but is not part of the generated types.
-TYPES.unshift(formatType('bytes32', 'bytes32'));
+MAP_TYPES.unshift(toMapTypeDescr({ key: typeDescr({ type: 'bytes32' }), value: typeDescr({ type: 'bytes32' }) }));
 
 async function fixture() {
   const mock = await ethers.deployContract('$EnumerableMap');
+
   const env = Object.fromEntries(
-    TYPES.map(({ name, keyType, valueType }) => [
+    MAP_TYPES.map(({ name, key, value }) => [
       name,
       {
-        keyType,
-        keys: Array.from({ length: 3 }, generators[keyType]),
-        values: Array.from({ length: 3 }, generators[valueType]),
-        zeroValue: generators[valueType].zero,
+        key,
+        value,
+        keys: Array.from({ length: 3 }, generators[key.type]),
+        values: Array.from({ length: 3 }, generators[value.type]),
+        zeroValue: generators[value.type].zero,
         methods: mapValues(
-          {
-            set: `$set(uint256,${keyType},${valueType})`,
-            get: `$get_EnumerableMap_${name}(uint256,${keyType})`,
-            tryGet: `$tryGet_EnumerableMap_${name}(uint256,${keyType})`,
-            remove: `$remove_EnumerableMap_${name}(uint256,${keyType})`,
-            clear: `$clear_EnumerableMap_${name}(uint256)`,
-            length: `$length_EnumerableMap_${name}(uint256)`,
-            at: `$at_EnumerableMap_${name}(uint256,uint256)`,
-            contains: `$contains_EnumerableMap_${name}(uint256,${keyType})`,
-            keys: `$keys_EnumerableMap_${name}(uint256)`,
-          },
+          MAP_TYPES.filter(map => map.key.name == key.name).length == 1
+            ? {
+                set: `$set(uint256,${key.type},${value.type})`,
+                get: `$get(uint256,${key.type})`,
+                tryGet: `$tryGet(uint256,${key.type})`,
+                remove: `$remove(uint256,${key.type})`,
+                contains: `$contains(uint256,${key.type})`,
+                clear: `$clear_EnumerableMap_${name}(uint256)`,
+                length: `$length_EnumerableMap_${name}(uint256)`,
+                at: `$at_EnumerableMap_${name}(uint256,uint256)`,
+                keys: `$keys_EnumerableMap_${name}(uint256)`,
+              }
+            : {
+                set: `$set(uint256,${key.type},${value.type})`,
+                get: `$get_EnumerableMap_${name}(uint256,${key.type})`,
+                tryGet: `$tryGet_EnumerableMap_${name}(uint256,${key.type})`,
+                remove: `$remove_EnumerableMap_${name}(uint256,${key.type})`,
+                contains: `$contains_EnumerableMap_${name}(uint256,${key.type})`,
+                clear: `$clear_EnumerableMap_${name}(uint256)`,
+                length: `$length_EnumerableMap_${name}(uint256)`,
+                at: `$at_EnumerableMap_${name}(uint256,uint256)`,
+                keys: `$keys_EnumerableMap_${name}(uint256)`,
+              },
           fnSig =>
             (...args) =>
               mock.getFunction(fnSig)(0, ...args),
         ),
         events: {
-          setReturn: `return$set_EnumerableMap_${name}_${keyType}_${valueType}`,
-          removeReturn: `return$remove_EnumerableMap_${name}_${keyType}`,
+          setReturn: `return$set_EnumerableMap_${name}_${key.type}_${value.type}`,
+          removeReturn: `return$remove_EnumerableMap_${name}_${key.type}`,
         },
+        error: key.memory || value.memory ? `EnumerableMapNonexistent${key.name}Key` : `EnumerableMapNonexistentKey`,
       },
     ]),
   );
@@ -52,8 +67,8 @@ describe('EnumerableMap', function () {
     Object.assign(this, await loadFixture(fixture));
   });
 
-  for (const { name } of TYPES) {
-    describe(name, function () {
+  for (const { name, key, value } of MAP_TYPES) {
+    describe(`${name} (enumerable map from ${key.type} to ${value.type})`, function () {
       beforeEach(async function () {
         Object.assign(this, this.env[name]);
         [this.keyA, this.keyB, this.keyC] = this.keys;

+ 18 - 15
test/utils/structs/EnumerableSet.test.js

@@ -3,39 +3,42 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { mapValues } = require('../../helpers/iterate');
 const { generators } = require('../../helpers/random');
-const { TYPES } = require('../../../scripts/generate/templates/EnumerableSet.opts');
+const { SET_TYPES } = require('../../../scripts/generate/templates/Enumerable.opts');
 
 const { shouldBehaveLikeSet } = require('./EnumerableSet.behavior');
 
-const getMethods = (mock, fnSigs) => {
-  return mapValues(
+const getMethods = (mock, fnSigs) =>
+  mapValues(
     fnSigs,
     fnSig =>
       (...args) =>
         mock.getFunction(fnSig)(0, ...args),
   );
-};
 
 async function fixture() {
   const mock = await ethers.deployContract('$EnumerableSet');
 
   const env = Object.fromEntries(
-    TYPES.map(({ name, type }) => [
-      type,
+    SET_TYPES.map(({ name, value }) => [
+      name,
       {
-        values: Array.from({ length: 3 }, generators[type]),
+        value,
+        values: Array.from(
+          { length: 3 },
+          value.size ? () => Array.from({ length: value.size }, generators[value.base]) : generators[value.type],
+        ),
         methods: getMethods(mock, {
-          add: `$add(uint256,${type})`,
-          remove: `$remove(uint256,${type})`,
+          add: `$add(uint256,${value.type})`,
+          remove: `$remove(uint256,${value.type})`,
+          contains: `$contains(uint256,${value.type})`,
           clear: `$clear_EnumerableSet_${name}(uint256)`,
-          contains: `$contains(uint256,${type})`,
           length: `$length_EnumerableSet_${name}(uint256)`,
           at: `$at_EnumerableSet_${name}(uint256,uint256)`,
           values: `$values_EnumerableSet_${name}(uint256)`,
         }),
         events: {
-          addReturn: `return$add_EnumerableSet_${name}_${type}`,
-          removeReturn: `return$remove_EnumerableSet_${name}_${type}`,
+          addReturn: `return$add_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`,
+          removeReturn: `return$remove_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`,
         },
       },
     ]),
@@ -49,10 +52,10 @@ describe('EnumerableSet', function () {
     Object.assign(this, await loadFixture(fixture));
   });
 
-  for (const { type } of TYPES) {
-    describe(type, function () {
+  for (const { name, value } of SET_TYPES) {
+    describe(`${name} (enumerable set of ${value.type})`, function () {
       beforeEach(function () {
-        Object.assign(this, this.env[type]);
+        Object.assign(this, this.env[name]);
         [this.valueA, this.valueB, this.valueC] = this.values;
       });