Browse Source

Transient version of ReentrancyGuard (#4988)

Co-authored-by: ernestognw <ernestognw@gmail.com>
Hadrien Croubois 1 year ago
parent
commit
b6e07917eb

+ 5 - 0
.changeset/witty-chicken-smile.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`ReentrancyGuardTransient`: Added a variant of `ReentrancyGuard` that uses transient storage.

+ 50 - 0
contracts/mocks/ReentrancyTransientMock.sol

@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.24;
+
+import {ReentrancyGuardTransient} from "../utils/ReentrancyGuardTransient.sol";
+import {ReentrancyAttack} from "./ReentrancyAttack.sol";
+
+contract ReentrancyTransientMock is ReentrancyGuardTransient {
+    uint256 public counter;
+
+    constructor() {
+        counter = 0;
+    }
+
+    function callback() external nonReentrant {
+        _count();
+    }
+
+    function countLocalRecursive(uint256 n) public nonReentrant {
+        if (n > 0) {
+            _count();
+            countLocalRecursive(n - 1);
+        }
+    }
+
+    function countThisRecursive(uint256 n) public nonReentrant {
+        if (n > 0) {
+            _count();
+            (bool success, ) = address(this).call(abi.encodeCall(this.countThisRecursive, (n - 1)));
+            require(success, "ReentrancyTransientMock: failed call");
+        }
+    }
+
+    function countAndCall(ReentrancyAttack attacker) public nonReentrant {
+        _count();
+        attacker.callSender(abi.encodeCall(this.callback, ()));
+    }
+
+    function _count() private {
+        counter += 1;
+    }
+
+    function guardedCheckEntered() public nonReentrant {
+        require(_reentrancyGuardEntered());
+    }
+
+    function unguardedCheckNotEntered() public view {
+        require(!_reentrancyGuardEntered());
+    }
+}

+ 3 - 0
contracts/utils/README.adoc

@@ -13,6 +13,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t
  * {MerkleProof}: Functions for verifying https://en.wikipedia.org/wiki/Merkle_tree[Merkle Tree] proofs.
  * {EIP712}: Contract with functions to allow processing signed typed structure data according to https://eips.ethereum.org/EIPS/eip-712[EIP-712].
  * {ReentrancyGuard}: A modifier that can prevent reentrancy during certain functions.
+ * {ReentrancyGuardTransient}: Variant of {ReentrancyGuard} that uses transient storage (https://eips.ethereum.org/EIPS/eip-1153[EIP-1153]).
  * {Pausable}: A common emergency response mechanism that can pause functionality while a remediation is pending.
  * {Nonces}: Utility for tracking and verifying address nonces that only increment.
  * {ERC165, ERC165Checker}: Utilities for inspecting interfaces supported by contracts.
@@ -65,6 +66,8 @@ Because Solidity does not support generic types, {EnumerableMap} and {Enumerable
 
 {{ReentrancyGuard}}
 
+{{ReentrancyGuardTransient}}
+
 {{Pausable}}
 
 {{Nonces}}

+ 3 - 0
contracts/utils/ReentrancyGuard.sol

@@ -15,6 +15,9 @@ pragma solidity ^0.8.20;
  * those functions `private`, and then adding `external` `nonReentrant` entry
  * points to them.
  *
+ * TIP: If EIP-1153 (transient storage) is available on the chain you're deploying at,
+ * consider using {ReentrancyGuardTransient} instead.
+ *
  * TIP: If you would like to learn more about reentrancy and alternative ways
  * to protect against it, check out our blog post
  * https://blog.openzeppelin.com/reentrancy-after-istanbul/[Reentrancy After Istanbul].

+ 58 - 0
contracts/utils/ReentrancyGuardTransient.sol

@@ -0,0 +1,58 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.24;
+
+import {StorageSlot} from "./StorageSlot.sol";
+
+/**
+ * @dev Variant of {ReentrancyGuard} that uses transient storage.
+ *
+ * NOTE: This variant only works on networks where EIP-1153 is available.
+ */
+abstract contract ReentrancyGuardTransient {
+    using StorageSlot for *;
+
+    // keccak256(abi.encode(uint256(keccak256("openzeppelin.storage.ReentrancyGuard")) - 1)) & ~bytes32(uint256(0xff))
+    bytes32 private constant REENTRANCY_GUARD_STORAGE =
+        0x9b779b17422d0df92223018b32b4d1fa46e071723d6817e2486d003becc55f00;
+
+    /**
+     * @dev Unauthorized reentrant call.
+     */
+    error ReentrancyGuardReentrantCall();
+
+    /**
+     * @dev Prevents a contract from calling itself, directly or indirectly.
+     * Calling a `nonReentrant` function from another `nonReentrant`
+     * function is not supported. It is possible to prevent this from happening
+     * by making the `nonReentrant` function external, and making it call a
+     * `private` function that does the actual work.
+     */
+    modifier nonReentrant() {
+        _nonReentrantBefore();
+        _;
+        _nonReentrantAfter();
+    }
+
+    function _nonReentrantBefore() private {
+        // On the first call to nonReentrant, _status will be NOT_ENTERED
+        if (_reentrancyGuardEntered()) {
+            revert ReentrancyGuardReentrantCall();
+        }
+
+        // Any calls to nonReentrant after this point will fail
+        REENTRANCY_GUARD_STORAGE.asBoolean().tstore(true);
+    }
+
+    function _nonReentrantAfter() private {
+        REENTRANCY_GUARD_STORAGE.asBoolean().tstore(false);
+    }
+
+    /**
+     * @dev Returns true if the reentrancy guard is currently set to "entered", which indicates there is a
+     * `nonReentrant` function in the call stack.
+     */
+    function _reentrancyGuardEntered() internal view returns (bool) {
+        return REENTRANCY_GUARD_STORAGE.asBoolean().tload();
+    }
+}

+ 45 - 42
test/utils/ReentrancyGuard.test.js

@@ -2,46 +2,49 @@ const { ethers } = require('hardhat');
 const { expect } = require('chai');
 const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
-async function fixture() {
-  const mock = await ethers.deployContract('ReentrancyMock');
-  return { mock };
-}
-
-describe('ReentrancyGuard', function () {
-  beforeEach(async function () {
-    Object.assign(this, await loadFixture(fixture));
-  });
-
-  it('nonReentrant function can be called', async function () {
-    expect(await this.mock.counter()).to.equal(0n);
-    await this.mock.callback();
-    expect(await this.mock.counter()).to.equal(1n);
-  });
-
-  it('does not allow remote callback', async function () {
-    const attacker = await ethers.deployContract('ReentrancyAttack');
-    await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
-  });
-
-  it('_reentrancyGuardEntered should be true when guarded', async function () {
-    await this.mock.guardedCheckEntered();
+for (const variant of ['', 'Transient']) {
+  describe(`Reentrancy${variant}Guard`, function () {
+    async function fixture() {
+      const name = `Reentrancy${variant}Mock`;
+      const mock = await ethers.deployContract(name);
+      return { name, mock };
+    }
+
+    beforeEach(async function () {
+      Object.assign(this, await loadFixture(fixture));
+    });
+
+    it('nonReentrant function can be called', async function () {
+      expect(await this.mock.counter()).to.equal(0n);
+      await this.mock.callback();
+      expect(await this.mock.counter()).to.equal(1n);
+    });
+
+    it('does not allow remote callback', async function () {
+      const attacker = await ethers.deployContract('ReentrancyAttack');
+      await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call');
+    });
+
+    it('_reentrancyGuardEntered should be true when guarded', async function () {
+      await this.mock.guardedCheckEntered();
+    });
+
+    it('_reentrancyGuardEntered should be false when unguarded', async function () {
+      await this.mock.unguardedCheckNotEntered();
+    });
+
+    // The following are more side-effects than intended behavior:
+    // I put them here as documentation, and to monitor any changes
+    // in the side-effects.
+    it('does not allow local recursion', async function () {
+      await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError(
+        this.mock,
+        'ReentrancyGuardReentrantCall',
+      );
+    });
+
+    it('does not allow indirect local recursion', async function () {
+      await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith(`${this.name}: failed call`);
+    });
   });
-
-  it('_reentrancyGuardEntered should be false when unguarded', async function () {
-    await this.mock.unguardedCheckNotEntered();
-  });
-
-  // The following are more side-effects than intended behavior:
-  // I put them here as documentation, and to monitor any changes
-  // in the side-effects.
-  it('does not allow local recursion', async function () {
-    await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError(
-      this.mock,
-      'ReentrancyGuardReentrantCall',
-    );
-  });
-
-  it('does not allow indirect local recursion', async function () {
-    await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith('ReentrancyMock: failed call');
-  });
-});
+}