Parcourir la source

Simplify Initializable (#3450)

Francisco Giordano il y a 3 ans
Parent
commit
d506e3b1a5

+ 5 - 0
CHANGELOG.md

@@ -17,6 +17,11 @@
  * `ERC721`, `ERC1155`: simplified revert reasons. ([#3254](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3254), ([#3438](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3438)))
  * `ERC721`: removed redundant require statement. ([#3434](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3434))
  * `PaymentSplitter`: add `releasable` getters. ([#3350](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3350))
+ * `Initializable`: refactored implementation of modifiers for easier understanding. ([#3450](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3450))
+
+### Breaking changes
+
+ * `Initializable`: functions decorated with the modifier `reinitializer(1)` may no longer invoke each other.
 
 ## 4.6.0 (2022-04-26)
 

+ 20 - 0
contracts/mocks/InitializableMock.sol

@@ -100,3 +100,23 @@ contract ReinitializerMock is Initializable {
         counter++;
     }
 }
+
+contract DisableNew is Initializable {
+    constructor() {
+        _disableInitializers();
+    }
+}
+
+contract DisableOld is Initializable {
+    constructor() initializer {}
+}
+
+contract DisableBad1 is DisableNew, DisableOld {}
+
+contract DisableBad2 is Initializable {
+    constructor() initializer {
+        _disableInitializers();
+    }
+}
+
+contract DisableOk is DisableOld, DisableNew {}

+ 15 - 30
contracts/proxy/utils/Initializable.sol

@@ -76,7 +76,12 @@ abstract contract Initializable {
      * `onlyInitializing` functions can be used to initialize parent contracts. Equivalent to `reinitializer(1)`.
      */
     modifier initializer() {
-        bool isTopLevelCall = _setInitializedVersion(1);
+        bool isTopLevelCall = !_initializing;
+        require(
+            (isTopLevelCall && _initialized < 1) || (!Address.isContract(address(this)) && _initialized == 1),
+            "Initializable: contract is already initialized"
+        );
+        _initialized = 1;
         if (isTopLevelCall) {
             _initializing = true;
         }
@@ -100,15 +105,12 @@ abstract contract Initializable {
      * a contract, executing them in the right order is up to the developer or operator.
      */
     modifier reinitializer(uint8 version) {
-        bool isTopLevelCall = _setInitializedVersion(version);
-        if (isTopLevelCall) {
-            _initializing = true;
-        }
+        require(!_initializing && _initialized < version, "Initializable: contract is already initialized");
+        _initialized = version;
+        _initializing = true;
         _;
-        if (isTopLevelCall) {
-            _initializing = false;
-            emit Initialized(version);
-        }
+        _initializing = false;
+        emit Initialized(version);
     }
 
     /**
@@ -127,27 +129,10 @@ abstract contract Initializable {
      * through proxies.
      */
     function _disableInitializers() internal virtual {
-        _setInitializedVersion(type(uint8).max);
-    }
-
-    function _setInitializedVersion(uint8 version) private returns (bool) {
-        // If the contract is initializing we ignore whether _initialized is set in order to support multiple
-        // inheritance patterns, but we only do this in the context of a constructor, and for the lowest level
-        // of initializers, because in other contexts the contract may have been reentered.
-
-        bool isTopLevelCall = !_initializing; // cache sload
-        uint8 currentVersion = _initialized; // cache sload
-
-        require(
-            (isTopLevelCall && version > currentVersion) || // not nested with increasing version or
-                (!Address.isContract(address(this)) && (version == 1 || version == type(uint8).max)), // contract being constructed
-            "Initializable: contract is already initialized"
-        );
-
-        if (isTopLevelCall) {
-            _initialized = version;
+        require(!_initializing, "Initializable: contract is initializing");
+        if (_initialized < type(uint8).max) {
+            _initialized = type(uint8).max;
+            emit Initialized(type(uint8).max);
         }
-
-        return isTopLevelCall;
     }
 }

+ 1 - 1
package.json

@@ -29,7 +29,7 @@
     "release": "scripts/release/release.sh",
     "version": "scripts/release/version.sh",
     "test": "hardhat test",
-    "test:inheritance": "scripts/checks/inheritanceOrdering.js artifacts/build-info/*",
+    "test:inheritance": "scripts/checks/inheritance-ordering.js artifacts/build-info/*",
     "test:generation": "scripts/checks/generation.sh",
     "gas-report": "env ENABLE_GAS_REPORT=true npm run test",
     "slither": "npm run clean && slither . --detect reentrancy-eth,reentrancy-no-eth,reentrancy-unlimited-gas"

+ 4 - 0
scripts/checks/inheritanceOrdering.js → scripts/checks/inheritance-ordering.js

@@ -13,6 +13,10 @@ for (const artifact of artifacts) {
   const linearized = [];
 
   for (const source in solcOutput.contracts) {
+    if (source.includes('/mocks/')) {
+      continue;
+    }
+
     for (const contractDef of findAll('ContractDefinition', solcOutput.sources[source].ast)) {
       names[contractDef.id] = contractDef.name;
       linearized.push(contractDef.linearizedBaseContracts);

+ 16 - 0
test/proxy/utils/Initializable.test.js

@@ -6,6 +6,9 @@ const ConstructorInitializableMock = artifacts.require('ConstructorInitializable
 const ChildConstructorInitializableMock = artifacts.require('ChildConstructorInitializableMock');
 const ReinitializerMock = artifacts.require('ReinitializerMock');
 const SampleChild = artifacts.require('SampleChild');
+const DisableBad1 = artifacts.require('DisableBad1');
+const DisableBad2 = artifacts.require('DisableBad2');
+const DisableOk = artifacts.require('DisableOk');
 
 contract('Initializable', function (accounts) {
   describe('basic testing without inheritance', function () {
@@ -184,4 +187,17 @@ contract('Initializable', function (accounts) {
       expect(await this.contract.child()).to.be.bignumber.equal(child);
     });
   });
+
+  describe('disabling initialization', function () {
+    it('old and new patterns in bad sequence', async function () {
+      await expectRevert(DisableBad1.new(), 'Initializable: contract is already initialized');
+      await expectRevert(DisableBad2.new(), 'Initializable: contract is initializing');
+    });
+
+    it('old and new patterns in good sequence', async function () {
+      const ok = await DisableOk.new();
+      await expectEvent.inConstruction(ok, 'Initialized', { version: '1' });
+      await expectEvent.inConstruction(ok, 'Initialized', { version: '255' });
+    });
+  });
 });