Browse Source

After external function call, retrieve return data from contract storage

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 4 years ago
parent
commit
d6483a6420

+ 54 - 0
integration/solana/calls.spec.ts

@@ -0,0 +1,54 @@
+import expect from 'expect';
+import { establishConnection } from './index';
+import crypto from 'crypto';
+
+describe('Deploy solang contract and test', () => {
+    it('external_call', async function () {
+        this.timeout(50000);
+
+        let conn = await establishConnection();
+
+        let caller = await conn.loadProgram("caller.so", "caller.abi");
+        let callee = await conn.loadProgram("callee.so", "callee.abi");
+        let callee2 = await conn.loadProgram("callee2.so", "callee2.abi");
+
+        // call the constructor
+        await caller.call_constructor(conn, []);
+        await callee.call_constructor(conn, []);
+        await callee2.call_constructor(conn, []);
+
+        await callee.call_function(conn, "set_x", ["102"]);
+
+        let res = await callee.call_function(conn, "get_x", []);
+
+        expect(res["0"]).toBe("102");
+
+        let address_callee = '0x' + callee.get_storage_key().toBuffer().toString('hex');
+        let address_callee2 = '0x' + callee2.get_storage_key().toBuffer().toString('hex');
+        console.log("addres: " + address_callee);
+
+        await caller.call_function(conn, "do_call", [address_callee, "13123"], callee.all_keys());
+
+        res = await callee.call_function(conn, "get_x", []);
+
+        expect(res["0"]).toBe("13123");
+
+        res = await caller.call_function(conn, "do_call2", [address_callee, "20000"], callee.all_keys());
+
+        expect(res["0"]).toBe("33123");
+
+        let all_keys = callee.all_keys()
+
+        all_keys.push(...callee2.all_keys());
+
+        res = await caller.call_function(conn, "do_call3", [address_callee, address_callee2, ["3", "5", "7", "9"], "yo"], all_keys);
+
+        expect(res["0"]).toBe("24");
+        expect(res["1"]).toBe("my name is callee");
+
+        res = await caller.call_function(conn, "do_call4", [address_callee, address_callee2, ["1", "2", "3", "4"], "asda"], all_keys);
+
+        expect(res["0"]).toBe("10");
+        expect(res["1"]).toBe("x:asda");
+    });
+});

+ 38 - 0
integration/solana/external_call.sol

@@ -3,6 +3,20 @@ contract caller {
     function do_call(callee e, int64 v) public {
     function do_call(callee e, int64 v) public {
         e.set_x(v);
         e.set_x(v);
     }
     }
+
+    function do_call2(callee e, int64 v) view public returns (int64) {
+        return v + e.get_x();
+    }
+
+    // call two different functions
+    function do_call3(callee e, callee2 e2, int64[4] memory x, string memory y) pure public returns (int64, string memory) {
+        return (e2.do_stuff(x), e.get_name());
+    }
+
+    // call two different functions
+    function do_call4(callee e, callee2 e2, int64[4] memory x, string memory y) public returns (int64, string memory) {
+        return (e2.do_stuff(x), e.call2(e2, y));
+    }
 }
 }
 
 
 contract callee {
 contract callee {
@@ -15,4 +29,28 @@ contract callee {
     function get_x() public view returns (int64) {
     function get_x() public view returns (int64) {
         return x;
         return x;
     }
     }
+
+    function call2(callee2 e2, string s) public returns (string) {
+        return e2.do_stuff2(s);
+    }
+
+    function get_name() public pure returns (string) {
+        return "my name is callee";
+    }
+}
+
+contract callee2 {
+    function do_stuff(int64[4] memory x) public pure returns (int64) {
+        int64 total = 0;
+
+        for (uint i=0; i< x.length; i++)  {
+            total += x[i];
+        }
+
+        return total;
+    }
+
+    function do_stuff2(string x) public pure returns (string) {
+        return "x:" + x;
+    }
 }
 }

+ 2 - 2
src/bin/solang.rs

@@ -251,8 +251,6 @@ fn process_filename(
         "aggressive" => inkwell::OptimizationLevel::Aggressive,
         "aggressive" => inkwell::OptimizationLevel::Aggressive,
         _ => unreachable!(),
         _ => unreachable!(),
     };
     };
-    let context = inkwell::context::Context::create();
-
     let mut json_contracts = HashMap::new();
     let mut json_contracts = HashMap::new();
 
 
     // resolve phase
     // resolve phase
@@ -307,6 +305,8 @@ fn process_filename(
             );
             );
         }
         }
 
 
+        let context = inkwell::context::Context::create();
+
         let contract =
         let contract =
             resolved_contract.emit(&ns, &context, &filename, llvm_opt, math_overflow_check);
             resolved_contract.emit(&ns, &context, &filename, llvm_opt, math_overflow_check);
 
 

+ 123 - 64
src/emit/mod.rs

@@ -2,6 +2,7 @@ use crate::parser::pt;
 use crate::sema::ast;
 use crate::sema::ast;
 use crate::sema::ast::{Builtin, Expression, FormatArg, StringLocation};
 use crate::sema::ast::{Builtin, Expression, FormatArg, StringLocation};
 use std::cell::RefCell;
 use std::cell::RefCell;
+use std::ffi::CStr;
 use std::fmt;
 use std::fmt;
 use std::path::Path;
 use std::path::Path;
 use std::str;
 use std::str;
@@ -19,8 +20,7 @@ use inkwell::memory_buffer::MemoryBuffer;
 use inkwell::module::{Linkage, Module};
 use inkwell::module::{Linkage, Module};
 use inkwell::passes::PassManager;
 use inkwell::passes::PassManager;
 use inkwell::targets::{CodeModel, FileType, RelocMode, TargetTriple};
 use inkwell::targets::{CodeModel, FileType, RelocMode, TargetTriple};
-use inkwell::types::BasicTypeEnum;
-use inkwell::types::{BasicType, FunctionType, IntType, StringRadix};
+use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType, StringRadix};
 use inkwell::values::{
 use inkwell::values::{
     ArrayValue, BasicValueEnum, FunctionValue, GlobalValue, IntValue, PhiValue, PointerValue,
     ArrayValue, BasicValueEnum, FunctionValue, GlobalValue, IntValue, PhiValue, PointerValue,
 };
 };
@@ -3949,37 +3949,11 @@ pub trait TargetRuntime<'a> {
                         tys,
                         tys,
                         data,
                         data,
                     } => {
                     } => {
-                        let v = self
-                            .expression(contract, data, &w.vars, function)
-                            .into_pointer_value();
+                        let v = self.expression(contract, data, &w.vars, function);
 
 
-                        let mut data = unsafe {
-                            contract.builder.build_gep(
-                                v,
-                                &[
-                                    contract.context.i32_type().const_zero(),
-                                    contract.context.i32_type().const_int(2, false),
-                                ],
-                                "data",
-                            )
-                        };
+                        let mut data = contract.vector_bytes(v);
 
 
-                        let mut data_len = contract
-                            .builder
-                            .build_load(
-                                unsafe {
-                                    contract.builder.build_gep(
-                                        v,
-                                        &[
-                                            contract.context.i32_type().const_zero(),
-                                            contract.context.i32_type().const_zero(),
-                                        ],
-                                        "data_len",
-                                    )
-                                },
-                                "data_len",
-                            )
-                            .into_int_value();
+                        let mut data_len = contract.vector_len(v);
 
 
                         if let Some(selector) = selector {
                         if let Some(selector) = selector {
                             let exception = exception.unwrap();
                             let exception = exception.unwrap();
@@ -6132,28 +6106,67 @@ impl<'a> Contract<'a> {
                 .unwrap()
                 .unwrap()
                 .into_int_value()
                 .into_int_value()
         } else {
         } else {
-            // field 0 is the length
-            let vector = vector.into_pointer_value();
+            let struct_ty = vector
+                .into_pointer_value()
+                .get_type()
+                .get_element_type()
+                .into_struct_type();
+            let name = struct_ty.get_name().unwrap();
 
 
-            let len = unsafe {
-                self.builder.build_gep(
-                    vector,
-                    &[
-                        self.context.i32_type().const_zero(),
-                        self.context.i32_type().const_zero(),
-                    ],
-                    "vector_len",
-                )
-            };
+            if name == CStr::from_bytes_with_nul(b"struct.SolAccountInfo\0").unwrap() {
+                // load the data pointer
+                let data = self
+                    .builder
+                    .build_load(
+                        self.builder
+                            .build_struct_gep(vector.into_pointer_value(), 3, "data")
+                            .unwrap(),
+                        "data",
+                    )
+                    .into_pointer_value();
 
 
-            self.builder
-                .build_select(
-                    self.builder.build_is_null(vector, "vector_is_null"),
-                    self.context.i32_type().const_zero(),
-                    self.builder.build_load(len, "vector_len").into_int_value(),
-                    "length",
-                )
-                .into_int_value()
+                // get the offset of the return data
+                let header_ptr = self.builder.build_pointer_cast(
+                    data,
+                    self.context.i32_type().ptr_type(AddressSpace::Generic),
+                    "header_ptr",
+                );
+
+                let data_len_ptr = unsafe {
+                    self.builder.build_gep(
+                        header_ptr,
+                        &[self.context.i64_type().const_int(1, false)],
+                        "data_len_ptr",
+                    )
+                };
+
+                self.builder
+                    .build_load(data_len_ptr, "len")
+                    .into_int_value()
+            } else {
+                // field 0 is the length
+                let vector = vector.into_pointer_value();
+
+                let len = unsafe {
+                    self.builder.build_gep(
+                        vector,
+                        &[
+                            self.context.i32_type().const_zero(),
+                            self.context.i32_type().const_zero(),
+                        ],
+                        "vector_len",
+                    )
+                };
+
+                self.builder
+                    .build_select(
+                        self.builder.build_is_null(vector, "vector_is_null"),
+                        self.context.i32_type().const_zero(),
+                        self.builder.build_load(len, "vector_len").into_int_value(),
+                        "length",
+                    )
+                    .into_int_value()
+            }
         }
         }
     }
     }
 
 
@@ -6168,22 +6181,68 @@ impl<'a> Contract<'a> {
                 .unwrap()
                 .unwrap()
                 .into_pointer_value()
                 .into_pointer_value()
         } else {
         } else {
-            let data = unsafe {
-                self.builder.build_gep(
-                    vector.into_pointer_value(),
-                    &[
-                        self.context.i32_type().const_zero(),
-                        self.context.i32_type().const_int(2, false),
-                    ],
+            // Get the type name of the struct we are point to
+            let struct_ty = vector
+                .into_pointer_value()
+                .get_type()
+                .get_element_type()
+                .into_struct_type();
+            let name = struct_ty.get_name().unwrap();
+
+            if name == CStr::from_bytes_with_nul(b"struct.SolAccountInfo\0").unwrap() {
+                // load the data pointer
+                let data = self
+                    .builder
+                    .build_load(
+                        self.builder
+                            .build_struct_gep(vector.into_pointer_value(), 3, "data")
+                            .unwrap(),
+                        "data",
+                    )
+                    .into_pointer_value();
+
+                // get the offset of the return data
+                let header_ptr = self.builder.build_pointer_cast(
+                    data,
+                    self.context.i32_type().ptr_type(AddressSpace::Generic),
+                    "header_ptr",
+                );
+
+                let data_ptr = unsafe {
+                    self.builder.build_gep(
+                        header_ptr,
+                        &[self.context.i64_type().const_int(2, false)],
+                        "data_ptr",
+                    )
+                };
+
+                let offset = self.builder.build_load(data_ptr, "offset").into_int_value();
+
+                let v = unsafe { self.builder.build_gep(data, &[offset], "data") };
+
+                self.builder.build_pointer_cast(
+                    v,
+                    self.context.i8_type().ptr_type(AddressSpace::Generic),
                     "data",
                     "data",
                 )
                 )
-            };
+            } else {
+                let data = unsafe {
+                    self.builder.build_gep(
+                        vector.into_pointer_value(),
+                        &[
+                            self.context.i32_type().const_zero(),
+                            self.context.i32_type().const_int(2, false),
+                        ],
+                        "data",
+                    )
+                };
 
 
-            self.builder.build_pointer_cast(
-                data,
-                self.context.i8_type().ptr_type(AddressSpace::Generic),
-                "data",
-            )
+                self.builder.build_pointer_cast(
+                    data,
+                    self.context.i8_type().ptr_type(AddressSpace::Generic),
+                    "data",
+                )
+            }
         }
         }
     }
     }
 
 

+ 27 - 5
src/emit/solana.rs

@@ -247,7 +247,7 @@ impl SolanaTarget {
             .build_load(
             .build_load(
                 contract
                 contract
                     .builder
                     .builder
-                    .build_struct_gep(sol_params, 4, "input")
+                    .build_struct_gep(sol_params, 5, "input")
                     .unwrap(),
                     .unwrap(),
                 "data",
                 "data",
             )
             )
@@ -258,7 +258,7 @@ impl SolanaTarget {
             .build_load(
             .build_load(
                 contract
                 contract
                     .builder
                     .builder
-                    .build_struct_gep(sol_params, 5, "input_len")
+                    .build_struct_gep(sol_params, 6, "input_len")
                     .unwrap(),
                     .unwrap(),
                 "data_len",
                 "data_len",
             )
             )
@@ -2402,11 +2402,13 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
         success: Option<&mut BasicValueEnum<'b>>,
         success: Option<&mut BasicValueEnum<'b>>,
         payload: PointerValue<'b>,
         payload: PointerValue<'b>,
         payload_len: IntValue<'b>,
         payload_len: IntValue<'b>,
-        _address: Option<PointerValue<'b>>,
+        address: Option<PointerValue<'b>>,
         _gas: IntValue<'b>,
         _gas: IntValue<'b>,
         _value: IntValue<'b>,
         _value: IntValue<'b>,
         _ty: ast::CallTy,
         _ty: ast::CallTy,
     ) {
     ) {
+        debug_assert!(address.is_none());
+
         let parameters = contract
         let parameters = contract
             .builder
             .builder
             .get_insert_block()
             .get_insert_block()
@@ -2465,8 +2467,28 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
     }
     }
 
 
     /// Get return buffer for external call
     /// Get return buffer for external call
-    fn return_data<'b>(&self, _contract: &Contract<'b>) -> PointerValue<'b> {
-        unimplemented!();
+    fn return_data<'b>(&self, contract: &Contract<'b>) -> PointerValue<'b> {
+        let parameters = contract
+            .builder
+            .get_insert_block()
+            .unwrap()
+            .get_parent()
+            .unwrap()
+            .get_last_param()
+            .unwrap()
+            .into_pointer_value();
+
+        // return the account that returned the value
+        contract
+            .builder
+            .build_load(
+                contract
+                    .builder
+                    .build_struct_gep(parameters, 3, "ka_last_called")
+                    .unwrap(),
+                "data",
+            )
+            .into_pointer_value()
     }
     }
 
 
     fn return_code<'b>(&self, contract: &'b Contract, ret: IntValue<'b>) {
     fn return_code<'b>(&self, contract: &'b Contract, ret: IntValue<'b>) {

BIN
stdlib/bpf/solana.bc


+ 25 - 30
stdlib/solana.c

@@ -42,7 +42,7 @@ void *__malloc(uint32_t size)
     return sol_alloc_free_(size, NULL);
     return sol_alloc_free_(size, NULL);
 }
 }
 
 
-uint64_t external_call(uint8_t *input, uint32_t input_len, const SolParameters *params)
+uint64_t external_call(uint8_t *input, uint32_t input_len, SolParameters *params)
 {
 {
     uint64_t sol_invoke_signed_c(
     uint64_t sol_invoke_signed_c(
         const SolInstruction *instruction,
         const SolInstruction *instruction,
@@ -54,45 +54,40 @@ uint64_t external_call(uint8_t *input, uint32_t input_len, const SolParameters *
     // The first 32 bytes of the input is the destination address
     // The first 32 bytes of the input is the destination address
     const SolPubkey *dest = (const SolPubkey *)input;
     const SolPubkey *dest = (const SolPubkey *)input;
 
 
+    SolAccountMeta metas[10];
+    SolInstruction instruction = {
+        .program_id = NULL,
+        .accounts = metas,
+        .account_len = params->ka_num,
+        .data = input,
+        .data_len = input_len,
+    };
+
     for (int account_no = 0; account_no < params->ka_num; account_no++)
     for (int account_no = 0; account_no < params->ka_num; account_no++)
     {
     {
         const SolAccountInfo *acc = &params->ka[account_no];
         const SolAccountInfo *acc = &params->ka[account_no];
 
 
         if (SolPubkey_same(dest, acc->key))
         if (SolPubkey_same(dest, acc->key))
         {
         {
-            // found it
-            SolAccountMeta metas[3] = {
-                {
-                    .pubkey = params->ka[0].key,
-                    .is_writable = true,
-                    .is_signer = false,
-                },
-                {
-                    .pubkey = acc->key,
-                    .is_writable = true,
-                    .is_signer = false,
-                },
-                {
-                    .pubkey = acc->owner,
-                    .is_writable = false,
-                    .is_signer = false,
-                }};
-
-            SolInstruction instruction = {
-                .program_id = acc->owner,
-                .accounts = metas,
-                .account_len = SOL_ARRAY_SIZE(metas),
-                .data = input,
-                .data_len = input_len,
-            };
-
-            return sol_invoke_signed_c(&instruction, params->ka, params->ka_num, NULL, 0);
+            instruction.program_id = acc->owner;
+            params->ka_last_called = acc;
         }
         }
+
+        metas[account_no].pubkey = acc->key;
+        metas[account_no].is_writable = acc->is_writable;
+        metas[account_no].is_signer = acc->is_signer;
     }
     }
 
 
-    sol_log("call to account not in transaction");
+    if (instruction.program_id)
+    {
+        return sol_invoke_signed_c(&instruction, params->ka, params->ka_num, NULL, 0);
+    }
+    else
+    {
+        sol_log("call to account not in transaction");
 
 
-    return ERROR_INVALID_ACCOUNT_DATA;
+        return ERROR_INVALID_ACCOUNT_DATA;
+    }
 }
 }
 
 
 struct account_data_header
 struct account_data_header

+ 1 - 0
stdlib/solana_sdk.h

@@ -221,6 +221,7 @@ typedef struct
                           point to an array of SolAccountInfos */
                           point to an array of SolAccountInfos */
   uint64_t ka_num;       /** Number of SolAccountInfo entries in `ka` */
   uint64_t ka_num;       /** Number of SolAccountInfo entries in `ka` */
   uint64_t ka_cur;
   uint64_t ka_cur;
+  const SolAccountInfo *ka_last_called;
   const SolPubkey *account_id;
   const SolPubkey *account_id;
   const uint8_t *input;        /** pointer to the instruction data */
   const uint8_t *input;        /** pointer to the instruction data */
   uint64_t input_len;          /** Length in bytes of the instruction data */
   uint64_t input_len;          /** Length in bytes of the instruction data */

+ 16 - 0
tests/solana.rs

@@ -399,6 +399,8 @@ fn translate_slice_inner<'a, T>(
 
 
 struct SyscallInvokeSignedC<'a> {
 struct SyscallInvokeSignedC<'a> {
     context: Rc<RefCell<&'a mut VirtualMachine>>,
     context: Rc<RefCell<&'a mut VirtualMachine>>,
+    input: &'a [u8],
+    calldata: &'a [u8],
 }
 }
 
 
 impl<'a> SyscallInvokeSignedC<'a> {
 impl<'a> SyscallInvokeSignedC<'a> {
@@ -467,6 +469,18 @@ impl<'a> SyscallObject<UserError> for SyscallInvokeSignedC<'a> {
 
 
             context.execute(&instruction.data);
             context.execute(&instruction.data);
 
 
+            let parameter_bytes = serialize_parameters(&self.calldata, &context);
+
+            assert_eq!(parameter_bytes.len(), self.input.len());
+
+            unsafe {
+                std::ptr::copy(
+                    parameter_bytes.as_ptr(),
+                    self.input.as_ptr() as *mut u8,
+                    parameter_bytes.len(),
+                );
+            }
+
             context.stack.remove(0);
             context.stack.remove(0);
         }
         }
 
 
@@ -509,6 +523,8 @@ impl VirtualMachine {
             "sol_invoke_signed_c",
             "sol_invoke_signed_c",
             Syscall::Object(Box::new(SyscallInvokeSignedC {
             Syscall::Object(Box::new(SyscallInvokeSignedC {
                 context: context.clone(),
                 context: context.clone(),
+                input: &parameter_bytes,
+                calldata: &calldata,
             })),
             })),
         )
         )
         .unwrap();
         .unwrap();

+ 71 - 1
tests/solana_tests/call.rs

@@ -57,7 +57,7 @@ fn calltys() {
 }
 }
 
 
 #[test]
 #[test]
-fn two_contracts() {
+fn simple_external_call() {
     let mut vm = build_solidity(
     let mut vm = build_solidity(
         r#"
         r#"
         contract bar0 {
         contract bar0 {
@@ -101,3 +101,73 @@ fn two_contracts() {
 
 
     assert_eq!(vm.printbuf, "bar1 says: cross contract call");
     assert_eq!(vm.printbuf, "bar1 says: cross contract call");
 }
 }
+
+#[test]
+fn external_call_with_returns() {
+    let mut vm = build_solidity(
+        r#"
+        contract bar0 {
+            function test_other(bar1 x) public returns (int64) {
+                return x.test_bar(7) + 5;
+            }
+        }
+
+        contract bar1 {
+            function test_bar(int64 y) public returns (int64) {
+                return 3 + y;
+            }
+        }"#,
+    );
+
+    vm.constructor(&[]);
+
+    let res = vm.function("test_bar", &[Token::Int(ethereum_types::U256::from(21))]);
+
+    assert_eq!(res, vec![Token::Int(ethereum_types::U256::from(24))]);
+
+    let bar1_account = vm.stack[0].data;
+
+    vm.set_program(0);
+
+    vm.constructor(&[]);
+
+    let res = vm.function("test_other", &[Token::FixedBytes(bar1_account.to_vec())]);
+
+    assert_eq!(res, vec![Token::Int(ethereum_types::U256::from(15))]);
+}
+
+#[test]
+fn external_call_with_string_returns() {
+    let mut vm = build_solidity(
+        r#"
+        contract bar0 {
+            function test_other(bar1 x) public returns (string) {
+                string y = x.test_bar(7);
+                print(y);
+                return y;
+            }
+        }
+
+        contract bar1 {
+            function test_bar(int64 y) public returns (string) {
+                return "foo:{}".format(y);
+            }
+        }"#,
+    );
+
+    vm.constructor(&[]);
+
+    let res = vm.function("test_bar", &[Token::Int(ethereum_types::U256::from(22))]);
+
+    assert_eq!(res, vec![Token::String(String::from("foo:22"))]);
+
+    let bar1_account = vm.stack[0].data;
+
+    vm.set_program(0);
+
+    vm.constructor(&[]);
+
+    let res = vm.function("test_other", &[Token::FixedBytes(bar1_account.to_vec())]);
+
+    assert_eq!(res, vec![Token::String(String::from("foo:7"))]);
+}