Browse Source

Polkadot: Implement `caller_is_root` runtime API (#1620)

Recent versions of the contracts pallet expose a runtime API for
contracts to check whether their caller is of root origin. The PR
exposes this API as a builtin.
Cyrill Leutwiler 1 year ago
parent
commit
cda387d7c1

+ 5 - 0
docs/language/builtins.rst

@@ -343,6 +343,11 @@ is_contract(address AccountId) returns (bool)
 
 Only available on Polkadot. Checks whether the given address is a contract address.
 
+caller_is_root() returns (bool)
++++++++++++++++++++++++++++++++
+
+Only available on Polkadot. Returns true if the caller of the contract is `root <https://docs.substrate.io/build/origins/>`_.
+
 set_code_hash(uint8[32] hash) returns (uint32)
 ++++++++++++++++++++++++++++++++++++++++++++++
 

+ 14 - 0
integration/polkadot/caller_is_root.sol

@@ -0,0 +1,14 @@
+import "polkadot";
+
+contract CallerIsRoot {
+    uint public balance;
+
+    function covert() public payable {
+        if (caller_is_root()) {
+            balance = 0xdeadbeef;
+        } else {
+            print("burn more gas");
+            balance = 1;
+        }
+    }
+}

+ 42 - 0
integration/polkadot/caller_is_root.spec.ts

@@ -0,0 +1,42 @@
+import expect from 'expect';
+import { createConnection, deploy, aliceKeypair, query, weight, transaction } from './index';
+import { ContractPromise } from '@polkadot/api-contract';
+import { ApiPromise } from '@polkadot/api';
+import { KeyringPair } from '@polkadot/keyring/types';
+
+describe('Deploy the caller_is_root contract and test it', () => {
+    let conn: ApiPromise;
+    let contract: ContractPromise;
+    let alice: KeyringPair;
+
+    before(async function () {
+        conn = await createConnection();
+        alice = aliceKeypair();
+        const instance = await deploy(conn, alice, 'CallerIsRoot.contract', 0n);
+        contract = new ContractPromise(conn, instance.abi, instance.address);
+    });
+
+    after(async function () {
+        await conn.disconnect();
+    });
+
+    it('is correct on a non-root caller', async function () {
+        // Without sudo the caller should not be root
+        const gasLimit = await weight(conn, contract, "covert");
+        await transaction(contract.tx.covert({ gasLimit }), alice);
+
+        // Calling `covert` as non-root sets the balance to 1
+        const balance = await query(conn, alice, contract, "balance", []);
+        expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(1n);
+    });
+
+    it('is correct on a root caller', async function () {
+        // Alice has sudo rights on --dev nodes
+        const gasLimit = await weight(conn, contract, "covert");
+        await transaction(conn.tx.sudo.sudo(contract.tx.covert({ gasLimit })), alice);
+
+        // Calling `covert` as root sets the balance to 0xdeadbeef
+        const balance = await query(conn, alice, contract, "balance", []);
+        expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(0xdeadbeefn);
+    });
+});

+ 8 - 4
src/emit/instructions.rs

@@ -519,10 +519,14 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
                 }
             }
 
-            let first_arg_type = bin.llvm_type(&args[0].ty(), ns);
-            if let Some(ret) =
-                target.builtin_function(bin, function, callee, &parms, first_arg_type, ns)
-            {
+            if let Some(ret) = target.builtin_function(
+                bin,
+                function,
+                callee,
+                &parms,
+                args.first().map(|arg| bin.llvm_type(&arg.ty(), ns)),
+                ns,
+            ) {
                 let success = bin.builder.build_int_compare(
                     IntPredicate::EQ,
                     ret.into_int_value(),

+ 1 - 1
src/emit/mod.rs

@@ -231,7 +231,7 @@ pub trait TargetRuntime<'a> {
         function: FunctionValue<'a>,
         builtin_func: &Function,
         args: &[BasicMetadataValueEnum<'a>],
-        first_arg_type: BasicTypeEnum,
+        first_arg_type: Option<BasicTypeEnum>,
         ns: &Namespace,
     ) -> Option<BasicValueEnum<'a>>;
 

+ 2 - 0
src/emit/polkadot/mod.rs

@@ -118,6 +118,7 @@ impl PolkadotTarget {
             "transfer",
             "is_contract",
             "set_code_hash",
+            "caller_is_root",
         ]);
 
         binary
@@ -266,6 +267,7 @@ impl PolkadotTarget {
         external!("deposit_event", void_type, u8_ptr, u32_val, u8_ptr, u32_val);
         external!("is_contract", i32_type, u8_ptr);
         external!("set_code_hash", i32_type, u8_ptr);
+        external!("caller_is_root", i32_type,);
     }
 
     /// Emits the "deploy" function if `storage_initializer` is `Some`, otherwise emits the "call" function.

+ 12 - 1
src/emit/polkadot/target.rs

@@ -1501,7 +1501,7 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
         _function: FunctionValue<'a>,
         builtin_func: &Function,
         args: &[BasicMetadataValueEnum<'a>],
-        _first_arg_type: BasicTypeEnum,
+        _first_arg_type: Option<BasicTypeEnum>,
         ns: &Namespace,
     ) -> Option<BasicValueEnum<'a>> {
         emit_context!(binary);
@@ -1579,6 +1579,17 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
                     .build_store(args[1].into_pointer_value(), ret);
                 None
             }
+            "caller_is_root" => {
+                let is_root = call!("caller_is_root", &[], "seal_caller_is_root")
+                    .try_as_basic_value()
+                    .left()
+                    .unwrap()
+                    .into_int_value();
+                binary
+                    .builder
+                    .build_store(args[0].into_pointer_value(), is_root);
+                None
+            }
             _ => unimplemented!(),
         }
     }

+ 4 - 1
src/emit/solana/target.rs

@@ -1251,9 +1251,12 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
         function: FunctionValue<'a>,
         builtin_func: &ast::Function,
         args: &[BasicMetadataValueEnum<'a>],
-        first_arg_type: BasicTypeEnum,
+        first_arg_type: Option<BasicTypeEnum>,
         ns: &ast::Namespace,
     ) -> Option<BasicValueEnum<'a>> {
+        let first_arg_type =
+            first_arg_type.expect("solana does not have builtin without any parameter");
+
         if builtin_func.id.name == "create_program_address" {
             let func = binary
                 .module

+ 1 - 1
src/emit/soroban/target.rs

@@ -218,7 +218,7 @@ impl<'a> TargetRuntime<'a> for SorobanTarget {
         function: FunctionValue<'a>,
         builtin_func: &Function,
         args: &[BasicMetadataValueEnum<'a>],
-        first_arg_type: BasicTypeEnum,
+        first_arg_type: Option<BasicTypeEnum>,
         ns: &Namespace,
     ) -> Option<BasicValueEnum<'a>> {
         unimplemented!()

+ 27 - 0
src/sema/builtin.rs

@@ -1826,6 +1826,33 @@ impl Namespace {
                 }],
                 self,
             ),
+            // caller_is_root API
+            Function::new(
+                loc,
+                loc,
+                pt::Identifier {
+                    name: "caller_is_root".to_string(),
+                    loc,
+                },
+                None,
+                Vec::new(),
+                pt::FunctionTy::Function,
+                Some(pt::Mutability::View(loc)),
+                pt::Visibility::Public(Some(loc)),
+                vec![],
+                vec![Parameter {
+                    loc,
+                    id: Some(identifier("caller_is_root")),
+                    ty: Type::Bool,
+                    ty_loc: Some(loc),
+                    readonly: false,
+                    indexed: false,
+                    infinite_size: false,
+                    recursive: false,
+                    annotation: None,
+                }],
+                self,
+            ),
         ] {
             func.has_body = true;
             let func_no = self.functions.len();

+ 7 - 7
tests/lir_tests/convert_lir.rs

@@ -660,7 +660,7 @@ fn test_assertion_using_require() {
     assert_polkadot_lir_str_eq(
         src,
         0,
-        r#"public function sol#3 Test::Test::function::test__int32 (int32):
+        r#"public function sol#4 Test::Test::function::test__int32 (int32):
 block#0 entry:
     int32 %num = int32(arg#0);
     bool %temp.ssa_ir.1 = int32(%num) > int32(10);
@@ -690,7 +690,7 @@ fn test_call_1() {
     assert_polkadot_lir_str_eq(
         src,
         0,
-        r#"public function sol#3 Test::Test::function::test__int32 (int32):
+        r#"public function sol#4 Test::Test::function::test__int32 (int32):
 block#0 entry:
     int32 %num = int32(arg#0);
      = call function#1(int32(%num));
@@ -754,7 +754,7 @@ fn test_value_transfer() {
     assert_polkadot_lir_str_eq(
         src,
         0,
-        r#"public function sol#3 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
+        r#"public function sol#4 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
 block#0 entry:
     uint8[32] %addr = uint8[32](arg#0);
     uint128 %amount = uint128(arg#1);
@@ -928,7 +928,7 @@ fn test_keccak256() {
     assert_polkadot_lir_str_eq(
         src,
         0,
-        r#"public function sol#3 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
+        r#"public function sol#4 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
 block#0 entry:
     ptr<struct.vector<uint8>> %name = ptr<struct.vector<uint8>>(arg#0);
     uint8[32] %addr = uint8[32](arg#1);
@@ -960,7 +960,7 @@ fn test_internal_function_cfg() {
     assert_polkadot_lir_str_eq(
         src,
         1,
-        r#"public function sol#4 A::A::function::bar__uint256 (uint256) returns (uint256):
+        r#"public function sol#5 A::A::function::bar__uint256 (uint256) returns (uint256):
 block#0 entry:
     uint256 %b = uint256(arg#0);
     ptr<function (uint256) returns (uint256)> %temp.ssa_ir.6 = function#0;
@@ -1124,14 +1124,14 @@ fn test_constructor() {
     assert_polkadot_lir_str_eq(
         src,
         0,
-        r#"public function sol#3 B::B::function::test__uint256 (uint256):
+        r#"public function sol#4 B::B::function::test__uint256 (uint256):
 block#0 entry:
     uint256 %a = uint256(arg#0);
     ptr<struct.vector<uint8>> %abi_encoded.temp.18 = alloc ptr<struct.vector<uint8>>[uint32(36)];
     uint32 %temp.ssa_ir.20 = uint32 hex"58_16_c4_25";
     write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(0) value:uint32(%temp.ssa_ir.20);
     write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(4) value:uint256(%a);
-    uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 5, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
+    uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 6, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
     switch uint32(%success.temp.17):
     case:    uint32(0) => block#1, 
     case:    uint32(2) => block#2

+ 12 - 0
tests/polkadot.rs

@@ -357,6 +357,8 @@ fn read_hash(mem: &[u8], ptr: u32) -> Hash {
 /// Host functions mock the original implementation, refer to the [pallet docs][1] for more information.
 ///
 /// [1]: https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/index.html
+///
+/// Address `[0; u8]` is considered the root account.
 #[wasm_host]
 impl Runtime {
     #[seal(0)]
@@ -787,6 +789,11 @@ impl Runtime {
             .into())
     }
 
+    #[seal(0)]
+    fn caller_is_root() -> Result<u32, Trap> {
+        Ok((vm.accounts[vm.caller_account].address == [0; 32]).into())
+    }
+
     #[seal(0)]
     fn set_code_hash(code_hash_ptr: u32) -> Result<u32, Trap> {
         let hash = read_hash(mem, code_hash_ptr);
@@ -818,6 +825,11 @@ impl MockSubstrate {
         Ok(())
     }
 
+    /// Overwrites the address at asssociated `account` index with the given `address`.
+    pub fn set_account_address(&mut self, account: usize, address: [u8; 32]) {
+        self.0.data_mut().accounts[account].address = address;
+    }
+
     /// Specify the caller account index for the next function or constructor call.
     pub fn set_account(&mut self, index: usize) {
         self.0.data_mut().account = index;

+ 21 - 0
tests/polkadot_tests/builtins.rs

@@ -845,3 +845,24 @@ fn set_code_hash() {
     runtime.function("count", vec![]);
     assert_eq!(runtime.output(), 1u32.encode());
 }
+
+#[test]
+fn caller_is_root() {
+    let mut runtime = build_solidity(
+        r#"
+        import { caller_is_root } from "polkadot";
+        contract Test {
+            function test() public view returns (bool) {
+                return caller_is_root();
+            }
+        }"#,
+    );
+
+    runtime.function("test", runtime.0.data().accounts[0].address.to_vec());
+    assert_eq!(runtime.output(), false.encode());
+
+    // Set the caller address to [0; 32] which is the mock VM root account
+    runtime.set_account_address(0, [0; 32]);
+    runtime.function("test", [0; 32].to_vec());
+    assert_eq!(runtime.output(), true.encode());
+}