浏览代码

Enable initialization of arrays with array literals (#942)

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 3 年之前
父节点
当前提交
d058852707

+ 14 - 7
src/codegen/cfg.rs

@@ -53,7 +53,7 @@ pub enum Instr {
         false_block: usize,
     },
     /// Set array element in memory
-    Store { dest: Expression, pos: usize },
+    Store { dest: Expression, data: Expression },
     /// Abort execution
     AssertFailure { expr: Option<Expression> },
     /// Print to log message
@@ -177,7 +177,6 @@ impl Instr {
     ) {
         match self {
             Instr::BranchCond { cond: expr, .. }
-            | Instr::Store { dest: expr, .. }
             | Instr::LoadStorage { storage: expr, .. }
             | Instr::ClearStorage { storage: expr, .. }
             | Instr::Print { expr }
@@ -193,9 +192,17 @@ impl Instr {
                 expr.recurse(cx, f);
             }
 
-            Instr::SetStorage { value, storage, .. } => {
-                value.recurse(cx, f);
-                storage.recurse(cx, f);
+            Instr::SetStorage {
+                value: item_1,
+                storage: item_2,
+                ..
+            }
+            | Instr::Store {
+                dest: item_1,
+                data: item_2,
+            } => {
+                item_1.recurse(cx, f);
+                item_2.recurse(cx, f);
             }
             Instr::PushStorage { value, storage, .. } => {
                 if let Some(value) = value {
@@ -1055,10 +1062,10 @@ impl ControlFlowGraph {
                     .join(", "),
             ),
 
-            Instr::Store { dest, pos } => format!(
+            Instr::Store { dest, data } => format!(
                 "store {}, {}",
                 self.expr_to_string(contract, ns, dest),
-                self.vars[pos].id.name
+                self.expr_to_string(contract, ns, data),
             ),
             Instr::Print { expr } => format!("print {}", self.expr_to_string(contract, ns, expr)),
             Instr::Constructor {

+ 3 - 2
src/codegen/constant_folding.rs

@@ -79,10 +79,11 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         };
                     }
                 }
-                Instr::Store { dest, pos } => {
+                Instr::Store { dest, data } => {
                     let (dest, _) = expression(dest, Some(&vars), cfg, ns);
+                    let (data, _) = expression(data, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no] = Instr::Store { dest, pos: *pos };
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Store { dest, data };
                 }
                 Instr::AssertFailure { expr: Some(expr) } => {
                     let (expr, _) = expression(expr, Some(&vars), cfg, ns);

+ 95 - 4
src/codegen/expression.rs

@@ -25,7 +25,7 @@ use crate::Target;
 use num_bigint::BigInt;
 use num_traits::{FromPrimitive, One, ToPrimitive, Zero};
 use solang_parser::pt;
-use solang_parser::pt::CodeLocation;
+use solang_parser::pt::{CodeLocation, Loc};
 use std::ops::Mul;
 
 pub fn expression(
@@ -556,6 +556,12 @@ pub fn expression(
                 )
             }
         }
+        ast::Expression::Cast(loc, ty @ Type::Array(..), e)
+            if matches!(**e, ast::Expression::ArrayLiteral(..)) =>
+        {
+            let codegen_expr = expression(e, cfg, contract_no, func, ns, vartab, opt);
+            array_literal_to_memory_array(loc, &codegen_expr, ty, cfg, vartab)
+        }
         ast::Expression::Cast(loc, ty, e) => {
             if e.ty() == Type::Rational {
                 let (_, n) = eval_const_rational(e, ns).unwrap();
@@ -884,7 +890,13 @@ fn post_incdec(
                     );
                 }
                 Type::Ref(_) => {
-                    cfg.add(vartab, Instr::Store { pos: res, dest });
+                    cfg.add(
+                        vartab,
+                        Instr::Store {
+                            dest,
+                            data: Expression::Variable(Loc::Codegen, ty.clone(), res),
+                        },
+                    );
                 }
                 _ => unreachable!(),
             }
@@ -957,7 +969,13 @@ fn pre_incdec(
                     );
                 }
                 Type::Ref(_) => {
-                    cfg.add(vartab, Instr::Store { pos: res, dest });
+                    cfg.add(
+                        vartab,
+                        Instr::Store {
+                            dest,
+                            data: Expression::Variable(Loc::Codegen, ty.clone(), res),
+                        },
+                    );
                 }
                 _ => unreachable!(),
             }
@@ -1986,7 +2004,13 @@ pub fn assign_single(
                     );
                 }
                 Type::Ref(_) => {
-                    cfg.add(vartab, Instr::Store { pos, dest });
+                    cfg.add(
+                        vartab,
+                        Instr::Store {
+                            dest,
+                            data: Expression::Variable(Loc::Codegen, ty.clone(), pos),
+                        },
+                    );
                 }
                 _ => unreachable!(),
             }
@@ -2884,3 +2908,70 @@ pub fn load_storage(
 
     Expression::Variable(*loc, ty.clone(), res)
 }
+
+fn array_literal_to_memory_array(
+    loc: &pt::Loc,
+    expr: &Expression,
+    ty: &Type,
+    cfg: &mut ControlFlowGraph,
+    vartab: &mut Vartable,
+) -> Expression {
+    let memory_array = vartab.temp_anonymous(ty);
+    let elem_ty = ty.array_elem();
+    let dims = expr.ty().array_length().unwrap().clone();
+    let array_size = Expression::NumberLiteral(*loc, Type::Uint(32), dims);
+
+    cfg.add(
+        vartab,
+        Instr::Set {
+            loc: *loc,
+            res: memory_array,
+            expr: Expression::AllocDynamicArray(
+                *loc,
+                ty.clone(),
+                Box::new(array_size.clone()),
+                None,
+            ),
+        },
+    );
+
+    let elements = if let Expression::ArrayLiteral(_, _, _, items) = expr {
+        items
+    } else {
+        unreachable!()
+    };
+
+    for (item_no, item) in elements.iter().enumerate() {
+        cfg.add(
+            vartab,
+            Instr::Store {
+                dest: Expression::Subscript(
+                    *loc,
+                    Type::Ref(Box::new(elem_ty.clone())),
+                    ty.clone(),
+                    Box::new(Expression::Variable(*loc, ty.clone(), memory_array)),
+                    Box::new(Expression::NumberLiteral(
+                        *loc,
+                        Type::Uint(32),
+                        BigInt::from(item_no),
+                    )),
+                ),
+                data: item.clone(),
+            },
+        );
+    }
+
+    let temp_res = vartab.temp_name("array_length", &Type::Uint(32));
+    cfg.add(
+        vartab,
+        Instr::Set {
+            loc: Loc::Codegen,
+            res: temp_res,
+            expr: array_size,
+        },
+    );
+
+    cfg.array_lengths_temps.insert(memory_array, temp_res);
+
+    Expression::Variable(*loc, ty.clone(), memory_array)
+}

+ 2 - 1
src/codegen/strength_reduce/mod.rs

@@ -116,8 +116,9 @@ fn block_reduce(
                     .map(|e| expression_reduce(e, &vars, ns))
                     .collect();
             }
-            Instr::Store { dest, .. } => {
+            Instr::Store { dest, data } => {
                 *dest = expression_reduce(dest, &vars, ns);
+                *data = expression_reduce(data, &vars, ns);
             }
             Instr::AssertFailure { expr: Some(expr) } => {
                 *expr = expression_reduce(expr, &vars, ns);

+ 13 - 6
src/codegen/subexpression_elimination/instruction.rs

@@ -14,7 +14,6 @@ impl AvailableExpressionSet {
     ) {
         match instr {
             Instr::BranchCond { cond: expr, .. }
-            | Instr::Store { dest: expr, .. }
             | Instr::LoadStorage { storage: expr, .. }
             | Instr::ClearStorage { storage: expr, .. }
             | Instr::Print { expr }
@@ -45,9 +44,17 @@ impl AvailableExpressionSet {
                 let _ = self.gen_expression(expr, ave, cst);
             }
 
-            Instr::SetStorage { value, storage, .. } => {
-                let _ = self.gen_expression(value, ave, cst);
-                let _ = self.gen_expression(storage, ave, cst);
+            Instr::SetStorage {
+                value: item_1,
+                storage: item_2,
+                ..
+            }
+            | Instr::Store {
+                dest: item_1,
+                data: item_2,
+            } => {
+                let _ = self.gen_expression(item_1, ave, cst);
+                let _ = self.gen_expression(item_2, ave, cst);
             }
             Instr::PushStorage { value, storage, .. } => {
                 if let Some(value) = value {
@@ -204,9 +211,9 @@ impl AvailableExpressionSet {
                 false_block: *false_block,
             },
 
-            Instr::Store { dest, pos } => Instr::Store {
+            Instr::Store { dest, data } => Instr::Store {
                 dest: self.regenerate_expression(dest, ave, cst).1,
-                pos: *pos,
+                data: self.regenerate_expression(data, ave, cst).1,
             },
 
             Instr::AssertFailure { expr: Some(exp) } => Instr::AssertFailure {

+ 5 - 3
src/codegen/vector_to_slice.rs

@@ -73,9 +73,11 @@ fn find_writable_vectors(
 
                 apply_transfers(&block.transfers[instr_no], vars, writable);
             }
-            Instr::Store { pos, .. } => {
-                if let Some(entry) = vars.get_mut(pos) {
-                    writable.extend(entry.keys());
+            Instr::Store { data, .. } => {
+                if let Expression::Variable(_, _, var_no) = data {
+                    if let Some(entry) = vars.get_mut(var_no) {
+                        writable.extend(entry.keys());
+                    }
                 }
 
                 apply_transfers(&block.transfers[instr_no], vars, writable);

+ 2 - 3
src/emit/mod.rs

@@ -3449,8 +3449,8 @@ pub trait TargetRuntime<'a> {
                         bin.builder.position_at_end(pos);
                         bin.builder.build_unconditional_branch(bb.bb);
                     }
-                    Instr::Store { dest, pos } => {
-                        let value_ref = w.vars[pos].value;
+                    Instr::Store { dest, data } => {
+                        let value_ref = self.expression(bin, data, &w.vars, function, ns);
                         let dest_ref = self
                             .expression(bin, dest, &w.vars, function, ns)
                             .into_pointer_value();
@@ -3592,7 +3592,6 @@ pub trait TargetRuntime<'a> {
                         let arr = w.vars[array].value;
 
                         let llvm_ty = bin.llvm_type(ty, ns);
-
                         let elem_ty = ty.array_elem();
 
                         // Calculate total size for reallocation

+ 15 - 0
src/sema/expression.rs

@@ -370,6 +370,21 @@ impl Expression {
                     BigRational::from(n.clone()),
                 ));
             }
+
+            (
+                &Expression::ArrayLiteral(..),
+                Type::Array(from_ty, from_dims),
+                Type::Array(to_ty, to_dims),
+            ) => {
+                if from_ty == to_ty
+                    && from_dims.len() == to_dims.len()
+                    && from_dims.len() == 1
+                    && matches!(to_dims.last().unwrap(), ArrayLength::Dynamic)
+                {
+                    return Ok(Expression::Cast(*loc, to.clone(), Box::new(self.clone())));
+                }
+            }
+
             _ => (),
         };
 

+ 41 - 11
tests/codegen_testcases/solidity/array_boundary_opt.sol

@@ -8,25 +8,25 @@ contract Array_bound_Test {
         uint32 size32
     ) public pure returns (uint256) {
         // CHECK: ty:uint32 %1.cse_temp = (trunc uint32 (arg #1))
-	    // CHECK: ty:uint32 %array_length.temp.23 = %1.cse_temp
+	    // CHECK: ty:uint32 %array_length.temp.32 = %1.cse_temp
         uint256[] a = new uint256[](size);
 
-        // CHECK: ty:uint32 %array_length.temp.24 = (arg #2)
+        // CHECK: ty:uint32 %array_length.temp.33 = (arg #2)
         uint256[] c = new uint256[](size32);
 
-        // CHECK: ty:uint32 %array_length.temp.25 = uint32 20
+        // CHECK: ty:uint32 %array_length.temp.34 = uint32 20
         uint256[] d = new uint256[](20);
 
-        // CHECK: ty:uint32 %array_length.temp.23 = (%1.cse_temp + uint32 1)
+        // CHECK: ty:uint32 %array_length.temp.32 = (%1.cse_temp + uint32 1)
         a.push();
 
-        // CHECK: ty:uint32 %array_length.temp.24 = ((arg #2) - uint32 1)
+        // CHECK: ty:uint32 %array_length.temp.33 = ((arg #2) - uint32 1)
         c.pop();
 
-        // CHECK: ty:uint32 %array_length.temp.25 = uint32 21
+        // CHECK: ty:uint32 %array_length.temp.34 = uint32 21
         d.push();
 
-        // CHECK: return (zext uint256 (((%array_length.temp.23 + (builtin ArrayLength ((arg #0)))) + ((arg #2) - uint32 1)) + uint32 21))
+        // CHECK: return (zext uint256 (((%array_length.temp.32 + (builtin ArrayLength ((arg #0)))) + ((arg #2) - uint32 1)) + uint32 21))
         return a.length + b.length + c.length + d.length;
     }
 
@@ -35,11 +35,11 @@ contract Array_bound_Test {
         bool[] b = new bool[](210);
 
         if (cond) {
-            // CHECK: ty:uint32 %array_length.temp.30 = uint32 211
+            // CHECK: ty:uint32 %array_length.temp.39 = uint32 211
             b.push(true);
         }
 
-        // CHECK: return %array_length.temp.30
+        // CHECK: return %array_length.temp.39
         return b.length;
     }
 
@@ -79,16 +79,46 @@ contract Array_bound_Test {
         int256[] vec = new int256[](10);
 
         for (int256 i = 0; i < 5; i++) {
-            // CHECK: branchcond (unsigned more %array_length.temp.40 > uint32 20), block5, block6
+            // CHECK: branchcond (unsigned more %array_length.temp.49 > uint32 20), block5, block6
             if (vec.length > 20) {
                 break;
             }
             vec.push(3);
         }
 
-        // CHECK: branchcond (%array_length.temp.40 == uint32 15), block7, block8
+        // CHECK: branchcond (%array_length.temp.49 == uint32 15), block7, block8
         assert(vec.length == 15);
     }
 
+    // BEGIN-CHECK: Array_bound_Test::Array_bound_Test::function::getVec__int32_int32
+    function getVec(int32 a, int32 b) public pure returns (uint32) {
+        int32[] memory vec;
+        vec = [a, b];
+        // CHECK: ty:int32[] %vec = undef
+	    // CHECK: ty:uint32 %array_length.temp.52 = uint32 0
+	    // CHECK: ty:int32[] %temp.53 = (alloc int32[] len uint32 2)
+        // CHECK: ty:uint32 %array_length.temp.54 = uint32 2
+	    // CHECK: ty:int32[] %vec = %temp.53
+
+
+        vec.push(5);
+        // CHECK: ty:uint32 %array_length.temp.54 = uint32 3
+        // CHECK: return uint32 3
+        return vec.length;
+    }
+
+    // BEGIN-CHECK: Array_bound_Test::Array_bound_Test::function::testVec__uint32_uint32_uint32
+    function testVec(uint32 a, uint32 b, uint32 c) public pure returns (uint32) {
+        // CHECK: ty:uint32[] %temp.56 = (alloc uint32[] len uint32 3)
+        // CHECK: ty:uint32 %array_length.temp.57 = uint32 3
+        uint32[] memory vec = [a, b, b];
+        // CHECK: ty:uint32[] %vec = %temp.56
+
+        vec.pop();
+        // CHECK: ty:uint32 %array_length.temp.57 = uint32 2
+        // CHECK: return uint32 2
+        return vec.length;
+    }
+
 
 }

+ 63 - 1
tests/solana_tests/arrays.rs

@@ -1,5 +1,5 @@
 use crate::build_solidity;
-use ethabi::{ethereum_types::U256, Token};
+use ethabi::{ethereum_types::U256, FixedBytes, Token, Uint};
 
 #[test]
 fn fixed_array() {
@@ -960,3 +960,65 @@ fn storage_pop_push() {
     // make sure every thing has been freed
     assert_eq!(vm.validate_account_data_heap(), 0);
 }
+
+#[test]
+fn initialization_with_literal() {
+    let mut vm = build_solidity(
+        r#"
+        contract Testing {
+            address[] splitAddresses;
+
+            function split(address addr1, address addr2) public {
+                splitAddresses = [addr1, addr2];
+            }
+
+            function getIdx(uint32 idx) public view returns (address) {
+                return splitAddresses[idx];
+            }
+
+            function getVec(uint32 a, uint32 b) public pure returns (uint32[] memory) {
+                uint32[] memory vec;
+                vec = [a, b];
+                return vec;
+            }
+        }
+        "#,
+    );
+
+    vm.constructor("Testing", &[]);
+
+    let mut addr1: Vec<u8> = Vec::new();
+    addr1.resize(32, 0);
+    addr1[0] = 1;
+    let mut addr2: Vec<u8> = Vec::new();
+    addr2.resize(32, 0);
+    addr2[0] = 2;
+    let _ = vm.function(
+        "split",
+        &[
+            Token::FixedBytes(FixedBytes::from(&addr1[..])),
+            Token::FixedBytes(FixedBytes::from(&addr2[..])),
+        ],
+        &[],
+        None,
+    );
+    let returns = vm.function("getIdx", &[Token::Uint(Uint::from(0))], &[], None);
+    let returned_addr1 = returns[0].clone().into_fixed_bytes().unwrap();
+    assert_eq!(addr1, returned_addr1);
+
+    let returns = vm.function("getIdx", &[Token::Uint(Uint::from(1))], &[], None);
+    let returned_addr2 = returns[0].clone().into_fixed_bytes().unwrap();
+    assert_eq!(addr2, returned_addr2);
+
+    let returns = vm.function(
+        "getVec",
+        &[Token::Uint(Uint::from(563)), Token::Uint(Uint::from(895))],
+        &[],
+        None,
+    );
+    let array = returns[0].clone().into_array().unwrap();
+    assert_eq!(
+        array,
+        vec![Token::Uint(Uint::from(563)), Token::Uint(Uint::from(895))]
+    );
+}