Browse Source

Add new EnumerableMap types (#4843)

Co-authored-by: ernestognw <ernestognw@gmail.com>
Hadrien Croubois 1 year ago
parent
commit
a5c4cd8182

+ 5 - 0
.changeset/yellow-deers-walk.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`EnumerableMap`: add `UintToBytes32Map`, `AddressToAddressMap`, `AddressToBytes32Map` and `Bytes32ToAddressMap`.

+ 380 - 0
contracts/utils/structs/EnumerableMap.sol

@@ -34,6 +34,10 @@ import {EnumerableSet} from "./EnumerableSet.sol";
  * - `bytes32 -> bytes32` (`Bytes32ToBytes32Map`) since v4.6.0
  * - `uint256 -> uint256` (`UintToUintMap`) since v4.7.0
  * - `bytes32 -> uint256` (`Bytes32ToUintMap`) since v4.7.0
+ * - `uint256 -> bytes32` (`UintToBytes32Map`) since v5.1.0
+ * - `address -> address` (`AddressToAddressMap`) since v5.1.0
+ * - `address -> bytes32` (`AddressToBytes32Map`) since v5.1.0
+ * - `bytes32 -> address` (`Bytes32ToAddressMap`) since v5.1.0
  *
  * [WARNING]
  * ====
@@ -343,6 +347,100 @@ library EnumerableMap {
         return result;
     }
 
+    // UintToBytes32Map
+
+    struct UintToBytes32Map {
+        Bytes32ToBytes32Map _inner;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(UintToBytes32Map storage map, uint256 key, bytes32 value) internal returns (bool) {
+        return set(map._inner, bytes32(key), value);
+    }
+
+    /**
+     * @dev Removes a value from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(UintToBytes32Map storage map, uint256 key) internal returns (bool) {
+        return remove(map._inner, bytes32(key));
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(UintToBytes32Map storage map, uint256 key) internal view returns (bool) {
+        return contains(map._inner, bytes32(key));
+    }
+
+    /**
+     * @dev Returns the number of elements in the map. O(1).
+     */
+    function length(UintToBytes32Map storage map) internal view returns (uint256) {
+        return length(map._inner);
+    }
+
+    /**
+     * @dev Returns the element stored at position `index` in the map. O(1).
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(UintToBytes32Map storage map, uint256 index) internal view returns (uint256, bytes32) {
+        (bytes32 key, bytes32 value) = at(map._inner, index);
+        return (uint256(key), value);
+    }
+
+    /**
+     * @dev Tries to returns the value associated with `key`. O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(UintToBytes32Map storage map, uint256 key) internal view returns (bool, bytes32) {
+        (bool success, bytes32 value) = tryGet(map._inner, bytes32(key));
+        return (success, value);
+    }
+
+    /**
+     * @dev Returns the value associated with `key`. O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(UintToBytes32Map storage map, uint256 key) internal view returns (bytes32) {
+        return get(map._inner, bytes32(key));
+    }
+
+    /**
+     * @dev Return the an array containing all the keys
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function keys(UintToBytes32Map storage map) internal view returns (uint256[] memory) {
+        bytes32[] memory store = keys(map._inner);
+        uint256[] memory result;
+
+        /// @solidity memory-safe-assembly
+        assembly {
+            result := store
+        }
+
+        return result;
+    }
+
     // AddressToUintMap
 
     struct AddressToUintMap {
@@ -437,6 +535,194 @@ library EnumerableMap {
         return result;
     }
 
+    // AddressToAddressMap
+
+    struct AddressToAddressMap {
+        Bytes32ToBytes32Map _inner;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(AddressToAddressMap storage map, address key, address value) internal returns (bool) {
+        return set(map._inner, bytes32(uint256(uint160(key))), bytes32(uint256(uint160(value))));
+    }
+
+    /**
+     * @dev Removes a value from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(AddressToAddressMap storage map, address key) internal returns (bool) {
+        return remove(map._inner, bytes32(uint256(uint160(key))));
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(AddressToAddressMap storage map, address key) internal view returns (bool) {
+        return contains(map._inner, bytes32(uint256(uint160(key))));
+    }
+
+    /**
+     * @dev Returns the number of elements in the map. O(1).
+     */
+    function length(AddressToAddressMap storage map) internal view returns (uint256) {
+        return length(map._inner);
+    }
+
+    /**
+     * @dev Returns the element stored at position `index` in the map. O(1).
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(AddressToAddressMap storage map, uint256 index) internal view returns (address, address) {
+        (bytes32 key, bytes32 value) = at(map._inner, index);
+        return (address(uint160(uint256(key))), address(uint160(uint256(value))));
+    }
+
+    /**
+     * @dev Tries to returns the value associated with `key`. O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(AddressToAddressMap storage map, address key) internal view returns (bool, address) {
+        (bool success, bytes32 value) = tryGet(map._inner, bytes32(uint256(uint160(key))));
+        return (success, address(uint160(uint256(value))));
+    }
+
+    /**
+     * @dev Returns the value associated with `key`. O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(AddressToAddressMap storage map, address key) internal view returns (address) {
+        return address(uint160(uint256(get(map._inner, bytes32(uint256(uint160(key)))))));
+    }
+
+    /**
+     * @dev Return the an array containing all the keys
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function keys(AddressToAddressMap storage map) internal view returns (address[] memory) {
+        bytes32[] memory store = keys(map._inner);
+        address[] memory result;
+
+        /// @solidity memory-safe-assembly
+        assembly {
+            result := store
+        }
+
+        return result;
+    }
+
+    // AddressToBytes32Map
+
+    struct AddressToBytes32Map {
+        Bytes32ToBytes32Map _inner;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(AddressToBytes32Map storage map, address key, bytes32 value) internal returns (bool) {
+        return set(map._inner, bytes32(uint256(uint160(key))), value);
+    }
+
+    /**
+     * @dev Removes a value from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(AddressToBytes32Map storage map, address key) internal returns (bool) {
+        return remove(map._inner, bytes32(uint256(uint160(key))));
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(AddressToBytes32Map storage map, address key) internal view returns (bool) {
+        return contains(map._inner, bytes32(uint256(uint160(key))));
+    }
+
+    /**
+     * @dev Returns the number of elements in the map. O(1).
+     */
+    function length(AddressToBytes32Map storage map) internal view returns (uint256) {
+        return length(map._inner);
+    }
+
+    /**
+     * @dev Returns the element stored at position `index` in the map. O(1).
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(AddressToBytes32Map storage map, uint256 index) internal view returns (address, bytes32) {
+        (bytes32 key, bytes32 value) = at(map._inner, index);
+        return (address(uint160(uint256(key))), value);
+    }
+
+    /**
+     * @dev Tries to returns the value associated with `key`. O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(AddressToBytes32Map storage map, address key) internal view returns (bool, bytes32) {
+        (bool success, bytes32 value) = tryGet(map._inner, bytes32(uint256(uint160(key))));
+        return (success, value);
+    }
+
+    /**
+     * @dev Returns the value associated with `key`. O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(AddressToBytes32Map storage map, address key) internal view returns (bytes32) {
+        return get(map._inner, bytes32(uint256(uint160(key))));
+    }
+
+    /**
+     * @dev Return the an array containing all the keys
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function keys(AddressToBytes32Map storage map) internal view returns (address[] memory) {
+        bytes32[] memory store = keys(map._inner);
+        address[] memory result;
+
+        /// @solidity memory-safe-assembly
+        assembly {
+            result := store
+        }
+
+        return result;
+    }
+
     // Bytes32ToUintMap
 
     struct Bytes32ToUintMap {
@@ -530,4 +816,98 @@ library EnumerableMap {
 
         return result;
     }
+
+    // Bytes32ToAddressMap
+
+    struct Bytes32ToAddressMap {
+        Bytes32ToBytes32Map _inner;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(Bytes32ToAddressMap storage map, bytes32 key, address value) internal returns (bool) {
+        return set(map._inner, key, bytes32(uint256(uint160(value))));
+    }
+
+    /**
+     * @dev Removes a value from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(Bytes32ToAddressMap storage map, bytes32 key) internal returns (bool) {
+        return remove(map._inner, key);
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(Bytes32ToAddressMap storage map, bytes32 key) internal view returns (bool) {
+        return contains(map._inner, key);
+    }
+
+    /**
+     * @dev Returns the number of elements in the map. O(1).
+     */
+    function length(Bytes32ToAddressMap storage map) internal view returns (uint256) {
+        return length(map._inner);
+    }
+
+    /**
+     * @dev Returns the element stored at position `index` in the map. O(1).
+     * Note that there are no guarantees on the ordering of values inside the
+     * array, and it may change when more values are added or removed.
+     *
+     * Requirements:
+     *
+     * - `index` must be strictly less than {length}.
+     */
+    function at(Bytes32ToAddressMap storage map, uint256 index) internal view returns (bytes32, address) {
+        (bytes32 key, bytes32 value) = at(map._inner, index);
+        return (key, address(uint160(uint256(value))));
+    }
+
+    /**
+     * @dev Tries to returns the value associated with `key`. O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(Bytes32ToAddressMap storage map, bytes32 key) internal view returns (bool, address) {
+        (bool success, bytes32 value) = tryGet(map._inner, key);
+        return (success, address(uint160(uint256(value))));
+    }
+
+    /**
+     * @dev Returns the value associated with `key`. O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(Bytes32ToAddressMap storage map, bytes32 key) internal view returns (address) {
+        return address(uint160(uint256(get(map._inner, key))));
+    }
+
+    /**
+     * @dev Return the an array containing all the keys
+     *
+     * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
+     * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
+     * this function has an unbounded cost, and using it as part of a state-changing function may render the function
+     * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
+     */
+    function keys(Bytes32ToAddressMap storage map) internal view returns (bytes32[] memory) {
+        bytes32[] memory store = keys(map._inner);
+        bytes32[] memory result;
+
+        /// @solidity memory-safe-assembly
+        assembly {
+            result := store
+        }
+
+        return result;
+    }
 }

+ 4 - 0
scripts/generate/templates/EnumerableMap.js

@@ -36,6 +36,10 @@ import {EnumerableSet} from "./EnumerableSet.sol";
  * - \`bytes32 -> bytes32\` (\`Bytes32ToBytes32Map\`) since v4.6.0
  * - \`uint256 -> uint256\` (\`UintToUintMap\`) since v4.7.0
  * - \`bytes32 -> uint256\` (\`Bytes32ToUintMap\`) since v4.7.0
+ * - \`uint256 -> bytes32\` (\`UintToBytes32Map\`) since v5.1.0
+ * - \`address -> address\` (\`AddressToAddressMap\`) since v5.1.0
+ * - \`address -> bytes32\` (\`AddressToBytes32Map\`) since v5.1.0
+ * - \`bytes32 -> address\` (\`Bytes32ToAddressMap\`) since v5.1.0
  *
  * [WARNING]
  * ====

+ 4 - 6
scripts/generate/templates/EnumerableMap.opts.js

@@ -8,12 +8,10 @@ const formatType = (keyType, valueType) => ({
   valueType,
 });
 
-const TYPES = [
-  ['uint256', 'uint256'],
-  ['uint256', 'address'],
-  ['address', 'uint256'],
-  ['bytes32', 'uint256'],
-].map(args => formatType(...args));
+const TYPES = ['uint256', 'address', 'bytes32']
+  .flatMap((key, _, array) => array.map(value => [key, value]))
+  .slice(0, -1) // remove bytes32 → byte32 (last one) that is already defined
+  .map(args => formatType(...args));
 
 module.exports = {
   TYPES,

+ 5 - 0
test/helpers/random.js

@@ -9,6 +9,11 @@ const generators = {
   hexBytes: length => ethers.hexlify(ethers.randomBytes(length)),
 };
 
+generators.address.zero = ethers.ZeroAddress;
+generators.bytes32.zero = ethers.ZeroHash;
+generators.uint256.zero = 0n;
+generators.hexBytes.zero = '0x';
+
 module.exports = {
   randomArray,
   generators,

+ 19 - 46
test/utils/structs/EnumerableMap.test.js

@@ -7,60 +7,34 @@ const { TYPES, formatType } = require('../../../scripts/generate/templates/Enume
 
 const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior');
 
-const getMethods = (mock, fnSigs) => {
-  return mapValues(
-    fnSigs,
-    fnSig =>
-      (...args) =>
-        mock.getFunction(fnSig)(0, ...args),
-  );
-};
-
-const testTypes = [formatType('bytes32', 'bytes32'), ...TYPES];
+// Add Bytes32ToBytes32Map that must be tested but is not part of the generated types.
+TYPES.unshift(formatType('bytes32', 'bytes32'));
 
 async function fixture() {
   const mock = await ethers.deployContract('$EnumerableMap');
-
-  const zeroValue = {
-    uint256: 0n,
-    address: ethers.ZeroAddress,
-    bytes32: ethers.ZeroHash,
-  };
-
   const env = Object.fromEntries(
-    testTypes.map(({ name, keyType, valueType }) => [
+    TYPES.map(({ name, keyType, valueType }) => [
       name,
       {
         keyType,
         keys: randomArray(generators[keyType]),
         values: randomArray(generators[valueType]),
-
-        methods: getMethods(
-          mock,
-          testTypes.filter(t => keyType == t.keyType).length == 1
-            ? {
-                set: `$set(uint256,${keyType},${valueType})`,
-                get: `$get(uint256,${keyType})`,
-                tryGet: `$tryGet(uint256,${keyType})`,
-                remove: `$remove(uint256,${keyType})`,
-                length: `$length_EnumerableMap_${name}(uint256)`,
-                at: `$at_EnumerableMap_${name}(uint256,uint256)`,
-                contains: `$contains(uint256,${keyType})`,
-                keys: `$keys_EnumerableMap_${name}(uint256)`,
-              }
-            : {
-                set: `$set(uint256,${keyType},${valueType})`,
-                get: `$get_EnumerableMap_${name}(uint256,${keyType})`,
-                tryGet: `$tryGet_EnumerableMap_${name}(uint256,${keyType})`,
-                remove: `$remove_EnumerableMap_${name}(uint256,${keyType})`,
-                length: `$length_EnumerableMap_${name}(uint256)`,
-                at: `$at_EnumerableMap_${name}(uint256,uint256)`,
-                contains: `$contains_EnumerableMap_${name}(uint256,${keyType})`,
-                keys: `$keys_EnumerableMap_${name}(uint256)`,
-              },
+        zeroValue: generators[valueType].zero,
+        methods: mapValues(
+          {
+            set: `$set(uint256,${keyType},${valueType})`,
+            get: `$get_EnumerableMap_${name}(uint256,${keyType})`,
+            tryGet: `$tryGet_EnumerableMap_${name}(uint256,${keyType})`,
+            remove: `$remove_EnumerableMap_${name}(uint256,${keyType})`,
+            length: `$length_EnumerableMap_${name}(uint256)`,
+            at: `$at_EnumerableMap_${name}(uint256,uint256)`,
+            contains: `$contains_EnumerableMap_${name}(uint256,${keyType})`,
+            keys: `$keys_EnumerableMap_${name}(uint256)`,
+          },
+          fnSig =>
+            (...args) =>
+              mock.getFunction(fnSig)(0, ...args),
         ),
-
-        zeroValue: zeroValue[valueType],
         events: {
           setReturn: `return$set_EnumerableMap_${name}_${keyType}_${valueType}`,
           removeReturn: `return$remove_EnumerableMap_${name}_${keyType}`,
@@ -77,8 +51,7 @@ describe('EnumerableMap', function () {
     Object.assign(this, await loadFixture(fixture));
   });
 
-  // UintToAddressMap
-  for (const { name } of testTypes) {
+  for (const { name } of TYPES) {
     describe(name, function () {
       beforeEach(async function () {
         Object.assign(this, this.env[name]);