Selaa lähdekoodia

Improve gas efficiency of EnumerableMap (#2518)

Co-authored-by: Francisco Giordano <frangio.1@gmail.com>
Hadrien Croubois 4 vuotta sitten
vanhempi
sitoutus
e66e3ca523
3 muutettua tiedostoa jossa 30 lisäystä ja 71 poistoa
  1. 1 0
      CHANGELOG.md
  2. 27 69
      contracts/utils/EnumerableMap.sol
  3. 2 2
      test/token/ERC721/ERC721.test.js

+ 1 - 0
CHANGELOG.md

@@ -6,6 +6,7 @@
  * `Context`: making `_msgData` return `bytes calldata` instead of `bytes memory` ([#2492](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2492))
  * `ERC20`: Removed the `_setDecimals` function and the storage slot associated to decimals. ([#2502](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2502))
  * `Strings`: addition of a `toHexString` function.  ([#2504](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2504))
+ * `EnumerableMap`: change implementation to optimize for `key → value` lookups instead of enumeration. ([#2518](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2518))
 
 ## 3.4.0 (2021-02-02)
 

+ 27 - 69
contracts/utils/EnumerableMap.sol

@@ -2,6 +2,8 @@
 
 pragma solidity ^0.8.0;
 
+import "./EnumerableSet.sol";
+
 /**
  * @dev Library for managing an enumerable variant of Solidity's
  * https://solidity.readthedocs.io/en/latest/types.html#mapping-types[`mapping`]
@@ -27,6 +29,8 @@ pragma solidity ^0.8.0;
  * supported.
  */
 library EnumerableMap {
+    using EnumerableSet for EnumerableSet.Bytes32Set;
+
     // To implement this library for multiple types with as little code
     // repetition as possible, we write it in terms of a generic Map type with
     // bytes32 keys and values.
@@ -36,18 +40,11 @@ library EnumerableMap {
     // This means that we can only create new EnumerableMaps for types that fit
     // in bytes32.
 
-    struct MapEntry {
-        bytes32 _key;
-        bytes32 _value;
-    }
-
     struct Map {
-        // Storage of map keys and values
-        MapEntry[] _entries;
+        // Storage of keys
+        EnumerableSet.Bytes32Set _keys;
 
-        // Position of the entry defined by a key in the `entries` array, plus 1
-        // because index 0 means a key is not in the map.
-        mapping (bytes32 => uint256) _indexes;
+        mapping (bytes32 => bytes32) _values;
     }
 
     /**
@@ -58,19 +55,8 @@ library EnumerableMap {
      * already present.
      */
     function _set(Map storage map, bytes32 key, bytes32 value) private returns (bool) {
-        // We read and store the key's index to prevent multiple reads from the same storage slot
-        uint256 keyIndex = map._indexes[key];
-
-        if (keyIndex == 0) { // Equivalent to !contains(map, key)
-            map._entries.push(MapEntry({ _key: key, _value: value }));
-            // The entry is stored at length-1, but we add 1 to all indexes
-            // and use 0 as a sentinel value
-            map._indexes[key] = map._entries.length;
-            return true;
-        } else {
-            map._entries[keyIndex - 1]._value = value;
-            return false;
-        }
+        map._values[key] = value;
+        return map._keys.add(key);
     }
 
     /**
@@ -79,51 +65,22 @@ library EnumerableMap {
      * Returns true if the key was removed from the map, that is if it was present.
      */
     function _remove(Map storage map, bytes32 key) private returns (bool) {
-        // We read and store the key's index to prevent multiple reads from the same storage slot
-        uint256 keyIndex = map._indexes[key];
-
-        if (keyIndex != 0) { // Equivalent to contains(map, key)
-            // To delete a key-value pair from the _entries array in O(1), we swap the entry to delete with the last one
-            // in the array, and then remove the last entry (sometimes called as 'swap and pop').
-            // This modifies the order of the array, as noted in {at}.
-
-            uint256 toDeleteIndex = keyIndex - 1;
-            uint256 lastIndex = map._entries.length - 1;
-
-            // When the entry to delete is the last one, the swap operation is unnecessary. However, since this occurs
-            // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement.
-
-            MapEntry storage lastEntry = map._entries[lastIndex];
-
-            // Move the last entry to the index where the entry to delete is
-            map._entries[toDeleteIndex] = lastEntry;
-            // Update the index for the moved entry
-            map._indexes[lastEntry._key] = toDeleteIndex + 1; // All indexes are 1-based
-
-            // Delete the slot where the moved entry was stored
-            map._entries.pop();
-
-            // Delete the index for the deleted slot
-            delete map._indexes[key];
-
-            return true;
-        } else {
-            return false;
-        }
+        delete map._values[key];
+        return map._keys.remove(key);
     }
 
     /**
      * @dev Returns true if the key is in the map. O(1).
      */
     function _contains(Map storage map, bytes32 key) private view returns (bool) {
-        return map._indexes[key] != 0;
+        return map._keys.contains(key);
     }
 
     /**
      * @dev Returns the number of key-value pairs in the map. O(1).
      */
     function _length(Map storage map) private view returns (uint256) {
-        return map._entries.length;
+        return map._keys.length();
     }
 
    /**
@@ -137,10 +94,8 @@ library EnumerableMap {
     * - `index` must be strictly less than {length}.
     */
     function _at(Map storage map, uint256 index) private view returns (bytes32, bytes32) {
-        require(map._entries.length > index, "EnumerableMap: index out of bounds");
-
-        MapEntry storage entry = map._entries[index];
-        return (entry._key, entry._value);
+        bytes32 key = map._keys.at(index);
+        return (key, map._values[key]);
     }
 
     /**
@@ -148,9 +103,12 @@ library EnumerableMap {
      * Does not revert if `key` is not in the map.
      */
     function _tryGet(Map storage map, bytes32 key) private view returns (bool, bytes32) {
-        uint256 keyIndex = map._indexes[key];
-        if (keyIndex == 0) return (false, 0); // Equivalent to contains(map, key)
-        return (true, map._entries[keyIndex - 1]._value); // All indexes are 1-based
+        bytes32 value = map._values[key];
+        if (value == bytes32(0)) {
+            return (_contains(map, key), bytes32(0));
+        } else {
+            return (true, value);
+        }
     }
 
     /**
@@ -161,9 +119,9 @@ library EnumerableMap {
      * - `key` must be in the map.
      */
     function _get(Map storage map, bytes32 key) private view returns (bytes32) {
-        uint256 keyIndex = map._indexes[key];
-        require(keyIndex != 0, "EnumerableMap: nonexistent key"); // Equivalent to contains(map, key)
-        return map._entries[keyIndex - 1]._value; // All indexes are 1-based
+        bytes32 value = map._values[key];
+        require(value != 0 || _contains(map, key), "EnumerableMap: nonexistent key");
+        return value;
     }
 
     /**
@@ -173,9 +131,9 @@ library EnumerableMap {
      * message unnecessarily. For custom revert reasons use {_tryGet}.
      */
     function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) {
-        uint256 keyIndex = map._indexes[key];
-        require(keyIndex != 0, errorMessage); // Equivalent to contains(map, key)
-        return map._entries[keyIndex - 1]._value; // All indexes are 1-based
+        bytes32 value = map._values[key];
+        require(value != 0 || _contains(map, key), errorMessage);
+        return value;
     }
 
     // UintToAddressMap

+ 2 - 2
test/token/ERC721/ERC721.test.js

@@ -801,7 +801,7 @@ contract('ERC721', function (accounts) {
 
       it('reverts if index is greater than supply', async function () {
         await expectRevert(
-          this.token.tokenByIndex(2), 'EnumerableMap: index out of bounds',
+          this.token.tokenByIndex(2), 'EnumerableSet: index out of bounds',
         );
       });
 
@@ -908,7 +908,7 @@ contract('ERC721', function (accounts) {
           await this.token.burn(secondTokenId, { from: owner });
           expect(await this.token.totalSupply()).to.be.bignumber.equal('0');
           await expectRevert(
-            this.token.tokenByIndex(0), 'EnumerableMap: index out of bounds',
+            this.token.tokenByIndex(0), 'EnumerableSet: index out of bounds',
           );
         });