Эх сурвалжийг харах

Add nonReentrantView modifier (#5800)

Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
Hadrien Croubois 1 сар өмнө
parent
commit
0134b00956

+ 5 - 0
.changeset/quick-pianos-press.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`ReentrancyGuard` and `ReentrancyGuardTransient`: Add `nonReentrantView`, a read-only version of the `nonReentrant` modifier.

+ 5 - 0
contracts/mocks/ReentrancyAttack.sol

@@ -9,4 +9,9 @@ contract ReentrancyAttack is Context {
         (bool success, ) = _msgSender().call(data);
         require(success, "ReentrancyAttack: failed call");
     }
+
+    function staticcallSender(bytes calldata data) public view {
+        (bool success, ) = _msgSender().staticcall(data);
+        require(success, "ReentrancyAttack: failed call");
+    }
 }

+ 9 - 0
contracts/mocks/ReentrancyMock.sol

@@ -16,6 +16,10 @@ contract ReentrancyMock is ReentrancyGuard {
         _count();
     }
 
+    function viewCallback() external view nonReentrantView returns (uint256) {
+        return counter;
+    }
+
     function countLocalRecursive(uint256 n) public nonReentrant {
         if (n > 0) {
             _count();
@@ -36,6 +40,11 @@ contract ReentrancyMock is ReentrancyGuard {
         attacker.callSender(abi.encodeCall(this.callback, ()));
     }
 
+    function countAndCallView(ReentrancyAttack attacker) public nonReentrant {
+        _count();
+        attacker.staticcallSender(abi.encodeCall(this.viewCallback, ()));
+    }
+
     function _count() private {
         counter += 1;
     }

+ 9 - 0
contracts/mocks/ReentrancyTransientMock.sol

@@ -16,6 +16,10 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient {
         _count();
     }
 
+    function viewCallback() external view nonReentrantView returns (uint256) {
+        return counter;
+    }
+
     function countLocalRecursive(uint256 n) public nonReentrant {
         if (n > 0) {
             _count();
@@ -36,6 +40,11 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient {
         attacker.callSender(abi.encodeCall(this.callback, ()));
     }
 
+    function countAndCallView(ReentrancyAttack attacker) public nonReentrant {
+        _count();
+        attacker.staticcallSender(abi.encodeCall(this.viewCallback, ()));
+    }
+
     function _count() private {
         counter += 1;
     }

+ 19 - 2
contracts/utils/ReentrancyGuard.sol

@@ -61,11 +61,28 @@ abstract contract ReentrancyGuard {
         _nonReentrantAfter();
     }
 
-    function _nonReentrantBefore() private {
-        // On the first call to nonReentrant, _status will be NOT_ENTERED
+    /**
+     * @dev A `view` only version of {nonReentrant}. Use to block view functions
+     * from being called, preventing reading from inconsistent contract state.
+     *
+     * CAUTION: This is a "view" modifier and does not change the reentrancy
+     * status. Use it only on view functions. For payable or non-payable functions,
+     * use the standard {nonReentrant} modifier instead.
+     */
+    modifier nonReentrantView() {
+        _nonReentrantBeforeView();
+        _;
+    }
+
+    function _nonReentrantBeforeView() private view {
         if (_status == ENTERED) {
             revert ReentrancyGuardReentrantCall();
         }
+    }
+
+    function _nonReentrantBefore() private {
+        // On the first call to nonReentrant, _status will be NOT_ENTERED
+        _nonReentrantBeforeView();
 
         // Any calls to nonReentrant after this point will fail
         _status = ENTERED;

+ 19 - 2
contracts/utils/ReentrancyGuardTransient.sol

@@ -37,11 +37,28 @@ abstract contract ReentrancyGuardTransient {
         _nonReentrantAfter();
     }
 
-    function _nonReentrantBefore() private {
-        // On the first call to nonReentrant, REENTRANCY_GUARD_STORAGE.asBoolean().tload() will be false
+    /**
+     * @dev A `view` only version of {nonReentrant}. Use to block view functions
+     * from being called, preventing reading from inconsistent contract state.
+     *
+     * CAUTION: This is a "view" modifier and does not change the reentrancy
+     * status. Use it only on view functions. For payable or non-payable functions,
+     * use the standard {nonReentrant} modifier instead.
+     */
+    modifier nonReentrantView() {
+        _nonReentrantBeforeView();
+        _;
+    }
+
+    function _nonReentrantBeforeView() private view {
         if (_reentrancyGuardEntered()) {
             revert ReentrancyGuardReentrantCall();
         }
+    }
+
+    function _nonReentrantBefore() private {
+        // On the first call to nonReentrant, REENTRANCY_GUARD_STORAGE.asBoolean().tload() will be false
+        _nonReentrantBeforeView();
 
         // Any calls to nonReentrant after this point will fail
         REENTRANCY_GUARD_STORAGE.asBoolean().tstore(true);

+ 12 - 4
test/utils/ReentrancyGuard.test.js

@@ -7,7 +7,8 @@ for (const variant of ['', 'Transient']) {
     async function fixture() {
       const name = `Reentrancy${variant}Mock`;
       const mock = await ethers.deployContract(name);
-      return { name, mock };
+      const attacker = await ethers.deployContract('ReentrancyAttack');
+      return { name, mock, attacker };
     }
 
     beforeEach(async function () {
@@ -20,9 +21,16 @@ for (const variant of ['', 'Transient']) {
       expect(await this.mock.counter()).to.equal(1n);
     });
 
-    it('does not allow remote callback', async function () {
-      const attacker = await ethers.deployContract('ReentrancyAttack');
-      await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
+    it('nonReentrantView function can be called', async function () {
+      await this.mock.viewCallback();
+    });
+
+    it('does not allow remote callback to nonReentrant function', async function () {
+      await expect(this.mock.countAndCall(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
+    });
+
+    it('does not allow remote callback to nonReentrantView function', async function () {
+      await expect(this.mock.countAndCallView(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
     });
 
     it('_reentrancyGuardEntered should be true when guarded', async function () {