Просмотр исходного кода

Ensure no data is returned when a function has no return values (#1103)

* Ensure no data is returned when a function has no return values

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 2 лет назад
Родитель
Сommit
3234e18e35
2 измененных файлов с 38 добавлено и 11 удалено
  1. 19 2
      src/codegen/dispatch.rs
  2. 19 9
      tests/solana.rs

+ 19 - 2
src/codegen/dispatch.rs

@@ -256,10 +256,27 @@ fn add_dispatch_case(
                 data_len: zext_len,
             },
         );
+    } else {
+        // TODO: On Solana, we could elide setting the return data if this function calls no external functions
+        // and replace this with a simple Instr::Return, which does not set any return data.
+        //
+        // The return data buffer is empty when Solana VM first executes a program, but if another program is
+        // called via CPI then that program may set return data. We must clear this buffer, else return data
+        // from the CPI callee will be visible to this program's callee.
+        cfg.add(
+            vartab,
+            Instr::ReturnData {
+                data: Expression::AllocDynamicBytes(
+                    Loc::Codegen,
+                    Type::DynamicBytes,
+                    Expression::NumberLiteral(Loc::Codegen, Type::Uint(32), 0.into()).into(),
+                    None,
+                ),
+                data_len: Expression::NumberLiteral(Loc::Codegen, Type::Uint(64), 0.into()),
+            },
+        );
     }
 
-    cfg.add(vartab, Instr::Return { value: vec![] });
-
     cases.push((
         Expression::NumberLiteral(
             Loc::Codegen,

+ 19 - 9
tests/solana.rs

@@ -875,6 +875,8 @@ impl<'a> SyscallObject<UserError> for SyscallSetReturnData<'a> {
 
         let buf = question_mark!(translate_slice::<u8>(memory_mapping, addr, len), result);
 
+        println!("sol_set_return_data: {}", hex::encode(buf));
+
         if let Ok(mut vm) = self.context.vm.try_borrow_mut() {
             if len == 0 {
                 vm.return_data = None;
@@ -1597,6 +1599,8 @@ impl VirtualMachine {
     }
 
     fn constructor_expected(&mut self, expected: u64, name: &str, args: &[BorshToken]) {
+        self.return_data = None;
+
         let program = &self.stack[0];
         println!("constructor for {}", hex::encode(program.data));
 
@@ -1612,6 +1616,9 @@ impl VirtualMachine {
 
         println!("res:{:?}", res);
         assert_eq!(res, Ok(expected));
+        if let Some((_, return_data)) = &self.return_data {
+            assert_eq!(return_data.len(), 0);
+        }
     }
 
     fn function(&mut self, name: &str, args: &[BorshToken]) -> Vec<BorshToken> {
@@ -1626,14 +1633,14 @@ impl VirtualMachine {
         name: &str,
         args: &[BorshToken],
     ) -> Vec<BorshToken> {
+        self.return_data = None;
+
         let program = &self.stack[0];
 
-        println!("function for {}", hex::encode(program.data));
+        println!("function {} for {}", name, hex::encode(program.data));
 
         let mut calldata = VirtualMachine::input(&program.data, name);
 
-        println!("input: {} ", hex::encode(&calldata));
-
         let selector = discriminator("global", name);
         calldata.extend_from_slice(&selector);
         let mut encoded_args = encode_arguments(args);
@@ -1648,14 +1655,17 @@ impl VirtualMachine {
             Err(e) => panic!("error: {:?}", e),
         };
 
-        if let Some((_, return_data)) = &self.return_data {
-            println!("return: {}", hex::encode(return_data));
-
-            let func = &self.stack[0].abi.as_ref().unwrap().functions[name][0];
-            decode_output(&func.outputs, return_data)
+        let return_data = if let Some((_, return_data)) = &self.return_data {
+            return_data.as_slice()
         } else {
-            Vec::new()
+            &[]
+        };
+
+        let func = &self.stack[0].abi.as_ref().unwrap().functions[name][0];
+        if func.outputs.is_empty() {
+            assert_eq!(return_data.len(), 0);
         }
+        decode_output(&func.outputs, return_data)
     }
 
     fn function_must_fail(