فهرست منبع

Substrate: Implement builtin for `set_code_hash` API (#1397)

Cyrill Leutwiler 2 سال پیش
والد
کامیت
d7f62e8d60

+ 24 - 0
docs/language/builtins.rst

@@ -346,6 +346,30 @@ is_contract(address AccountId) returns (bool)
 
 Only available on Substrate. Checks whether the given address is a contract address. 
 
+set_code_hash(bytes hash) returns (uint32)
+++++++++++++++++++++++++++++++++++++++++++
+
+Only available on Substrate. Replace the contract's code with the code corresponding to ``hash``.
+Assumes that the new code was already uploaded, otherwise the operation fails.
+A return value of 0 indicates success; a return value of 7 indicates that there was no corresponding code found.
+
+.. note::
+
+    This is a low level function. We strongly advise consulting the underlying 
+    `API documentation <https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/trait.Version0.html#tymethod.set_code_hash>`_ 
+    to obtain a full understanding of its implications.
+
+This functionality is intended to be used for implementing upgradeable contracts. 
+Pitfalls generally applying to writing
+`upgradeable contracts <https://docs.openzeppelin.com/upgrades-plugins/1.x/writing-upgradeable>`_ 
+must be considered whenever using this builtin function, most notably:
+
+* The contract must safeguard access to this functionality, so that it is only callable by priviledged users.
+* The code you are upgrading to must be 
+  `storage compatible <https://docs.openzeppelin.com/upgrades-plugins/1.x/proxies#storage-collisions-between-implementation-versions>`_
+  with the existing code.
+* Constructors and any other initializers, including initial storage value definitions, won't be executed.
+
 Cryptography
 ____________
 

+ 31 - 0
integration/substrate/set_code_hash.sol

@@ -0,0 +1,31 @@
+import "substrate";
+
+abstract contract Upgradeable {
+    function set_code(uint8[32] code) external {
+        require(set_code_hash(code) == 0);
+    }
+}
+
+contract SetCodeCounterV1 is Upgradeable {
+    uint public count;
+
+    constructor(uint _count) {
+        count = _count;
+    }
+
+    function inc() external {
+        count += 1;
+    }
+}
+
+contract SetCodeCounterV2 is Upgradeable {
+    uint public count;
+
+    constructor(uint _count) {
+        count = _count;
+    }
+
+    function inc() external {
+        count -= 1;
+    }
+}

+ 56 - 0
integration/substrate/set_code_hash.spec.ts

@@ -0,0 +1,56 @@
+import expect from 'expect';
+import { createConnection, deploy, aliceKeypair, query, debug_buffer, weight, transaction, } from './index';
+import { ContractPromise } from '@polkadot/api-contract';
+import { ApiPromise } from '@polkadot/api';
+import { KeyringPair } from '@polkadot/keyring/types';
+import { U8aFixed } from '@polkadot/types';
+
+describe('Deploy the SetCodeCounter contracts and test for the upgrade to work', () => {
+    let conn: ApiPromise;
+    let counter: ContractPromise;
+    let hashes: [U8aFixed, U8aFixed];
+    let alice: KeyringPair;
+
+    before(async function () {
+        alice = aliceKeypair();
+        conn = await createConnection();
+
+        const counterV1 = await deploy(conn, alice, 'SetCodeCounterV1.contract', 0n, 1336n);
+        const counterV2 = await deploy(conn, alice, 'SetCodeCounterV2.contract', 0n, 0n);
+        hashes = [counterV1.abi.info.source.wasmHash, counterV2.abi.info.source.wasmHash];
+        counter = new ContractPromise(conn, counterV1.abi, counterV1.address);
+    });
+
+    after(async function () {
+        await conn.disconnect();
+    });
+
+    it('can switch out implementation using set_code_hash', async function () {
+        // Code hash should be V1, expect to increment
+        let gasLimit = await weight(conn, counter, 'inc', []);
+        await transaction(counter.tx.inc({ gasLimit }), alice);
+        let count = await query(conn, alice, counter, "count");
+        expect(BigInt(count.output?.toString() ?? "")).toStrictEqual(1337n);
+
+        // Switching to V2
+        gasLimit = await weight(conn, counter, 'set_code', [hashes[1]]);
+        await transaction(counter.tx.setCode({ gasLimit }, hashes[1]), alice);
+
+        // Code hash should be V2, expect to decrement
+        gasLimit = await weight(conn, counter, 'inc', []);
+        await transaction(counter.tx.inc({ gasLimit }), alice);
+        count = await query(conn, alice, counter, "count");
+        expect(BigInt(count.output?.toString() ?? "")).toStrictEqual(1336n);
+
+        // Switching to V1
+        gasLimit = await weight(conn, counter, 'set_code', [hashes[0]]);
+        await transaction(counter.tx.setCode({ gasLimit }, hashes[0]), alice);
+
+        // Code hash should be V1, expect to increment
+        gasLimit = await weight(conn, counter, 'inc', []);
+        await transaction(counter.tx.inc({ gasLimit }), alice);
+        count = await query(conn, alice, counter, "count");
+        expect(BigInt(count.output?.toString() ?? "")).toStrictEqual(1337n);
+    });
+
+});

+ 2 - 0
src/emit/substrate/mod.rs

@@ -179,6 +179,7 @@ impl SubstrateTarget {
             "deposit_event",
             "transfer",
             "is_contract",
+            "set_code_hash",
         ]);
 
         binary
@@ -320,6 +321,7 @@ impl SubstrateTarget {
         external!("terminate", void_type, u8_ptr);
         external!("deposit_event", void_type, u8_ptr, u32_val, u8_ptr, u32_val);
         external!("is_contract", i32_type, u8_ptr);
+        external!("set_code_hash", i32_type, u8_ptr);
     }
 
     /// Emits the "deploy" function if `init` is `Some`, otherwise emits the "call" function.

+ 13 - 0
src/emit/substrate/target.rs

@@ -1719,6 +1719,19 @@ impl<'a> TargetRuntime<'a> for SubstrateTarget {
                     .build_store(args[1].into_pointer_value(), is_contract);
                 None
             }
+            "set_code_hash" => {
+                let ptr = args[0].into_pointer_value();
+                let ret = call!("set_code_hash", &[ptr.into()], "seal_set_code_hash")
+                    .try_as_basic_value()
+                    .left()
+                    .unwrap()
+                    .into_int_value();
+                binary
+                    .builder
+                    .build_store(args[1].into_pointer_value(), ret);
+                log_return_code(binary, "seal_set_code_hash", ret);
+                None
+            }
             _ => unimplemented!(),
         }
     }

+ 105 - 78
src/sema/builtin.rs

@@ -1669,111 +1669,138 @@ impl Namespace {
         assert!(self.add_symbol(file_no, None, &identifier("Hash"), symbol));
 
         // Chain extensions
-        let mut func = Function::new(
-            loc,
-            "chain_extension".to_string(),
-            None,
-            Vec::new(),
-            pt::FunctionTy::Function,
-            None,
-            pt::Visibility::Public(Some(loc)),
-            vec![
-                Parameter {
+        for mut func in [
+            Function::new(
+                loc,
+                "chain_extension".to_string(),
+                None,
+                Vec::new(),
+                pt::FunctionTy::Function,
+                None,
+                pt::Visibility::Public(Some(loc)),
+                vec![
+                    Parameter {
+                        loc,
+                        id: Some(identifier("id")),
+                        ty: Type::Uint(32),
+                        ty_loc: Some(loc),
+                        readonly: false,
+                        indexed: false,
+                        infinite_size: false,
+                        recursive: false,
+                        annotation: None,
+                    },
+                    Parameter {
+                        loc,
+                        id: Some(identifier("input")),
+                        ty: Type::DynamicBytes,
+                        ty_loc: Some(loc),
+                        readonly: false,
+                        indexed: false,
+                        infinite_size: false,
+                        recursive: false,
+                        annotation: None,
+                    },
+                ],
+                vec![
+                    Parameter {
+                        loc,
+                        id: Some(identifier("return_value")),
+                        ty: Type::Uint(32),
+                        ty_loc: Some(loc),
+                        readonly: false,
+                        indexed: false,
+                        infinite_size: false,
+                        recursive: false,
+                        annotation: None,
+                    },
+                    Parameter {
+                        loc,
+                        id: Some(identifier("output")),
+                        ty: Type::DynamicBytes,
+                        ty_loc: Some(loc),
+                        readonly: false,
+                        indexed: false,
+                        infinite_size: false,
+                        recursive: false,
+                        annotation: None,
+                    },
+                ],
+                self,
+            ),
+            // is_contract API
+            Function::new(
+                loc,
+                "is_contract".to_string(),
+                None,
+                Vec::new(),
+                pt::FunctionTy::Function,
+                Some(pt::Mutability::View(loc)),
+                pt::Visibility::Public(Some(loc)),
+                vec![Parameter {
                     loc,
-                    id: Some(identifier("id")),
-                    ty: Type::Uint(32),
+                    id: Some(identifier("address")),
+                    ty: Type::Address(false),
                     ty_loc: Some(loc),
                     readonly: false,
                     indexed: false,
                     infinite_size: false,
                     recursive: false,
                     annotation: None,
-                },
-                Parameter {
+                }],
+                vec![Parameter {
                     loc,
-                    id: Some(identifier("input")),
-                    ty: Type::DynamicBytes,
+                    id: Some(identifier("is_contract")),
+                    ty: Type::Bool,
                     ty_loc: Some(loc),
                     readonly: false,
                     indexed: false,
                     infinite_size: false,
                     recursive: false,
                     annotation: None,
-                },
-            ],
-            vec![
-                Parameter {
+                }],
+                self,
+            ),
+            // set_code_hash API
+            Function::new(
+                loc,
+                "set_code_hash".to_string(),
+                None,
+                Vec::new(),
+                pt::FunctionTy::Function,
+                None,
+                pt::Visibility::Public(Some(loc)),
+                vec![Parameter {
                     loc,
-                    id: Some(identifier("return_value")),
-                    ty: Type::Uint(32),
+                    id: Some(identifier("code_hash_ptr")),
+                    // FIXME: The hash length should be configurable
+                    ty: Type::Array(Type::Uint(8).into(), vec![ArrayLength::Fixed(32.into())]),
                     ty_loc: Some(loc),
                     readonly: false,
                     indexed: false,
                     infinite_size: false,
                     recursive: false,
                     annotation: None,
-                },
-                Parameter {
+                }],
+                vec![Parameter {
                     loc,
-                    id: Some(identifier("output")),
-                    ty: Type::DynamicBytes,
+                    id: Some(identifier("return_code")),
+                    ty: Type::Uint(32),
                     ty_loc: Some(loc),
                     readonly: false,
                     indexed: false,
                     infinite_size: false,
                     recursive: false,
                     annotation: None,
-                },
-            ],
-            self,
-        );
-
-        func.has_body = true;
-        let func_no = self.functions.len();
-        let id = identifier(&func.name);
-        self.functions.push(func);
-
-        assert!(self.add_symbol(file_no, None, &id, Symbol::Function(vec![(loc, func_no)])));
-
-        // is_contract API
-        let mut func = Function::new(
-            loc,
-            "is_contract".to_string(),
-            None,
-            Vec::new(),
-            pt::FunctionTy::Function,
-            Some(pt::Mutability::View(loc)),
-            pt::Visibility::Public(Some(loc)),
-            vec![Parameter {
-                loc,
-                id: Some(identifier("address")),
-                ty: Type::Address(false),
-                ty_loc: Some(loc),
-                readonly: false,
-                indexed: false,
-                infinite_size: false,
-                recursive: false,
-                annotation: None,
-            }],
-            vec![Parameter {
-                loc,
-                id: Some(identifier("is_contract")),
-                ty: Type::Bool,
-                ty_loc: Some(loc),
-                readonly: false,
-                indexed: false,
-                infinite_size: false,
-                recursive: false,
-                annotation: None,
-            }],
-            self,
-        );
-
-        func.has_body = true;
-        let func_no = self.functions.len();
-        let id = identifier(&func.name);
-        self.functions.push(func);
-
-        assert!(self.add_symbol(file_no, None, &id, Symbol::Function(vec![(loc, func_no)])));
+                }],
+                self,
+            ),
+        ] {
+            func.has_body = true;
+            let func_no = self.functions.len();
+            let id = identifier(&func.name);
+            self.functions.push(func);
+            assert!(self.add_symbol(file_no, None, &id, Symbol::Function(vec![(loc, func_no)])));
+        }
     }
 }

+ 24 - 5
tests/substrate.rs

@@ -70,12 +70,12 @@ impl HostError for HostReturn {}
 
 /// Represents a contract code artifact.
 #[derive(Clone)]
-struct WasmCode {
+pub struct WasmCode {
     /// A mapping from function names to selectors.
     messages: HashMap<String, Vec<u8>>,
     /// A list of the selectors of the constructors.
     constructors: Vec<Vec<u8>>,
-    hash: [u8; 32],
+    hash: Hash,
     blob: Vec<u8>,
 }
 
@@ -235,7 +235,7 @@ impl Runtime {
         Self {
             accounts: blobs
                 .iter()
-                .map(|blob| Account::with_contract(&blob.hash, blob))
+                .map(|blob| Account::with_contract(blob.hash.as_ref(), blob))
                 .collect(),
             blobs,
             ..Default::default()
@@ -309,7 +309,7 @@ impl Runtime {
     /// Returns `None` if there is no contract corresponding to the given `code_hash`.
     fn deploy(
         &mut self,
-        code_hash: [u8; 32],
+        code_hash: Hash,
         value: u128,
         salt: &[u8],
         input: Vec<u8>,
@@ -349,6 +349,10 @@ fn read_account(mem: &[u8], ptr: u32) -> Address {
     Address::try_from(&mem[ptr as usize..(ptr + 32) as usize]).unwrap()
 }
 
+fn read_hash(mem: &[u8], ptr: u32) -> Hash {
+    Hash::try_from(&mem[ptr as usize..(ptr + 32) as usize]).unwrap()
+}
+
 /// Host functions mock the original implementation, refer to the [pallet docs][1] for more information.
 ///
 /// [1]: https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/index.html
@@ -572,7 +576,7 @@ impl Runtime {
         salt_ptr: u32,
         salt_len: u32,
     ) -> Result<u32, Trap> {
-        let code_hash = read_account(mem, code_hash_ptr);
+        let code_hash = read_hash(mem, code_hash_ptr);
         let salt = read_buf(mem, salt_ptr, salt_len);
         let input = read_buf(mem, input_data_ptr, input_data_len);
         let value = read_value(mem, value_ptr);
@@ -782,6 +786,16 @@ impl Runtime {
             .any(|account| account.contract.is_some() && account.address == address)
             .into())
     }
+
+    #[seal(0)]
+    fn set_code_hash(code_hash_ptr: u32) -> Result<u32, Trap> {
+        let hash = read_hash(mem, code_hash_ptr);
+        if let Some(code) = vm.blobs.iter().find(|code| code.hash == hash) {
+            vm.accounts[vm.account].contract.as_mut().unwrap().code = code.clone();
+            return Ok(0);
+        }
+        Ok(7) // ReturnCode::CodeNoteFound
+    }
 }
 
 /// Provides a mock implementation of substrates [contracts pallet][1]
@@ -873,6 +887,11 @@ impl MockSubstrate {
         self.raw_constructor(input);
     }
 
+    /// Get a list of all uploaded cotracts
+    pub fn blobs(&self) -> Vec<WasmCode> {
+        self.0.data().blobs.clone()
+    }
+
     /// Call the "deploy" function with the given `input`.
     ///
     /// `input` must contain the selector fo the constructor.

+ 51 - 0
tests/substrate_tests/builtins.rs

@@ -794,3 +794,54 @@ fn is_contract() {
     runtime.function("test", [0; 32].to_vec());
     assert_eq!(runtime.output(), vec![0]);
 }
+
+#[test]
+fn set_code_hash() {
+    let mut runtime = build_solidity(
+        r##"
+        import "substrate";
+
+        abstract contract SetCode {
+            function set_code(uint8[32] code_hash) external {
+                require(set_code_hash(code_hash) == 0);
+            }
+        }
+        
+        contract CounterV1 is SetCode {
+            uint32 public count;
+        
+            function inc() external {
+                count += 1;
+            }
+        }
+        
+        contract CounterV2 is SetCode {
+            uint32 public count;
+        
+            function inc() external {
+                count -= 1;
+            }
+        }"##,
+    );
+
+    runtime.function("inc", vec![]);
+    runtime.function("count", vec![]);
+    assert_eq!(runtime.output(), 1u32.encode());
+
+    let v2_code_hash = ink_primitives::Hash::default().as_ref().to_vec();
+    runtime.function_expect_failure("set_code", v2_code_hash);
+
+    let v2_code_hash = runtime.blobs()[1].hash;
+    runtime.function("set_code", v2_code_hash.as_ref().to_vec());
+
+    runtime.function("inc", vec![]);
+    runtime.function("count", vec![]);
+    assert_eq!(runtime.output(), 0u32.encode());
+
+    let v1_code_hash = runtime.blobs()[0].hash;
+    runtime.function("set_code", v1_code_hash.as_ref().to_vec());
+
+    runtime.function("inc", vec![]);
+    runtime.function("count", vec![]);
+    assert_eq!(runtime.output(), 1u32.encode());
+}