Browse Source

Use anticipated expressions to fix CSE bug (#1189)

* Use anticipated expressions to fix CSE bug

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 2 years ago
parent
commit
5e717ca8f2

+ 0 - 1
Cargo.toml

@@ -50,7 +50,6 @@ solang-parser = { path = "solang-parser", version = "0.2.2" }
 codespan-reporting = "0.11"
 phf = { version = "0.11", features = ["macros"] }
 rust-lapper = "1.1"
-bitflags = "1.3"
 anchor-syn = { version = "0.26", features = ["idl"] }
 convert_case = "0.6"
 parse-display = "0.8.0"

+ 5 - 28
src/codegen/cfg.rs

@@ -384,31 +384,16 @@ impl fmt::Display for HashTy {
     }
 }
 
-#[derive(Clone)]
+#[derive(Clone, Default)]
 pub struct BasicBlock {
     pub phis: Option<BTreeSet<usize>>,
     pub name: String,
-    pub instr: Vec<(InstrOrigin, Instr)>,
+    pub instr: Vec<Instr>,
     pub defs: reaching_definitions::VarDefs,
     pub loop_reaching_variables: HashSet<usize>,
     pub transfers: Vec<Vec<reaching_definitions::Transfer>>,
 }
 
-/// This enum saves information about the origin of each instruction. They can originate from
-/// Solidity code, Yul code or during code generation.
-#[derive(Clone)]
-pub enum InstrOrigin {
-    Solidity,
-    Yul,
-    Codegen,
-}
-
-impl BasicBlock {
-    fn add(&mut self, instr_origin: InstrOrigin, ins: Instr) {
-        self.instr.push((instr_origin, ins));
-    }
-}
-
 #[derive(Clone)]
 pub struct ControlFlowGraph {
     pub name: String,
@@ -503,20 +488,12 @@ impl ControlFlowGraph {
         self.current = pos;
     }
 
-    /// Add an instruction from Solidity to the CFG
+    /// Add an instruction to the CFG
     pub fn add(&mut self, vartab: &mut Vartable, ins: Instr) {
         if let Instr::Set { res, .. } = ins {
             vartab.set_dirty(res);
         }
-        self.blocks[self.current].add(InstrOrigin::Solidity, ins);
-    }
-
-    /// Add an instruction from Yul to the CFG
-    pub fn add_yul(&mut self, vartab: &mut Vartable, ins: Instr) {
-        if let Instr::Set { res, .. } = ins {
-            vartab.set_dirty(res);
-        }
-        self.blocks[self.current].add(InstrOrigin::Yul, ins);
+        self.blocks[self.current].instr.push(ins);
     }
 
     /// Retrieve the basic block being processed
@@ -1311,7 +1288,7 @@ impl ControlFlowGraph {
             .unwrap();
         }
 
-        for (_, ins) in &self.blocks[pos].instr {
+        for ins in &self.blocks[pos].instr {
             writeln!(s, "\t{}", self.instr_to_string(contract, ns, ins)).unwrap();
         }
 

+ 26 - 26
src/codegen/constant_folding.rs

@@ -22,7 +22,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
         let mut vars = cfg.blocks[block_no].defs.clone();
 
         for instr_no in 0..cfg.blocks[block_no].instr.len() {
-            match &cfg.blocks[block_no].instr[instr_no].1 {
+            match &cfg.blocks[block_no].instr[instr_no] {
                 Instr::Set { loc, res, expr, .. } => {
                     let (expr, expr_constant) = expression(expr, Some(&vars), cfg, ns);
 
@@ -30,7 +30,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         ns.var_constants.insert(*loc, expr.clone());
                     }
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Set {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Set {
                         loc: *loc,
                         res: *res,
                         expr,
@@ -47,7 +47,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .map(|e| expression(e, Some(&vars), cfg, ns).0)
                         .collect();
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Call {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Call {
                         res: res.clone(),
                         call: call.clone(),
                         args,
@@ -60,7 +60,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .map(|e| expression(e, Some(&vars), cfg, ns).0)
                         .collect();
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Return { value };
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Return { value };
                 }
                 Instr::BranchCond {
                     cond,
@@ -70,11 +70,11 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                     let (cond, _) = expression(cond, Some(&vars), cfg, ns);
 
                     if let Expression::BoolLiteral(_, cond) = cond {
-                        cfg.blocks[block_no].instr[instr_no].1 = Instr::Branch {
+                        cfg.blocks[block_no].instr[instr_no] = Instr::Branch {
                             block: if cond { *true_block } else { *false_block },
                         };
                     } else {
-                        cfg.blocks[block_no].instr[instr_no].1 = Instr::BranchCond {
+                        cfg.blocks[block_no].instr[instr_no] = Instr::BranchCond {
                             cond,
                             true_block: *true_block,
                             false_block: *false_block,
@@ -85,26 +85,26 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                     let (dest, _) = expression(dest, Some(&vars), cfg, ns);
                     let (data, _) = expression(data, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Store { dest, data };
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Store { dest, data };
                 }
                 Instr::AssertFailure {
                     encoded_args: Some(expr),
                 } => {
                     let (buf, _) = expression(expr, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::AssertFailure {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::AssertFailure {
                         encoded_args: Some(buf),
                     };
                 }
                 Instr::Print { expr } => {
                     let (expr, _) = expression(expr, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Print { expr };
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Print { expr };
                 }
                 Instr::ClearStorage { ty, storage } => {
                     let (storage, _) = expression(storage, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::ClearStorage {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::ClearStorage {
                         ty: ty.clone(),
                         storage,
                     };
@@ -113,7 +113,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                     let (storage, _) = expression(storage, Some(&vars), cfg, ns);
                     let (value, _) = expression(value, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::SetStorage {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::SetStorage {
                         ty: ty.clone(),
                         storage,
                         value,
@@ -122,7 +122,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                 Instr::LoadStorage { ty, storage, res } => {
                     let (storage, _) = expression(storage, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::LoadStorage {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::LoadStorage {
                         ty: ty.clone(),
                         storage,
                         res: *res,
@@ -137,7 +137,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                     let (value, _) = expression(value, Some(&vars), cfg, ns);
                     let (offset, _) = expression(offset, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::SetStorageBytes {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::SetStorageBytes {
                         storage,
                         value,
                         offset,
@@ -154,7 +154,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .as_ref()
                         .map(|expr| expression(expr, Some(&vars), cfg, ns).0);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::PushStorage {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::PushStorage {
                         res: *res,
                         ty: ty.clone(),
                         storage,
@@ -164,7 +164,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                 Instr::PopStorage { res, ty, storage } => {
                     let (storage, _) = expression(storage, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::PopStorage {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::PopStorage {
                         res: *res,
                         ty: ty.clone(),
                         storage,
@@ -178,7 +178,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                 } => {
                     let (value, _) = expression(value, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::PushMemory {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::PushMemory {
                         res: *res,
                         ty: ty.clone(),
                         array: *array,
@@ -214,7 +214,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .as_ref()
                         .map(|expr| expression(expr, Some(&vars), cfg, ns).0);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Constructor {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Constructor {
                         success: *success,
                         res: *res,
                         contract_no: *contract_no,
@@ -251,7 +251,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .as_ref()
                         .map(|expr| expression(expr, Some(&vars), cfg, ns).0);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::ExternalCall {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::ExternalCall {
                         success: *success,
                         address,
                         accounts,
@@ -275,7 +275,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .as_ref()
                         .map(|e| expression(e, Some(&vars), cfg, ns).0);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::AbiDecode {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::AbiDecode {
                         res: res.clone(),
                         selector: *selector,
                         exception_block: *exception_block,
@@ -287,7 +287,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                 Instr::SelfDestruct { recipient } => {
                     let (recipient, _) = expression(recipient, Some(&vars), cfg, ns);
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::SelfDestruct { recipient };
+                    cfg.blocks[block_no].instr[instr_no] = Instr::SelfDestruct { recipient };
                 }
                 Instr::EmitEvent {
                     event_no,
@@ -299,7 +299,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .map(|e| expression(e, Some(&vars), cfg, ns).0)
                         .collect();
 
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::EmitEvent {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::EmitEvent {
                         event_no: *event_no,
                         data: expression(data, Some(&vars), cfg, ns).0,
                         topics,
@@ -313,7 +313,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                     let bytes = expression(bytes, Some(&vars), cfg, ns);
                     let source = expression(source, Some(&vars), cfg, ns);
                     let destination = expression(destination, Some(&vars), cfg, ns);
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::MemCopy {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::MemCopy {
                         source: source.0,
                         destination: destination.0,
                         bytes: bytes.0,
@@ -329,7 +329,7 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .iter()
                         .map(|(exp, goto)| (expression(exp, Some(&vars), cfg, ns).0, *goto))
                         .collect::<Vec<(Expression, usize)>>();
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Switch {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::Switch {
                         cond: cond.0,
                         cases,
                         default: *default,
@@ -338,13 +338,13 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                 Instr::ReturnData { data, data_len } => {
                     let data = expression(data, Some(&vars), cfg, ns);
                     let data_len = expression(data_len, Some(&vars), cfg, ns);
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::ReturnData {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::ReturnData {
                         data: data.0,
                         data_len: data_len.0,
                     };
                 }
                 Instr::WriteBuffer { buf, offset, value } => {
-                    cfg.blocks[block_no].instr[instr_no].1 = Instr::WriteBuffer {
+                    cfg.blocks[block_no].instr[instr_no] = Instr::WriteBuffer {
                         buf: buf.clone(),
                         offset: expression(offset, Some(&vars), cfg, ns).0,
                         value: expression(value, Some(&vars), cfg, ns).0,
@@ -1204,7 +1204,7 @@ fn get_definition<'a>(
     def: &reaching_definitions::Def,
     cfg: &'a ControlFlowGraph,
 ) -> Option<&'a Expression> {
-    if let Instr::Set { expr, .. } = &cfg.blocks[def.block_no].instr[def.instr_no].1 {
+    if let Instr::Set { expr, .. } = &cfg.blocks[def.block_no].instr[def.instr_no] {
         Some(expr)
     } else {
         None

+ 7 - 7
src/codegen/dead_storage.rs

@@ -181,7 +181,7 @@ fn reaching_definitions(cfg: &mut ControlFlowGraph) -> (Vec<Vec<Vec<Transfer>>>,
 fn instr_transfers(block_no: usize, block: &BasicBlock) -> Vec<Vec<Transfer>> {
     let mut transfers = Vec::new();
 
-    for (instr_no, (_, instr)) in block.instr.iter().enumerate() {
+    for (instr_no, instr) in block.instr.iter().enumerate() {
         let def = Definition::Instr {
             block_no,
             instr_no,
@@ -490,7 +490,7 @@ pub fn dead_storage(cfg: &mut ControlFlowGraph, _ns: &mut Namespace) {
 
             let vars = &block_vars[&block_no][instr_no];
 
-            match &cfg.blocks[block_no].instr[instr_no].1 {
+            match &cfg.blocks[block_no].instr[instr_no] {
                 Instr::LoadStorage { res, ty, storage } => {
                     // is there a definition which has the same storage expression
                     let mut found = None;
@@ -523,7 +523,7 @@ pub fn dead_storage(cfg: &mut ControlFlowGraph, _ns: &mut Namespace) {
                     }
 
                     if let Some(var_no) = found {
-                        cfg.blocks[block_no].instr[instr_no].1 = Instr::Set {
+                        cfg.blocks[block_no].instr[instr_no] = Instr::Set {
                             loc: Loc::Codegen,
                             res: *res,
                             expr: Expression::Variable(Loc::Codegen, ty.clone(), *var_no),
@@ -594,10 +594,10 @@ pub fn dead_storage(cfg: &mut ControlFlowGraph, _ns: &mut Namespace) {
                 // Function calls should never be eliminated from the CFG, as they might have side effects
                 // In addition, AbiDecode might fail and halt the execution.
                 if !matches!(
-                    cfg.blocks[*block_no].instr[*instr_no].1,
+                    cfg.blocks[*block_no].instr[*instr_no],
                     Instr::Call { .. } | Instr::AbiDecode { .. }
                 ) {
-                    cfg.blocks[*block_no].instr[*instr_no].1 = Instr::Nop;
+                    cfg.blocks[*block_no].instr[*instr_no] = Instr::Nop;
                 }
             }
         }
@@ -618,7 +618,7 @@ fn get_storage_definition<'a>(
         block_no, instr_no, ..
     } = def
     {
-        match &cfg.blocks[*block_no].instr[*instr_no].1 {
+        match &cfg.blocks[*block_no].instr[*instr_no] {
             Instr::LoadStorage {
                 storage, res, ty, ..
             } => Some(StorageDef {
@@ -641,7 +641,7 @@ fn get_definition<'a>(
         block_no, instr_no, ..
     } = def
     {
-        match &cfg.blocks[*block_no].instr[*instr_no].1 {
+        match &cfg.blocks[*block_no].instr[*instr_no] {
             Instr::LoadStorage { storage, ty, .. } => Some((storage, ty.clone())),
             Instr::Set { expr, .. } => Some((expr, expr.ty())),
             _ => None,

+ 2 - 2
src/codegen/reaching_definitions.rs

@@ -98,7 +98,7 @@ pub fn find(cfg: &mut ControlFlowGraph) {
 fn instr_transfers(block_no: usize, block: &BasicBlock) -> Vec<Vec<Transfer>> {
     let mut transfers = Vec::new();
 
-    for (instr_no, (_, instr)) in block.instr.iter().enumerate() {
+    for (instr_no, instr) in block.instr.iter().enumerate() {
         let set_var = |var_nos: &[usize]| {
             let mut transfer = Vec::new();
 
@@ -229,7 +229,7 @@ pub(super) fn block_edges(block: &BasicBlock) -> Vec<usize> {
 
     // out cfg has edge as the last instruction in a block; EXCEPT
     // Instr::AbiDecode() which has an edge when decoding fails
-    for (_, instr) in &block.instr {
+    for instr in &block.instr {
         match instr {
             Instr::Branch { block } => {
                 out.push(*block);

+ 1 - 1
src/codegen/strength_reduce/mod.rs

@@ -101,7 +101,7 @@ fn block_reduce(
     mut vars: Variables,
     ns: &mut Namespace,
 ) {
-    for (_, instr) in &mut cfg.blocks[block_no].instr {
+    for instr in &mut cfg.blocks[block_no].instr {
         match instr {
             Instr::Set { expr, .. } => {
                 *expr = expression_reduce(expr, &vars, ns);

+ 1 - 1
src/codegen/strength_reduce/reaching_values.rs

@@ -36,7 +36,7 @@ pub(super) fn reaching_values(
         block_vars.insert(block_no, vars.clone());
     }
 
-    for (_, instr) in &cfg.blocks[block_no].instr {
+    for instr in &cfg.blocks[block_no].instr {
         transfer(instr, vars, ns);
 
         match instr {

+ 253 - 0
src/codegen/subexpression_elimination/anticipated_expressions.rs

@@ -0,0 +1,253 @@
+// SPDX-License-Identifier: Apache-2.0
+
+use crate::codegen::cfg::{ControlFlowGraph, Instr};
+use crate::codegen::subexpression_elimination::{
+    kill_loop_variables, AvailableExpression, AvailableExpressionSet,
+};
+use crate::codegen::Expression;
+use std::collections::HashMap;
+
+/// The AnticipatedExpression struct manages everything related to traversing the CFG backwards, so
+/// we can perform an anticipated expression analysis.
+///
+/// "An expression is anticipated at a point if it is certain to be evaluated along any
+/// path before this expression's value is changed (any variables its evaluation depends
+/// on are reassigned)."
+///
+/// Available expressions tell us which sub expressions we have already evaluated. The anticipated
+/// expression analysis lets us go further and determine where a subexpression can be evaluate
+/// before any variable it depends on changes value.
+///
+/// Chapters 9.5.4 and 9.5.5 of the book "Compilers, Principles, Techniques & Tools" from
+/// Alfred V. Aho present more details about anticipated expressions.
+#[derive(Default, Clone)]
+pub(super) struct AnticipatedExpressions<'a> {
+    /// The AvailableExpressionSet for each CFG block, when the graph is evaluated in reverse
+    reverse_sets: HashMap<usize, AvailableExpressionSet<'a>>,
+    /// The CFG represented as a DAG, but with each edge reversed
+    reverse_dag: Vec<Vec<usize>>,
+    /// The order in which we must traverse the CFG. It is its topological sort but reversed.
+    traversing_order: Vec<(usize, bool)>,
+    /// The depth (distance from the entry block) for each one of the CFG blocks.
+    depth: Vec<u16>,
+}
+
+impl<'a> AnticipatedExpressions<'a> {
+    pub(super) fn new(
+        dag: &Vec<Vec<usize>>,
+        reverse_dag: Vec<Vec<usize>>,
+        traversing_order: Vec<(usize, bool)>,
+    ) -> AnticipatedExpressions {
+        let mut depth: Vec<u16> = vec![u16::MAX; dag.len()];
+        AnticipatedExpressions::blocks_depth(dag, 0, 0, &mut depth);
+        AnticipatedExpressions {
+            reverse_sets: HashMap::new(),
+            reverse_dag,
+            traversing_order,
+            depth,
+        }
+    }
+
+    /// Calculate the depth of each CFG block, using dfs (depth first search) traversal.
+    fn blocks_depth(dag: &Vec<Vec<usize>>, block: usize, level: u16, depth: &mut [u16]) -> u16 {
+        if level < depth[block] {
+            depth[block] = level;
+        } else {
+            return level;
+        }
+
+        dag[block]
+            .iter()
+            .map(|child| AnticipatedExpressions::blocks_depth(dag, *child, level + 1, depth))
+            .min()
+            .unwrap_or(u16::MAX)
+    }
+
+    /// This function calculates the anticipated expressions for each block. The analysis is similar
+    /// to available expressions with a few differences:
+    ///
+    /// 1. The CFG is traversed backwards: from the last executed block to the entry block.
+    /// 2. In each block, we traverse instructions from the last to the first.
+    /// 3. When a block has multiple children, we must unite all the anticipated expressions from them
+    ///    to perform the analysis for this block.
+    pub(super) fn calculate_anticipated_expressions<'b: 'a>(
+        &mut self,
+        instructions: &'b [Vec<Instr>],
+        cfg: &ControlFlowGraph,
+    ) {
+        let mut reverse_ave = AvailableExpression::default();
+        // Traverse the CFG according to its reversed topological order
+        for (block_no, cycle) in &self.traversing_order {
+            reverse_ave.set_cur_block(*block_no);
+            let mut cur_set = self.reverse_sets.get(block_no).cloned().unwrap_or_default();
+            kill_loop_variables(&cfg.blocks[*block_no], &mut cur_set, *cycle);
+
+            // Iterate over all instructions in reverse
+            for instr in instructions[*block_no].iter().rev() {
+                cur_set.process_instruction(instr, &mut reverse_ave, &mut None);
+            }
+
+            for edge in &self.reverse_dag[*block_no] {
+                if let Some(set) = self.reverse_sets.get_mut(edge) {
+                    // Instead of intersection two sets as in available expressions,
+                    // in anticipated expressions we need to unite them, because the expressions
+                    // of all a block's descendants can be anticipated there.
+                    set.union_sets(&cur_set);
+                } else {
+                    self.reverse_sets
+                        .insert(*edge, cur_set.clone_for_parent_block());
+                }
+            }
+        }
+    }
+
+    /// We calculate the flow in the graph considering block_1 and block_2 as sources, and using
+    /// the reversed CFG DAG. If any block has a flow that equals the sum of the two sources,
+    /// it can be used to calculate a common expressions that exists in both block_1 and block_2.
+    /// In the common subexpression elimination context, a block that has a total flow of
+    /// flow[block_1]+flow[block_2] means that it is a code path that leads to both block_1 and
+    /// block_2.
+    ///
+    /// When I use the term flow, I am referring to a flow network (https://en.wikipedia.org/wiki/Flow_network).
+    /// The flow of each vertex is equally divided between its children, and the flow a vertex
+    /// receives is the sum of the flows from its incoming edges.
+    pub(super) fn calculate_flow(&self, block_1: usize, block_2: usize) -> Vec<f64> {
+        let mut flow: Vec<f64> = vec![0.0; self.reverse_dag.len()];
+        flow[block_1] = 1000.0;
+        flow[block_2] = 1000.0;
+
+        for (block_no, _) in &self.traversing_order {
+            let divided_flow = flow[*block_no] / (self.reverse_dag[*block_no].len() as f64);
+            for child in &self.reverse_dag[*block_no] {
+                flow[*child] += divided_flow;
+            }
+        }
+
+        flow
+    }
+
+    /// This function find the correct block to place the evaluation of the common subexpression
+    /// 'expr', considering flow, depth and its anticipated availability.
+    pub(super) fn find_ancestor(
+        &self,
+        block_1: usize,
+        block_2: usize,
+        expr: &Expression,
+    ) -> Option<usize> {
+        if block_1 == block_2 {
+            return Some(block_1);
+        }
+
+        let flow = self.calculate_flow(block_1, block_2);
+
+        let mut candidate = usize::MAX;
+
+        for (block_no, flow_magnitude) in flow.iter().enumerate() {
+            // The condition is the following:
+            // 1. We prefer deeper blocks to evaluate the subexpression (depth[block_no] < depth[candidate]).
+            //    This is because if we evaluate a subexpression too early, we risk taking a branch
+            //    where the subexpression is not even used.
+            // 2. The flow_magnitude must be 2000. (2000.0 - *flow_magnitude).abs() deals with
+            //    floating point imprecision. We can also set a lower threshold for the comparison.
+            //    Ideally, it should be greater than the machine epsilon.
+            // 3. The expression must be available at the anticipated expression set for the block
+            //    we are analysing.
+            if (candidate == usize::MAX || self.depth[block_no] > self.depth[candidate])
+                && (2000.0 - *flow_magnitude).abs() < 0.000001
+                && self
+                    .reverse_sets
+                    .get(&block_no)
+                    .unwrap()
+                    .find_expression(expr)
+                    .is_some()
+            {
+                candidate = block_no;
+            }
+        }
+
+        if candidate < usize::MAX {
+            Some(candidate)
+        } else {
+            None
+        }
+    }
+}
+
+#[test]
+fn test_depth() {
+    let dag = vec![
+        vec![1, 2], // 0 -> 1, 2
+        vec![3, 4], // 1 -> 3, 4
+        vec![3, 4], // 2 -> 3, 4
+        vec![],     // 3
+        vec![],     // 4
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 5];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 1, 2, 2]);
+
+    let dag = vec![
+        vec![1, 2, 4], // 0 -> 1, 2, 4
+        vec![2, 3],    // 1 -> 2, 3
+        vec![4],       // 2 -> 4
+        vec![],        // 3
+        vec![],        // 4
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 5];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 1, 2, 1]);
+
+    let dag = vec![
+        vec![1, 4], // 0 -> 1, 4
+        vec![2, 3], // 1 -> 2, 3
+        vec![],     // 2
+        vec![5],    // 3 -> 5
+        vec![5],    // 4 -> 5
+        vec![],     // 5
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 6];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 2, 2, 1, 2]);
+
+    let dag = vec![
+        vec![1, 6],    // 0 -> 1, 6
+        vec![2, 4],    // 1 -> 2, 4
+        vec![3, 4],    // 2 -> 3, 4
+        vec![],        // 3
+        vec![5],       // 4 -> 5
+        vec![],        // 5
+        vec![4, 7, 8], // 6 -> 4, 7, 8
+        vec![5],       // 7 -> 5
+        vec![],        // 8
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 9];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 2, 3, 2, 3, 1, 2, 2]);
+
+    // Case 5
+    let dag = vec![
+        vec![1, 3],    // 0 -> 1, 3
+        vec![2, 4],    // 1 -> 2, 4
+        vec![7, 8, 6], // 2 -> 7, 8, 6
+        vec![2, 6],    // 3 -> 2, 6
+        vec![7, 5],    // 4 -> 7, 5
+        vec![],        // 5
+        vec![],        // 6
+        vec![],        // 7
+        vec![],        // 8
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 9];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 2, 1, 2, 3, 2, 3, 3]);
+
+    // Loop case
+    let dag = vec![
+        vec![1], // 0 -> 1
+        vec![2], // 1 -> 2
+        vec![3], // 2 -> 3
+        vec![1], // 3 -> 1
+    ];
+    let mut depth: Vec<u16> = vec![u16::MAX; 4];
+    AnticipatedExpressions::blocks_depth(&dag, 0, 0, &mut depth);
+    assert_eq!(depth, vec![0, 1, 2, 3]);
+}

+ 23 - 22
src/codegen/subexpression_elimination/available_expression.rs

@@ -10,23 +10,23 @@ use std::rc::Rc;
 
 impl AvailableExpression {
     /// Add a node to represent a literal
-    pub fn add_literal_node(
+    pub fn add_literal_node<'b, 'a: 'b>(
         &mut self,
-        expr: &Expression,
-        expr_set: &mut AvailableExpressionSet,
+        expr: &'a Expression,
+        expr_set: &mut AvailableExpressionSet<'b>,
     ) -> NodeId {
         let expr_type = expr.get_constant_expression_type();
 
-        self.add_node_to_memory(expr_set, expr_type);
+        self.add_node_to_memory(expr_set, expr_type, expr);
 
         self.global_id_counter - 1
     }
 
     /// Add a node to represent a variable
-    pub fn add_variable_node(
+    pub fn add_variable_node<'b, 'a: 'b>(
         &mut self,
-        expr: &Expression,
-        expr_set: &mut AvailableExpressionSet,
+        expr: &'a Expression,
+        expr_set: &mut AvailableExpressionSet<'b>,
     ) -> NodeId {
         let expr_type = match expr {
             Expression::Variable(_, _, pos) => ExpressionType::Variable(*pos),
@@ -36,16 +36,16 @@ impl AvailableExpression {
             _ => unreachable!("This expression is not a variable or a function argument"),
         };
 
-        self.add_node_to_memory(expr_set, expr_type);
+        self.add_node_to_memory(expr_set, expr_type, expr);
 
         self.global_id_counter - 1
     }
 
     /// Add a node to represent a binary expression
-    pub fn add_binary_node(
+    pub fn add_binary_node<'b, 'a: 'b>(
         &mut self,
-        exp: &Expression,
-        expr_set: &mut AvailableExpressionSet,
+        exp: &'a Expression,
+        expr_set: &mut AvailableExpressionSet<'b>,
         left: NodeId,
         right: NodeId,
     ) -> NodeId {
@@ -55,9 +55,9 @@ impl AvailableExpression {
             expression_id: self.global_id_counter,
             children: Default::default(),
             available_variable: AvailableVariable::Unavailable,
-            parent_block: expr_set.parent_block_no,
-            on_parent_block: false,
+            parent_block: None,
             block: self.cur_block,
+            reference: exp,
         }));
         expr_set
             .expression_memory
@@ -88,11 +88,11 @@ impl AvailableExpression {
     }
 
     /// Add a node to represent an unary operation
-    pub fn add_unary_node(
+    pub fn add_unary_node<'b, 'a: 'b>(
         &mut self,
-        exp: &Expression,
+        exp: &'a Expression,
         parent: usize,
-        expr_set: &mut AvailableExpressionSet,
+        expr_set: &mut AvailableExpressionSet<'b>,
     ) -> NodeId {
         let operation = exp.get_ave_operator();
         let new_node = Rc::new(RefCell::new(BasicExpression {
@@ -100,9 +100,9 @@ impl AvailableExpression {
             expression_id: self.global_id_counter,
             children: Default::default(),
             available_variable: AvailableVariable::Unavailable,
-            parent_block: expr_set.parent_block_no,
-            on_parent_block: false,
+            parent_block: None,
             block: self.cur_block,
+            reference: exp,
         }));
 
         expr_set
@@ -126,10 +126,11 @@ impl AvailableExpression {
         self.global_id_counter - 1
     }
 
-    fn add_node_to_memory(
+    fn add_node_to_memory<'b, 'a: 'b>(
         &mut self,
-        expr_set: &mut AvailableExpressionSet,
+        expr_set: &mut AvailableExpressionSet<'b>,
         expr_type: ExpressionType,
+        expr: &'a Expression,
     ) {
         expr_set.expression_memory.insert(
             self.global_id_counter,
@@ -138,9 +139,9 @@ impl AvailableExpression {
                 expression_id: self.global_id_counter,
                 children: Default::default(),
                 available_variable: AvailableVariable::Unavailable,
-                parent_block: expr_set.parent_block_no,
-                on_parent_block: false,
+                parent_block: None,
                 block: self.cur_block,
+                reference: expr,
             })),
         );
 

+ 84 - 26
src/codegen/subexpression_elimination/available_expression_set.rs

@@ -12,13 +12,13 @@ use std::cell::RefCell;
 use std::collections::{HashMap, HashSet};
 use std::rc::Rc;
 
-impl AvailableExpressionSet {
+impl<'a, 'b: 'a> AvailableExpressionSet<'a> {
     /// Clone a set for a given parent block
-    pub fn clone_for_parent_block(&self, parent_block: usize) -> AvailableExpressionSet {
+    pub fn clone_for_parent_block(&self) -> AvailableExpressionSet<'a> {
         let mut new_set = AvailableExpressionSet {
             expression_memory: HashMap::default(),
             expr_map: self.expr_map.clone(),
-            parent_block_no: parent_block,
+            mapped_variable: self.mapped_variable.clone(),
         };
 
         for (key, value) in &self.expression_memory {
@@ -29,9 +29,9 @@ impl AvailableExpressionSet {
                     expression_id: value.borrow().expression_id,
                     children: HashMap::default(),
                     available_variable: value.borrow().available_variable.clone(),
-                    parent_block: value.borrow().parent_block,
-                    on_parent_block: value.borrow().on_parent_block,
                     block: value.borrow().block,
+                    parent_block: value.borrow().parent_block,
+                    reference: value.borrow().reference,
                 })),
             );
         }
@@ -87,12 +87,12 @@ impl AvailableExpressionSet {
                 node_1.children.clear();
                 let node_2_id = set_2.expr_map.get(key).unwrap();
 
-                node_1.on_parent_block = true;
                 // Find the common ancestor of both blocks. The deepest block after which there are
                 // multiple paths to both blocks.
                 node_1.parent_block = cst.find_parent_block(
-                    node_1.parent_block,
-                    set_2.expression_memory[node_2_id].borrow().parent_block,
+                    node_1.block,
+                    set_2.expression_memory[node_2_id].borrow().block,
+                    node_1.reference,
                 );
                 if let (Some(var_id_1), Some(var_id_2)) = (
                     set_2.expression_memory[node_2_id]
@@ -127,6 +127,42 @@ impl AvailableExpressionSet {
         }
     }
 
+    /// Calculate the union between two sets
+    pub fn union_sets(&mut self, set_2: &AvailableExpressionSet<'a>) {
+        let mut node_translation: HashMap<NodeId, NodeId> = HashMap::new();
+        for (key, node_id) in &set_2.expr_map {
+            if let Some(other_id) = self.expr_map.get(key) {
+                node_translation.insert(*node_id, *other_id);
+            }
+        }
+
+        for (key, node_id) in &set_2.expr_map {
+            if !self.expr_map.contains_key(key) {
+                let new_key = match key {
+                    ExpressionType::BinaryOperation(id_1, id_2, op) => {
+                        ExpressionType::BinaryOperation(
+                            node_translation.get(id_1).cloned().unwrap_or(*id_1),
+                            node_translation.get(id_2).cloned().unwrap_or(*id_2),
+                            op.clone(),
+                        )
+                    }
+                    ExpressionType::UnaryOperation(id, op) => ExpressionType::UnaryOperation(
+                        node_translation.get(id).cloned().unwrap_or(*id),
+                        op.clone(),
+                    ),
+                    _ => key.clone(),
+                };
+                self.expr_map.insert(new_key, *node_id);
+            }
+        }
+
+        for (key, expr) in &set_2.expression_memory {
+            if !self.expression_memory.contains_key(key) {
+                self.expression_memory.insert(*key, expr.clone());
+            }
+        }
+    }
+
     /// Check if a commutative expression exists in the set
     fn find_commutative(
         &self,
@@ -163,10 +199,10 @@ impl AvailableExpressionSet {
     /// Try to fetch the ID of left and right operands.
     fn process_left_right(
         &mut self,
-        left: &Expression,
-        right: &Expression,
+        left: &'a Expression,
+        right: &'a Expression,
         ave: &mut AvailableExpression,
-        cst: &mut CommonSubExpressionTracker,
+        cst: &mut Option<&mut CommonSubExpressionTracker>,
     ) -> Option<(NodeId, NodeId)> {
         let left_id = self.gen_expression(left, ave, cst)?;
         let right_id = self.gen_expression(right, ave, cst)?;
@@ -177,11 +213,11 @@ impl AvailableExpressionSet {
     /// Add a commutative expression to the set if it is not there yet
     fn process_commutative(
         &mut self,
-        exp: &Expression,
-        left: &Expression,
-        right: &Expression,
+        exp: &'a Expression,
+        left: &'a Expression,
+        right: &'a Expression,
         ave: &mut AvailableExpression,
-        cst: &mut CommonSubExpressionTracker,
+        cst: &mut Option<&mut CommonSubExpressionTracker>,
     ) -> Option<NodeId> {
         let (left_id, right_id) = self.process_left_right(left, right, ave, cst)?;
         Some(ave.add_binary_node(exp, self, left_id, right_id))
@@ -190,14 +226,16 @@ impl AvailableExpressionSet {
     /// Add expression to the graph and check if it is available on a parallel branch.
     pub fn gen_expression(
         &mut self,
-        exp: &Expression,
+        exp: &'a Expression,
         ave: &mut AvailableExpression,
-        cst: &mut CommonSubExpressionTracker,
+        cst: &mut Option<&mut CommonSubExpressionTracker>,
     ) -> Option<NodeId> {
         let id = self.gen_expression_aux(exp, ave, cst);
         if let Some(id) = id {
             let node = &*self.expression_memory.get(&id).unwrap().borrow();
-            cst.check_availability_on_branches(&node.expr_type);
+            if let Some(tracker) = cst.as_mut() {
+                tracker.check_availability_on_branches(&node.expr_type, exp);
+            }
         }
         id
     }
@@ -205,12 +243,14 @@ impl AvailableExpressionSet {
     /// Add an expression to the graph if it is not there
     pub fn gen_expression_aux(
         &mut self,
-        exp: &Expression,
+        exp: &'a Expression,
         ave: &mut AvailableExpression,
-        cst: &mut CommonSubExpressionTracker,
+        cst: &mut Option<&mut CommonSubExpressionTracker>,
     ) -> Option<NodeId> {
         if let Some(id) = self.find_expression(exp) {
-            self.add_to_cst(exp, &id, cst);
+            if let Some(tracker) = cst.as_mut() {
+                self.add_to_cst(exp, &id, tracker);
+            }
             return Some(id);
         }
 
@@ -297,6 +337,24 @@ impl AvailableExpressionSet {
         self.expr_map.remove(&basic_exp.expr_type);
     }
 
+    /// This functions indicates that an available node that was once mapped to an existing variable
+    /// no longer should be linked to that variable.
+    ///
+    /// When we have an assignment 'x = a + b', and later we find the usage of 'a + b', we can
+    /// replace it by 'x', instead of creating a new cse temporary. Nonetheless, whenever the 'x'
+    /// is reassigned, we must indicate that 'x' does not represent 'a + b' anymore, so we would
+    /// need a temporary if we were to replace a repeated occurrence of 'a + b'
+    pub fn remove_mapped(&mut self, var_no: usize) {
+        if let Some(node_id) = self.mapped_variable.remove(&var_no) {
+            if let Some(node) = self.expression_memory.get(&node_id) {
+                let mut node_mut = node.borrow_mut();
+                if node_mut.available_variable.is_available() {
+                    node_mut.available_variable = AvailableVariable::Unavailable;
+                }
+            }
+        }
+    }
+
     /// When a reaching definition change, we remove the variable node and all its descendants from
     /// the graph
     pub fn kill(&mut self, var_no: usize) {
@@ -391,9 +449,9 @@ impl AvailableExpressionSet {
     /// Regenerate commutative expressions
     fn regenerate_commutative(
         &mut self,
-        exp: &Expression,
-        left: &Expression,
-        right: &Expression,
+        exp: &'a Expression,
+        left: &'a Expression,
+        right: &'a Expression,
         ave: &mut AvailableExpression,
         cst: &mut CommonSubExpressionTracker,
     ) -> (Option<NodeId>, Expression) {
@@ -436,7 +494,7 @@ impl AvailableExpressionSet {
     /// a temporary, we do it here.
     pub fn regenerate_expression(
         &mut self,
-        exp: &Expression,
+        exp: &'a Expression,
         ave: &mut AvailableExpression,
         cst: &mut CommonSubExpressionTracker,
     ) -> (Option<NodeId>, Expression) {
@@ -448,7 +506,7 @@ impl AvailableExpressionSet {
             | Expression::NumberLiteral(..)
             | Expression::BoolLiteral(..)
             | Expression::BytesLiteral(..) => {
-                return (self.gen_expression(exp, ave, cst), exp.clone());
+                return (self.gen_expression(exp, ave, &mut Some(cst)), exp.clone());
             }
 
             Expression::StringCompare(_, left, right)

+ 58 - 147
src/codegen/subexpression_elimination/common_subexpression_tracker.rs

@@ -1,6 +1,6 @@
 // SPDX-License-Identifier: Apache-2.0
 
-use crate::codegen::cfg::InstrOrigin;
+use crate::codegen::subexpression_elimination::anticipated_expressions::AnticipatedExpressions;
 use crate::codegen::subexpression_elimination::{BasicExpression, ExpressionType};
 use crate::codegen::{
     vartable::{Storage, Variable},
@@ -8,10 +8,9 @@ use crate::codegen::{
 };
 use crate::sema::ast::RetrieveType;
 use crate::sema::ast::{Namespace, Type};
-use bitflags::bitflags;
 use solang_parser::pt::OptionalCodeLocation;
 use solang_parser::pt::{Identifier, Loc};
-use std::collections::{HashMap, VecDeque};
+use std::collections::HashMap;
 
 #[derive(Clone)]
 struct CommonSubexpression {
@@ -24,35 +23,29 @@ struct CommonSubexpression {
     on_parent_block: Option<usize>,
 }
 
-bitflags! {
-  struct Color: u8 {
-     const WHITE = 0;
-     const BLUE = 2;
-     const YELLOW = 4;
-     const GREEN = 6;
-  }
-}
-
 #[derive(Default, Clone)]
-pub struct CommonSubExpressionTracker {
+pub struct CommonSubExpressionTracker<'a> {
+    /// This hash map tracks the inserted common subexpressions. The usize is the index into
+    /// common_subexpressions vector
     inserted_subexpressions: HashMap<ExpressionType, usize>,
+    /// We store common subexpressions in this vector
     common_subexpressions: Vec<CommonSubexpression>,
-    len: usize,
-    name_cnt: usize,
+    /// The cur_block tracks the current block we are currently analysing
     cur_block: usize,
-    new_cfg_instr: Vec<(InstrOrigin, Instr)>,
+    /// We save here the new instructions we need to add to the current block
+    new_cfg_instr: Vec<Instr>,
+    /// Here, we store the instruction we must add to blocks different than the one we are
+    /// analysing now
     parent_block_instr: Vec<(usize, Instr)>,
     /// Map from variable number to common subexpression
     mapped_variables: HashMap<usize, usize>,
-    /// The CFG is a cyclic graph. In order properly find the lowest common block,
-    /// we transformed it in a DAG, removing cycles from loops.
-    cfg_dag: Vec<Vec<usize>>,
+    /// anticipated_expressions saves the ancipated expressions for every block in the CFG
+    anticipated_expressions: AnticipatedExpressions<'a>,
 }
 
-impl CommonSubExpressionTracker {
-    /// Save the DAG to the CST
-    pub fn set_dag(&mut self, dag: Vec<Vec<usize>>) {
-        self.cfg_dag = dag;
+impl<'a> CommonSubExpressionTracker<'a> {
+    pub(super) fn set_anticipated(&mut self, anticipated: AnticipatedExpressions<'a>) {
+        self.anticipated_expressions = anticipated;
     }
 
     /// Add an expression to the tracker.
@@ -79,7 +72,13 @@ impl CommonSubExpressionTracker {
         }
 
         self.inserted_subexpressions
-            .insert(expr_type.clone(), self.len);
+            .insert(expr_type.clone(), self.common_subexpressions.len());
+
+        if let Some(var_no) = node.available_variable.get_var_number() {
+            // If we encounter an expression like 'x = y+2', we can map 'x' to 'y+2', whenever possible.
+            self.mapped_variables
+                .insert(var_no, self.common_subexpressions.len());
+        }
 
         self.common_subexpressions.push(CommonSubexpression {
             in_cfg: node.available_variable.is_available(),
@@ -88,19 +87,8 @@ impl CommonSubExpressionTracker {
             instantiated: false,
             var_type: exp.ty(),
             block: node.block,
-            on_parent_block: if node.on_parent_block {
-                Some(node.parent_block)
-            } else {
-                None
-            },
+            on_parent_block: node.parent_block,
         });
-
-        if let Some(var_no) = node.available_variable.get_var_number() {
-            // If we encounter an expression like 'x = y+2', we can map 'x' to 'y+2', whenever possible.
-            self.mapped_variables.insert(var_no, self.len);
-        }
-
-        self.len += 1;
     }
 
     /// Invalidate a mapped variable
@@ -114,15 +102,16 @@ impl CommonSubExpressionTracker {
 
     /// Create variables in the CFG
     pub fn create_variables(&mut self, ns: &mut Namespace, cfg: &mut ControlFlowGraph) {
+        let mut name_cnt: usize = 0;
         for exp in self.common_subexpressions.iter_mut() {
             if exp.var_no.is_none() {
-                self.name_cnt += 1;
+                name_cnt += 1;
                 cfg.vars.insert(
                     ns.next_id,
                     Variable {
                         id: Identifier {
                             loc: Loc::Codegen,
-                            name: format!("{}.cse_temp", self.name_cnt),
+                            name: format!("{name_cnt}.cse_temp"),
                         },
                         ty: exp.var_type.clone(),
                         storage: Storage::Local,
@@ -157,15 +146,27 @@ impl CommonSubExpressionTracker {
     /// '''
     ///
     /// This avoids the repeated calculation of 'a+b'
-    pub fn check_availability_on_branches(&mut self, expr_type: &ExpressionType) {
+    pub fn check_availability_on_branches(
+        &mut self,
+        expr_type: &ExpressionType,
+        expr: &Expression,
+    ) {
         if let Some(expr_id) = self.inserted_subexpressions.get(expr_type) {
             let expr_block = self.common_subexpressions[*expr_id].block;
             let expr_block = self.common_subexpressions[*expr_id]
                 .on_parent_block
                 .unwrap_or(expr_block);
-            let ancestor = self.find_parent_block(self.cur_block, expr_block);
-            if ancestor != expr_block {
-                self.common_subexpressions[*expr_id].on_parent_block = Some(ancestor);
+            let ancestor = self.find_parent_block(self.cur_block, expr_block, expr);
+            if let Some(ancestor_no) = ancestor {
+                if ancestor_no != expr_block {
+                    let common_expression = &mut self.common_subexpressions[*expr_id];
+                    // When an expression is going to be evaluated on a block that's different from
+                    // the place where we first saw it, it cannot be replaced by an existing variable.
+                    common_expression.var_no = None;
+                    common_expression.var_loc = None;
+                    common_expression.in_cfg = false;
+                    common_expression.on_parent_block = Some(ancestor_no);
+                }
             }
         }
     }
@@ -196,7 +197,7 @@ impl CommonSubExpressionTracker {
             };
 
             if common_expression.on_parent_block.is_none() {
-                self.new_cfg_instr.push((InstrOrigin::Codegen, new_instr));
+                self.new_cfg_instr.push(new_instr);
             } else {
                 self.parent_block_instr
                     .push((common_expression.on_parent_block.unwrap(), new_instr));
@@ -217,18 +218,16 @@ impl CommonSubExpressionTracker {
     }
 
     /// Add new instructions to the instruction vector
-    pub fn add_new_instructions(&mut self, instr_vec: &mut Vec<(InstrOrigin, Instr)>) {
+    pub fn add_new_instructions(&mut self, instr_vec: &mut Vec<Instr>) {
         instr_vec.append(&mut self.new_cfg_instr);
     }
 
-    /// If a variable create should be hoisted in a different block than where it it read, we
+    /// If a variable create should be placed in a different block than where it it read, we
     /// do it here.
     pub fn add_parent_block_instructions(&self, cfg: &mut ControlFlowGraph) {
         for (block_no, instr) in &self.parent_block_instr {
             let index = cfg.blocks[*block_no].instr.len() - 1;
-            cfg.blocks[*block_no]
-                .instr
-                .insert(index, (InstrOrigin::Codegen, instr.to_owned()));
+            cfg.blocks[*block_no].instr.insert(index, instr.to_owned());
         }
     }
 
@@ -240,105 +239,17 @@ impl CommonSubExpressionTracker {
 
     /// For common subexpression elimination to work properly, we need to find the common parent of
     /// two blocks. The parent is the deepest block in which every path from the entry block to both
-    /// 'block_1' and 'block_2' passes through such a block.
-    pub fn find_parent_block(&self, block_1: usize, block_2: usize) -> usize {
-        if block_1 == block_2 {
-            return block_1;
-        }
-        let mut colors: Vec<Color> = vec![Color::WHITE; self.cfg_dag.len()];
-        let mut visited: Vec<bool> = vec![false; self.cfg_dag.len()];
-        /*
-        Given a DAG (directed acyclic graph), we color all the ancestors of 'block_1' with yellow.
-        Then, we color every ancestor of 'block_2' with blue. As the mixture of blue and yellow
-        results in green, green blocks are all possible common ancestors!
-
-        We can't add colors to code. Here, bitwise ORing 2 to a block's color mean painting with yellow.
-        Likewise, bitwise ORing 4 means painting with blue. Green blocks have 6 (2|4) as their color
-        number.
-
-         */
-
-        self.coloring_dfs(block_1, 0, Color::BLUE, &mut colors, &mut visited);
-        visited.fill(false);
-        self.coloring_dfs(block_2, 0, Color::YELLOW, &mut colors, &mut visited);
-
-        /*
-        Having a bunch of green block, which of them are we looking for?
-        We must choose the deepest block, in which all paths from the entry block to both block_1
-        and block_2 pass through this block.
-
-        Have a look at the 'find_ancestor' function to know more about the algorithm.
-         */
-        self.find_ancestor(0, &colors)
-    }
-
-    /// Given a colored graph, find the lowest common ancestor.
-    fn find_ancestor(&self, start_block: usize, colors: &[Color]) -> usize {
-        let mut candidate = start_block;
-        let mut queue: VecDeque<usize> = VecDeque::new();
-        let mut visited: Vec<bool> = vec![false; self.cfg_dag.len()];
-
-        visited[start_block] = true;
-        queue.push_back(start_block);
-
-        let mut six_child: usize = 0;
-        // This is a BFS (breadth first search) traversal
-        while let Some(cur_block) = queue.pop_front() {
-            let mut not_ancestors: usize = 0;
-            for child in &self.cfg_dag[cur_block] {
-                if colors[*child] == Color::WHITE {
-                    // counting the number of children which are not ancestors from neither block_1
-                    // nor block_2
-                    not_ancestors += 1;
-                }
-
-                if colors[*child] == Color::GREEN {
-                    // This is the possible candidate to search next.
-                    six_child = *child;
-                }
-            }
-
-            // If the current block has only one child that leads to both block_1 and block_2, it is
-            // a candidate to be the lowest common ancestor.
-            if not_ancestors + 1 == self.cfg_dag[cur_block].len() && !visited[six_child] {
-                visited[six_child] = true;
-                queue.push_back(six_child);
-                candidate = six_child;
-            }
-        }
-
-        candidate
-    }
-
-    /// This function performs a DFS (depth first search) to color all the ancestors of a block.
-    fn coloring_dfs(
+    /// 'block_1' and 'block_2' passes through such a block, provided that the expression is
+    /// anticipated there.
+    pub fn find_parent_block(
         &self,
-        search_block: usize,
-        cur_block: usize,
-        color: Color,
-        colors: &mut Vec<Color>,
-        visited: &mut Vec<bool>,
-    ) -> bool {
-        if colors[cur_block].contains(color) {
-            return true;
-        }
-
-        if visited[cur_block] {
-            return false;
-        }
-
-        visited[cur_block] = true;
-        if cur_block == search_block {
-            colors[cur_block].insert(color);
-            return true;
-        }
-
-        for next in &self.cfg_dag[cur_block] {
-            if self.coloring_dfs(search_block, *next, color, colors, visited) {
-                colors[cur_block].insert(color);
-            }
-        }
-
-        colors[cur_block].contains(color)
+        block_1: usize,
+        block_2: usize,
+        expr: &Expression,
+    ) -> Option<usize> {
+        // The analysis is done at another data structure to isolate the logic of traversing the
+        // CFG from the end to the beginning (backwards).
+        self.anticipated_expressions
+            .find_ancestor(block_1, block_2, expr)
     }
 }

+ 29 - 13
src/codegen/subexpression_elimination/instruction.rs

@@ -6,13 +6,13 @@ use crate::codegen::subexpression_elimination::AvailableExpression;
 use crate::codegen::subexpression_elimination::{AvailableExpressionSet, AvailableVariable};
 use crate::codegen::Expression;
 
-impl AvailableExpressionSet {
+impl<'a, 'b: 'a> AvailableExpressionSet<'a> {
     /// Check if we can add the expressions of an instruction to the graph
     pub fn process_instruction(
         &mut self,
-        instr: &Instr,
+        instr: &'b Instr,
         ave: &mut AvailableExpression,
-        cst: &mut CommonSubExpressionTracker,
+        cst: &mut Option<&mut CommonSubExpressionTracker>,
     ) {
         match instr {
             Instr::BranchCond { cond: expr, .. }
@@ -29,19 +29,35 @@ impl AvailableExpressionSet {
             }
 
             Instr::Set { res, expr, loc } => {
-                let node_id = self.gen_expression(expr, ave, cst);
-                if node_id.is_some() {
-                    let node = &mut *self
-                        .expression_memory
-                        .get(node_id.as_ref().unwrap())
-                        .unwrap()
-                        .borrow_mut();
+                if cst.is_none() {
+                    // If there is no cst, we are traversing the CFG in reverse, so we kill the
+                    // definition before processing the assignment
+                    // e.g.
+                    // -- Here we have a previous definition of x and x + y is available
+                    // x = x + y -> kill x first, then make x+y available
+                    // -- x+y is not available
+                    self.kill(*res);
+                }
+
+                self.remove_mapped(*res);
+                if let Some(node_id) = self.gen_expression(expr, ave, cst) {
+                    let node = &mut *self.expression_memory.get(&node_id).unwrap().borrow_mut();
                     if !node.available_variable.is_available() {
                         node.available_variable = AvailableVariable::Available(*res, *loc);
+                        self.mapped_variable.insert(*res, node_id);
                     }
                 }
-                cst.invalidate_mapped_variable(res);
-                self.kill(*res);
+
+                if let Some(tracker) = cst {
+                    // If there is a cst, we are traversing the CFG in the same order as code
+                    // execution , so we kill the definition after processing the assignment
+                    // e.g.
+                    // -- x+y not available
+                    // x = x + y -> make x+y available, than make kill x, which also kills x+y
+                    // -- x + y is not available here, because x has a new definition
+                    self.kill(*res);
+                    tracker.invalidate_mapped_variable(res);
+                }
             }
 
             Instr::PushMemory { value: expr, .. } => {
@@ -185,7 +201,7 @@ impl AvailableExpressionSet {
     /// Regenerate instructions after that we exchanged common subexpressions for temporaries
     pub fn regenerate_instruction(
         &mut self,
-        instr: &Instr,
+        instr: &'b Instr,
         ave: &mut AvailableExpression,
         cst: &mut CommonSubExpressionTracker,
     ) -> Instr {

+ 99 - 34
src/codegen/subexpression_elimination/mod.rs

@@ -1,16 +1,19 @@
 // SPDX-License-Identifier: Apache-2.0
 
-use crate::codegen::cfg::{BasicBlock, ControlFlowGraph, Instr, InstrOrigin};
+use crate::codegen::cfg::{BasicBlock, ControlFlowGraph, Instr};
 use crate::codegen::reaching_definitions::block_edges;
+use crate::codegen::subexpression_elimination::anticipated_expressions::AnticipatedExpressions;
 use crate::codegen::subexpression_elimination::available_variable::AvailableVariable;
 use crate::codegen::subexpression_elimination::common_subexpression_tracker::CommonSubExpressionTracker;
 use crate::codegen::subexpression_elimination::operator::Operator;
+use crate::codegen::Expression;
 use crate::sema::ast::Namespace;
 use num_bigint::BigInt;
 use std::cell::RefCell;
 use std::collections::{HashMap, HashSet, VecDeque};
 use std::rc::Rc;
 
+mod anticipated_expressions;
 mod available_expression;
 mod available_expression_set;
 mod available_variable;
@@ -46,16 +49,26 @@ pub struct AvailableExpression {
     cur_block: usize,
 }
 
-/// Each BasicExpression is a graph node
+/// Each BasicExpression is a graph node that tracks a real codegen::Expression
 #[derive(Clone)]
-pub struct BasicExpression {
+pub struct BasicExpression<'a> {
+    /// The expression type for this node
     expr_type: ExpressionType,
+    /// The node global id
     expression_id: NodeId,
-    children: HashMap<NodeId, Rc<RefCell<BasicExpression>>>,
+    /// This map tracks all the node's children
+    children: HashMap<NodeId, Rc<RefCell<BasicExpression<'a>>>>,
+    /// Reference points to the real codegen::Expression this node represents
+    pub reference: &'a Expression,
+    /// Available_variable tells us if a CFG variable is available for this node.
+    /// E.g. if 'x=a+b' is evaluated, 'a+b' is already assigned to a variable in the CFG, so it
+    /// does not need a temporary in case it happens to be a common subexpression.
     pub available_variable: AvailableVariable,
+    /// Block is the CFG block where the expression was first seen.
     pub block: usize,
-    pub parent_block: usize,
-    pub on_parent_block: bool,
+    /// When parent_block is set, the expression should be evaluated at the parent_block instead of
+    /// block (the parameter right above this one).
+    pub parent_block: Option<usize>,
 }
 
 /// Type of constant to streamline the use of a hashmap
@@ -77,37 +90,68 @@ pub enum ExpressionType {
 }
 
 /// Sets contain the available expression at a certain portion of the CFG
-#[derive(Default)]
-pub struct AvailableExpressionSet {
+#[derive(Default, Clone)]
+pub struct AvailableExpressionSet<'a> {
     // node_no => BasicExpression
-    expression_memory: HashMap<NodeId, Rc<RefCell<BasicExpression>>>,
+    expression_memory: HashMap<NodeId, Rc<RefCell<BasicExpression<'a>>>>,
     // Expression => node_id
     expr_map: HashMap<ExpressionType, NodeId>,
-    parent_block_no: usize,
+    mapped_variable: HashMap<usize, NodeId>,
+}
+
+/// This struct serves to be the return of function 'find_visiting_order', which helps finding all
+/// the CFG representations for the analysis.
+struct CfgAsDag {
+    visiting_order: Vec<(usize, bool)>,
+    reverse_visiting_order: Vec<(usize, bool)>,
+    dag: Vec<Vec<usize>>,
+    reverse_dag: Vec<Vec<usize>>,
 }
 
 /// Performs common subexpression elimination
 pub fn common_sub_expression_elimination(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
+    // visiting_order: the order in which we should traverse the CFG (this is its topological sorting)
+    // dag: The CFG represented as a DAG (direct acyclic graph)
+    // reverse_dag: The CFG represented as a DAG, but with all the edges reversed.
+    let cfg_as_dag = find_visiting_order(cfg);
+
+    let mut old_instr: Vec<Vec<Instr>> = vec![Vec::new(); cfg.blocks.len()];
+    // We need to remove the instructions from the blocks, so we can store references in the
+    // available expressions set.
+    for (block_no, block) in cfg.blocks.iter_mut().enumerate() {
+        std::mem::swap(&mut old_instr[block_no], &mut block.instr);
+    }
+
+    // Anticipated expression is an analysis that calculated where in the CFG we can anticipate
+    // the evaluation of an expression, provided that none of its constituents variables have
+    // been assigned a new value.
+    let mut anticipated_expressions = AnticipatedExpressions::new(
+        &cfg_as_dag.dag,
+        cfg_as_dag.reverse_dag,
+        cfg_as_dag.reverse_visiting_order,
+    );
+    anticipated_expressions.calculate_anticipated_expressions(&old_instr, cfg);
+
     let mut ave = AvailableExpression::default();
     let mut cst = CommonSubExpressionTracker::default();
-
     let mut sets: HashMap<usize, AvailableExpressionSet> = HashMap::new();
-    let (visiting_order, dag) = find_visiting_order(cfg);
-    cst.set_dag(dag);
     sets.insert(0, AvailableExpressionSet::default());
+    // The anticipated expression values are part of the common subsexpression tracker, so we
+    // can always evaluate the common subexpressions in the correct place.
+    cst.set_anticipated(anticipated_expressions);
 
     // First pass: identify common subexpressions using available expressions analysis
-    for (block_no, cycle) in &visiting_order {
+    for (block_no, cycle) in &cfg_as_dag.visiting_order {
         let cur_block = &cfg.blocks[*block_no];
         ave.set_cur_block(*block_no);
         cst.set_cur_block(*block_no);
         let mut cur_set = sets.remove(block_no).unwrap();
         kill_loop_variables(cur_block, &mut cur_set, *cycle);
-        for (_, instr) in cur_block.instr.iter() {
-            cur_set.process_instruction(instr, &mut ave, &mut cst);
+        for instr in old_instr[*block_no].iter() {
+            cur_set.process_instruction(instr, &mut ave, &mut Some(&mut cst));
         }
 
-        add_neighbor_blocks(cur_block, &cur_set, block_no, &mut sets, &cst);
+        add_neighbor_blocks(cur_set, &cfg_as_dag.dag[*block_no], &mut sets, &cst);
     }
 
     cst.create_variables(ns, cfg);
@@ -117,39 +161,38 @@ pub fn common_sub_expression_elimination(cfg: &mut ControlFlowGraph, ns: &mut Na
     sets.insert(0, AvailableExpressionSet::default());
 
     // Second pass: eliminate common subexpressions
-    for (block_no, cycle) in &visiting_order {
+    for (block_no, cycle) in &cfg_as_dag.visiting_order {
         let mut cur_set = sets.remove(block_no).unwrap();
         let mut cur_block = &mut cfg.blocks[*block_no];
         ave.set_cur_block(*block_no);
         cst.set_cur_block(*block_no);
-        let mut new_instructions: Vec<(InstrOrigin, Instr)> = Vec::new();
+        let mut new_instructions: Vec<Instr> = Vec::new();
         kill_loop_variables(cur_block, &mut cur_set, *cycle);
-        for (origin, instr) in cur_block.instr.iter() {
+        for instr in old_instr[*block_no].iter() {
             let instr = cur_set.regenerate_instruction(instr, &mut ave, &mut cst);
             cst.add_new_instructions(&mut new_instructions);
-            new_instructions.push((origin.clone(), instr));
+            new_instructions.push(instr);
         }
 
         cur_block.instr = new_instructions;
-        add_neighbor_blocks(cur_block, &cur_set, block_no, &mut sets, &cst);
+        add_neighbor_blocks(cur_set, &cfg_as_dag.dag[*block_no], &mut sets, &cst);
     }
 
     cst.add_parent_block_instructions(cfg);
 }
 
 /// Add neighbor block to the hashset of Available expressions to be processed
-fn add_neighbor_blocks(
-    cur_block: &BasicBlock,
-    cur_set: &AvailableExpressionSet,
-    block_no: &usize,
-    sets: &mut HashMap<usize, AvailableExpressionSet>,
+fn add_neighbor_blocks<'b>(
+    cur_set: AvailableExpressionSet<'b>,
+    edges: &[usize],
+    sets: &mut HashMap<usize, AvailableExpressionSet<'b>>,
     cst: &CommonSubExpressionTracker,
 ) {
-    for edge in block_edges(cur_block) {
-        if let Some(set) = sets.get_mut(&edge) {
-            set.intersect_sets(cur_set, cst);
+    for edge in edges {
+        if let Some(set) = sets.get_mut(edge) {
+            set.intersect_sets(&cur_set, cst);
         } else {
-            sets.insert(edge, cur_set.clone_for_parent_block(*block_no));
+            sets.insert(*edge, cur_set.clone_for_parent_block());
         }
     }
 }
@@ -161,6 +204,7 @@ fn kill_loop_variables(block: &BasicBlock, cur_set: &mut AvailableExpressionSet,
         return;
     }
     for var_no in &block.loop_reaching_variables {
+        cur_set.remove_mapped(*var_no);
         cur_set.kill(*var_no);
     }
 }
@@ -168,14 +212,16 @@ fn kill_loop_variables(block: &BasicBlock, cur_set: &mut AvailableExpressionSet,
 /// Find the correct visiting order for the CFG traversal, using topological sorting. The visiting
 /// order should be the same as the execution order. This function also returns a DAG for the
 /// execution graph. This helps us find the lowest common ancestor later.
-fn find_visiting_order(cfg: &ControlFlowGraph) -> (Vec<(usize, bool)>, Vec<Vec<usize>>) {
+fn find_visiting_order(cfg: &ControlFlowGraph) -> CfgAsDag {
     let mut order: Vec<(usize, bool)> = Vec::with_capacity(cfg.blocks.len());
     let mut visited: HashSet<usize> = HashSet::new();
     let mut stack: HashSet<usize> = HashSet::new();
     let mut has_cycle: Vec<bool> = vec![false; cfg.blocks.len()];
     let mut degrees: Vec<i32> = vec![0; cfg.blocks.len()];
     let mut dag: Vec<Vec<usize>> = Vec::new();
+    let mut reverse_dag: Vec<Vec<usize>> = Vec::new();
     dag.resize(cfg.blocks.len(), vec![]);
+    reverse_dag.resize(cfg.blocks.len(), vec![]);
 
     cfg_dfs(
         0,
@@ -185,6 +231,7 @@ fn find_visiting_order(cfg: &ControlFlowGraph) -> (Vec<(usize, bool)>, Vec<Vec<u
         &mut degrees,
         &mut has_cycle,
         &mut dag,
+        &mut reverse_dag,
     );
 
     let mut queue: VecDeque<usize> = VecDeque::new();
@@ -200,7 +247,14 @@ fn find_visiting_order(cfg: &ControlFlowGraph) -> (Vec<(usize, bool)>, Vec<Vec<u
         }
     }
 
-    (order, dag)
+    let mut reverse_visiting_order = order.clone();
+    reverse_visiting_order.reverse();
+    CfgAsDag {
+        visiting_order: order,
+        reverse_visiting_order,
+        dag,
+        reverse_dag,
+    }
 }
 
 /// Run DFS (depth first search) in the CFG to find cycles.
@@ -212,6 +266,7 @@ fn cfg_dfs(
     degrees: &mut Vec<i32>,
     has_cycle: &mut Vec<bool>,
     dag: &mut Vec<Vec<usize>>,
+    reverse_dag: &mut Vec<Vec<usize>>,
 ) -> bool {
     if visited.contains(&block_no) {
         return true;
@@ -227,8 +282,18 @@ fn cfg_dfs(
 
     for edge in block_edges(&cfg.blocks[block_no]) {
         degrees[edge] += 1;
-        if cfg_dfs(edge, cfg, visited, stack, degrees, has_cycle, dag) {
+        if cfg_dfs(
+            edge,
+            cfg,
+            visited,
+            stack,
+            degrees,
+            has_cycle,
+            dag,
+            reverse_dag,
+        ) {
             dag[block_no].push(edge);
+            reverse_dag[edge].push(block_no);
         }
     }
 

+ 422 - 30
src/codegen/subexpression_elimination/tests.rs

@@ -2,7 +2,8 @@
 
 #![cfg(test)]
 
-use crate::codegen::cfg::Instr;
+use crate::codegen::cfg::{ASTFunction, BasicBlock, ControlFlowGraph, Instr};
+use crate::codegen::subexpression_elimination::anticipated_expressions::AnticipatedExpressions;
 use crate::codegen::subexpression_elimination::common_subexpression_tracker::CommonSubExpressionTracker;
 use crate::codegen::subexpression_elimination::{AvailableExpression, AvailableExpressionSet};
 use crate::codegen::Expression;
@@ -33,7 +34,7 @@ fn add_variable_function_arg() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&var).is_some());
     assert!(set.find_expression(&arg).is_some());
@@ -59,7 +60,7 @@ fn add_constants() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&var).is_some());
     assert!(set.find_expression(&num).is_some());
@@ -95,7 +96,7 @@ fn add_commutative() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&expr).is_some());
     assert!(set.find_expression(&expr_other).is_some());
@@ -130,7 +131,7 @@ fn non_commutative() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&sub).is_some());
     assert!(set.find_expression(&num).is_some());
@@ -159,7 +160,7 @@ fn unary_operation() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&cast).is_some());
     assert!(set.find_expression(&exp).is_some());
@@ -189,7 +190,7 @@ fn not_tracked() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&minus).is_none());
     assert!(set.find_expression(&exp).is_none());
@@ -214,7 +215,7 @@ fn invalid() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&arg).is_none());
     assert!(set.find_expression(&exp).is_none());
@@ -289,7 +290,7 @@ fn complex_expression() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&pot).is_some());
     assert!(set.find_expression(&unary).is_some());
@@ -349,7 +350,7 @@ fn string() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
 
     assert!(set.find_expression(&concat).is_some());
     assert!(set.find_expression(&compare).is_some());
@@ -429,7 +430,7 @@ fn kill() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
     set.kill(2);
 
     // Available expressions
@@ -518,8 +519,8 @@ fn clone() {
     let mut set = AvailableExpressionSet::default();
     let mut cst = CommonSubExpressionTracker::default();
 
-    set.process_instruction(&instr, &mut ave, &mut cst);
-    let set_2 = set.clone_for_parent_block(1);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
+    let set_2 = set.clone_for_parent_block();
 
     // Available expressions
     assert!(set_2.find_expression(&unary).is_some());
@@ -608,21 +609,6 @@ fn intersect() {
         value: vec![var2.clone(), var3.clone()],
     };
 
-    let mut ave = AvailableExpression::default();
-    let mut set = AvailableExpressionSet::default();
-    let mut cst = CommonSubExpressionTracker::default();
-    let cfg_dag = vec![vec![1, 2], vec![], vec![1]];
-    cst.set_dag(cfg_dag);
-
-    ave.set_cur_block(0);
-    cst.set_cur_block(0);
-    set.process_instruction(&instr, &mut ave, &mut cst);
-    set.process_instruction(&instr2, &mut ave, &mut cst);
-    let mut set_2 = set.clone_for_parent_block(0);
-    cst.set_cur_block(2);
-    ave.set_cur_block(2);
-    set.kill(1);
-
     let sum2 = Expression::Add(
         Loc::Codegen,
         Type::Int(8),
@@ -645,10 +631,38 @@ fn intersect() {
         value: Box::new(sub2.clone()),
     };
 
-    set.process_instruction(&instr3, &mut ave, &mut cst);
+    let mut ave = AvailableExpression::default();
+    let mut set = AvailableExpressionSet::default();
+    let mut cst = CommonSubExpressionTracker::default();
+    let cfg_dag = vec![vec![1, 2], vec![], vec![1]];
+    let reverse_dag = vec![vec![], vec![0, 2], vec![0]];
+    let traversing_order = vec![(1, false), (2, false), (0, false)];
+    let mut anticipated = AnticipatedExpressions::new(&cfg_dag, reverse_dag, traversing_order);
+
+    let instructions = vec![
+        vec![instr.clone(), instr2.clone()],
+        vec![instr3.clone()],
+        vec![instr3.clone()],
+    ];
+
+    let mut cfg = ControlFlowGraph::new("name".to_string(), ASTFunction::None);
+    cfg.blocks = vec![BasicBlock::default(); 3];
+    anticipated.calculate_anticipated_expressions(&instructions, &cfg);
+    cst.set_anticipated(anticipated);
+
+    ave.set_cur_block(0);
+    cst.set_cur_block(0);
+    set.process_instruction(&instr, &mut ave, &mut Some(&mut cst));
+    set.process_instruction(&instr2, &mut ave, &mut Some(&mut cst));
+    let mut set_2 = set.clone().clone_for_parent_block();
+    cst.set_cur_block(2);
+    ave.set_cur_block(2);
+    set.kill(1);
+
+    set.process_instruction(&instr3, &mut ave, &mut Some(&mut cst));
     cst.set_cur_block(1);
     ave.set_cur_block(1);
-    set_2.process_instruction(&instr3, &mut ave, &mut cst);
+    set_2.process_instruction(&instr3, &mut ave, &mut Some(&mut cst));
 
     set_2.intersect_sets(&set, &cst);
 
@@ -674,3 +688,381 @@ fn intersect() {
     // Child of expression created on both sets should not be available
     assert!(set_2.find_expression(&sub2).is_none());
 }
+
+#[test]
+fn test_flow() {
+    // Case 1:
+    let dag = vec![
+        vec![1, 2], // 0 -> 1, 2
+        vec![3, 4], // 1 -> 3, 4
+        vec![3, 4], // 2 -> 3, 4
+        vec![],     // 3
+        vec![],     // 4
+    ];
+    let reverse_dag = vec![
+        vec![],
+        vec![0],    // 1 -> 0
+        vec![0],    // 2 -> 0
+        vec![1, 2], // 3 -> 1, 2
+        vec![1, 2], // 4 -> 1, 2
+    ];
+    let mut traversing_order = vec![(0, false), (1, false), (2, false), (3, false), (4, false)];
+    traversing_order.reverse();
+    let anticipated = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let flow = anticipated.calculate_flow(3, 4);
+    assert_eq!(flow, vec![2000.0, 1000.0, 1000.0, 1000.0, 1000.0]);
+
+    // Case 2:
+    let dag = vec![
+        vec![1, 2, 4], // 0 -> 1, 2, 4
+        vec![2, 3],    // 1 -> 2, 3
+        vec![4],       // 2 -> 4
+        vec![],        // 3
+        vec![],        // 4
+    ];
+    let reverse_dag = vec![
+        vec![],     // 0
+        vec![0],    // 1 -> 0
+        vec![1, 0], // 2 -> 1, 0
+        vec![1],    // 3 -> 1
+        vec![0, 2], // 4 -> 0, 2
+    ];
+    let mut traversing_order = vec![(0, false), (1, false), (2, false), (3, false), (4, false)];
+    traversing_order.reverse();
+    let anticipated_expressions = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let flow = anticipated_expressions.calculate_flow(3, 4);
+    assert_eq!(flow, vec![2000.0, 1250.0, 500.0, 1000.0, 1000.0]);
+
+    // Case 3
+    let dag = vec![
+        vec![1, 4], // 0 -> 1, 4
+        vec![2, 3], // 1 -> 2, 3
+        vec![],     // 2
+        vec![5],    // 3 -> 5
+        vec![5],    // 4 -> 5
+        vec![],     // 5
+    ];
+    let reverse_dag = vec![
+        vec![],     // 0
+        vec![0],    // 1 -> 0
+        vec![1],    // 2 -> 1
+        vec![1],    // 3 -> 1
+        vec![0],    // 4 -> 0
+        vec![3, 4], // 5 -> 3, 4
+    ];
+    let mut traversing_order = vec![
+        (0, false),
+        (1, false),
+        (4, false),
+        (2, false),
+        (3, false),
+        (5, false),
+    ];
+    traversing_order.reverse();
+    let anticipated_expressions = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let flow = anticipated_expressions.calculate_flow(4, 5);
+    assert_eq!(flow, vec![2000.0, 500.0, 0.0, 500.0, 1500.0, 1000.0]);
+
+    // Case 4
+    let dag = vec![
+        vec![1, 6],    // 0 -> 1, 6
+        vec![2, 4],    // 1 -> 2, 4
+        vec![3, 4],    // 2 -> 3, 4
+        vec![],        // 3
+        vec![5],       // 4 -> 5
+        vec![],        // 5
+        vec![4, 7, 8], // 6 -> 4, 7, 8
+        vec![5],       // 7 -> 5
+        vec![],        // 8
+    ];
+    let reverse_dag = vec![
+        vec![],        // 0
+        vec![0],       // 1 -> 0
+        vec![1],       // 2 -> 1
+        vec![2],       // 3 -> 2
+        vec![2, 1, 6], // 4 -> 2, 1, 6
+        vec![4],       // 5 -> 4
+        vec![0],       // 6 -> 0
+        vec![6],       // 7 -> 6
+        vec![6],       // 8 -> 6
+    ];
+    let mut traversing_order = vec![
+        (0, false),
+        (1, false),
+        (6, false),
+        (8, false),
+        (7, false),
+        (2, false),
+        (3, false),
+        (4, false),
+        (5, false),
+    ];
+    traversing_order.reverse();
+    let anticipated_expressions = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let flow = anticipated_expressions.calculate_flow(5, 8);
+    for (item_no, flow_mag) in flow.iter().enumerate() {
+        if item_no == 0 {
+            assert!((*flow_mag - 2000.0).abs() < 0.000001);
+        } else {
+            assert!((*flow_mag - 2000.0).abs() > 0.000001);
+        }
+    }
+
+    let flow = anticipated_expressions.calculate_flow(3, 5);
+    for (item_no, flow_mag) in flow.iter().enumerate() {
+        if item_no == 0 {
+            assert!((*flow_mag - 2000.0).abs() < 0.000001);
+        } else {
+            assert!((*flow_mag - 2000.0).abs() > 0.000001);
+        }
+    }
+
+    // Case 5
+    let dag = vec![
+        vec![1, 3],    // 0 -> 1, 3
+        vec![2, 4],    // 1 -> 2, 4
+        vec![7, 8, 6], // 2 -> 7, 8, 6
+        vec![2, 6],    // 3 -> 2, 6
+        vec![7, 5],    // 4 -> 7, 5
+        vec![],        // 5
+        vec![],        // 6
+        vec![],        // 7
+        vec![],        // 8
+    ];
+    let reverse_dag = vec![
+        vec![],     // 0
+        vec![0],    // 1 -> 0
+        vec![1, 3], // 2 -> 1, 3
+        vec![0],    // 3 -> 0
+        vec![1],    // 4 -> 1
+        vec![4],    // 5 -> 4
+        vec![2, 3], // 6 -> 2, 3
+        vec![4, 2], // 7 -> 2, 4
+        vec![2],    // 8 -> 2
+    ];
+    let mut traversing_order = vec![
+        (0, false),
+        (1, false),
+        (3, false),
+        (4, false),
+        (2, false),
+        (5, false),
+        (6, false),
+        (7, false),
+        (8, false),
+    ];
+    traversing_order.reverse();
+    let anticipated_expressions = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let flow = anticipated_expressions.calculate_flow(8, 6);
+    for (item_no, flow_mag) in flow.iter().enumerate() {
+        if item_no == 0 {
+            assert!((*flow_mag - 2000.0).abs() < 0.000001);
+        } else {
+            assert!((*flow_mag - 2000.0).abs() > 0.000001);
+        }
+    }
+    let flow = anticipated_expressions.calculate_flow(8, 3);
+    for (item_no, flow_mag) in flow.iter().enumerate() {
+        if item_no == 0 {
+            assert!((*flow_mag - 2000.0).abs() < 0.000001);
+        } else {
+            assert!((*flow_mag - 2000.0).abs() > 0.000001);
+        }
+    }
+}
+
+#[test]
+fn unite_expressions() {
+    let var = Expression::Variable(Loc::Codegen, Type::Int(8), 1);
+    let cte = Expression::NumberLiteral(Loc::Codegen, Type::Int(8), BigInt::from(3));
+    let arg = Expression::FunctionArg(Loc::Codegen, Type::Int(9), 5);
+
+    let sum = Expression::Add(
+        Loc::Codegen,
+        Type::Int(8),
+        false,
+        Box::new(var.clone()),
+        Box::new(cte.clone()),
+    );
+    let sub = Expression::Subtract(
+        Loc::Codegen,
+        Type::Int(3),
+        false,
+        Box::new(cte.clone()),
+        Box::new(arg.clone()),
+    );
+    let div = Expression::SignedDivide(Loc::Codegen, Type::Int(8), Box::new(sum), Box::new(sub));
+    let mul = Expression::Multiply(
+        Loc::Codegen,
+        Type::Int(8),
+        false,
+        Box::new(var),
+        Box::new(cte.clone()),
+    );
+
+    let shift = Expression::ShiftRight(
+        Loc::Codegen,
+        Type::Int(2),
+        Box::new(mul),
+        Box::new(div),
+        true,
+    );
+    let modu = Expression::SignedModulo(Loc::Codegen, Type::Int(8), Box::new(cte), Box::new(arg));
+
+    let mut set_1 = AvailableExpressionSet::default();
+    let mut ave = AvailableExpression::default();
+    let _ = set_1.gen_expression_aux(&shift, &mut ave, &mut None);
+
+    let mut set_2 = AvailableExpressionSet::default();
+    let _ = set_2.gen_expression_aux(&modu, &mut ave, &mut None);
+
+    assert!(set_1.find_expression(&shift).is_some());
+    assert!(set_1.find_expression(&modu).is_none());
+    assert!(set_2.find_expression(&shift).is_none());
+    assert!(set_2.find_expression(&modu).is_some());
+
+    set_1.union_sets(&set_2);
+
+    let id_1 = set_1.find_expression(&shift);
+    assert!(id_1.is_some());
+    let id_2 = set_1.find_expression(&modu);
+    assert!(id_2.is_some());
+}
+
+#[test]
+fn ancestor_found() {
+    let dag = vec![
+        vec![1],    // 0 -> 1
+        vec![2, 3], // 1 -> 2, 3
+        vec![],     // 2
+        vec![],     // 3
+    ];
+    let reverse_dag = vec![
+        vec![],  // 0
+        vec![0], // 1 -> 0
+        vec![1], // 2 -> 1
+        vec![1], // 3 -> 1
+    ];
+    let mut traversing_order = vec![(0, false), (1, false), (2, false), (3, false)];
+    traversing_order.reverse();
+    let mut anticipated = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let var1 = Expression::Variable(Loc::Implicit, Type::Uint(32), 0);
+    let var2 = Expression::Variable(Loc::Implicit, Type::Uint(32), 1);
+    let addition = Expression::Add(
+        Loc::Implicit,
+        Type::Uint(32),
+        false,
+        Box::new(var1.clone()),
+        Box::new(var2.clone()),
+    );
+    let instr = vec![
+        vec![
+            Instr::Set {
+                res: 0,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Uint(32), 9.into()),
+            },
+            Instr::Set {
+                res: 1,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Uint(32), 8.into()),
+            },
+            Instr::Branch { block: 1 },
+        ],
+        vec![Instr::BranchCond {
+            cond: Expression::LessEqual(Loc::Implicit, Box::new(var1), Box::new(var2)),
+            true_block: 2,
+            false_block: 3,
+        }],
+        vec![Instr::Set {
+            loc: Loc::Implicit,
+            res: 3,
+            expr: addition.clone(),
+        }],
+        vec![Instr::Set {
+            loc: Loc::Implicit,
+            res: 4,
+            expr: addition.clone(),
+        }],
+    ];
+    let mut cfg = ControlFlowGraph::new("func".to_string(), ASTFunction::None);
+    cfg.blocks = vec![BasicBlock::default(); 4];
+    anticipated.calculate_anticipated_expressions(&instr, &cfg);
+    let ancestor = anticipated.find_ancestor(2, 3, &addition);
+    assert_eq!(ancestor, Some(1));
+}
+
+#[test]
+fn ancestor_not_found() {
+    let dag = vec![
+        vec![1, 2], // 0 -> 1, 2
+        vec![3],    // 1 -> 3
+        vec![],     // 2
+        vec![],     // 3
+    ];
+    let reverse_dag = vec![
+        vec![],  // 0
+        vec![0], // 1 -> 0
+        vec![0], // 2 -> 0
+        vec![1], // 3 -> 1
+    ];
+    let mut traversing_order = vec![(0, false), (0, false), (2, false), (3, false)];
+    traversing_order.reverse();
+    let mut anticipated = AnticipatedExpressions::new(&dag, reverse_dag, traversing_order);
+    let var1 = Expression::Variable(Loc::Implicit, Type::Int(32), 0);
+    let var2 = Expression::Variable(Loc::Implicit, Type::Int(32), 1);
+    let expr = Expression::Multiply(
+        Loc::Implicit,
+        Type::Int(32),
+        false,
+        Box::new(var1.clone()),
+        Box::new(var2.clone()),
+    );
+    let instr = vec![
+        vec![
+            Instr::Set {
+                res: 0,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Int(32), 8.into()),
+            },
+            Instr::Set {
+                res: 1,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Int(32), 7.into()),
+            },
+            Instr::BranchCond {
+                cond: Expression::MoreEqual(Loc::Implicit, Box::new(var1), Box::new(var2)),
+                true_block: 1,
+                false_block: 2,
+            },
+        ],
+        vec![
+            Instr::Set {
+                res: 0,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Int(32), 10.into()),
+            },
+            Instr::Set {
+                res: 1,
+                loc: Loc::Implicit,
+                expr: Expression::NumberLiteral(Loc::Implicit, Type::Int(32), 27.into()),
+            },
+            Instr::Branch { block: 3 },
+        ],
+        vec![Instr::Set {
+            res: 2,
+            loc: Loc::Implicit,
+            expr: expr.clone(),
+        }],
+        vec![Instr::Set {
+            res: 3,
+            loc: Loc::Implicit,
+            expr: expr.clone(),
+        }],
+    ];
+    let mut cfg = ControlFlowGraph::new("func".to_string(), ASTFunction::None);
+    cfg.blocks = vec![BasicBlock::default(); 4];
+    anticipated.calculate_anticipated_expressions(&instr, &cfg);
+    let ancestor = anticipated.find_ancestor(2, 3, &expr);
+    assert!(ancestor.is_none());
+}

+ 2 - 2
src/codegen/undefined_variable.rs

@@ -29,7 +29,7 @@ pub fn find_undefined_variables(
     let mut diagnostics: HashMap<usize, Diagnostic> = HashMap::new();
     for block in &cfg.blocks {
         let mut var_defs: VarDefs = block.defs.clone();
-        for (instr_no, (_, instruction)) in block.instr.iter().enumerate() {
+        for (instr_no, instruction) in block.instr.iter().enumerate() {
             check_variables_in_expression(
                 func_no,
                 instruction,
@@ -93,7 +93,7 @@ pub fn find_undefined_variables_in_expression(
                 for (def, modified) in def_map {
                     if let Instr::Set {
                         expr: instr_expr, ..
-                    } = &ctx.cfg.blocks[def.block_no].instr[def.instr_no].1
+                    } = &ctx.cfg.blocks[def.block_no].instr[def.instr_no]
                     {
                         // If an undefined definition reaches this read and the variable
                         // has not been modified since its definition, it is undefined

+ 4 - 4
src/codegen/vector_to_slice.rs

@@ -37,7 +37,7 @@ fn find_writable_vectors(
     writable: &mut HashSet<Def>,
 ) {
     for instr_no in 0..block.instr.len() {
-        match &block.instr[instr_no].1 {
+        match &block.instr[instr_no] {
             Instr::Set {
                 res,
                 expr: Expression::Variable(_, _, var_no),
@@ -177,7 +177,7 @@ fn update_vectors_to_slice(
             if let Instr::Set {
                 expr: Expression::AllocDynamicBytes(..),
                 ..
-            } = &cfg.blocks[block_no].instr[instr_no].1
+            } = &cfg.blocks[block_no].instr[instr_no]
             {
                 let cur = Def {
                     block_no,
@@ -216,10 +216,10 @@ fn update_vectors_to_slice(
             loc,
             res,
             expr: Expression::AllocDynamicBytes(_, _, len, Some(bs)),
-        } = &cfg.blocks[def.block_no].instr[def.instr_no].1
+        } = &cfg.blocks[def.block_no].instr[def.instr_no]
         {
             let res = *res;
-            cfg.blocks[def.block_no].instr[def.instr_no].1 = Instr::Set {
+            cfg.blocks[def.block_no].instr[def.instr_no] = Instr::Set {
                 loc: *loc,
                 res,
                 expr: Expression::AllocDynamicBytes(

+ 11 - 11
src/codegen/yul/builtin.rs

@@ -163,7 +163,7 @@ pub(crate) fn process_builtin(
 
         YulBuiltInFunction::SelfDestruct => {
             let recipient = expression(&args[0], contract_no, ns, vartab, cfg, opt).cast(&Type::Address(true), ns);
-            cfg.add_yul(vartab, Instr::SelfDestruct { recipient });
+            cfg.add(vartab, Instr::SelfDestruct { recipient });
             Expression::Poison
         }
 
@@ -404,7 +404,7 @@ fn branch_if_zero(
     let then = cfg.new_basic_block("then".to_string());
     let else_ = cfg.new_basic_block("else".to_string());
     let endif = cfg.new_basic_block("endif".to_string());
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::BranchCond {
             cond,
@@ -415,7 +415,7 @@ fn branch_if_zero(
 
     cfg.set_basic_block(then);
     vartab.new_dirty_tracker();
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Set {
             loc: pt::Loc::Codegen,
@@ -423,10 +423,10 @@ fn branch_if_zero(
             expr: Expression::NumberLiteral(pt::Loc::Codegen, Type::Uint(256), BigInt::from(0)),
         },
     );
-    cfg.add_yul(vartab, Instr::Branch { block: endif });
+    cfg.add(vartab, Instr::Branch { block: endif });
 
     cfg.set_basic_block(else_);
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Set {
             loc: pt::Loc::Codegen,
@@ -434,7 +434,7 @@ fn branch_if_zero(
             expr: codegen_expr,
         },
     );
-    cfg.add_yul(vartab, Instr::Branch { block: endif });
+    cfg.add(vartab, Instr::Branch { block: endif });
     cfg.set_phis(endif, vartab.pop_dirty_tracker());
     cfg.set_basic_block(endif);
 
@@ -468,7 +468,7 @@ fn byte_builtin(
     let else_ = cfg.new_basic_block("else".to_string());
     let endif = cfg.new_basic_block("endif".to_string());
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::BranchCond {
             cond,
@@ -479,7 +479,7 @@ fn byte_builtin(
 
     cfg.set_basic_block(then);
     vartab.new_dirty_tracker();
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Set {
             loc: pt::Loc::Codegen,
@@ -487,7 +487,7 @@ fn byte_builtin(
             expr: Expression::NumberLiteral(pt::Loc::Codegen, Type::Uint(256), BigInt::zero()),
         },
     );
-    cfg.add_yul(vartab, Instr::Branch { block: endif });
+    cfg.add(vartab, Instr::Branch { block: endif });
 
     cfg.set_basic_block(else_);
 
@@ -533,7 +533,7 @@ fn byte_builtin(
         )),
     );
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Set {
             loc: *loc,
@@ -541,7 +541,7 @@ fn byte_builtin(
             expr: masked_result,
         },
     );
-    cfg.add_yul(vartab, Instr::Branch { block: endif });
+    cfg.add(vartab, Instr::Branch { block: endif });
 
     cfg.set_phis(endif, vartab.pop_dirty_tracker());
     cfg.set_basic_block(endif);

+ 2 - 2
src/codegen/yul/expression.rs

@@ -214,7 +214,7 @@ pub(crate) fn process_function_call(
     let cfg_no = ns.yul_functions[function_no].cfg_no;
 
     if ns.yul_functions[function_no].returns.is_empty() {
-        cfg.add_yul(
+        cfg.add(
             vartab,
             Instr::Call {
                 res: Vec::new(),
@@ -243,7 +243,7 @@ pub(crate) fn process_function_call(
         returns.push(Expression::Variable(id.loc, ret.ty.clone(), temp_pos));
     }
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Call {
             res,

+ 1 - 1
src/codegen/yul/mod.rs

@@ -113,7 +113,7 @@ fn yul_function_cfg(
     if yul_func.body.is_empty()
         || (!yul_func.body.is_empty() && yul_func.body.last().unwrap().is_reachable())
     {
-        cfg.add_yul(&mut vartab, returns);
+        cfg.add(&mut vartab, returns);
     }
 
     vartab.finalize(ns, &mut cfg);

+ 19 - 19
src/codegen/yul/statements.rs

@@ -109,14 +109,14 @@ pub(crate) fn statement(
 
         YulStatement::Leave(..) => {
             if let Some(early_leave) = early_return {
-                cfg.add_yul(vartab, early_leave.clone());
+                cfg.add(vartab, early_leave.clone());
             } else {
-                cfg.add_yul(vartab, Instr::Return { value: vec![] });
+                cfg.add(vartab, Instr::Return { value: vec![] });
             }
         }
 
         YulStatement::Break(..) => {
-            cfg.add_yul(
+            cfg.add(
                 vartab,
                 Instr::Branch {
                     block: loops.do_break(),
@@ -125,7 +125,7 @@ pub(crate) fn statement(
         }
 
         YulStatement::Continue(..) => {
-            cfg.add_yul(
+            cfg.add(
                 vartab,
                 Instr::Branch {
                     block: loops.do_continue(),
@@ -162,7 +162,7 @@ fn process_variable_declaration(
     };
 
     for (var_index, item) in vars.iter().enumerate() {
-        cfg.add_yul(
+        cfg.add(
             vartab,
             Instr::Set {
                 loc: *loc,
@@ -216,7 +216,7 @@ fn cfg_single_assigment(
         | ast::YulExpression::SolidityLocalVariable(_, ty, None, var_no) => {
             // Ensure both types are compatible
             let rhs = rhs.cast(ty, ns);
-            cfg.add_yul(
+            cfg.add(
                 vartab,
                 Instr::Set {
                     loc: *loc,
@@ -234,7 +234,7 @@ fn cfg_single_assigment(
         ) => {
             // This is an assignment to a pointer, so we make sure the rhs has a compatible size
             let rhs = rhs.cast(ty, ns);
-            cfg.add_yul(
+            cfg.add(
                 vartab,
                 Instr::Set {
                     loc: *loc,
@@ -254,7 +254,7 @@ fn cfg_single_assigment(
                 ) => match suffix {
                     YulSuffix::Offset => {
                         let rhs = rhs.cast(&lhs.ty(), ns);
-                        cfg.add_yul(
+                        cfg.add(
                             vartab,
                             Instr::Set {
                                 loc: *loc,
@@ -294,7 +294,7 @@ fn cfg_single_assigment(
                         member_no,
                     );
 
-                    cfg.add_yul(
+                    cfg.add(
                         vartab,
                         Instr::Store {
                             dest: ptr,
@@ -312,7 +312,7 @@ fn cfg_single_assigment(
                     // This assignment changes the value of a pointer to storage
                     if matches!(suffix, YulSuffix::Slot) {
                         let rhs = rhs.cast(&lhs.ty(), ns);
-                        cfg.add_yul(
+                        cfg.add(
                             vartab,
                             Instr::Set {
                                 loc: *loc,
@@ -371,7 +371,7 @@ fn process_if_block(
     let then = cfg.new_basic_block("then".to_string());
     let endif = cfg.new_basic_block("endif".to_string());
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::BranchCond {
             cond: bool_cond,
@@ -388,7 +388,7 @@ fn process_if_block(
     }
 
     if block.is_next_reachable() {
-        cfg.add_yul(vartab, Instr::Branch { block: endif });
+        cfg.add(vartab, Instr::Branch { block: endif });
     }
 
     cfg.set_phis(endif, vartab.pop_dirty_tracker());
@@ -424,7 +424,7 @@ fn process_for_block(
     let body_block = cfg.new_basic_block("body".to_string());
     let end_block = cfg.new_basic_block("end_for".to_string());
 
-    cfg.add_yul(vartab, Instr::Branch { block: cond_block });
+    cfg.add(vartab, Instr::Branch { block: cond_block });
     cfg.set_basic_block(cond_block);
 
     let cond_expr = expression(condition, contract_no, ns, vartab, cfg, opt);
@@ -443,7 +443,7 @@ fn process_for_block(
         )
     };
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::BranchCond {
             cond: cond_expr,
@@ -461,7 +461,7 @@ fn process_for_block(
     }
 
     if execution_block.is_next_reachable() {
-        cfg.add_yul(vartab, Instr::Branch { block: next_block });
+        cfg.add(vartab, Instr::Branch { block: next_block });
     }
 
     loops.leave_scope();
@@ -473,7 +473,7 @@ fn process_for_block(
     }
 
     if post_block.is_next_reachable() {
-        cfg.add_yul(vartab, Instr::Branch { block: cond_block });
+        cfg.add(vartab, Instr::Branch { block: cond_block });
     }
 
     cfg.set_basic_block(end_block);
@@ -512,7 +512,7 @@ fn switch(
             statement(stmt, contract_no, loops, ns, cfg, vartab, early_return, opt);
         }
         if item.block.is_next_reachable() {
-            cfg.add_yul(vartab, Instr::Branch { block: end_switch });
+            cfg.add(vartab, Instr::Branch { block: end_switch });
         }
         cases_cfg.push((case_cond, case_block));
     }
@@ -524,7 +524,7 @@ fn switch(
             statement(stmt, contract_no, loops, ns, cfg, vartab, early_return, opt);
         }
         if default_block.is_next_reachable() {
-            cfg.add_yul(vartab, Instr::Branch { block: end_switch });
+            cfg.add(vartab, Instr::Branch { block: end_switch });
         }
         new_block
     } else {
@@ -535,7 +535,7 @@ fn switch(
 
     cfg.set_basic_block(current_block);
 
-    cfg.add_yul(
+    cfg.add(
         vartab,
         Instr::Switch {
             cond,

+ 2 - 2
src/emit/cfg.rs

@@ -76,7 +76,7 @@ pub(super) fn emit_cfg<'a, T: TargetRuntime<'a> + ?Sized>(
                     di_flags,
                 );
 
-                let func_loc = cfg.blocks[0].instr.first().unwrap().1.loc();
+                let func_loc = cfg.blocks[0].instr.first().unwrap().loc();
                 let line_num = if let pt::Loc::File(file_offset, offset, _) = func_loc {
                     let (line, _) = ns.files[file_offset].offset_to_line_column(offset);
                     line
@@ -180,7 +180,7 @@ pub(super) fn emit_cfg<'a, T: TargetRuntime<'a> + ?Sized>(
             w.vars.get_mut(v).unwrap().value = (*phi).as_basic_value();
         }
 
-        for (_, ins) in &cfg.blocks[w.block_no].instr {
+        for ins in &cfg.blocks[w.block_no].instr {
             if bin.options.generate_debug_information {
                 let debug_loc = ins.loc();
                 if let pt::Loc::File(file_offset, offset, _) = debug_loc {

+ 148 - 0
tests/codegen_testcases/solidity/anticipated_expressions.sol

@@ -0,0 +1,148 @@
+// RUN: --target solana --emit cfg
+
+contract Test {
+
+    // BEGIN-CHECK: Test::Test::function::test1__int256_int256
+    function test1(int a, int b) pure public returns (int) {
+        // CHECK: ty:int256 %1.cse_temp = ((arg #0) + (arg #1))
+	    // CHECK: ty:int256 %x = %1.cse_temp
+	    // CHECK: ty:int256 %3.cse_temp = ((arg #0) - (arg #1))
+	    // CHECK: ty:int256 %z = %3.cse_temp
+        int x = a + b;
+        int z = a-b;
+        int p=0;
+
+        // CHECK: block1: # cond
+        // CHECK: ty:int256 %2.cse_temp = (signed modulo (arg #0) % (arg #1))
+        while(x != 0) {
+        // CHECK: block2: # body
+        // CHECK: ty:int256 %z = (%z + int256 9)
+	    // CHECK: ty:int256 %y = %1.cse_temp
+	    // CHECK: ty:int256 %x = (%x - %1.cse_temp)
+            z+=9;
+            int y = a + b;
+            x -= y;
+        // CHECK: block3: # endwhile
+        // CHECK: ty:int256 %p2 = %2.cse_temp
+	    // CHECK: return ((((%x + int256 9) - %z) + %p) - (int256 2 * %2.cse_temp))
+
+            if (x == 9) {
+                // CHECK: block4: # then
+                // CHECK: ty:int256 %y = %3.cse_temp
+	            // CHECK: ty:int256 %p = %2.cse_temp
+	            // CHECK: ty:int256 %x = (%x + %3.cse_temp)
+                y = a - b;
+                p = a % b;
+                x += y;
+            }
+        }
+
+        int p2 = a%b;
+        return x+9-z + p - 2*p2;
+    }
+
+    // BEGIN-CHECK: Test::Test::function::test2__int256_int256
+    function test2(int a, int b) public pure returns (int) {
+        int y = a-b;
+        int j=0;
+        int k=0;
+        int l=0;
+        int m=0;
+        // CHECK: ty:int256 %1.cse_temp = ((arg #0) + (arg #1))
+        if(y == 5) {
+            // CHECK: block1: # then
+            // CHECK: ty:int256 %j = %1.cse_temp
+            j = a+b;
+            // CHECK: block3: # endif
+            // CHECK: ty:int256 %n = %1.cse_temp
+            // CHECK: return ((((%j + %k) + %l) + %m) + %1.cse_temp)
+        } else if (y == 2) {
+            // CHECK: block4: # then
+            // CHECK: ty:int256 %k = %1.cse_temp
+            k = a+b;
+        } else if (y == 3) {
+            // CHECK: block7: # then
+            // CHECK: ty:int256 %l = %1.cse_temp
+            l = a+b;
+        } else {
+            // CHECK: block8: # else
+            // CHECK: ty:int256 %m = %1.cse_temp
+            m = a+b;
+        }
+
+        int n = a+b;
+        return j+k+l+m+n;
+    }
+
+    // BEGIN-CHECK: Test::Test::function::test3__int256_int256
+    function test3(int a, int b) public pure returns (int) {
+        int y = a-b;
+        int j=0;
+        int k=0;
+        int l=0;
+        int m=0;
+        // NOT-CHECK: ty:int256 %1.cse_temp
+        if(y == 5) {
+            // CHECK: block1: # then
+            // CHECK: ty:int256 %j = ((arg #0) + (arg #1))
+            j = a+b;
+            // CHECK: block3: # endif
+            // CHECK: ty:int256 %n = (%a + (arg #1))
+	        // CHECK: return ((((%j + %k) + %l) + %m) + %n)
+        } else if (y == 2) {
+            // CHECK: block4: # then
+	        // CHECK: ty:int256 %a = int256 9
+	        // CHECK: ty:int256 %k = (int256 9 + (arg #1))
+            a = 9;
+            k = a+b;
+        } else if (y == 3) {
+            // CHECK: block7: # then
+	        // CHECK: ty:int256 %l = ((arg #0) + (arg #1))
+            l = a+b;
+        } else {
+            // CHECK: block8: # else
+	        // CHECK: ty:int256 %m = ((arg #0) + (arg #1))
+            m = a+b;
+        }
+
+        int n = a+b;
+        return j+k+l+m+n;
+    }
+
+    // BEGIN-CHECK: Test::Test::function::test4__int256_int256
+    function test4(int a, int b) public pure returns (int) {
+        int y = a-b;
+        int j=0;
+        int k=0;
+        int l=0;
+        int m=0;
+        // CHECK: ty:int256 %1.cse_temp = (unchecked (arg #0) * (arg #1))
+        // CHECK: block1: # end_switch
+	    // CHECK: ty:int256 %m = %1.cse_temp
+	    // CHECK: return (%l + %1.cse_temp)
+        assembly {
+            switch y 
+                case 1 {
+                    // CHECK: block2: # case_0
+	                // CHECK: ty:int256 %j = %1.cse_temp
+                    j := mul(a, b)
+                }
+                case 2 {
+                    // CHECK: block3: # case_1
+	                // CHECK: ty:int256 %k = %1.cse_temp
+                    k := mul(a, b)
+                }
+                default {
+                    // CHECK: block4: # default
+	                // CHECK: ty:int256 %l = %1.cse_temp
+                    l := mul(a, b)
+                }
+        }
+
+        unchecked {
+            m = a*b;
+        }
+
+        return l+m;
+    }
+}

+ 2 - 2
tests/codegen_testcases/solidity/common_subexpression_elimination.sol

@@ -226,7 +226,6 @@ contract c1 {
         // CHECK:  ty:int256 %1.cse_temp = ((arg #0) + (arg #1))
         int x = a + b + instance.a;
         // CHECK: ty:int256 %x = (%1.cse_temp + (load (struct %instance field 0)))
-        // CHECK: ty:int256 %2.cse_temp = ((arg #0) * (arg #1))
         // CHECK: branchcond (signed less (%x + int256((load (struct %instance field 1)))) < int256 0)
         if(x  + int(instance.b) < 0) {
             // CHECK: ty:uint256 %p = uint256((%1.cse_temp + (load (struct %instance field 0))))
@@ -244,6 +243,7 @@ contract c1 {
         // CHECK: branchcond %e3, block3, block4
         if (trunc2 < trunc && trunc > 2) {
             // CHECK: = %e2
+            // CHECK: ty:int256 %2.cse_temp = ((arg #0) * (arg #1))
             // CHECK: ty:int256 %p2 = %1.cse_temp
             int p2 = a+b;
             int p3 = p2 - x + a + b;
@@ -457,7 +457,6 @@ contract c1 {
             return (a << b) + 1;
         }
 
-        // CHECK: ty:uint256 %3.cse_temp = ((arg #0) & (arg #1))
         // CHECK: branchcond %2.cse_temp, block4, block3
         if(!b1 || c > 0) {
             // CHECK: = %b1
@@ -470,6 +469,7 @@ contract c1 {
             c++;
         }
 
+        // CHECK: ty:uint256 %3.cse_temp = ((arg #0) & (arg #1))
         // CHECK: branchcond (%3.cse_temp == uint256 0), block13, block14
         if (a & b == 0) {
             return c--;

+ 0 - 1
tests/codegen_testcases/yul/common_subexpression_elimination.sol

@@ -17,7 +17,6 @@ contract testing  {
 
             // CHECK: block1: # cond
             // CHECK: ty:uint256 %1.cse_temp = (zext uint256 (arg #0))
-            // CHECK: branch block1
             for {let i := 0} lt(i, 10) {i := add(i, 1)} {
                 // CHECK: block3: # body
                 // CHECK: branchcond (%1.cse_temp == uint256 259), block5, block6

+ 42 - 0
tests/codegen_testcases/yul/cse_switch.sol

@@ -0,0 +1,42 @@
+// RUN: --target solana --emit cfg
+
+contract foo {
+    // BEGIN-CHECK: foo::foo::function::test
+    function test() public {
+        uint256 yy=0;
+        assembly {
+        // Ensure the CSE temp is not before the switch
+        // CHECK: ty:uint256 %x = uint256 54
+        // CHECK: ty:uint256 %y = uint256 5
+	    // CHECK: switch uint256 2:
+            let x := 54
+            let y := 5
+
+            switch and(x, 3)
+                case 0 {
+                    y := 5
+                    x := 5
+                }
+                case 1 {
+                    y := 7
+                    x := 9
+                }
+                case 3 {
+                    y := 10
+                    x := 80
+                }
+        
+            // CHECK: block1: # end_switch
+	        // CHECK: ty:uint256 %1.cse_temp = (unchecked %x + %y)
+	        // CHECK: branchcond (%1.cse_temp == uint256 90), block5, block6
+            if eq(add(x, y), 90) {
+                yy := 9
+            }
+
+            // CHECK: branchcond (%1.cse_temp == uint256 80), block7, block8
+            if eq(add(x, y), 80) {
+                yy := 90
+            }
+        }
+    }
+}