Kaynağa Gözat

Use new solana mechanism for return data

Requires https://github.com/solana-labs/solana/pull/19548

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 4 yıl önce
ebeveyn
işleme
ccd6217794

+ 1 - 1
.github/workflows/test.yml

@@ -143,7 +143,7 @@ jobs:
     needs: linux
     services:
       solana:
-        image: solanalabs/solana:v1.7.3
+        image: solanalabs/solana:edge
         ports:
           - 8899
           - 8900

+ 0 - 1
integration/solana/calls.spec.ts

@@ -1,6 +1,5 @@
 import expect from 'expect';
 import { establishConnection } from './index';
-import crypto from 'crypto';
 
 describe('Deploy solang contract and test', () => {
     it('external_call', async function () {

+ 9 - 0
integration/solana/errors.sol

@@ -0,0 +1,9 @@
+contract errors {
+    function do_revert(bool yes) pure public returns (int) {
+        if (yes) {
+            revert("Do the revert thing");
+        } else {
+            return 3124445;
+        }
+    }
+}

+ 24 - 0
integration/solana/errors.spec.ts

@@ -0,0 +1,24 @@
+import expect from 'expect';
+import { establishConnection } from './index';
+
+describe('Deploy solang contract and test', () => {
+    it('errors', async function () {
+        this.timeout(50000);
+
+        let conn = await establishConnection();
+
+        let errors = await conn.loadProgram("bundle.so", "errors.abi");
+
+        // call the constructor
+        await errors.call_constructor(conn, 'errors', []);
+
+        let res = await errors.call_function(conn, "do_revert", [false]);
+
+        expect(res["0"]).toBe("3124445");
+
+        let revert_res = await errors.call_function_expect_revert(conn, "do_revert", [true]);
+
+        expect(revert_res).toBe("Do the revert thing");
+
+    });
+});

+ 92 - 10
integration/solana/index.ts

@@ -14,13 +14,15 @@ import fs from 'fs';
 import { AbiItem } from 'web3-utils';
 import { utils } from 'ethers';
 import crypto from 'crypto';
+import { encode } from 'querystring';
 const Web3EthAbi = require('web3-eth-abi');
 
 const default_url: string = "http://localhost:8899";
+const return_data_prefix = 'Program return data: ';
 
 export async function establishConnection(): Promise<TestConnection> {
     let url = process.env.RPC_URL || default_url;
-    let connection = new Connection(url, 'recent');
+    let connection = new Connection(url, 'confirmed');
     const version = await connection.getVersion();
     console.log('Connection to cluster established:', url, version);
 
@@ -101,7 +103,7 @@ class TestConnection {
             [this.payerAccount, account],
             {
                 skipPreflight: false,
-                commitment: 'recent',
+                commitment: 'confirmed',
                 preflightCommitment: undefined,
             },
         );
@@ -190,7 +192,7 @@ class Program {
             [test.payerAccount],
             {
                 skipPreflight: false,
-                commitment: 'recent',
+                commitment: 'confirmed',
                 preflightCommitment: undefined,
             },
         );
@@ -221,7 +223,8 @@ class Program {
         keys.push({ pubkey: PublicKey.default, isSigner: false, isWritable: false });
 
         for (let i = 0; i < pubkeys.length; i++) {
-            keys.push({ pubkey: pubkeys[i], isSigner: false, isWritable: true });
+            // make each 2nd key writable (will be account storage for contract)
+            keys.push({ pubkey: pubkeys[i], isSigner: false, isWritable: (i & 1) == 1 });
         }
 
         const instruction = new TransactionInstruction({
@@ -232,24 +235,37 @@ class Program {
 
         signers.unshift(test.payerAccount);
 
-        await sendAndConfirmTransaction(
+        let signature = await sendAndConfirmTransaction(
             test.connection,
             new Transaction().add(instruction),
             signers,
             {
                 skipPreflight: false,
-                commitment: 'recent',
+                commitment: 'confirmed',
                 preflightCommitment: undefined,
             },
         );
 
         if (abi.outputs?.length) {
-            const accountInfo = await test.connection.getAccountInfo(this.contractStorageAccount.publicKey);
+            const parsedTx = await test.connection.getParsedConfirmedTransaction(
+                signature,
+            );
+
+            let encoded = Buffer.from([]);
 
-            let length = Number(accountInfo!.data.readUInt32LE(4));
-            let offset = Number(accountInfo!.data.readUInt32LE(8));
+            let seen = 0;
 
-            let encoded = accountInfo!.data.slice(offset, length + offset);
+            for (let message of parsedTx!.meta?.logMessages!) {
+                if (message.startsWith(return_data_prefix)) {
+                    let [program_id, return_data] = message.slice(return_data_prefix.length).split(" ");
+                    encoded = Buffer.from(return_data, 'base64')
+                    seen += 1;
+                }
+            }
+
+            if (seen == 0) {
+                throw 'return data not set';
+            }
 
             let returns = Web3EthAbi.decodeParameters(abi.outputs!, encoded.toString('hex'));
 
@@ -267,6 +283,72 @@ class Program {
         }
     }
 
+    async call_function_expect_revert(test: TestConnection, name: string, params: any[], pubkeys: PublicKey[] = [], seeds: any[] = [], signers: Keypair[] = []): Promise<string> {
+        let abi: AbiItem = JSON.parse(this.abi).find((e: AbiItem) => e.name == name);
+
+        const input: string = Web3EthAbi.encodeFunctionCall(abi, params);
+        const data = Buffer.concat([
+            this.contractStorageAccount.publicKey.toBuffer(),
+            test.payerAccount.publicKey.toBuffer(),
+            Buffer.from('00000000', 'hex'),
+            this.encode_seeds(seeds),
+            Buffer.from(input.replace('0x', ''), 'hex')
+        ]);
+
+        let debug = 'calling function ' + name + ' [' + params + ']';
+
+        let keys = [];
+
+        seeds.forEach((seed) => {
+            keys.push({ pubkey: seed.address, isSigner: false, isWritable: true });
+        });
+
+        keys.push({ pubkey: this.contractStorageAccount.publicKey, isSigner: false, isWritable: true });
+        keys.push({ pubkey: SYSVAR_CLOCK_PUBKEY, isSigner: false, isWritable: false });
+        keys.push({ pubkey: PublicKey.default, isSigner: false, isWritable: false });
+
+        for (let i = 0; i < pubkeys.length; i++) {
+            // make each 2nd key writable (will be account storage for contract)
+            keys.push({ pubkey: pubkeys[i], isSigner: false, isWritable: (i & 1) == 1 });
+        }
+
+        const instruction = new TransactionInstruction({
+            keys,
+            programId: this.programId,
+            data,
+        });
+
+        signers.unshift(test.payerAccount);
+
+        const { err, logs } = (await test.connection.simulateTransaction(new Transaction().add(instruction),
+            signers)).value;
+
+        if (!err) {
+            throw 'error is not falsy';
+        }
+
+        let encoded;
+        let seen = 0;
+
+        for (let message of logs!) {
+            if (message.startsWith(return_data_prefix)) {
+                let [program_id, return_data] = message.slice(return_data_prefix.length).split(" ");
+                encoded = Buffer.from(return_data, 'base64')
+                seen += 1;
+            }
+        }
+
+        if (seen == 0) {
+            throw 'return data not set';
+        }
+
+        if (encoded?.readUInt32BE(0) != 0x08c379a0) {
+            throw 'signature not correct';
+        }
+
+        return Web3EthAbi.decodeParameter('string', encoded.subarray(4).toString('hex'));
+    }
+
     async contract_storage(test: TestConnection, upto: number): Promise<Buffer> {
         const accountInfo = await test.connection.getAccountInfo(this.contractStorageAccount.publicKey);
 

+ 0 - 18
integration/solana/simple.spec.ts

@@ -407,24 +407,6 @@ describe('Deploy solang contract and test', () => {
             .toThrowError(new Error('failed to send transaction: Transaction simulation failed: Error processing Instruction 0: account data too small for instruction'));
     });
 
-    it('returndata too small', async function () {
-        this.timeout(50000);
-
-        let conn = await establishConnection();
-
-        // storage.sol needs 168 byes
-        let prog = await conn.loadProgram("bundle.so", "store.abi", 512);
-
-        await prog.call_constructor(conn, 'store', []);
-
-        await prog.call_function(conn, "set_foo1", []);
-
-        // get foo1
-        await expect(prog.call_function(conn, "get_both_foos", []))
-            .rejects
-            .toThrowError(new Error('failed to send transaction: Transaction simulation failed: Error processing Instruction 0: account data too small for instruction'));
-    });
-
     it('account storage too small dynamic alloc', async function () {
         this.timeout(50000);
 

+ 1 - 1
src/emit/ewasm.rs

@@ -1527,7 +1527,7 @@ impl<'a> TargetRuntime<'a> for EwasmTarget {
         }
     }
 
-    fn return_data<'b>(&self, binary: &Binary<'b>) -> PointerValue<'b> {
+    fn return_data<'b>(&self, binary: &Binary<'b>, _function: FunctionValue) -> PointerValue<'b> {
         let length = binary
             .builder
             .build_call(

+ 1 - 1
src/emit/generic.rs

@@ -638,7 +638,7 @@ impl<'a> TargetRuntime<'a> for GenericTarget {
     }
 
     /// Get return buffer for external call
-    fn return_data<'b>(&self, _binary: &Binary<'b>) -> PointerValue<'b> {
+    fn return_data<'b>(&self, _binary: &Binary<'b>, _function: FunctionValue) -> PointerValue<'b> {
         panic!("generic cannot call other contracts");
     }
 

+ 15 - 61
src/emit/mod.rs

@@ -291,7 +291,7 @@ pub trait TargetRuntime<'a> {
     ) -> BasicValueEnum<'b>;
 
     /// Return the return data from an external call (either revert error or return values)
-    fn return_data<'b>(&self, bin: &Binary<'b>) -> PointerValue<'b>;
+    fn return_data<'b>(&self, bin: &Binary<'b>, function: FunctionValue<'b>) -> PointerValue<'b>;
 
     /// Return the value we received
     fn value_transferred<'b>(&self, bin: &Binary<'b>, ns: &ast::Namespace) -> IntValue<'b>;
@@ -2537,7 +2537,7 @@ pub trait TargetRuntime<'a> {
                     )
                     .into()
             }
-            Expression::ReturnData(_) => self.return_data(bin).into(),
+            Expression::ReturnData(_) => self.return_data(bin, function).into(),
             Expression::StorageArrayLength { array, elem_ty, .. } => {
                 let slot = self
                     .expression(bin, array, vartab, function, ns)
@@ -6167,68 +6167,22 @@ impl<'a> Binary<'a> {
                 .unwrap()
                 .into_pointer_value()
         } else {
-            // Get the type name of the struct we are point to
-            let struct_ty = vector
-                .into_pointer_value()
-                .get_type()
-                .get_element_type()
-                .into_struct_type();
-            let name = struct_ty.get_name().unwrap();
-
-            if name == CStr::from_bytes_with_nul(b"struct.SolAccountInfo\0").unwrap() {
-                // load the data pointer
-                let data = self
-                    .builder
-                    .build_load(
-                        self.builder
-                            .build_struct_gep(vector.into_pointer_value(), 3, "data")
-                            .unwrap(),
-                        "data",
-                    )
-                    .into_pointer_value();
-
-                // get the offset of the return data
-                let header_ptr = self.builder.build_pointer_cast(
-                    data,
-                    self.context.i32_type().ptr_type(AddressSpace::Generic),
-                    "header_ptr",
-                );
-
-                let data_ptr = unsafe {
-                    self.builder.build_gep(
-                        header_ptr,
-                        &[self.context.i64_type().const_int(2, false)],
-                        "data_ptr",
-                    )
-                };
-
-                let offset = self.builder.build_load(data_ptr, "offset").into_int_value();
-
-                let v = unsafe { self.builder.build_gep(data, &[offset], "data") };
-
-                self.builder.build_pointer_cast(
-                    v,
-                    self.context.i8_type().ptr_type(AddressSpace::Generic),
+            let data = unsafe {
+                self.builder.build_gep(
+                    vector.into_pointer_value(),
+                    &[
+                        self.context.i32_type().const_zero(),
+                        self.context.i32_type().const_int(2, false),
+                    ],
                     "data",
                 )
-            } else {
-                let data = unsafe {
-                    self.builder.build_gep(
-                        vector.into_pointer_value(),
-                        &[
-                            self.context.i32_type().const_zero(),
-                            self.context.i32_type().const_int(2, false),
-                        ],
-                        "data",
-                    )
-                };
+            };
 
-                self.builder.build_pointer_cast(
-                    data,
-                    self.context.i8_type().ptr_type(AddressSpace::Generic),
-                    "data",
-                )
-            }
+            self.builder.build_pointer_cast(
+                data,
+                self.context.i8_type().ptr_type(AddressSpace::Generic),
+                "data",
+            )
         }
     }
 

+ 1 - 1
src/emit/sabre.rs

@@ -719,7 +719,7 @@ impl<'a> TargetRuntime<'a> for SabreTarget {
     }
 
     /// Get return buffer for external call
-    fn return_data<'b>(&self, _binary: &Binary<'b>) -> PointerValue<'b> {
+    fn return_data<'b>(&self, _binary: &Binary<'b>, _function: FunctionValue) -> PointerValue<'b> {
         panic!("Sabre cannot call other binarys");
     }
 

+ 180 - 142
src/emit/solana.rs

@@ -243,6 +243,24 @@ impl SolanaTarget {
         function
             .as_global_value()
             .set_unnamed_address(UnnamedAddress::Local);
+
+        let function = binary.module.add_function(
+            "sol_set_return_data",
+            void_ty.fn_type(&[u8_ptr.into(), u64_ty.into()], false),
+            None,
+        );
+        function
+            .as_global_value()
+            .set_unnamed_address(UnnamedAddress::Local);
+
+        let function = binary.module.add_function(
+            "sol_get_return_data",
+            u64_ty.fn_type(&[u8_ptr.into(), u64_ty.into(), u8_ptr.into()], false),
+            None,
+        );
+        function
+            .as_global_value()
+            .set_unnamed_address(UnnamedAddress::Local);
     }
 
     /// Returns the SolAccountInfo of the executing binary
@@ -2373,57 +2391,25 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
     }
 
     fn return_empty_abi(&self, binary: &Binary) {
-        let data = self.contract_storage_data(binary);
-
-        let header_ptr = binary.builder.build_pointer_cast(
-            data,
-            binary.context.i32_type().ptr_type(AddressSpace::Generic),
-            "header_ptr",
-        );
-
-        let data_len_ptr = unsafe {
-            binary.builder.build_gep(
-                header_ptr,
-                &[binary.context.i64_type().const_int(1, false)],
-                "data_len_ptr",
-            )
-        };
-
-        let data_ptr = unsafe {
-            binary.builder.build_gep(
-                header_ptr,
-                &[binary.context.i64_type().const_int(2, false)],
-                "data_ptr",
-            )
-        };
-
-        let offset = binary
-            .builder
-            .build_load(data_ptr, "offset")
-            .into_int_value();
-
-        binary.builder.build_call(
-            binary.module.get_function("account_data_free").unwrap(),
-            &[data.into(), offset.into()],
-            "",
-        );
-
-        binary
-            .builder
-            .build_store(data_len_ptr, binary.context.i32_type().const_zero());
-
-        binary
-            .builder
-            .build_store(data_ptr, binary.context.i32_type().const_zero());
-
         // return 0 for success
         binary
             .builder
             .build_return(Some(&binary.context.i64_type().const_int(0, false)));
     }
 
-    fn return_abi<'b>(&self, binary: &'b Binary, _data: PointerValue<'b>, _length: IntValue) {
-        // return data already filled in output binary
+    fn return_abi<'b>(&self, binary: &'b Binary, data: PointerValue<'b>, length: IntValue) {
+        // set return data
+        binary.builder.build_call(
+            binary.module.get_function("sol_set_return_data").unwrap(),
+            &[
+                data.into(),
+                binary
+                    .builder
+                    .build_int_z_extend(length, binary.context.i64_type(), "length")
+                    .into(),
+            ],
+            "",
+        );
 
         // return 0 for success
         binary
@@ -2431,8 +2417,19 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
             .build_return(Some(&binary.context.i64_type().const_int(0, false)));
     }
 
-    fn assert_failure<'b>(&self, binary: &'b Binary, _data: PointerValue, _length: IntValue) {
+    fn assert_failure<'b>(&self, binary: &'b Binary, data: PointerValue, length: IntValue) {
         // the reason code should be null (and already printed)
+        binary.builder.build_call(
+            binary.module.get_function("sol_set_return_data").unwrap(),
+            &[
+                data.into(),
+                binary
+                    .builder
+                    .build_int_z_extend(length, binary.context.i64_type(), "length")
+                    .into(),
+            ],
+            "",
+        );
 
         // return 1 for failure
         binary.builder.build_return(Some(
@@ -2480,101 +2477,21 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
 
         let length = encoder.encoded_length();
 
-        let data = self.contract_storage_data(binary);
-        let account = self.contract_storage_account(binary);
-
-        let header_ptr = binary.builder.build_pointer_cast(
-            data,
-            binary.context.i32_type().ptr_type(AddressSpace::Generic),
-            "header_ptr",
-        );
-
-        let data_len_ptr = unsafe {
-            binary.builder.build_gep(
-                header_ptr,
-                &[binary.context.i64_type().const_int(1, false)],
-                "data_len_ptr",
-            )
-        };
-
-        let data_offset_ptr = unsafe {
-            binary.builder.build_gep(
-                header_ptr,
-                &[binary.context.i64_type().const_int(2, false)],
-                "data_offset_ptr",
-            )
-        };
-
-        let offset = binary
-            .builder
-            .build_load(data_offset_ptr, "offset")
-            .into_int_value();
-
-        let account_data_realloc = binary.module.get_function("account_data_realloc").unwrap();
-
-        let arg1 = binary.builder.build_pointer_cast(
-            account,
-            account_data_realloc.get_type().get_param_types()[0].into_pointer_type(),
-            "",
-        );
-
-        let rc = binary
+        let encoded_data = binary
             .builder
             .build_call(
-                account_data_realloc,
-                &[
-                    arg1.into(),
-                    offset.into(),
-                    length.into(),
-                    data_offset_ptr.into(),
-                ],
+                binary.module.get_function("__malloc").unwrap(),
+                &[length.into()],
                 "",
             )
             .try_as_basic_value()
             .left()
             .unwrap()
-            .into_int_value();
-
-        let is_rc_zero = binary.builder.build_int_compare(
-            IntPredicate::EQ,
-            rc,
-            binary.context.i64_type().const_zero(),
-            "is_rc_zero",
-        );
-
-        let rc_not_zero = binary.context.append_basic_block(function, "rc_not_zero");
-        let rc_zero = binary.context.append_basic_block(function, "rc_zero");
-
-        binary
-            .builder
-            .build_conditional_branch(is_rc_zero, rc_zero, rc_not_zero);
-
-        binary.builder.position_at_end(rc_not_zero);
-
-        self.return_code(
-            binary,
-            binary.context.i64_type().const_int(5u64 << 32, false),
-        );
-
-        binary.builder.position_at_end(rc_zero);
-
-        binary.builder.build_store(data_len_ptr, length);
-
-        let offset = binary
-            .builder
-            .build_load(data_offset_ptr, "offset")
-            .into_int_value();
-
-        // step over that field, and cast to u8* for the buffer itself
-        let output = binary.builder.build_pointer_cast(
-            unsafe { binary.builder.build_gep(data, &[offset], "data_ptr") },
-            binary.context.i8_type().ptr_type(AddressSpace::Generic),
-            "data_ptr",
-        );
+            .into_pointer_value();
 
-        encoder.finish(binary, function, output, ns);
+        encoder.finish(binary, function, encoded_data, ns);
 
-        (output, length)
+        (encoded_data, length)
     }
 
     fn abi_decode<'b>(
@@ -2849,20 +2766,141 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
     }
 
     /// Get return buffer for external call
-    fn return_data<'b>(&self, binary: &Binary<'b>) -> PointerValue<'b> {
-        let parameters = self.sol_parameters(binary);
+    fn return_data<'b>(
+        &self,
+        binary: &Binary<'b>,
+        function: FunctionValue<'b>,
+    ) -> PointerValue<'b> {
+        let null_u8_ptr = binary
+            .context
+            .i8_type()
+            .ptr_type(AddressSpace::Generic)
+            .const_zero();
 
-        // return the account that returned the value
-        binary
+        let length_as_64 = binary
             .builder
-            .build_load(
-                binary
-                    .builder
-                    .build_struct_gep(parameters, 3, "ka_last_called")
-                    .unwrap(),
+            .build_call(
+                binary.module.get_function("sol_get_return_data").unwrap(),
+                &[
+                    null_u8_ptr.into(),
+                    binary.context.i64_type().const_zero().into(),
+                    null_u8_ptr.into(),
+                ],
+                "returndatasize",
+            )
+            .try_as_basic_value()
+            .left()
+            .unwrap()
+            .into_int_value();
+
+        let length =
+            binary
+                .builder
+                .build_int_truncate(length_as_64, binary.context.i32_type(), "length");
+
+        let malloc_length = binary.builder.build_int_add(
+            length,
+            binary
+                .module
+                .get_struct_type("struct.vector")
+                .unwrap()
+                .size_of()
+                .unwrap()
+                .const_cast(binary.context.i32_type(), false),
+            "size",
+        );
+
+        let p = binary
+            .builder
+            .build_call(
+                binary.module.get_function("__malloc").unwrap(),
+                &[malloc_length.into()],
+                "",
+            )
+            .try_as_basic_value()
+            .left()
+            .unwrap()
+            .into_pointer_value();
+
+        let v = binary.builder.build_pointer_cast(
+            p,
+            binary
+                .module
+                .get_struct_type("struct.vector")
+                .unwrap()
+                .ptr_type(AddressSpace::Generic),
+            "string",
+        );
+
+        let data_len = unsafe {
+            binary.builder.build_gep(
+                v,
+                &[
+                    binary.context.i32_type().const_zero(),
+                    binary.context.i32_type().const_zero(),
+                ],
+                "data_len",
+            )
+        };
+
+        binary.builder.build_store(data_len, length);
+
+        let data_size = unsafe {
+            binary.builder.build_gep(
+                v,
+                &[
+                    binary.context.i32_type().const_zero(),
+                    binary.context.i32_type().const_int(1, false),
+                ],
+                "data_size",
+            )
+        };
+
+        binary.builder.build_store(data_size, length);
+
+        let data = unsafe {
+            binary.builder.build_gep(
+                v,
+                &[
+                    binary.context.i32_type().const_zero(),
+                    binary.context.i32_type().const_int(2, false),
+                ],
                 "data",
             )
-            .into_pointer_value()
+        };
+
+        let program_id = binary.build_array_alloca(
+            function,
+            binary.context.i8_type(),
+            binary.context.i32_type().const_int(32, false),
+            "program_id",
+        );
+
+        binary.builder.build_call(
+            binary.module.get_function("sol_get_return_data").unwrap(),
+            &[
+                binary
+                    .builder
+                    .build_pointer_cast(
+                        data,
+                        binary.context.i8_type().ptr_type(AddressSpace::Generic),
+                        "",
+                    )
+                    .into(),
+                length_as_64.into(),
+                binary
+                    .builder
+                    .build_pointer_cast(
+                        program_id,
+                        binary.context.i8_type().ptr_type(AddressSpace::Generic),
+                        "",
+                    )
+                    .into(),
+            ],
+            "",
+        );
+
+        v
     }
 
     fn return_code<'b>(&self, binary: &'b Binary, ret: IntValue<'b>) {

+ 1 - 1
src/emit/substrate.rs

@@ -3672,7 +3672,7 @@ impl<'a> TargetRuntime<'a> for SubstrateTarget {
         }
     }
 
-    fn return_data<'b>(&self, binary: &Binary<'b>) -> PointerValue<'b> {
+    fn return_data<'b>(&self, binary: &Binary<'b>, _function: FunctionValue) -> PointerValue<'b> {
         let scratch_buf = binary.builder.build_pointer_cast(
             binary.scratch.unwrap().as_pointer_value(),
             binary.context.i8_type().ptr_type(AddressSpace::Generic),

+ 119 - 12
tests/solana.rs

@@ -59,7 +59,7 @@ struct VirtualMachine {
     programs: Vec<Contract>,
     stack: Vec<Contract>,
     printbuf: String,
-    output: Vec<u8>,
+    return_data: Option<(Account, Vec<u8>)>,
 }
 
 #[derive(Clone)]
@@ -208,7 +208,7 @@ fn build_solidity(src: &str) -> VirtualMachine {
         programs,
         stack: vec![cur],
         printbuf: String::new(),
-        output: Vec::new(),
+        return_data: None,
     }
 }
 
@@ -491,6 +491,85 @@ impl SyscallObject<UserError> for SolKeccak256 {
     }
 }
 
+struct SyscallSetReturnData<'a> {
+    context: Rc<RefCell<&'a mut VirtualMachine>>,
+}
+
+impl<'a> SyscallObject<UserError> for SyscallSetReturnData<'a> {
+    fn call(
+        &mut self,
+        addr: u64,
+        len: u64,
+        _arg3: u64,
+        _arg4: u64,
+        _arg5: u64,
+        memory_mapping: &MemoryMapping,
+        result: &mut Result<u64, EbpfError<UserError>>,
+    ) {
+        if len > 1024 {
+            panic!("sol_set_return_data: length is {}", len);
+        }
+        let buf = question_mark!(translate_slice::<u8>(memory_mapping, addr, len), result);
+
+        if let Ok(mut context) = self.context.try_borrow_mut() {
+            if len == 0 {
+                context.return_data = None;
+            } else {
+                context.return_data = Some((context.stack[0].program, buf.to_vec()));
+            }
+
+            *result = Ok(0);
+        } else {
+            panic!();
+        }
+    }
+}
+
+struct SyscallGetReturnData<'a> {
+    context: Rc<RefCell<&'a mut VirtualMachine>>,
+}
+
+impl<'a> SyscallObject<UserError> for SyscallGetReturnData<'a> {
+    fn call(
+        &mut self,
+        addr: u64,
+        len: u64,
+        program_id_addr: u64,
+        _arg4: u64,
+        _arg5: u64,
+        memory_mapping: &MemoryMapping,
+        result: &mut Result<u64, EbpfError<UserError>>,
+    ) {
+        if let Ok(context) = self.context.try_borrow() {
+            if let Some((program_id, return_data)) = &context.return_data {
+                let length = std::cmp::min(len, return_data.len() as u64);
+
+                if len > 0 {
+                    let set_result = question_mark!(
+                        translate_slice_mut::<u8>(memory_mapping, addr, length),
+                        result
+                    );
+
+                    set_result.copy_from_slice(&return_data[..length as usize]);
+
+                    let program_id_result = question_mark!(
+                        translate_slice_mut::<u8>(memory_mapping, program_id_addr, 32),
+                        result
+                    );
+
+                    program_id_result.copy_from_slice(program_id);
+                }
+
+                *result = Ok(return_data.len() as u64);
+            } else {
+                *result = Ok(0);
+            }
+        } else {
+            panic!();
+        }
+    }
+}
+
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Ed25519SigCheckError {
     InvalidPublicKey,
@@ -886,6 +965,8 @@ impl<'a> SyscallObject<UserError> for SyscallInvokeSignedC<'a> {
                 })
                 .collect();
 
+            context.return_data = None;
+
             if instruction.program_id.is_system_instruction() {
                 match bincode::deserialize::<u32>(&instruction.data).unwrap() {
                     0 => {
@@ -1088,6 +1169,14 @@ impl VirtualMachine {
             .register_syscall_by_name(b"sol_ed25519_sig_check", SyscallEd25519SigCheck::call)
             .unwrap();
 
+        syscall_registry
+            .register_syscall_by_name(b"sol_set_return_data", SyscallSetReturnData::call)
+            .unwrap();
+
+        syscall_registry
+            .register_syscall_by_name(b"sol_get_return_data", SyscallGetReturnData::call)
+            .unwrap();
+
         let executable = <dyn Executable<UserError, TestInstructionMeter>>::from_elf(
             &self.account_data[&program.program].data,
             None,
@@ -1140,6 +1229,22 @@ impl VirtualMachine {
         )
         .unwrap();
 
+        vm.bind_syscall_context_object(
+            Box::new(SyscallSetReturnData {
+                context: context.clone(),
+            }),
+            None,
+        )
+        .unwrap();
+
+        vm.bind_syscall_context_object(
+            Box::new(SyscallGetReturnData {
+                context: context.clone(),
+            }),
+            None,
+        )
+        .unwrap();
+
         let res = vm
             .execute_program_interpreted(&mut TestInstructionMeter { remaining: 1000000 })
             .unwrap();
@@ -1153,11 +1258,9 @@ impl VirtualMachine {
 
         VirtualMachine::validate_heap(output);
 
-        let len = LittleEndian::read_u32(&output[4..]) as usize;
-        let offset = LittleEndian::read_u32(&output[8..]) as usize;
-        elf.output = output[offset..offset + len].to_vec();
-
-        println!("return: {}", hex::encode(&elf.output));
+        if let Some((_, return_data)) = &elf.return_data {
+            println!("return: {}", hex::encode(&return_data));
+        }
 
         assert_eq!(res, 0);
     }
@@ -1199,13 +1302,17 @@ impl VirtualMachine {
 
         self.execute(&calldata, seeds);
 
-        println!("output: {}", hex::encode(&self.output));
+        if let Some((_, return_data)) = &self.return_data {
+            println!("return: {}", hex::encode(&return_data));
 
-        let program = &self.stack[0];
+            let program = &self.stack[0];
 
-        program.abi.as_ref().unwrap().functions[name][0]
-            .decode_output(&self.output)
-            .unwrap()
+            program.abi.as_ref().unwrap().functions[name][0]
+                .decode_output(return_data)
+                .unwrap()
+        } else {
+            Vec::new()
+        }
     }
 
     fn input(

+ 5 - 5
tests/solana_tests/storage.rs

@@ -33,10 +33,10 @@ fn string() {
 
     assert_eq!(
         vm.data()[0..20].to_vec(),
-        vec![65, 177, 160, 100, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 120, 0, 0, 0]
+        vec![65, 177, 160, 100, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 40, 0, 0, 0]
     );
 
-    assert_eq!(vm.data()[120..133].to_vec(), b"Hello, World!");
+    assert_eq!(vm.data()[40..53].to_vec(), b"Hello, World!");
 
     let returns = vm.function("get", &[], &[]);
 
@@ -52,7 +52,7 @@ fn string() {
 
     assert_eq!(
         vm.data()[0..20].to_vec(),
-        vec![65, 177, 160, 100, 96, 0, 0, 0, 152, 0, 0, 0, 24, 0, 0, 0, 120, 0, 0, 0]
+        vec![65, 177, 160, 100, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 40, 0, 0, 0]
     );
 
     // Try setting this to an empty string. This is also a special case where
@@ -65,7 +65,7 @@ fn string() {
 
     assert_eq!(
         vm.data()[0..20].to_vec(),
-        vec![65, 177, 160, 100, 64, 0, 0, 0, 40, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0]
+        vec![65, 177, 160, 100, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0]
     );
 }
 
@@ -115,7 +115,7 @@ fn bytes() {
 
     assert_eq!(
         vm.data()[0..20].to_vec(),
-        vec![11, 66, 182, 57, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 88, 0, 0, 0]
+        vec![11, 66, 182, 57, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 40, 0, 0, 0]
     );
 
     for (i, b) in b"The shoemaker always wears the worst shoes"