فهرست منبع

Allow the re-initialization of contracts (#3232)

* allow re-initialization of contracts

* fix lint

* use a private function to avoid code duplication

* use oz-retyped-from syntax

* add documentation

* rephrase

* documentation

* Update contracts/proxy/utils/Initializable.sol

Co-authored-by: Francisco Giordano <frangio.1@gmail.com>

* reinitialize test

* lint

* typos and style

* add note about relation between initializer and reinitializer

* lint

* set _initializing in the modifier

* remove unnecessary variable set

* rename _preventInitialize -> _disableInitializers

* rename preventInitialize -> disableInitializers

* test nested reinitializers in reverse order

* docs typos and style

* edit docs for consistency between initializer and reinitializer

Co-authored-by: Francisco Giordano <frangio.1@gmail.com>
Hadrien Croubois 3 سال پیش
والد
کامیت
0eba5112c8
3فایلهای تغییر یافته به همراه176 افزوده شده و 30 حذف شده
  1. 29 0
      contracts/mocks/InitializableMock.sol
  2. 78 16
      contracts/proxy/utils/Initializable.sol
  3. 69 14
      test/proxy/utils/Initializable.test.js

+ 29 - 0
contracts/mocks/InitializableMock.sol

@@ -59,3 +59,32 @@ contract ConstructorInitializableMock is Initializable {
         onlyInitializingRan = true;
     }
 }
+
+contract ReinitializerMock is Initializable {
+    uint256 public counter;
+
+    function initialize() public initializer {
+        doStuff();
+    }
+
+    function reinitialize(uint8 i) public reinitializer(i) {
+        doStuff();
+    }
+
+    function nestedReinitialize(uint8 i, uint8 j) public reinitializer(i) {
+        reinitialize(j);
+    }
+
+    function chainReinitialize(uint8 i, uint8 j) public {
+        reinitialize(i);
+        reinitialize(j);
+    }
+
+    function disableInitializers() public {
+        _disableInitializers();
+    }
+
+    function doStuff() public onlyInitializing {
+        counter++;
+    }
+}

+ 78 - 16
contracts/proxy/utils/Initializable.sol

@@ -11,6 +11,26 @@ import "../../utils/Address.sol";
  * external initializer function, usually called `initialize`. It then becomes necessary to protect this initializer
  * function so it can only be called once. The {initializer} modifier provided by this contract will have this effect.
  *
+ * The initialization functions use a version number. Once a version number is used, it is consumed and cannot be
+ * reused. This mechanism prevents re-execution of each "step" but allows the creation of new initialization steps in
+ * case an upgrade adds a module that needs to be initialized.
+ *
+ * For example:
+ *
+ * [.hljs-theme-light.nopadding]
+ * ```
+ * contract MyToken is ERC20Upgradeable {
+ *     function initialize() initializer public {
+ *         __ERC20_init("MyToken", "MTK");
+ *     }
+ * }
+ * contract MyTokenV2 is MyToken, ERC20PermitUpgradeable {
+ *     function initializeV2() reinitializer(2) public {
+ *         __ERC20Permit_init("MyToken");
+ *     }
+ * }
+ * ```
+ *
  * TIP: To avoid leaving the proxy in an uninitialized state, the initializer function should be called as early as
  * possible by providing the encoded function call as the `_data` argument to {ERC1967Proxy-constructor}.
  *
@@ -22,21 +42,24 @@ import "../../utils/Address.sol";
  * Avoid leaving a contract uninitialized.
  *
  * An uninitialized contract can be taken over by an attacker. This applies to both a proxy and its implementation
- * contract, which may impact the proxy. To initialize the implementation contract, you can either invoke the
- * initializer manually, or you can include a constructor to automatically mark it as initialized when it is deployed:
+ * contract, which may impact the proxy. To prevent the implementation contract from being used, you should invoke
+ * the {_disableInitializers} function in the constructor to automatically lock it when it is deployed:
  *
  * [.hljs-theme-light.nopadding]
  * ```
  * /// @custom:oz-upgrades-unsafe-allow constructor
- * constructor() initializer {}
+ * constructor() {
+ *     _disableInitializers();
+ * }
  * ```
  * ====
  */
 abstract contract Initializable {
     /**
      * @dev Indicates that the contract has been initialized.
+     * @custom:oz-retyped-from bool
      */
-    bool private _initialized;
+    uint8 private _initialized;
 
     /**
      * @dev Indicates that the contract is in the process of being initialized.
@@ -44,22 +67,38 @@ abstract contract Initializable {
     bool private _initializing;
 
     /**
-     * @dev Modifier to protect an initializer function from being invoked twice.
+     * @dev A modifier that defines a protected initializer function that can be invoked at most once. In its scope,
+     * `onlyInitializing` functions can be used to initialize parent contracts. Equivalent to `reinitializer(1)`.
      */
     modifier initializer() {
-        // If the contract is initializing we ignore whether _initialized is set in order to support multiple
-        // inheritance patterns, but we only do this in the context of a constructor, because in other contexts the
-        // contract may have been reentered.
-        require(_initializing ? _isConstructor() : !_initialized, "Initializable: contract is already initialized");
-
-        bool isTopLevelCall = !_initializing;
+        bool isTopLevelCall = _setInitializedVersion(1);
         if (isTopLevelCall) {
             _initializing = true;
-            _initialized = true;
         }
-
         _;
+        if (isTopLevelCall) {
+            _initializing = false;
+        }
+    }
 
+    /**
+     * @dev A modifier that defines a protected reinitializer function that can be invoked at most once, and only if the
+     * contract hasn't been initialized to a greater version before. In its scope, `onlyInitializing` functions can be
+     * used to initialize parent contracts.
+     *
+     * `initializer` is equivalent to `reinitializer(1)`, so a reinitializer may be used after the original
+     * initialization step. This is essential to configure modules that are added through upgrades and that require
+     * initialization.
+     *
+     * Note that versions can jump in increments greater than 1; this implies that if multiple reinitializers coexist in
+     * a contract, executing them in the right order is up to the developer or operator.
+     */
+    modifier reinitializer(uint8 version) {
+        bool isTopLevelCall = _setInitializedVersion(version);
+        if (isTopLevelCall) {
+            _initializing = true;
+        }
+        _;
         if (isTopLevelCall) {
             _initializing = false;
         }
@@ -67,14 +106,37 @@ abstract contract Initializable {
 
     /**
      * @dev Modifier to protect an initialization function so that it can only be invoked by functions with the
-     * {initializer} modifier, directly or indirectly.
+     * {initializer} and {reinitializer} modifiers, directly or indirectly.
      */
     modifier onlyInitializing() {
         require(_initializing, "Initializable: contract is not initializing");
         _;
     }
 
-    function _isConstructor() private view returns (bool) {
-        return !Address.isContract(address(this));
+    /**
+     * @dev Locks the contract, preventing any future reinitialization. This cannot be part of an initializer call.
+     * Calling this in the constructor of a contract will prevent that contract from being initialized or reinitialized
+     * to any version. It is recommended to use this to lock implementation contracts that are designed to be called
+     * through proxies.
+     */
+    function _disableInitializers() internal virtual {
+        _setInitializedVersion(type(uint8).max);
+    }
+
+    function _setInitializedVersion(uint8 version) private returns (bool) {
+        // If the contract is initializing we ignore whether _initialized is set in order to support multiple
+        // inheritance patterns, but we only do this in the context of a constructor, and for the lowest level
+        // of initializers, because in other contexts the contract may have been reentered.
+        if (_initializing) {
+            require(
+                version == 1 && !Address.isContract(address(this)),
+                "Initializable: contract is already initialized"
+            );
+            return false;
+        } else {
+            require(_initialized < version, "Initializable: contract is already initialized");
+            _initialized = version;
+            return true;
+        }
     }
 }

+ 69 - 14
test/proxy/utils/Initializable.test.js

@@ -1,8 +1,9 @@
 const { expectRevert } = require('@openzeppelin/test-helpers');
-const { assert } = require('chai');
+const { expect } = require('chai');
 
 const InitializableMock = artifacts.require('InitializableMock');
 const ConstructorInitializableMock = artifacts.require('ConstructorInitializableMock');
+const ReinitializerMock = artifacts.require('ReinitializerMock');
 const SampleChild = artifacts.require('SampleChild');
 
 contract('Initializable', function (accounts) {
@@ -13,7 +14,7 @@ contract('Initializable', function (accounts) {
 
     context('before initialize', function () {
       it('initializer has not run', async function () {
-        assert.isFalse(await this.contract.initializerRan());
+        expect(await this.contract.initializerRan()).to.equal(false);
       });
     });
 
@@ -23,7 +24,7 @@ contract('Initializable', function (accounts) {
       });
 
       it('initializer has run', async function () {
-        assert.isTrue(await this.contract.initializerRan());
+        expect(await this.contract.initializerRan()).to.equal(true);
       });
 
       it('initializer does not run again', async function () {
@@ -38,7 +39,7 @@ contract('Initializable', function (accounts) {
 
       it('onlyInitializing modifier succeeds', async function () {
         await this.contract.onlyInitializingNested();
-        assert.isTrue(await this.contract.onlyInitializingRan());
+        expect(await this.contract.onlyInitializingRan()).to.equal(true);
       });
     });
 
@@ -49,15 +50,69 @@ contract('Initializable', function (accounts) {
 
   it('nested initializer can run during construction', async function () {
     const contract2 = await ConstructorInitializableMock.new();
-    assert.isTrue(await contract2.initializerRan());
-    assert.isTrue(await contract2.onlyInitializingRan());
+    expect(await contract2.initializerRan()).to.equal(true);
+    expect(await contract2.onlyInitializingRan()).to.equal(true);
+  });
+
+  describe('reinitialization', function () {
+    beforeEach('deploying', async function () {
+      this.contract = await ReinitializerMock.new();
+    });
+
+    it('can reinitialize', async function () {
+      expect(await this.contract.counter()).to.be.bignumber.equal('0');
+      await this.contract.initialize();
+      expect(await this.contract.counter()).to.be.bignumber.equal('1');
+      await this.contract.reinitialize(2);
+      expect(await this.contract.counter()).to.be.bignumber.equal('2');
+      await this.contract.reinitialize(3);
+      expect(await this.contract.counter()).to.be.bignumber.equal('3');
+    });
+
+    it('can jump multiple steps', async function () {
+      expect(await this.contract.counter()).to.be.bignumber.equal('0');
+      await this.contract.initialize();
+      expect(await this.contract.counter()).to.be.bignumber.equal('1');
+      await this.contract.reinitialize(128);
+      expect(await this.contract.counter()).to.be.bignumber.equal('2');
+    });
+
+    it('cannot nest reinitializers', async function () {
+      expect(await this.contract.counter()).to.be.bignumber.equal('0');
+      await expectRevert(this.contract.nestedReinitialize(2, 3), 'Initializable: contract is already initialized');
+      await expectRevert(this.contract.nestedReinitialize(3, 2), 'Initializable: contract is already initialized');
+    });
+
+    it('can chain reinitializers', async function () {
+      expect(await this.contract.counter()).to.be.bignumber.equal('0');
+      await this.contract.chainReinitialize(2, 3);
+      expect(await this.contract.counter()).to.be.bignumber.equal('2');
+    });
+
+    describe('contract locking', function () {
+      it('prevents initialization', async function () {
+        await this.contract.disableInitializers();
+        await expectRevert(this.contract.initialize(), 'Initializable: contract is already initialized');
+      });
+
+      it('prevents re-initialization', async function () {
+        await this.contract.disableInitializers();
+        await expectRevert(this.contract.reinitialize(255), 'Initializable: contract is already initialized');
+      });
+
+      it('can lock contract after initialization', async function () {
+        await this.contract.initialize();
+        await this.contract.disableInitializers();
+        await expectRevert(this.contract.reinitialize(255), 'Initializable: contract is already initialized');
+      });
+    });
   });
 
   describe('complex testing with inheritance', function () {
-    const mother = 12;
+    const mother = '12';
     const gramps = '56';
-    const father = 34;
-    const child = 78;
+    const father = '34';
+    const child = '78';
 
     beforeEach('deploying', async function () {
       this.contract = await SampleChild.new();
@@ -68,23 +123,23 @@ contract('Initializable', function (accounts) {
     });
 
     it('initializes human', async function () {
-      assert.equal(await this.contract.isHuman(), true);
+      expect(await this.contract.isHuman()).to.be.equal(true);
     });
 
     it('initializes mother', async function () {
-      assert.equal(await this.contract.mother(), mother);
+      expect(await this.contract.mother()).to.be.bignumber.equal(mother);
     });
 
     it('initializes gramps', async function () {
-      assert.equal(await this.contract.gramps(), gramps);
+      expect(await this.contract.gramps()).to.be.bignumber.equal(gramps);
     });
 
     it('initializes father', async function () {
-      assert.equal(await this.contract.father(), father);
+      expect(await this.contract.father()).to.be.bignumber.equal(father);
     });
 
     it('initializes child', async function () {
-      assert.equal(await this.contract.child(), child);
+      expect(await this.contract.child()).to.be.bignumber.equal(child);
     });
   });
 });