Browse Source

Add SafeERC20.forceApprove() (#4067)

Hadrien Croubois 2 years ago
parent
commit
8b47e96af1

+ 5 - 0
.changeset/small-terms-sleep.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`SafeERC20`: Add a `forceApprove` function to improve compatibility with tokens behaving like USDT.

+ 13 - 0
contracts/mocks/token/ERC20ForceApproveMock.sol

@@ -0,0 +1,13 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.0;
+
+import "../../token/ERC20/ERC20.sol";
+
+// contract that replicate USDT (0xdac17f958d2ee523a2206206994597c13d831ec7) approval beavior
+abstract contract ERC20ForceApproveMock is ERC20 {
+    function approve(address spender, uint256 amount) public virtual override returns (bool) {
+        require(amount == 0 || allowance(msg.sender, spender) == 0, "USDT approval failure");
+        return super.approve(spender, amount);
+    }
+}

+ 18 - 11
contracts/mocks/token/ERC20NoReturnMock.sol

@@ -2,20 +2,27 @@
 
 pragma solidity ^0.8.0;
 
-contract ERC20NoReturnMock {
-    mapping(address => uint256) private _allowances;
+import "../../token/ERC20/ERC20.sol";
 
-    function transfer(address, uint256) public {}
-
-    function transferFrom(address, address, uint256) public {}
-
-    function approve(address, uint256) public {}
+abstract contract ERC20NoReturnMock is ERC20 {
+    function transfer(address to, uint256 amount) public override returns (bool) {
+        super.transfer(to, amount);
+        assembly {
+            return(0, 0)
+        }
+    }
 
-    function setAllowance(address account, uint256 allowance_) public {
-        _allowances[account] = allowance_;
+    function transferFrom(address from, address to, uint256 amount) public override returns (bool) {
+        super.transferFrom(from, to, amount);
+        assembly {
+            return(0, 0)
+        }
     }
 
-    function allowance(address owner, address) public view returns (uint256) {
-        return _allowances[owner];
+    function approve(address spender, uint256 amount) public override returns (bool) {
+        super.approve(spender, amount);
+        assembly {
+            return(0, 0)
+        }
     }
 }

+ 1 - 3
contracts/mocks/token/ERC20PermitNoRevertMock.sol

@@ -5,9 +5,7 @@ pragma solidity ^0.8.0;
 import "../../token/ERC20/ERC20.sol";
 import "../../token/ERC20/extensions/draft-ERC20Permit.sol";
 
-contract ERC20PermitNoRevertMock is ERC20, ERC20Permit {
-    constructor() ERC20("ERC20PermitNoRevertMock", "ERC20PermitNoRevertMock") ERC20Permit("ERC20PermitNoRevertMock") {}
-
+abstract contract ERC20PermitNoRevertMock is ERC20Permit {
     function permitThatMayRevert(
         address owner,
         address spender,

+ 5 - 13
contracts/mocks/token/ERC20ReturnFalseMock.sol

@@ -2,26 +2,18 @@
 
 pragma solidity ^0.8.0;
 
-contract ERC20ReturnFalseMock {
-    mapping(address => uint256) private _allowances;
+import "../../token/ERC20/ERC20.sol";
 
-    function transfer(address, uint256) public pure returns (bool) {
+abstract contract ERC20ReturnFalseMock is ERC20 {
+    function transfer(address, uint256) public pure override returns (bool) {
         return false;
     }
 
-    function transferFrom(address, address, uint256) public pure returns (bool) {
+    function transferFrom(address, address, uint256) public pure override returns (bool) {
         return false;
     }
 
-    function approve(address, uint256) public pure returns (bool) {
+    function approve(address, uint256) public pure override returns (bool) {
         return false;
     }
-
-    function setAllowance(address account, uint256 allowance_) public {
-        _allowances[account] = allowance_;
-    }
-
-    function allowance(address owner, address) public view returns (uint256) {
-        return _allowances[owner];
-    }
 }

+ 0 - 27
contracts/mocks/token/ERC20ReturnTrueMock.sol

@@ -1,27 +0,0 @@
-// SPDX-License-Identifier: MIT
-
-pragma solidity ^0.8.0;
-
-contract ERC20ReturnTrueMock {
-    mapping(address => uint256) private _allowances;
-
-    function transfer(address, uint256) public pure returns (bool) {
-        return true;
-    }
-
-    function transferFrom(address, address, uint256) public pure returns (bool) {
-        return true;
-    }
-
-    function approve(address, uint256) public pure returns (bool) {
-        return true;
-    }
-
-    function setAllowance(address account, uint256 allowance_) public {
-        _allowances[account] = allowance_;
-    }
-
-    function allowance(address owner, address) public view returns (uint256) {
-        return _allowances[owner];
-    }
-}

+ 56 - 8
contracts/token/ERC20/utils/SafeERC20.sol

@@ -19,10 +19,18 @@ import "../../../utils/Address.sol";
 library SafeERC20 {
     using Address for address;
 
+    /**
+     * @dev Transfer `value` amount of `token` from the calling contract to `to`. If `token` returns no value,
+     * non-reverting calls are assumed to be successful.
+     */
     function safeTransfer(IERC20 token, address to, uint256 value) internal {
         _callOptionalReturn(token, abi.encodeWithSelector(token.transfer.selector, to, value));
     }
 
+    /**
+     * @dev Transfer `value` amount of `token` from `from` to `to`, spending the approval given by `from` to the
+     * calling contract. If `token` returns no value, non-reverting calls are assumed to be successful.
+     */
     function safeTransferFrom(IERC20 token, address from, address to, uint256 value) internal {
         _callOptionalReturn(token, abi.encodeWithSelector(token.transferFrom.selector, from, to, value));
     }
@@ -45,20 +53,45 @@ library SafeERC20 {
         _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, value));
     }
 
+    /**
+     * @dev Increase the calling contract's allowance toward `spender` by `value`. If `token` returns no value,
+     * non-reverting calls are assumed to be successful.
+     */
     function safeIncreaseAllowance(IERC20 token, address spender, uint256 value) internal {
-        uint256 newAllowance = token.allowance(address(this), spender) + value;
-        _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, newAllowance));
+        uint256 oldAllowance = token.allowance(address(this), spender);
+        _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, oldAllowance + value));
     }
 
+    /**
+     * @dev Decrease the calling contract's allowance toward `spender` by `value`. If `token` returns no value,
+     * non-reverting calls are assumed to be successful.
+     */
     function safeDecreaseAllowance(IERC20 token, address spender, uint256 value) internal {
         unchecked {
             uint256 oldAllowance = token.allowance(address(this), spender);
             require(oldAllowance >= value, "SafeERC20: decreased allowance below zero");
-            uint256 newAllowance = oldAllowance - value;
-            _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, newAllowance));
+            _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, oldAllowance - value));
+        }
+    }
+
+    /**
+     * @dev Set the calling contract's allowance toward `spender` to `value`. If `token` returns no value,
+     * non-reverting calls are assumed to be successful. Compatible with tokens that require the approval to be set to
+     * 0 before setting it to a non-zero value.
+     */
+    function forceApprove(IERC20 token, address spender, uint256 value) internal {
+        bytes memory approvalCall = abi.encodeWithSelector(token.approve.selector, spender, value);
+
+        if (!_callOptionalReturnBool(token, approvalCall)) {
+            _callOptionalReturn(token, abi.encodeWithSelector(token.approve.selector, spender, 0));
+            _callOptionalReturn(token, approvalCall);
         }
     }
 
+    /**
+     * @dev Use a ERC-2612 signature to set the `owner` approval toward `spender` on `token`.
+     * Revert on invalid signature.
+     */
     function safePermit(
         IERC20Permit token,
         address owner,
@@ -87,9 +120,24 @@ library SafeERC20 {
         // the target address contains contract code and also asserts for success in the low-level call.
 
         bytes memory returndata = address(token).functionCall(data, "SafeERC20: low-level call failed");
-        if (returndata.length > 0) {
-            // Return data is optional
-            require(abi.decode(returndata, (bool)), "SafeERC20: ERC20 operation did not succeed");
-        }
+        require(returndata.length == 0 || abi.decode(returndata, (bool)), "SafeERC20: ERC20 operation did not succeed");
+    }
+
+    /**
+     * @dev Imitates a Solidity high-level call (i.e. a regular function call to a contract), relaxing the requirement
+     * on the return value: the return value is optional (but if data is returned, it must not be false).
+     * @param token The token targeted by the call.
+     * @param data The call data (encoded using abi.encode or one of its variants).
+     *
+     * This is a variant of {_callOptionalReturn} that silents catches all reverts and returns a bool instead.
+     */
+    function _callOptionalReturnBool(IERC20 token, bytes memory data) private returns (bool) {
+        // We need to perform a low level call here, to bypass Solidity's return data size checking mechanism, since
+        // we're implementing it ourselves. We cannot use {Address-functionCall} here since this should return false
+        // and not revert is the subcall reverts.
+
+        (bool success, bytes memory returndata) = address(token).call(data);
+        return
+            success && (returndata.length == 0 || abi.decode(returndata, (bool))) && Address.isContract(address(token));
     }
 }

+ 1 - 1
package-lock.json

@@ -31,7 +31,7 @@
         "glob": "^8.0.3",
         "graphlib": "^2.1.8",
         "hardhat": "^2.9.1",
-        "hardhat-exposed": "^0.3.1",
+        "hardhat-exposed": "^0.3.2",
         "hardhat-gas-reporter": "^1.0.4",
         "hardhat-ignore-warnings": "^0.2.0",
         "keccak256": "^1.0.2",

+ 1 - 1
package.json

@@ -72,7 +72,7 @@
     "glob": "^8.0.3",
     "graphlib": "^2.1.8",
     "hardhat": "^2.9.1",
-    "hardhat-exposed": "^0.3.1",
+    "hardhat-exposed": "^0.3.2",
     "hardhat-gas-reporter": "^1.0.4",
     "hardhat-ignore-warnings": "^0.2.0",
     "keccak256": "^1.0.2",

+ 125 - 39
test/token/ERC20/utils/SafeERC20.test.js

@@ -1,10 +1,11 @@
-const { constants, expectRevert } = require('@openzeppelin/test-helpers');
+const { constants, expectEvent, expectRevert } = require('@openzeppelin/test-helpers');
 
 const SafeERC20 = artifacts.require('$SafeERC20');
-const ERC20ReturnFalseMock = artifacts.require('ERC20ReturnFalseMock');
-const ERC20ReturnTrueMock = artifacts.require('ERC20ReturnTrueMock');
-const ERC20NoReturnMock = artifacts.require('ERC20NoReturnMock');
-const ERC20PermitNoRevertMock = artifacts.require('ERC20PermitNoRevertMock');
+const ERC20ReturnFalseMock = artifacts.require('$ERC20ReturnFalseMock');
+const ERC20ReturnTrueMock = artifacts.require('$ERC20'); // default implementation returns true
+const ERC20NoReturnMock = artifacts.require('$ERC20NoReturnMock');
+const ERC20PermitNoRevertMock = artifacts.require('$ERC20PermitNoRevertMock');
+const ERC20ForceApproveMock = artifacts.require('$ERC20ForceApproveMock');
 
 const { getDomain, domainType, Permit } = require('../../../helpers/eip712');
 
@@ -12,6 +13,9 @@ const { fromRpcSig } = require('ethereumjs-util');
 const ethSigUtil = require('eth-sig-util');
 const Wallet = require('ethereumjs-wallet').default;
 
+const name = 'ERC20Mock';
+const symbol = 'ERC20Mock';
+
 contract('SafeERC20', function (accounts) {
   const [hasNoCode] = accounts;
 
@@ -24,31 +28,31 @@ contract('SafeERC20', function (accounts) {
       this.token = { address: hasNoCode };
     });
 
-    shouldRevertOnAllCalls('Address: call to non-contract');
+    shouldRevertOnAllCalls(accounts, 'Address: call to non-contract');
   });
 
   describe('with token that returns false on all calls', function () {
     beforeEach(async function () {
-      this.token = await ERC20ReturnFalseMock.new();
+      this.token = await ERC20ReturnFalseMock.new(name, symbol);
     });
 
-    shouldRevertOnAllCalls('SafeERC20: ERC20 operation did not succeed');
+    shouldRevertOnAllCalls(accounts, 'SafeERC20: ERC20 operation did not succeed');
   });
 
   describe('with token that returns true on all calls', function () {
     beforeEach(async function () {
-      this.token = await ERC20ReturnTrueMock.new();
+      this.token = await ERC20ReturnTrueMock.new(name, symbol);
     });
 
-    shouldOnlyRevertOnErrors();
+    shouldOnlyRevertOnErrors(accounts);
   });
 
   describe('with token that returns no boolean values', function () {
     beforeEach(async function () {
-      this.token = await ERC20NoReturnMock.new();
+      this.token = await ERC20NoReturnMock.new(name, symbol);
     });
 
-    shouldOnlyRevertOnErrors();
+    shouldOnlyRevertOnErrors(accounts);
   });
 
   describe("with token that doesn't revert on invalid permit", function () {
@@ -57,7 +61,7 @@ contract('SafeERC20', function (accounts) {
     const spender = hasNoCode;
 
     beforeEach(async function () {
-      this.token = await ERC20PermitNoRevertMock.new();
+      this.token = await ERC20PermitNoRevertMock.new(name, symbol, name);
 
       this.data = await getDomain(this.token).then(domain => ({
         primaryType: 'Permit',
@@ -165,65 +169,134 @@ contract('SafeERC20', function (accounts) {
       );
     });
   });
+
+  describe('with usdt approval beaviour', function () {
+    const spender = hasNoCode;
+
+    beforeEach(async function () {
+      this.token = await ERC20ForceApproveMock.new(name, symbol);
+    });
+
+    describe('with initial approval', function () {
+      beforeEach(async function () {
+        await this.token.$_approve(this.mock.address, spender, 100);
+      });
+
+      it('safeApprove fails to update approval to non-zero', async function () {
+        await expectRevert(
+          this.mock.$safeApprove(this.token.address, spender, 200),
+          'SafeERC20: approve from non-zero to non-zero allowance',
+        );
+      });
+
+      it('safeApprove can update approval to zero', async function () {
+        await this.mock.$safeApprove(this.token.address, spender, 0);
+      });
+
+      it('safeApprove can increase approval', async function () {
+        await expectRevert(this.mock.$safeIncreaseAllowance(this.token.address, spender, 10), 'USDT approval failure');
+      });
+
+      it('safeApprove can decrease approval', async function () {
+        await expectRevert(this.mock.$safeDecreaseAllowance(this.token.address, spender, 10), 'USDT approval failure');
+      });
+
+      it('forceApprove works', async function () {
+        await this.mock.$forceApprove(this.token.address, spender, 200);
+      });
+    });
+  });
 });
 
-function shouldRevertOnAllCalls(reason) {
+function shouldRevertOnAllCalls([receiver, spender], reason) {
   it('reverts on transfer', async function () {
-    await expectRevert(this.mock.$safeTransfer(this.token.address, constants.ZERO_ADDRESS, 0), reason);
+    await expectRevert(this.mock.$safeTransfer(this.token.address, receiver, 0), reason);
   });
 
   it('reverts on transferFrom', async function () {
-    await expectRevert(
-      this.mock.$safeTransferFrom(this.token.address, this.mock.address, constants.ZERO_ADDRESS, 0),
-      reason,
-    );
+    await expectRevert(this.mock.$safeTransferFrom(this.token.address, this.mock.address, receiver, 0), reason);
   });
 
   it('reverts on approve', async function () {
-    await expectRevert(this.mock.$safeApprove(this.token.address, constants.ZERO_ADDRESS, 0), reason);
+    await expectRevert(this.mock.$safeApprove(this.token.address, spender, 0), reason);
   });
 
   it('reverts on increaseAllowance', async function () {
     // [TODO] make sure it's reverting for the right reason
-    await expectRevert.unspecified(this.mock.$safeIncreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 0));
+    await expectRevert.unspecified(this.mock.$safeIncreaseAllowance(this.token.address, spender, 0));
   });
 
   it('reverts on decreaseAllowance', async function () {
     // [TODO] make sure it's reverting for the right reason
-    await expectRevert.unspecified(this.mock.$safeDecreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 0));
+    await expectRevert.unspecified(this.mock.$safeDecreaseAllowance(this.token.address, spender, 0));
   });
-}
 
-function shouldOnlyRevertOnErrors() {
-  it("doesn't revert on transfer", async function () {
-    await this.mock.$safeTransfer(this.token.address, constants.ZERO_ADDRESS, 0);
+  it('reverts on forceApprove', async function () {
+    await expectRevert(this.mock.$forceApprove(this.token.address, spender, 0), reason);
   });
+}
+
+function shouldOnlyRevertOnErrors([owner, receiver, spender]) {
+  describe('transfers', function () {
+    beforeEach(async function () {
+      await this.token.$_mint(owner, 100);
+      await this.token.$_mint(this.mock.address, 100);
+      await this.token.approve(this.mock.address, constants.MAX_UINT256, { from: owner });
+    });
+
+    it("doesn't revert on transfer", async function () {
+      const { tx } = await this.mock.$safeTransfer(this.token.address, receiver, 10);
+      await expectEvent.inTransaction(tx, this.token, 'Transfer', {
+        from: this.mock.address,
+        to: receiver,
+        value: '10',
+      });
+    });
 
-  it("doesn't revert on transferFrom", async function () {
-    await this.mock.$safeTransferFrom(this.token.address, this.mock.address, constants.ZERO_ADDRESS, 0);
+    it("doesn't revert on transferFrom", async function () {
+      const { tx } = await this.mock.$safeTransferFrom(this.token.address, owner, receiver, 10);
+      await expectEvent.inTransaction(tx, this.token, 'Transfer', {
+        from: owner,
+        to: receiver,
+        value: '10',
+      });
+    });
   });
 
   describe('approvals', function () {
     context('with zero allowance', function () {
       beforeEach(async function () {
-        await this.token.setAllowance(this.mock.address, 0);
+        await this.token.$_approve(this.mock.address, spender, 0);
       });
 
       it("doesn't revert when approving a non-zero allowance", async function () {
-        await this.mock.$safeApprove(this.token.address, constants.ZERO_ADDRESS, 100);
+        await this.mock.$safeApprove(this.token.address, spender, 100);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('100');
       });
 
       it("doesn't revert when approving a zero allowance", async function () {
-        await this.mock.$safeApprove(this.token.address, constants.ZERO_ADDRESS, 0);
+        await this.mock.$safeApprove(this.token.address, spender, 0);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('0');
+      });
+
+      it("doesn't revert when force approving a non-zero allowance", async function () {
+        await this.mock.$forceApprove(this.token.address, spender, 100);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('100');
+      });
+
+      it("doesn't revert when force approving a zero allowance", async function () {
+        await this.mock.$forceApprove(this.token.address, spender, 0);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('0');
       });
 
       it("doesn't revert when increasing the allowance", async function () {
-        await this.mock.$safeIncreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 10);
+        await this.mock.$safeIncreaseAllowance(this.token.address, spender, 10);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('10');
       });
 
       it('reverts when decreasing the allowance', async function () {
         await expectRevert(
-          this.mock.$safeDecreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 10),
+          this.mock.$safeDecreaseAllowance(this.token.address, spender, 10),
           'SafeERC20: decreased allowance below zero',
         );
       });
@@ -231,31 +304,44 @@ function shouldOnlyRevertOnErrors() {
 
     context('with non-zero allowance', function () {
       beforeEach(async function () {
-        await this.token.setAllowance(this.mock.address, 100);
+        await this.token.$_approve(this.mock.address, spender, 100);
       });
 
       it('reverts when approving a non-zero allowance', async function () {
         await expectRevert(
-          this.mock.$safeApprove(this.token.address, constants.ZERO_ADDRESS, 20),
+          this.mock.$safeApprove(this.token.address, spender, 20),
           'SafeERC20: approve from non-zero to non-zero allowance',
         );
       });
 
       it("doesn't revert when approving a zero allowance", async function () {
-        await this.mock.$safeApprove(this.token.address, constants.ZERO_ADDRESS, 0);
+        await this.mock.$safeApprove(this.token.address, spender, 0);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('0');
+      });
+
+      it("doesn't revert when force approving a non-zero allowance", async function () {
+        await this.mock.$forceApprove(this.token.address, spender, 20);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('20');
+      });
+
+      it("doesn't revert when force approving a zero allowance", async function () {
+        await this.mock.$forceApprove(this.token.address, spender, 0);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('0');
       });
 
       it("doesn't revert when increasing the allowance", async function () {
-        await this.mock.$safeIncreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 10);
+        await this.mock.$safeIncreaseAllowance(this.token.address, spender, 10);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('110');
       });
 
       it("doesn't revert when decreasing the allowance to a positive value", async function () {
-        await this.mock.$safeDecreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 50);
+        await this.mock.$safeDecreaseAllowance(this.token.address, spender, 50);
+        expect(await this.token.allowance(this.mock.address, spender)).to.be.bignumber.equal('50');
       });
 
       it('reverts when decreasing the allowance to a negative value', async function () {
         await expectRevert(
-          this.mock.$safeDecreaseAllowance(this.token.address, constants.ZERO_ADDRESS, 200),
+          this.mock.$safeDecreaseAllowance(this.token.address, spender, 200),
           'SafeERC20: decreased allowance below zero',
         );
       });