Sfoglia il codice sorgente

Add EnumerableMap, refactor ERC721 (#2160)

* Implement AddressSet in terms of a generic Set

* Add Uint256Set

* Add EnumerableMap

* Fix wording on EnumerableSet docs and tests

* Refactor ERC721 using EnumerableSet and EnumerableMap

* Fix tests

* Fix linter error

* Gas optimization for EnumerableMap

* Gas optimization for EnumerableSet

* Remove often not-taken if from Enumerable data structures

* Fix failing test

* Gas optimization for EnumerableMap

* Fix linter errors

* Add comment for clarification

* Improve test naming

* Rename EnumerableMap.add to set

* Add overload for EnumerableMap.get with custom error message

* Improve Enumerable docs

* Rename Uint256Set to UintSet

* Add changelog entry
Nicolás Venturo 5 anni fa
parent
commit
bd0778461d

+ 3 - 0
CHANGELOG.md

@@ -5,6 +5,7 @@
 ### New features
  * `AccessControl`: new contract for managing permissions in a system, replacement for `Ownable` and `Roles`. ([#2112](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2112))
  * `SafeCast`: new functions to convert to and from signed and unsigned values: `toUint256` and `toInt256`. ([#2123](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2123))
+ * `EnumerableMap`: a new data structure for key-value pairs (like `mapping`) that can be iterated over. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
 
 ### Breaking changes
  * `ERC721`: `burn(owner, tokenId)` was removed, use `burn(tokenId)` instead. ([#2125](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2125))
@@ -30,6 +31,8 @@
  * `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
  * `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151))
  * `ERC165Checker`: functions no longer have a leading underscore. ([#2150](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2150))
+ * `ERC721Metadata`, `ERC721Enumerable`: these contracts were removed, and their functionality merged into `ERC721`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
+ * `ERC721`: added a constructor for `name` and `symbol`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
  * `ERC20Detailed`: this contract was removed and its functionality merged into `ERC20`. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161))
  * `ERC20`: added a constructor for `name` and `symbol`. `decimals` now defaults to 18. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161))
 

+ 0 - 4
contracts/mocks/ERC721Mock.sol

@@ -13,10 +13,6 @@ contract ERC721Mock is ERC721 {
         return _exists(tokenId);
     }
 
-    function tokensOfOwner(address owner) public view returns (uint256[] memory) {
-        return _tokensOfOwner(owner);
-    }
-
     function setTokenURI(uint256 tokenId, string memory uri) public {
         _setTokenURI(tokenId, uri);
     }

+ 38 - 0
contracts/mocks/EnumerableMapMock.sol

@@ -0,0 +1,38 @@
+pragma solidity ^0.6.0;
+
+import "../utils/EnumerableMap.sol";
+
+contract EnumerableMapMock {
+    using EnumerableMap for EnumerableMap.UintToAddressMap;
+
+    event OperationResult(bool result);
+
+    EnumerableMap.UintToAddressMap private _map;
+
+    function contains(uint256 key) public view returns (bool) {
+        return _map.contains(key);
+    }
+
+    function set(uint256 key, address value) public {
+        bool result = _map.set(key, value);
+        emit OperationResult(result);
+    }
+
+    function remove(uint256 key) public {
+        bool result = _map.remove(key);
+        emit OperationResult(result);
+    }
+
+    function length() public view returns (uint256) {
+        return _map.length();
+    }
+
+    function at(uint256 index) public view returns (uint256 key, address value) {
+        return _map.at(index);
+    }
+
+
+    function get(uint256 key) public view returns (address) {
+        return _map.get(key);
+    }
+}

+ 3 - 7
contracts/mocks/EnumerableSetMock.sol

@@ -5,7 +5,7 @@ import "../utils/EnumerableSet.sol";
 contract EnumerableSetMock {
     using EnumerableSet for EnumerableSet.AddressSet;
 
-    event TransactionResult(bool result);
+    event OperationResult(bool result);
 
     EnumerableSet.AddressSet private _set;
 
@@ -15,16 +15,12 @@ contract EnumerableSetMock {
 
     function add(address value) public {
         bool result = _set.add(value);
-        emit TransactionResult(result);
+        emit OperationResult(result);
     }
 
     function remove(address value) public {
         bool result = _set.remove(value);
-        emit TransactionResult(result);
-    }
-
-    function enumerate() public view returns (address[] memory) {
-        return _set.enumerate();
+        emit OperationResult(result);
     }
 
     function length() public view returns (uint256) {

+ 28 - 136
contracts/token/ERC721/ERC721.sol

@@ -8,7 +8,8 @@ import "./IERC721Receiver.sol";
 import "../../introspection/ERC165.sol";
 import "../../math/SafeMath.sol";
 import "../../utils/Address.sol";
-import "../../utils/Counters.sol";
+import "../../utils/EnumerableSet.sol";
+import "../../utils/EnumerableMap.sol";
 
 /**
  * @title ERC721 Non-Fungible Token Standard basic implementation
@@ -17,21 +18,22 @@ import "../../utils/Counters.sol";
 contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable {
     using SafeMath for uint256;
     using Address for address;
-    using Counters for Counters.Counter;
+    using EnumerableSet for EnumerableSet.UintSet;
+    using EnumerableMap for EnumerableMap.UintToAddressMap;
 
     // Equals to `bytes4(keccak256("onERC721Received(address,address,uint256,bytes)"))`
     // which can be also obtained as `IERC721Receiver(0).onERC721Received.selector`
     bytes4 private constant _ERC721_RECEIVED = 0x150b7a02;
 
-    // Mapping from token ID to owner
-    mapping (uint256 => address) private _tokenOwner;
+    // Mapping from holder address to their (enumerable) set of owned tokens
+    mapping (address => EnumerableSet.UintSet) private _holderTokens;
+
+    // Enumerable mapping from token ids to their owners
+    EnumerableMap.UintToAddressMap private _tokenOwners;
 
     // Mapping from token ID to approved address
     mapping (uint256 => address) private _tokenApprovals;
 
-    // Mapping from owner to number of owned token
-    mapping (address => Counters.Counter) private _ownedTokensCount;
-
     // Mapping from owner to operator approvals
     mapping (address => mapping (address => bool)) private _operatorApprovals;
 
@@ -47,18 +49,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
     // Base URI
     string private _baseURI;
 
-    // Mapping from owner to list of owned token IDs
-    mapping(address => uint256[]) private _ownedTokens;
-
-    // Mapping from token ID to index of the owner tokens list
-    mapping(uint256 => uint256) private _ownedTokensIndex;
-
-    // Array with all token ids, used for enumeration
-    uint256[] private _allTokens;
-
-    // Mapping from token id to position in the allTokens array
-    mapping(uint256 => uint256) private _allTokensIndex;
-
     /*
      *     bytes4(keccak256('balanceOf(address)')) == 0x70a08231
      *     bytes4(keccak256('ownerOf(uint256)')) == 0x6352211e
@@ -111,7 +101,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
     function balanceOf(address owner) public view override returns (uint256) {
         require(owner != address(0), "ERC721: balance query for the zero address");
 
-        return _ownedTokensCount[owner].current();
+        return _holderTokens[owner].length();
     }
 
     /**
@@ -120,10 +110,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
      * @return address currently marked as the owner of the given token ID
      */
     function ownerOf(uint256 tokenId) public view override returns (address) {
-        address owner = _tokenOwner[tokenId];
-        require(owner != address(0), "ERC721: owner query for nonexistent token");
-
-        return owner;
+        return _tokenOwners.get(tokenId, "ERC721: owner query for nonexistent token");
     }
 
     /**
@@ -180,8 +167,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
      * @return uint256 token ID at the given index of the tokens list owned by the requested address
      */
     function tokenOfOwnerByIndex(address owner, uint256 index) public view override returns (uint256) {
-        require(index < balanceOf(owner), "ERC721Enumerable: owner index out of bounds");
-        return _ownedTokens[owner][index];
+        return _holderTokens[owner].at(index);
     }
 
     /**
@@ -189,7 +175,8 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
      * @return uint256 representing the total amount of tokens
      */
     function totalSupply() public view override returns (uint256) {
-        return _allTokens.length;
+        // _tokenOwners are indexed by tokenIds, so .length() returns the number of tokenIds
+        return _tokenOwners.length();
     }
 
     /**
@@ -199,8 +186,8 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
      * @return uint256 token ID at the given index of the tokens list
      */
     function tokenByIndex(uint256 index) public view override returns (uint256) {
-        require(index < totalSupply(), "ERC721Enumerable: global index out of bounds");
-        return _allTokens[index];
+        (uint256 tokenId, ) = _tokenOwners.at(index);
+        return tokenId;
     }
 
     /**
@@ -327,8 +314,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
      * @return bool whether the token exists
      */
     function _exists(uint256 tokenId) internal view returns (bool) {
-        address owner = _tokenOwner[tokenId];
-        return owner != address(0);
+        return _tokenOwners.contains(tokenId);
     }
 
     /**
@@ -386,11 +372,9 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
 
         _beforeTokenTransfer(address(0), to, tokenId);
 
-        _addTokenToOwnerEnumeration(to, tokenId);
-        _addTokenToAllTokensEnumeration(tokenId);
+        _holderTokens[to].add(tokenId);
 
-        _tokenOwner[tokenId] = to;
-        _ownedTokensCount[to].increment();
+        _tokenOwners.set(tokenId, to);
 
         emit Transfer(address(0), to, tokenId);
     }
@@ -405,22 +389,17 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
 
         _beforeTokenTransfer(owner, address(0), tokenId);
 
+        // Clear approvals
+        _approve(address(0), tokenId);
+
         // Clear metadata (if any)
         if (bytes(_tokenURIs[tokenId]).length != 0) {
             delete _tokenURIs[tokenId];
         }
 
-        _removeTokenFromOwnerEnumeration(owner, tokenId);
-        // Since tokenId will be deleted, we can clear its slot in _ownedTokensIndex to trigger a gas refund
-        _ownedTokensIndex[tokenId] = 0;
-
-        _removeTokenFromAllTokensEnumeration(tokenId);
-
-        // Clear approvals
-        _approve(address(0), tokenId);
+        _holderTokens[owner].remove(tokenId);
 
-        _ownedTokensCount[owner].decrement();
-        _tokenOwner[tokenId] = address(0);
+        _tokenOwners.remove(tokenId);
 
         emit Transfer(owner, address(0), tokenId);
     }
@@ -438,16 +417,13 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
 
         _beforeTokenTransfer(from, to, tokenId);
 
-        _removeTokenFromOwnerEnumeration(from, tokenId);
-        _addTokenToOwnerEnumeration(to, tokenId);
-
-        // Clear approvals
+        // Clear approvals from the previous owner
         _approve(address(0), tokenId);
 
-        _ownedTokensCount[from].decrement();
-        _ownedTokensCount[to].increment();
+        _holderTokens[from].remove(tokenId);
+        _holderTokens[to].add(tokenId);
 
-        _tokenOwner[tokenId] = to;
+        _tokenOwners.set(tokenId, to);
 
         emit Transfer(from, to, tokenId);
     }
@@ -474,15 +450,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
         _baseURI = baseURI_;
     }
 
-    /**
-     * @dev Gets the list of token IDs of the requested owner.
-     * @param owner address owning the tokens
-     * @return uint256[] List of token IDs owned by the requested address
-     */
-    function _tokensOfOwner(address owner) internal view returns (uint256[] storage) {
-        return _ownedTokens[owner];
-    }
-
     /**
      * @dev Internal function to invoke {IERC721Receiver-onERC721Received} on a target address.
      * The call is not executed if the target address is not a contract.
@@ -528,81 +495,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
         emit Approval(ownerOf(tokenId), to, tokenId);
     }
 
-    /**
-     * @dev Private function to add a token to this extension's ownership-tracking data structures.
-     * @param to address representing the new owner of the given token ID
-     * @param tokenId uint256 ID of the token to be added to the tokens list of the given address
-     */
-    function _addTokenToOwnerEnumeration(address to, uint256 tokenId) private {
-        _ownedTokensIndex[tokenId] = _ownedTokens[to].length;
-        _ownedTokens[to].push(tokenId);
-    }
-
-    /**
-     * @dev Private function to add a token to this extension's token tracking data structures.
-     * @param tokenId uint256 ID of the token to be added to the tokens list
-     */
-    function _addTokenToAllTokensEnumeration(uint256 tokenId) private {
-        _allTokensIndex[tokenId] = _allTokens.length;
-        _allTokens.push(tokenId);
-    }
-
-    /**
-     * @dev Private function to remove a token from this extension's ownership-tracking data structures. Note that
-     * while the token is not assigned a new owner, the `_ownedTokensIndex` mapping is _not_ updated: this allows for
-     * gas optimizations e.g. when performing a transfer operation (avoiding double writes).
-     * This has O(1) time complexity, but alters the order of the _ownedTokens array.
-     * @param from address representing the previous owner of the given token ID
-     * @param tokenId uint256 ID of the token to be removed from the tokens list of the given address
-     */
-    function _removeTokenFromOwnerEnumeration(address from, uint256 tokenId) private {
-        // To prevent a gap in from's tokens array, we store the last token in the index of the token to delete, and
-        // then delete the last slot (swap and pop).
-
-        uint256 lastTokenIndex = _ownedTokens[from].length.sub(1);
-        uint256 tokenIndex = _ownedTokensIndex[tokenId];
-
-        // When the token to delete is the last token, the swap operation is unnecessary
-        if (tokenIndex != lastTokenIndex) {
-            uint256 lastTokenId = _ownedTokens[from][lastTokenIndex];
-
-            _ownedTokens[from][tokenIndex] = lastTokenId; // Move the last token to the slot of the to-delete token
-            _ownedTokensIndex[lastTokenId] = tokenIndex; // Update the moved token's index
-        }
-
-        // Deletes the contents at the last position of the array
-        _ownedTokens[from].pop();
-
-        // Note that _ownedTokensIndex[tokenId] hasn't been cleared: it still points to the old slot (now occupied by
-        // lastTokenId, or just over the end of the array if the token was the last one).
-    }
-
-    /**
-     * @dev Private function to remove a token from this extension's token tracking data structures.
-     * This has O(1) time complexity, but alters the order of the _allTokens array.
-     * @param tokenId uint256 ID of the token to be removed from the tokens list
-     */
-    function _removeTokenFromAllTokensEnumeration(uint256 tokenId) private {
-        // To prevent a gap in the tokens array, we store the last token in the index of the token to delete, and
-        // then delete the last slot (swap and pop).
-
-        uint256 lastTokenIndex = _allTokens.length.sub(1);
-        uint256 tokenIndex = _allTokensIndex[tokenId];
-
-        // When the token to delete is the last token, the swap operation is unnecessary. However, since this occurs so
-        // rarely (when the last minted token is burnt) that we still do the swap here to avoid the gas cost of adding
-        // an 'if' statement (like in _removeTokenFromOwnerEnumeration)
-        uint256 lastTokenId = _allTokens[lastTokenIndex];
-
-        _allTokens[tokenIndex] = lastTokenId; // Move the last token to the slot of the to-delete token
-        _allTokensIndex[lastTokenId] = tokenIndex; // Update the moved token's index
-
-        // Delete the contents at the last position of the array
-        _allTokens.pop();
-
-        _allTokensIndex[tokenId] = 0;
-    }
-
     /**
      * @dev Hook that is called before any token transfer. This includes minting
      * and burning.

+ 211 - 0
contracts/utils/EnumerableMap.sol

@@ -0,0 +1,211 @@
+pragma solidity ^0.6.0;
+
+library EnumerableMap {
+    // To implement this library for multiple types with as little code
+    // repetition as possible, we write it in terms of a generic Map type with
+    // bytes32 keys and values.
+    // The Map implementation uses private functions, and user-facing
+    // implementations (such as Uint256ToAddressMap) are just wrappers around
+    // the underlying Map.
+    // This means that we can only create new EnumerableMaps for types that fit
+    // in bytes32.
+
+    struct MapEntry {
+        bytes32 _key;
+        bytes32 _value;
+    }
+
+    struct Map {
+        // Storage of map keys and values
+        MapEntry[] _entries;
+
+        // Position of the entry defined by a key in the `entries` array, plus 1
+        // because index 0 means a key is not in the map.
+        mapping (bytes32 => uint256) _indexes;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function _set(Map storage map, bytes32 key, bytes32 value) private returns (bool) {
+        // We read and store the key's index to prevent multiple reads from the same storage slot
+        uint256 keyIndex = map._indexes[key];
+
+        if (keyIndex == 0) { // Equivalent to !contains(map, key)
+            map._entries.push(MapEntry({ _key: key, _value: value }));
+            // The entry is stored at length-1, but we add 1 to all indexes
+            // and use 0 as a sentinel value
+            map._indexes[key] = map._entries.length;
+            return true;
+        } else {
+            map._entries[keyIndex - 1]._value = value;
+            return false;
+        }
+    }
+
+    /**
+     * @dev Removes a key-value pair from a map. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function _remove(Map storage map, bytes32 key) private returns (bool) {
+        // We read and store the key's index to prevent multiple reads from the same storage slot
+        uint256 keyIndex = map._indexes[key];
+
+        if (keyIndex != 0) { // Equivalent to contains(map, key)
+            // To delete a key-value pair from the _entries array in O(1), we swap the entry to delete with the last one
+            // in the array, and then remove the last entry (sometimes called as 'swap and pop').
+            // This modifies the order of the array, as noted in {at}.
+
+            uint256 toDeleteIndex = keyIndex - 1;
+            uint256 lastIndex = map._entries.length - 1;
+
+            // When the entry to delete is the last one, the swap operation is unnecessary. However, since this occurs
+            // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement.
+
+            MapEntry storage lastEntry = map._entries[lastIndex];
+
+            // Move the last entry to the index where the entry to delete is
+            map._entries[toDeleteIndex] = lastEntry;
+            // Update the index for the moved entry
+            map._indexes[lastEntry._key] = toDeleteIndex + 1; // All indexes are 1-based
+
+            // Delete the slot where the moved entry was stored
+            map._entries.pop();
+
+            // Delete the index for the deleted slot
+            delete map._indexes[key];
+
+            return true;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function _contains(Map storage map, bytes32 key) private view returns (bool) {
+        return map._indexes[key] != 0;
+    }
+
+    /**
+     * @dev Returns the number of key-value pairs in the map. O(1).
+     */
+    function _length(Map storage map) private view returns (uint256) {
+        return map._entries.length;
+    }
+
+   /**
+    * @dev Returns the key-value pair stored at position `index` in the map. O(1).
+    *
+    * Note that there are no guarantees on the ordering of entries inside the
+    * array, and it may change when more entries are added or removed.
+    *
+    * Requirements:
+    *
+    * - `index` must be strictly less than {length}.
+    */
+    function _at(Map storage map, uint256 index) private view returns (bytes32, bytes32) {
+        require(map._entries.length > index, "EnumerableMap: index out of bounds");
+
+        MapEntry storage entry = map._entries[index];
+        return (entry._key, entry._value);
+    }
+
+    /**
+     * @dev Returns the value associated with `key`.  O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function _get(Map storage map, bytes32 key) private view returns (bytes32) {
+        return _get(map, key, "EnumerableMap: nonexistent key");
+    }
+
+    /**
+     * @dev Same as {_get}, with a custom error message when `key` is not in the map.
+     */
+    function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) {
+        uint256 keyIndex = map._indexes[key];
+        require(keyIndex != 0, errorMessage); // Equivalent to contains(map, key)
+        return map._entries[keyIndex - 1]._value; // All indexes are 1-based
+    }
+
+    // UintToAddressMap
+
+    struct UintToAddressMap {
+        Map _inner;
+    }
+
+    /**
+     * @dev Adds a key-value pair to a map, or updates the value for an existing
+     * key. O(1).
+     *
+     * Returns true if the key was added to the map, that is if it was not
+     * already present.
+     */
+    function set(UintToAddressMap storage map, uint256 key, address value) internal returns (bool) {
+        return _set(map._inner, bytes32(key), bytes32(uint256(value)));
+    }
+
+    /**
+     * @dev Removes a value from a set. O(1).
+     *
+     * Returns true if the key was removed from the map, that is if it was present.
+     */
+    function remove(UintToAddressMap storage map, uint256 key) internal returns (bool) {
+        return _remove(map._inner, bytes32(key));
+    }
+
+    /**
+     * @dev Returns true if the key is in the map. O(1).
+     */
+    function contains(UintToAddressMap storage map, uint256 key) internal view returns (bool) {
+        return _contains(map._inner, bytes32(key));
+    }
+
+    /**
+     * @dev Returns the number of elements in the map. O(1).
+     */
+    function length(UintToAddressMap storage map) internal view returns (uint256) {
+        return _length(map._inner);
+    }
+
+   /**
+    * @dev Returns the element stored at position `index` in the set. O(1).
+    * Note that there are no guarantees on the ordering of values inside the
+    * array, and it may change when more values are added or removed.
+    *
+    * Requirements:
+    *
+    * - `index` must be strictly less than {length}.
+    */
+    function at(UintToAddressMap storage map, uint256 index) internal view returns (uint256, address) {
+        (bytes32 key, bytes32 value) = _at(map._inner, index);
+        return (uint256(key), address(uint256(value)));
+    }
+
+    /**
+     * @dev Returns the value associated with `key`.  O(1).
+     *
+     * Requirements:
+     *
+     * - `key` must be in the map.
+     */
+    function get(UintToAddressMap storage map, uint256 key) internal view returns (address) {
+        return address(uint256(_get(map._inner, bytes32(key))));
+    }
+
+    /**
+     * @dev Same as {get}, with a custom error message when `key` is not in the map.
+     */
+    function get(UintToAddressMap storage map, uint256 key, string memory errorMessage) internal view returns (address) {
+        return address(uint256(_get(map._inner, bytes32(key), errorMessage)));
+    }
+}

+ 146 - 57
contracts/utils/EnumerableSet.sol

@@ -18,24 +18,32 @@ pragma solidity ^0.6.0;
  * @author Alberto Cuesta Cañada
  */
 library EnumerableSet {
+    // To implement this library for multiple types with as little code
+    // repetition as possible, we write it in terms of a generic Set type with
+    // bytes32 values.
+    // The Set implementation uses private functions, and user-facing
+    // implementations (such as AddressSet) are just wrappers around the
+    // underlying Set.
+    // This means that we can only create new EnumerableSets for types that fit
+    // in bytes32.
+
+    struct Set {
+        // Storage of set values
+        bytes32[] _values;
 
-    struct AddressSet {
-        address[] _values;
         // Position of the value in the `values` array, plus 1 because index 0
         // means a value is not in the set.
-        mapping (address => uint256) _indexes;
+        mapping (bytes32 => uint256) _indexes;
     }
 
     /**
      * @dev Add a value to a set. O(1).
      *
-     * Returns false if the value was already in the set.
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
      */
-    function add(AddressSet storage set, address value)
-        internal
-        returns (bool)
-    {
-        if (!contains(set, value)) {
+    function _add(Set storage set, bytes32 value) private returns (bool) {
+        if (!_contains(set, value)) {
             set._values.push(value);
             // The value is stored at length-1, but we add 1 to all indexes
             // and use 0 as a sentinel value
@@ -49,25 +57,30 @@ library EnumerableSet {
     /**
      * @dev Removes a value from a set. O(1).
      *
-     * Returns false if the value was not present in the set.
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
      */
-    function remove(AddressSet storage set, address value)
-        internal
-        returns (bool)
-    {
-        if (contains(set, value)){
-            uint256 toDeleteIndex = set._indexes[value] - 1;
+    function _remove(Set storage set, bytes32 value) private returns (bool) {
+        // We read and store the value's index to prevent multiple reads from the same storage slot
+        uint256 valueIndex = set._indexes[value];
+
+        if (valueIndex != 0) { // Equivalent to contains(set, value)
+            // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
+            // the array, and then remove the last element (sometimes called as 'swap and pop').
+            // This modifies the order of the array, as noted in {at}.
+
+            uint256 toDeleteIndex = valueIndex - 1;
             uint256 lastIndex = set._values.length - 1;
 
-            // If the value we're deleting is the last one, we can just remove it without doing a swap
-            if (lastIndex != toDeleteIndex) {
-                address lastvalue = set._values[lastIndex];
+            // When the value to delete is the last one, the swap operation is unnecessary. However, since this occurs
+            // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement.
+
+            bytes32 lastvalue = set._values[lastIndex];
 
-                // Move the last value to the index where the deleted value is
-                set._values[toDeleteIndex] = lastvalue;
-                // Update the index for the moved value
-                set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based
-            }
+            // Move the last value to the index where the value to delete is
+            set._values[toDeleteIndex] = lastvalue;
+            // Update the index for the moved value
+            set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based
 
             // Delete the slot where the moved value was stored
             set._values.pop();
@@ -84,44 +97,125 @@ library EnumerableSet {
     /**
      * @dev Returns true if the value is in the set. O(1).
      */
-    function contains(AddressSet storage set, address value)
-        internal
-        view
-        returns (bool)
-    {
+    function _contains(Set storage set, bytes32 value) private view returns (bool) {
         return set._indexes[value] != 0;
     }
 
     /**
-     * @dev Returns an array with all values in the set. O(N).
+     * @dev Returns the number of values on the set. O(1).
+     */
+    function _length(Set storage set) private view returns (uint256) {
+        return set._values.length;
+    }
+
+   /**
+    * @dev Returns the value stored at position `index` in the set. O(1).
+    *
+    * Note that there are no guarantees on the ordering of values inside the
+    * array, and it may change when more values are added or removed.
+    *
+    * Requirements:
+    *
+    * - `index` must be strictly less than {length}.
+    */
+    function _at(Set storage set, uint256 index) private view returns (bytes32) {
+        require(set._values.length > index, "EnumerableSet: index out of bounds");
+        return set._values[index];
+    }
+
+    // AddressSet
+
+    struct AddressSet {
+        Set _inner;
+    }
+
+    /**
+     * @dev Add a value to a set. O(1).
+     *
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
+     */
+    function add(AddressSet storage set, address value) internal returns (bool) {
+        return _add(set._inner, bytes32(uint256(value)));
+    }
+
+    /**
+     * @dev Removes a value from a set. O(1).
      *
-     * Note that there are no guarantees on the ordering of values inside the
-     * array, and it may change when more values are added or removed.
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
+     */
+    function remove(AddressSet storage set, address value) internal returns (bool) {
+        return _remove(set._inner, bytes32(uint256(value)));
+    }
 
-     * WARNING: This function may run out of gas on large sets: use {length} and
-     * {at} instead in these cases.
+    /**
+     * @dev Returns true if the value is in the set. O(1).
      */
-    function enumerate(AddressSet storage set)
-        internal
-        view
-        returns (address[] memory)
-    {
-        address[] memory output = new address[](set._values.length);
-        for (uint256 i; i < set._values.length; i++){
-            output[i] = set._values[i];
-        }
-        return output;
+    function contains(AddressSet storage set, address value) internal view returns (bool) {
+        return _contains(set._inner, bytes32(uint256(value)));
+    }
+
+    /**
+     * @dev Returns the number of values in the set. O(1).
+     */
+    function length(AddressSet storage set) internal view returns (uint256) {
+        return _length(set._inner);
+    }
+
+   /**
+    * @dev Returns the value stored at position `index` in the set. O(1).
+    *
+    * Note that there are no guarantees on the ordering of values inside the
+    * array, and it may change when more values are added or removed.
+    *
+    * Requirements:
+    *
+    * - `index` must be strictly less than {length}.
+    */
+    function at(AddressSet storage set, uint256 index) internal view returns (address) {
+        return address(uint256(_at(set._inner, index)));
+    }
+
+
+    // UintSet
+
+    struct UintSet {
+        Set _inner;
+    }
+
+    /**
+     * @dev Add a value to a set. O(1).
+     *
+     * Returns true if the value was added to the set, that is if it was not
+     * already present.
+     */
+    function add(UintSet storage set, uint256 value) internal returns (bool) {
+        return _add(set._inner, bytes32(value));
+    }
+
+    /**
+     * @dev Removes a value from a set. O(1).
+     *
+     * Returns true if the value was removed from the set, that is if it was
+     * present.
+     */
+    function remove(UintSet storage set, uint256 value) internal returns (bool) {
+        return _remove(set._inner, bytes32(value));
+    }
+
+    /**
+     * @dev Returns true if the value is in the set. O(1).
+     */
+    function contains(UintSet storage set, uint256 value) internal view returns (bool) {
+        return _contains(set._inner, bytes32(value));
     }
 
     /**
      * @dev Returns the number of values on the set. O(1).
      */
-    function length(AddressSet storage set)
-        internal
-        view
-        returns (uint256)
-    {
-        return set._values.length;
+    function length(UintSet storage set) internal view returns (uint256) {
+        return _length(set._inner);
     }
 
    /**
@@ -134,12 +228,7 @@ library EnumerableSet {
     *
     * - `index` must be strictly less than {length}.
     */
-    function at(AddressSet storage set, uint256 index)
-        internal
-        view
-        returns (address)
-    {
-        require(set._values.length > index, "EnumerableSet: index out of bounds");
-        return set._values[index];
+    function at(UintSet storage set, uint256 index) internal view returns (uint256) {
+        return uint256(_at(set._inner, index));
     }
 }

+ 6 - 0
package-lock.json

@@ -31087,6 +31087,12 @@
         "lodash._reinterpolate": "^3.0.0"
       }
     },
+    "lodash.zip": {
+      "version": "4.2.0",
+      "resolved": "https://registry.npmjs.org/lodash.zip/-/lodash.zip-4.2.0.tgz",
+      "integrity": "sha1-7GZi5IlkCO1KtsVCo5kLcswIACA=",
+      "dev": true
+    },
     "log-symbols": {
       "version": "3.0.0",
       "resolved": "https://registry.npmjs.org/log-symbols/-/log-symbols-3.0.0.tgz",

+ 1 - 0
package.json

@@ -61,6 +61,7 @@
     "ethereumjs-util": "^6.2.0",
     "ganache-core-coverage": "https://github.com/OpenZeppelin/ganache-core-coverage/releases/download/2.5.3-coverage/ganache-core-coverage-2.5.3.tgz",
     "lodash.startcase": "^4.4.0",
+    "lodash.zip": "^4.2.0",
     "micromatch": "^4.0.2",
     "mocha": "^7.1.1",
     "solhint": "^3.0.0-rc.6",

+ 17 - 32
test/token/ERC721/ERC721.test.js

@@ -176,27 +176,17 @@ describe('ERC721', function () {
           expect(await this.token.ownerOf(tokenId)).to.be.equal(this.toWhom);
         });
 
+        it('emits a Transfer event', async function () {
+          expectEvent.inLogs(logs, 'Transfer', { from: owner, to: this.toWhom, tokenId: tokenId });
+        });
+
         it('clears the approval for the token ID', async function () {
           expect(await this.token.getApproved(tokenId)).to.be.equal(ZERO_ADDRESS);
         });
 
-        if (approved) {
-          it('emit only a transfer event', async function () {
-            expectEvent.inLogs(logs, 'Transfer', {
-              from: owner,
-              to: this.toWhom,
-              tokenId: tokenId,
-            });
-          });
-        } else {
-          it('emits only a transfer event', async function () {
-            expectEvent.inLogs(logs, 'Transfer', {
-              from: owner,
-              to: this.toWhom,
-              tokenId: tokenId,
-            });
-          });
-        }
+        it('emits an Approval event', async function () {
+          expectEvent.inLogs(logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: tokenId });
+        });
 
         it('adjusts owners balances', async function () {
           expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1');
@@ -708,15 +698,6 @@ describe('ERC721', function () {
       });
     });
 
-    describe('tokensOfOwner', function () {
-      it('returns total tokens of owner', async function () {
-        const tokenIds = await this.token.tokensOfOwner(owner);
-        expect(tokenIds.length).to.equal(2);
-        expect(tokenIds[0]).to.be.bignumber.equal(firstTokenId);
-        expect(tokenIds[1]).to.be.bignumber.equal(secondTokenId);
-      });
-    });
-
     describe('totalSupply', function () {
       it('returns total token supply', async function () {
         expect(await this.token.totalSupply()).to.be.bignumber.equal('2');
@@ -733,7 +714,7 @@ describe('ERC721', function () {
       describe('when the index is greater than or equal to the total tokens owned by the given address', function () {
         it('reverts', async function () {
           await expectRevert(
-            this.token.tokenOfOwnerByIndex(owner, 2), 'ERC721Enumerable: owner index out of bounds'
+            this.token.tokenOfOwnerByIndex(owner, 2), 'EnumerableSet: index out of bounds'
           );
         });
       });
@@ -741,7 +722,7 @@ describe('ERC721', function () {
       describe('when the given address does not own any token', function () {
         it('reverts', async function () {
           await expectRevert(
-            this.token.tokenOfOwnerByIndex(other, 0), 'ERC721Enumerable: owner index out of bounds'
+            this.token.tokenOfOwnerByIndex(other, 0), 'EnumerableSet: index out of bounds'
           );
         });
       });
@@ -764,7 +745,7 @@ describe('ERC721', function () {
         it('returns empty collection for original owner', async function () {
           expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('0');
           await expectRevert(
-            this.token.tokenOfOwnerByIndex(owner, 0), 'ERC721Enumerable: owner index out of bounds'
+            this.token.tokenOfOwnerByIndex(owner, 0), 'EnumerableSet: index out of bounds'
           );
         });
       });
@@ -781,7 +762,7 @@ describe('ERC721', function () {
 
       it('should revert if index is greater than supply', async function () {
         await expectRevert(
-          this.token.tokenByIndex(2), 'ERC721Enumerable: global index out of bounds'
+          this.token.tokenByIndex(2), 'EnumerableMap: index out of bounds'
         );
       });
 
@@ -790,7 +771,7 @@ describe('ERC721', function () {
           const newTokenId = new BN(300);
           const anotherNewTokenId = new BN(400);
 
-          await this.token.burn(tokenId, { from: owner });
+          await this.token.burn(tokenId);
           await this.token.mint(newOwner, newTokenId);
           await this.token.mint(newOwner, anotherNewTokenId);
 
@@ -865,6 +846,10 @@ describe('ERC721', function () {
           expectEvent.inLogs(this.logs, 'Transfer', { from: owner, to: ZERO_ADDRESS, tokenId: firstTokenId });
         });
 
+        it('emits an Approval event', function () {
+          expectEvent.inLogs(this.logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: firstTokenId });
+        });
+
         it('deletes the token', async function () {
           expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1');
           await expectRevert(
@@ -884,7 +869,7 @@ describe('ERC721', function () {
           await this.token.burn(secondTokenId, { from: owner });
           expect(await this.token.totalSupply()).to.be.bignumber.equal('0');
           await expectRevert(
-            this.token.tokenByIndex(0), 'ERC721Enumerable: global index out of bounds'
+            this.token.tokenByIndex(0), 'EnumerableMap: index out of bounds'
           );
         });
 

+ 139 - 0
test/utils/EnumerableMap.test.js

@@ -0,0 +1,139 @@
+const { accounts, contract } = require('@openzeppelin/test-environment');
+const { BN, expectEvent } = require('@openzeppelin/test-helpers');
+const { expect } = require('chai');
+
+const zip = require('lodash.zip');
+
+const EnumerableMapMock = contract.fromArtifact('EnumerableMapMock');
+
+describe('EnumerableMap', function () {
+  const [ accountA, accountB, accountC ] = accounts;
+
+  const keyA = new BN('7891');
+  const keyB = new BN('451');
+  const keyC = new BN('9592328');
+
+  beforeEach(async function () {
+    this.map = await EnumerableMapMock.new();
+  });
+
+  async function expectMembersMatch (map, keys, values) {
+    expect(keys.length).to.equal(values.length);
+
+    await Promise.all(keys.map(async key =>
+      expect(await map.contains(key)).to.equal(true)
+    ));
+
+    expect(await map.length()).to.bignumber.equal(keys.length.toString());
+
+    expect(await Promise.all(keys.map(key =>
+      map.get(key)
+    ))).to.have.same.members(values);
+
+    // To compare key-value pairs, we zip keys and values, and convert BNs to
+    // strings to workaround Chai limitations when dealing with nested arrays
+    expect(await Promise.all([...Array(keys.length).keys()].map(async (index) => {
+      const entry = await map.at(index);
+      return [entry.key.toString(), entry.value];
+    }))).to.have.same.deep.members(
+      zip(keys.map(k => k.toString()), values)
+    );
+  }
+
+  it('starts empty', async function () {
+    expect(await this.map.contains(keyA)).to.equal(false);
+
+    await expectMembersMatch(this.map, [], []);
+  });
+
+  it('adds a key', async function () {
+    const receipt = await this.map.set(keyA, accountA);
+    expectEvent(receipt, 'OperationResult', { result: true });
+
+    await expectMembersMatch(this.map, [keyA], [accountA]);
+  });
+
+  it('adds several keys', async function () {
+    await this.map.set(keyA, accountA);
+    await this.map.set(keyB, accountB);
+
+    await expectMembersMatch(this.map, [keyA, keyB], [accountA, accountB]);
+    expect(await this.map.contains(keyC)).to.equal(false);
+  });
+
+  it('returns false when adding keys already in the set', async function () {
+    await this.map.set(keyA, accountA);
+
+    const receipt = (await this.map.set(keyA, accountA));
+    expectEvent(receipt, 'OperationResult', { result: false });
+
+    await expectMembersMatch(this.map, [keyA], [accountA]);
+  });
+
+  it('updates values for keys already in the set', async function () {
+    await this.map.set(keyA, accountA);
+
+    await this.map.set(keyA, accountB);
+
+    await expectMembersMatch(this.map, [keyA], [accountB]);
+  });
+
+  it('removes added keys', async function () {
+    await this.map.set(keyA, accountA);
+
+    const receipt = await this.map.remove(keyA);
+    expectEvent(receipt, 'OperationResult', { result: true });
+
+    expect(await this.map.contains(keyA)).to.equal(false);
+    await expectMembersMatch(this.map, [], []);
+  });
+
+  it('returns false when removing keys not in the set', async function () {
+    const receipt = await this.map.remove(keyA);
+    expectEvent(receipt, 'OperationResult', { result: false });
+
+    expect(await this.map.contains(keyA)).to.equal(false);
+  });
+
+  it('adds and removes multiple keys', async function () {
+    // []
+
+    await this.map.set(keyA, accountA);
+    await this.map.set(keyC, accountC);
+
+    // [A, C]
+
+    await this.map.remove(keyA);
+    await this.map.remove(keyB);
+
+    // [C]
+
+    await this.map.set(keyB, accountB);
+
+    // [C, B]
+
+    await this.map.set(keyA, accountA);
+    await this.map.remove(keyC);
+
+    // [A, B]
+
+    await this.map.set(keyA, accountA);
+    await this.map.set(keyB, accountB);
+
+    // [A, B]
+
+    await this.map.set(keyC, accountC);
+    await this.map.remove(keyA);
+
+    // [B, C]
+
+    await this.map.set(keyA, accountA);
+    await this.map.remove(keyB);
+
+    // [A, C]
+
+    await expectMembersMatch(this.map, [keyA, keyC], [accountA, accountC]);
+
+    expect(await this.map.contains(keyB)).to.equal(false);
+  });
+});

+ 11 - 13
test/utils/EnumerableSet.test.js

@@ -11,18 +11,16 @@ describe('EnumerableSet', function () {
     this.set = await EnumerableSetMock.new();
   });
 
-  async function expectMembersMatch (set, members) {
-    await Promise.all(members.map(async account =>
+  async function expectMembersMatch (set, values) {
+    await Promise.all(values.map(async account =>
       expect(await set.contains(account)).to.equal(true)
     ));
 
-    expect(await set.enumerate()).to.have.same.members(members);
+    expect(await set.length()).to.bignumber.equal(values.length.toString());
 
-    expect(await set.length()).to.bignumber.equal(members.length.toString());
-
-    expect(await Promise.all([...Array(members.length).keys()].map(index =>
+    expect(await Promise.all([...Array(values.length).keys()].map(index =>
       set.at(index)
-    ))).to.have.same.members(members);
+    ))).to.have.same.members(values);
   }
 
   it('starts empty', async function () {
@@ -33,7 +31,7 @@ describe('EnumerableSet', function () {
 
   it('adds a value', async function () {
     const receipt = await this.set.add(accountA);
-    expectEvent(receipt, 'TransactionResult', { result: true });
+    expectEvent(receipt, 'OperationResult', { result: true });
 
     await expectMembersMatch(this.set, [accountA]);
   });
@@ -46,11 +44,11 @@ describe('EnumerableSet', function () {
     expect(await this.set.contains(accountC)).to.equal(false);
   });
 
-  it('returns false when adding elements already in the set', async function () {
+  it('returns false when adding values already in the set', async function () {
     await this.set.add(accountA);
 
     const receipt = (await this.set.add(accountA));
-    expectEvent(receipt, 'TransactionResult', { result: false });
+    expectEvent(receipt, 'OperationResult', { result: false });
 
     await expectMembersMatch(this.set, [accountA]);
   });
@@ -63,15 +61,15 @@ describe('EnumerableSet', function () {
     await this.set.add(accountA);
 
     const receipt = await this.set.remove(accountA);
-    expectEvent(receipt, 'TransactionResult', { result: true });
+    expectEvent(receipt, 'OperationResult', { result: true });
 
     expect(await this.set.contains(accountA)).to.equal(false);
     await expectMembersMatch(this.set, []);
   });
 
-  it('returns false when removing elements not in the set', async function () {
+  it('returns false when removing values not in the set', async function () {
     const receipt = await this.set.remove(accountA);
-    expectEvent(receipt, 'TransactionResult', { result: false });
+    expectEvent(receipt, 'OperationResult', { result: false });
 
     expect(await this.set.contains(accountA)).to.equal(false);
   });