Browse Source

Remove unnecessary SafeMath call (#1610)

* Refactor Counter to support increment and decrement.

* Move Counter out of drafts.

* Refactor ERC721 to use Counter.

* Rollback Counter returning the current value in increment and decrement.

* Update test/drafts/Counter.test.js

Co-Authored-By: nventuro <nicolas.venturo@gmail.com>

* Improve Counter documentation.

* Move Counter.test to utils.

* Move back Counter to drafts.
Nicolás Venturo 6 years ago
parent
commit
07603d5875

+ 2 - 0
CHANGELOG.md

@@ -7,6 +7,8 @@
 
 ### Improvements:
  * Upgraded the minimum compiler version to v0.5.2: this removes many Solidity warnings that were false positives.
+ * `Counter`'s API has been improved, and is now used by `ERC721` (though it is still in `drafts`).
+ * `ERC721`'s transfers are now more gas efficient due to removal of unnecessary `SafeMath` calls.
  * Fixed variable shadowing issues.
 
 ### Bugfixes:

+ 24 - 11
contracts/drafts/Counters.sol

@@ -1,24 +1,37 @@
 pragma solidity ^0.5.2;
 
+import "../math/SafeMath.sol";
+
 /**
  * @title Counters
  * @author Matt Condon (@shrugs)
- * @dev Provides an incrementing uint256 id acquired by the `Counter#next` getter.
- * Use this for issuing ERC721 ids or keeping track of request ids, anything you want, really.
+ * @dev Provides counters that can only be incremented or decremented by one. This can be used e.g. to track the number
+ * of elements in a mapping, issuing ERC721 ids, or counting request ids
  *
- * Include with `using Counters` for Counters.Counter;`
- * @notice Does not allow an Id of 0, which is popularly used to signify a null state in solidity.
- * Does not protect from overflows, but if you have 2^256 ids, you have other problems.
- * (But actually, it's generally impossible to increment a counter this many times, energy wise
- * so it's not something you have to worry about.)
+ * Include with `using Counter for Counter.Counter;`
+ * Since it is not possible to overflow a 256 bit integer with increments of one, `increment` can skip the SafeMath
+ * overflow check, thereby saving gas. This does assume however correct usage, in that the underlying `_value` is never
+ * directly accessed.
  */
 library Counters {
+    using SafeMath for uint256;
+
     struct Counter {
-        uint256 current; // default: 0
+        // This variable should never be directly accessed by users of the library: interactions must be restricted to
+        // the library's function. As of Solidity v0.5.2, this cannot be enforced, though there is a proposal to add
+        // this feature: see https://github.com/ethereum/solidity/issues/4637
+        uint256 _value; // default: 0
+    }
+
+    function current(Counter storage counter) internal view returns (uint256) {
+        return counter._value;
+    }
+
+    function increment(Counter storage counter) internal {
+        counter._value += 1;
     }
 
-    function next(Counter storage index) internal returns (uint256) {
-        index.current += 1;
-        return index.current;
+    function decrement(Counter storage counter) internal {
+        counter._value = counter._value.sub(1);
     }
 }

+ 10 - 6
contracts/mocks/CountersImpl.sol

@@ -5,13 +5,17 @@ import "../drafts/Counters.sol";
 contract CountersImpl {
     using Counters for Counters.Counter;
 
-    uint256 public theId;
+    Counters.Counter private _counter;
 
-    // use whatever key you want to track your counters
-    mapping(string => Counters.Counter) private _counters;
+    function current() public view returns (uint256) {
+        return _counter.current();
+    }
+
+    function increment() public {
+        _counter.increment();
+    }
 
-    function doThing(string memory key) public returns (uint256) {
-        theId = _counters[key].next();
-        return theId;
+    function decrement() public {
+        _counter.decrement();
     }
 }

+ 8 - 6
contracts/token/ERC721/ERC721.sol

@@ -4,6 +4,7 @@ import "./IERC721.sol";
 import "./IERC721Receiver.sol";
 import "../../math/SafeMath.sol";
 import "../../utils/Address.sol";
+import "../../drafts/Counters.sol";
 import "../../introspection/ERC165.sol";
 
 /**
@@ -13,6 +14,7 @@ import "../../introspection/ERC165.sol";
 contract ERC721 is ERC165, IERC721 {
     using SafeMath for uint256;
     using Address for address;
+    using Counters for Counters.Counter;
 
     // Equals to `bytes4(keccak256("onERC721Received(address,address,uint256,bytes)"))`
     // which can be also obtained as `IERC721Receiver(0).onERC721Received.selector`
@@ -25,7 +27,7 @@ contract ERC721 is ERC165, IERC721 {
     mapping (uint256 => address) private _tokenApprovals;
 
     // Mapping from owner to number of owned token
-    mapping (address => uint256) private _ownedTokensCount;
+    mapping (address => Counters.Counter) private _ownedTokensCount;
 
     // Mapping from owner to operator approvals
     mapping (address => mapping (address => bool)) private _operatorApprovals;
@@ -56,7 +58,7 @@ contract ERC721 is ERC165, IERC721 {
      */
     function balanceOf(address owner) public view returns (uint256) {
         require(owner != address(0));
-        return _ownedTokensCount[owner];
+        return _ownedTokensCount[owner].current();
     }
 
     /**
@@ -200,7 +202,7 @@ contract ERC721 is ERC165, IERC721 {
         require(!_exists(tokenId));
 
         _tokenOwner[tokenId] = to;
-        _ownedTokensCount[to] = _ownedTokensCount[to].add(1);
+        _ownedTokensCount[to].increment();
 
         emit Transfer(address(0), to, tokenId);
     }
@@ -217,7 +219,7 @@ contract ERC721 is ERC165, IERC721 {
 
         _clearApproval(tokenId);
 
-        _ownedTokensCount[owner] = _ownedTokensCount[owner].sub(1);
+        _ownedTokensCount[owner].decrement();
         _tokenOwner[tokenId] = address(0);
 
         emit Transfer(owner, address(0), tokenId);
@@ -245,8 +247,8 @@ contract ERC721 is ERC165, IERC721 {
 
         _clearApproval(tokenId);
 
-        _ownedTokensCount[from] = _ownedTokensCount[from].sub(1);
-        _ownedTokensCount[to] = _ownedTokensCount[to].add(1);
+        _ownedTokensCount[from].decrement();
+        _ownedTokensCount[to].increment();
 
         _tokenOwner[tokenId] = to;
 

+ 46 - 25
test/drafts/Counters.test.js

@@ -1,37 +1,58 @@
-const { BN } = require('openzeppelin-test-helpers');
+const { shouldFail } = require('openzeppelin-test-helpers');
 
 const CountersImpl = artifacts.require('CountersImpl');
 
-const EXPECTED = [new BN(1), new BN(2), new BN(3), new BN(4)];
-const KEY1 = web3.utils.sha3('key1');
-const KEY2 = web3.utils.sha3('key2');
-
-contract('Counters', function ([_, owner]) {
+contract('Counters', function () {
   beforeEach(async function () {
-    this.mock = await CountersImpl.new({ from: owner });
+    this.counter = await CountersImpl.new();
+  });
+
+  it('starts at zero', async function () {
+    (await this.counter.current()).should.be.bignumber.equal('0');
   });
 
-  context('custom key', async function () {
-    it('should return expected values', async function () {
-      for (const expectedId of EXPECTED) {
-        await this.mock.doThing(KEY1, { from: owner });
-        const actualId = await this.mock.theId();
-        actualId.should.be.bignumber.equal(expectedId);
-      }
+  describe('increment', function () {
+    it('increments the current value by one', async function () {
+      await this.counter.increment();
+      (await this.counter.current()).should.be.bignumber.equal('1');
+    });
+
+    it('can be called multiple times', async function () {
+      await this.counter.increment();
+      await this.counter.increment();
+      await this.counter.increment();
+
+      (await this.counter.current()).should.be.bignumber.equal('3');
     });
   });
 
-  context('parallel keys', async function () {
-    it('should return expected values for each counter', async function () {
-      for (const expectedId of EXPECTED) {
-        await this.mock.doThing(KEY1, { from: owner });
-        let actualId = await this.mock.theId();
-        actualId.should.be.bignumber.equal(expectedId);
-
-        await this.mock.doThing(KEY2, { from: owner });
-        actualId = await this.mock.theId();
-        actualId.should.be.bignumber.equal(expectedId);
-      }
+  describe('decrement', function () {
+    beforeEach(async function () {
+      await this.counter.increment();
+      (await this.counter.current()).should.be.bignumber.equal('1');
+    });
+
+    it('decrements the current value by one', async function () {
+      await this.counter.decrement();
+      (await this.counter.current()).should.be.bignumber.equal('0');
+    });
+
+    it('reverts if the current value is 0', async function () {
+      await this.counter.decrement();
+      await shouldFail.reverting(this.counter.decrement());
+    });
+
+    it('can be called multiple times', async function () {
+      await this.counter.increment();
+      await this.counter.increment();
+
+      (await this.counter.current()).should.be.bignumber.equal('3');
+
+      await this.counter.decrement();
+      await this.counter.decrement();
+      await this.counter.decrement();
+
+      (await this.counter.current()).should.be.bignumber.equal('0');
     });
   });
 });