Răsfoiți Sursa

Create switch instruction (#1023)

* Remove duplicate code from instruction.rs

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>

* Create switch instruction

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>

* Include additional test cases

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 3 ani în urmă
părinte
comite
af99c224c0

+ 37 - 0
src/codegen/cfg.rs

@@ -169,6 +169,11 @@ pub enum Instr {
         destination: Expression,
         bytes: Expression,
     },
+    Switch {
+        cond: Expression,
+        cases: Vec<(Expression, usize)>,
+        default: usize,
+    },
     /// Do nothing
     Nop,
 }
@@ -301,6 +306,13 @@ impl Instr {
                 bytes.recurse(cx, f);
             }
 
+            Instr::Switch { cond, cases, .. } => {
+                cond.recurse(cx, f);
+                for (case, _) in cases {
+                    case.recurse(cx, f);
+                }
+            }
+
             Instr::AssertFailure { expr: None }
             | Instr::Unreachable
             | Instr::Nop
@@ -474,6 +486,11 @@ impl ControlFlowGraph {
         self.blocks[self.current].add(InstrOrigin::Yul, ins);
     }
 
+    /// Retrieve the basic block being processed
+    pub fn current_block(&self) -> usize {
+        self.current
+    }
+
     /// Function to modify array length temp by inserting an add/sub instruction in the cfg right after a push/pop instruction.
     /// The operands of the add/sub instruction are the temp variable, and +/- 1.
     pub fn modify_temp_array_length(
@@ -1159,6 +1176,26 @@ impl ControlFlowGraph {
                     self.expr_to_string(contract, ns, bytes)
                 )
             }
+            Instr::Switch {
+                cond,
+                cases,
+                default,
+            } => {
+                let mut description =
+                    format!("switch {}:", self.expr_to_string(contract, ns, cond),);
+                for item in cases {
+                    description.push_str(
+                        format!(
+                            "\n\t\tcase {}: goto block #{}",
+                            self.expr_to_string(contract, ns, &item.0),
+                            item.1
+                        )
+                        .as_str(),
+                    );
+                }
+                description.push_str(format!("\n\t\tdefault: goto block #{}", default).as_str());
+                description
+            }
         }
     }
 

+ 16 - 0
src/codegen/constant_folding.rs

@@ -315,6 +315,22 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         bytes: bytes.0,
                     };
                 }
+                Instr::Switch {
+                    cond,
+                    cases,
+                    default,
+                } => {
+                    let cond = expression(cond, Some(&vars), cfg, ns);
+                    let cases = cases
+                        .iter()
+                        .map(|(exp, goto)| (expression(exp, Some(&vars), cfg, ns).0, *goto))
+                        .collect::<Vec<(Expression, usize)>>();
+                    cfg.blocks[block_no].instr[instr_no].1 = Instr::Switch {
+                        cond: cond.0,
+                        cases,
+                        default: *default,
+                    };
+                }
                 _ => (),
             }
 

+ 2 - 32
src/codegen/dead_storage.rs

@@ -1,6 +1,7 @@
 // SPDX-License-Identifier: Apache-2.0
 
 use super::cfg::{BasicBlock, ControlFlowGraph, Instr};
+use crate::codegen::reaching_definitions::block_edges;
 use crate::codegen::Expression;
 use crate::sema::ast::{Namespace, RetrieveType, Type};
 use solang_parser::pt::Loc;
@@ -332,7 +333,7 @@ fn apply_transfers(
 
     debug_assert_eq!(transfers.len(), cfg.blocks[block_no].instr.len());
 
-    // this is done in two paseses. The first pass just deals with variables.
+    // this is done in two passes. The first pass just deals with variables.
     // The second pass deals with storage stores
 
     // for each instruction
@@ -472,37 +473,6 @@ fn apply_transfers(
     block_vars.insert(block_no, res);
 }
 
-fn block_edges(block: &BasicBlock) -> Vec<usize> {
-    let mut out = Vec::new();
-
-    // out cfg has edge as the last instruction in a block; EXCEPT
-    // Instr::AbiDecode() which has an edge when decoding fails
-    for (_, instr) in &block.instr {
-        match instr {
-            Instr::Branch { block } => {
-                out.push(*block);
-            }
-            Instr::BranchCond {
-                true_block,
-                false_block,
-                ..
-            } => {
-                out.push(*true_block);
-                out.push(*false_block);
-            }
-            Instr::AbiDecode {
-                exception_block: Some(block),
-                ..
-            } => {
-                out.push(*block);
-            }
-            _ => (),
-        }
-    }
-
-    out
-}
-
 /// Eliminate dead storage load/store.
 pub fn dead_storage(cfg: &mut ControlFlowGraph, _ns: &mut Namespace) {
     // first calculate reaching definitions. We use a special case reaching definitions, which we track

+ 8 - 1
src/codegen/reaching_definitions.rs

@@ -223,7 +223,8 @@ pub fn apply_transfers(transfers: &[Transfer], vars: &mut IndexMap<usize, IndexM
     }
 }
 
-pub fn block_edges(block: &BasicBlock) -> Vec<usize> {
+/// Fetch the blocks that can be executed after the block passed as argument
+pub(super) fn block_edges(block: &BasicBlock) -> Vec<usize> {
     let mut out = Vec::new();
 
     // out cfg has edge as the last instruction in a block; EXCEPT
@@ -247,6 +248,12 @@ pub fn block_edges(block: &BasicBlock) -> Vec<usize> {
             } => {
                 out.push(*block);
             }
+            Instr::Switch { default, cases, .. } => {
+                out.push(*default);
+                for (_, goto) in cases {
+                    out.push(*goto);
+                }
+            }
             _ => (),
         }
     }

+ 20 - 0
src/codegen/subexpression_elimination/instruction.rs

@@ -164,6 +164,13 @@ impl AvailableExpressionSet {
                 let _ = self.gen_expression(bytes, ave, cst);
             }
 
+            Instr::Switch { cond, cases, .. } => {
+                let _ = self.gen_expression(cond, ave, cst);
+                for (case, _) in cases {
+                    let _ = self.gen_expression(case, ave, cst);
+                }
+            }
+
             Instr::AssertFailure { expr: None }
             | Instr::Unreachable
             | Instr::Nop
@@ -425,6 +432,19 @@ impl AvailableExpressionSet {
                 bytes: self.regenerate_expression(bytes, ave, cst).1,
             },
 
+            Instr::Switch {
+                cond,
+                cases,
+                default,
+            } => Instr::Switch {
+                cond: self.regenerate_expression(cond, ave, cst).1,
+                cases: cases
+                    .iter()
+                    .map(|(case, goto)| (self.regenerate_expression(case, ave, cst).1, *goto))
+                    .collect::<Vec<(Expression, usize)>>(),
+                default: *default,
+            },
+
             Instr::WriteBuffer { buf, offset, value } => Instr::WriteBuffer {
                 buf: self.regenerate_expression(buf, ave, cst).1,
                 offset: self.regenerate_expression(offset, ave, cst).1,

+ 1 - 0
src/codegen/vector_to_slice.rs

@@ -101,6 +101,7 @@ fn find_writable_vectors(
             | Instr::Nop
             | Instr::Branch { .. }
             | Instr::BranchCond { .. }
+            | Instr::Switch { .. }
             | Instr::PopMemory { .. }
             | Instr::LoadStorage { .. }
             | Instr::SetStorage { .. }

+ 82 - 5
src/codegen/yul/statements.rs

@@ -8,7 +8,7 @@ use crate::codegen::yul::expression::{expression, process_function_call};
 use crate::codegen::{Expression, Options};
 use crate::sema::ast::{Namespace, RetrieveType, Type};
 use crate::sema::yul::ast;
-use crate::sema::yul::ast::{YulStatement, YulSuffix};
+use crate::sema::yul::ast::{CaseBlock, YulBlock, YulExpression, YulStatement, YulSuffix};
 use num_bigint::BigInt;
 use num_traits::FromPrimitive;
 use solang_parser::pt;
@@ -67,10 +67,23 @@ pub(crate) fn statement(
             opt,
         ),
 
-        YulStatement::Switch { .. } => {
-            // Switch statements should use LLVM switch instruction, which requires changes in emit.
-            unreachable!("Switch statements for yul are not implemented yet");
-        }
+        YulStatement::Switch {
+            condition,
+            cases,
+            default,
+            ..
+        } => switch(
+            condition,
+            cases,
+            default,
+            loops,
+            contract_no,
+            ns,
+            vartab,
+            cfg,
+            early_return,
+            opt,
+        ),
 
         YulStatement::For {
             loc,
@@ -465,3 +478,67 @@ fn process_for_block(
     cfg.set_phis(end_block, set.clone());
     cfg.set_phis(cond_block, set);
 }
+
+/// Generate CFG code for a switch statement
+fn switch(
+    condition: &YulExpression,
+    cases: &[CaseBlock],
+    default: &Option<YulBlock>,
+    loops: &mut LoopScopes,
+    contract_no: usize,
+    ns: &Namespace,
+    vartab: &mut Vartable,
+    cfg: &mut ControlFlowGraph,
+    early_return: &Option<Instr>,
+    opt: &Options,
+) {
+    let cond = expression(condition, contract_no, ns, vartab, cfg, opt);
+    let end_switch = cfg.new_basic_block("end_switch".to_string());
+
+    let current_block = cfg.current_block();
+
+    vartab.new_dirty_tracker();
+    let mut cases_cfg: Vec<(Expression, usize)> = Vec::with_capacity(cases.len());
+    for (item_no, item) in cases.iter().enumerate() {
+        let case_cond =
+            expression(&item.condition, contract_no, ns, vartab, cfg, opt).cast(&cond.ty(), ns);
+        let case_block = cfg.new_basic_block(format!("case_{}", item_no));
+        cfg.set_basic_block(case_block);
+        for stmt in &item.block.body {
+            statement(stmt, contract_no, loops, ns, cfg, vartab, early_return, opt);
+        }
+        if item.block.is_next_reachable() {
+            cfg.add_yul(vartab, Instr::Branch { block: end_switch });
+        }
+        cases_cfg.push((case_cond, case_block));
+    }
+
+    let default_block = if let Some(default_block) = default {
+        let new_block = cfg.new_basic_block("default".to_string());
+        cfg.set_basic_block(new_block);
+        for stmt in &default_block.body {
+            statement(stmt, contract_no, loops, ns, cfg, vartab, early_return, opt);
+        }
+        if default_block.is_next_reachable() {
+            cfg.add_yul(vartab, Instr::Branch { block: end_switch });
+        }
+        new_block
+    } else {
+        end_switch
+    };
+
+    cfg.set_phis(end_switch, vartab.pop_dirty_tracker());
+
+    cfg.set_basic_block(current_block);
+
+    cfg.add_yul(
+        vartab,
+        Instr::Switch {
+            cond,
+            cases: cases_cfg,
+            default: default_block,
+        },
+    );
+
+    cfg.set_basic_block(end_switch);
+}

+ 62 - 49
src/emit/instructions.rs

@@ -9,7 +9,9 @@ use crate::emit::{ReturnCode, TargetRuntime};
 use crate::sema::ast::{Contract, Namespace, RetrieveType, Type};
 use crate::Target;
 use inkwell::types::BasicType;
-use inkwell::values::{BasicMetadataValueEnum, BasicValueEnum, CallableValue, FunctionValue};
+use inkwell::values::{
+    BasicMetadataValueEnum, BasicValueEnum, CallableValue, FunctionValue, IntValue,
+};
 use inkwell::{AddressSpace, IntPredicate};
 use num_traits::ToPrimitive;
 use std::collections::{HashMap, VecDeque};
@@ -60,22 +62,10 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
         Instr::Branch { block: dest } => {
             let pos = bin.builder.get_insert_block().unwrap();
 
-            if !blocks.contains_key(dest) {
-                blocks.insert(*dest, create_block(*dest, bin, cfg, function, ns));
-                work.push_back(Work {
-                    block_no: *dest,
-                    vars: w.vars.clone(),
-                });
-            }
-
-            let bb = blocks.get(dest).unwrap();
-
-            for (v, phi) in bb.phis.iter() {
-                phi.add_incoming(&[(&w.vars[v].value, pos)]);
-            }
+            let bb = add_or_retrieve_block(*dest, pos, bin, function, blocks, work, w, cfg, ns);
 
             bin.builder.position_at_end(pos);
-            bin.builder.build_unconditional_branch(bb.bb);
+            bin.builder.build_unconditional_branch(bb);
         }
         Instr::Store { dest, data } => {
             let value_ref = expression(target, bin, data, &w.vars, function, ns);
@@ -92,41 +82,11 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
 
             let pos = bin.builder.get_insert_block().unwrap();
 
-            let bb_true = {
-                if !blocks.contains_key(true_) {
-                    blocks.insert(*true_, create_block(*true_, bin, cfg, function, ns));
-                    work.push_back(Work {
-                        block_no: *true_,
-                        vars: w.vars.clone(),
-                    });
-                }
-
-                let bb = blocks.get(true_).unwrap();
-
-                for (v, phi) in bb.phis.iter() {
-                    phi.add_incoming(&[(&w.vars[v].value, pos)]);
-                }
+            let bb_true =
+                add_or_retrieve_block(*true_, pos, bin, function, blocks, work, w, cfg, ns);
 
-                bb.bb
-            };
-
-            let bb_false = {
-                if !blocks.contains_key(false_) {
-                    blocks.insert(*false_, create_block(*false_, bin, cfg, function, ns));
-                    work.push_back(Work {
-                        block_no: *false_,
-                        vars: w.vars.clone(),
-                    });
-                }
-
-                let bb = blocks.get(false_).unwrap();
-
-                for (v, phi) in bb.phis.iter() {
-                    phi.add_incoming(&[(&w.vars[v].value, pos)]);
-                }
-
-                bb.bb
-            };
+            let bb_false =
+                add_or_retrieve_block(*false_, pos, bin, function, blocks, work, w, cfg, ns);
 
             bin.builder.position_at_end(pos);
             bin.builder
@@ -1185,5 +1145,58 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
                 );
             }
         }
+        Instr::Switch {
+            cond,
+            cases,
+            default,
+        } => {
+            let pos = bin.builder.get_insert_block().unwrap();
+            let cond = expression(target, bin, cond, &w.vars, function, ns);
+            let cases = cases
+                .iter()
+                .map(|(exp, block_no)| {
+                    let exp = expression(target, bin, exp, &w.vars, function, ns);
+                    let bb = add_or_retrieve_block(
+                        *block_no, pos, bin, function, blocks, work, w, cfg, ns,
+                    );
+                    (exp.into_int_value(), bb)
+                })
+                .collect::<Vec<(IntValue, inkwell::basic_block::BasicBlock)>>();
+
+            let default_bb =
+                add_or_retrieve_block(*default, pos, bin, function, blocks, work, w, cfg, ns);
+            bin.builder.position_at_end(pos);
+            bin.builder
+                .build_switch(cond.into_int_value(), default_bb, cases.as_ref());
+        }
     }
 }
+
+/// Add or retrieve a basic block from the blocks' hashmap
+fn add_or_retrieve_block<'a>(
+    block_no: usize,
+    pos: inkwell::basic_block::BasicBlock<'a>,
+    bin: &Binary<'a>,
+    function: FunctionValue,
+    blocks: &mut HashMap<usize, BasicBlock<'a>>,
+    work: &mut VecDeque<Work<'a>>,
+    w: &mut Work<'a>,
+    cfg: &ControlFlowGraph,
+    ns: &Namespace,
+) -> inkwell::basic_block::BasicBlock<'a> {
+    if let std::collections::hash_map::Entry::Vacant(e) = blocks.entry(block_no) {
+        e.insert(create_block(block_no, bin, cfg, function, ns));
+        work.push_back(Work {
+            block_no,
+            vars: w.vars.clone(),
+        });
+    }
+
+    let bb = blocks.get(&block_no).unwrap();
+
+    for (v, phi) in bb.phis.iter() {
+        phi.add_incoming(&[(&w.vars[v].value, pos)]);
+    }
+
+    bb.bb
+}

+ 1 - 0
src/sema/ast.rs

@@ -1056,6 +1056,7 @@ impl CodeLocation for Instr {
                 pt::Loc::File(_, _, _) => source.loc(),
                 _ => destination.loc(),
             },
+            Instr::Switch { cond, .. } => cond.loc(),
             Instr::Branch { .. }
             | Instr::Unreachable
             | Instr::Nop

+ 0 - 7
src/sema/yul/statements.rs

@@ -110,13 +110,6 @@ pub(crate) fn resolve_yul_statement(
                 ns,
             )?;
             resolved_statements.push(resolved_switch.0);
-            ns.diagnostics.push(
-                Diagnostic::error(
-                    switch_statement.loc,
-                    "switch statements have no implementation in code generation yet. Please, file a GitHub issue \
-                    if there is urgent need for such a feature".to_string()
-                )
-            );
             Ok(resolved_switch.1)
         }
 

+ 32 - 0
src/sema/yul/switch.rs

@@ -8,8 +8,11 @@ use crate::sema::yul::block::resolve_yul_block;
 use crate::sema::yul::expression::{check_type, resolve_yul_expression};
 use crate::sema::yul::functions::FunctionsTable;
 use crate::sema::yul::types::verify_type_from_expression;
+use num_bigint::{BigInt, Sign};
+use num_traits::{One, Zero};
 use solang_parser::pt::{CodeLocation, YulSwitchOptions};
 use solang_parser::{diagnostics::Diagnostic, pt};
+use std::collections::HashMap;
 
 /// Resolve switch statement
 /// Returns the resolved block and a bool to indicate if the next statement is reachable.
@@ -42,6 +45,35 @@ pub(crate) fn resolve_switch(
         next_reachable |= block_reachable;
     }
 
+    let mut conditions: HashMap<BigInt, pt::Loc> = HashMap::new();
+    for item in &case_blocks {
+        let big_int = match &item.condition {
+            YulExpression::BoolLiteral(_, value, _) => {
+                if *value {
+                    BigInt::one()
+                } else {
+                    BigInt::zero()
+                }
+            }
+            YulExpression::NumberLiteral(_, value, _) => value.clone(),
+            YulExpression::StringLiteral(_, value, _) => BigInt::from_bytes_be(Sign::Plus, value),
+            _ => unreachable!("Switch condition should be a literal"),
+        };
+
+        let repeated_loc = conditions.get(&big_int);
+
+        if let Some(repeated) = repeated_loc {
+            ns.diagnostics.push(Diagnostic::error_with_note(
+                item.condition.loc(),
+                "duplicate case for switch".to_string(),
+                *repeated,
+                "repeated case found here".to_string(),
+            ));
+        } else {
+            conditions.insert(big_int, item.condition.loc());
+        }
+    }
+
     if yul_switch.default.is_some() && default_block.is_some() {
         ns.diagnostics.push(Diagnostic::error(
             yul_switch.default.as_ref().unwrap().loc(),

+ 6 - 7
src/sema/yul/tests/block.rs

@@ -186,7 +186,10 @@ contract testTypes {
     "#;
 
     let ns = parse(file);
-    assert!(ns.diagnostics.contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
+    for item in ns.diagnostics.iter() {
+        std::println!("{}", item.message);
+    }
+    assert!(ns.diagnostics.contains_message("unreachable yul statement"));
 
     let file = r#"
 contract testTypes {
@@ -209,12 +212,10 @@ contract testTypes {
     }
 }    "#;
     let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 1);
     assert!(ns
         .diagnostics
         .contains_message("found contract 'testTypes'"));
-    assert!(ns
-        .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
 
     let file = r#"
     contract testTypes {
@@ -237,12 +238,10 @@ contract testTypes {
 }    "#;
 
     let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 1);
     assert!(ns
         .diagnostics
         .contains_message("found contract 'testTypes'"));
-    assert!(ns
-        .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
 }
 
 #[test]

+ 3 - 4
src/sema/yul/tests/mutability.rs

@@ -178,7 +178,6 @@ fn if_block() {
 
 #[test]
 fn switch() {
-    // TODO: switch statements are not yet supported, so there is no way to test mutability here
     let file = r#"
     contract testTypes {
     function testAsm(uint[] calldata vl) public pure {
@@ -198,7 +197,7 @@ fn switch() {
     let ns = parse(file);
     assert!(ns
         .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
+        .contains_message("function declared 'pure' but this expression reads from state"));
 
     let file = r#"
     contract testTypes {
@@ -220,7 +219,7 @@ fn switch() {
     let ns = parse(file);
     assert!(ns
         .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
+        .contains_message("function declared 'pure' but this expression reads from state"));
 
     let file = r#"
     contract testTypes {
@@ -242,7 +241,7 @@ fn switch() {
     let ns = parse(file);
     assert!(ns
         .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
+        .contains_message("function declared 'pure' but this expression reads from state"));
 }
 
 #[test]

+ 131 - 4
src/sema/yul/tests/switch.rs

@@ -74,7 +74,6 @@ contract testTypes {
 
 #[test]
 fn correct_switch() {
-    // TODO: switch statements are not yet implemented
     let file = r#"
 contract testTypes {
     function testAsm() public pure {
@@ -101,7 +100,135 @@ contract testTypes {
     "#;
 
     let ns = parse(file);
-    assert!(ns
-        .diagnostics
-        .contains_message("switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature"));
+    for item in ns.diagnostics.iter() {
+        std::println!("{}", item.message);
+    }
+    assert_eq!(ns.diagnostics.len(), 1);
+    assert_eq!(
+        ns.diagnostics.iter().next().unwrap().message,
+        "found contract 'testTypes'"
+    );
+}
+
+#[test]
+fn repeated_switch_case() {
+    let file = r#"
+contract Testing {
+    function duplicate_cases(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            case hex"019a" {
+                b := 5
+            }
+            case 410 {
+                b := 6
+            }
+        }
+    }
+}
+    "#;
+    let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 2);
+    assert!(ns.diagnostics.contains_message("found contract 'Testing'"));
+    assert!(ns.diagnostics.contains_message("duplicate case for switch"));
+    let errors = ns.diagnostics.errors();
+    assert_eq!(errors.len(), 1);
+    assert_eq!(errors[0].notes.len(), 1);
+    assert_eq!(errors[0].notes[0].message, "repeated case found here");
+
+    let file = r#"
+contract Testing {
+    function duplicate_cases(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            case true {
+                b := 5
+            }
+            case 1 {
+                b := 6
+            }
+        }
+    }
+}
+    "#;
+    let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 2);
+    assert!(ns.diagnostics.contains_message("found contract 'Testing'"));
+    assert!(ns.diagnostics.contains_message("duplicate case for switch"));
+    let errors = ns.diagnostics.errors();
+    assert_eq!(errors.len(), 1);
+    assert_eq!(errors[0].notes.len(), 1);
+    assert_eq!(errors[0].notes[0].message, "repeated case found here");
+
+    let file = r#"
+contract Testing {
+    function duplicate_cases(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            case 0 {
+                b := 5
+            }
+            case false {
+                b := 6
+            }
+        }
+    }
+}
+    "#;
+    let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 2);
+    assert!(ns.diagnostics.contains_message("found contract 'Testing'"));
+    assert!(ns.diagnostics.contains_message("duplicate case for switch"));
+    let errors = ns.diagnostics.errors();
+    assert_eq!(errors.len(), 1);
+    assert_eq!(errors[0].notes.len(), 1);
+    assert_eq!(errors[0].notes[0].message, "repeated case found here");
+
+    let file = r#"
+contract Testing {
+    function duplicate_cases(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            case 16705 {
+                b := 5
+            }
+            case "AA" {
+                b := 6
+            }
+        }
+    }
+}
+    "#;
+    let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 2);
+    assert!(ns.diagnostics.contains_message("found contract 'Testing'"));
+    assert!(ns.diagnostics.contains_message("duplicate case for switch"));
+    let errors = ns.diagnostics.errors();
+    assert_eq!(errors.len(), 1);
+    assert_eq!(errors[0].notes.len(), 1);
+    assert_eq!(errors[0].notes[0].message, "repeated case found here");
+
+    let file = r#"
+contract Testing {
+    function duplicate_cases(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            case 16705 {
+                b := 5
+            }
+            case 16705 {
+                b := 6
+            }
+        }
+    }
+}
+    "#;
+    let ns = parse(file);
+    assert_eq!(ns.diagnostics.len(), 2);
+    assert!(ns.diagnostics.contains_message("found contract 'Testing'"));
+    assert!(ns.diagnostics.contains_message("duplicate case for switch"));
+    let errors = ns.diagnostics.errors();
+    assert_eq!(errors.len(), 1);
+    assert_eq!(errors[0].notes.len(), 1);
+    assert_eq!(errors[0].notes[0].message, "repeated case found here");
 }

+ 70 - 0
tests/codegen_testcases/yul/switch.sol

@@ -0,0 +1,70 @@
+// RUN: --target solana --emit cfg 
+
+contract Testing {
+    // BEGIN-CHECK: Testing::Testing::function::switch_default__uint256
+    function switch_default(uint a) public pure returns (uint b) {
+        assembly {
+            // CHECK: switch (arg #0):
+            switch a
+            // CHECK: case uint256 1: goto block #2
+            // CHECK: case uint256 2: goto block #3
+            // CHECK: default: goto block #4
+
+            // CHECK: block1: # end_switch
+            // CHECK: branchcond (%b == uint256 7), block5, block6
+            case 1 {
+                // CHECK: block2: # case_0
+                // CHECK: ty:uint256 %b = uint256 5
+                b := 5
+                // CHECK: branch block1
+            }
+            case 2 {
+                // CHECK: block3: # case_1
+                // CHECK: ty:uint256 %b = uint256 6
+                b := 6
+                // CHECK: branch block1
+            }
+            default {
+                // CHECK: block4: # default
+                // CHECK: ty:uint256 %b = uint256 7
+                b := 7
+                // CHECK: branch block1
+            }
+        }
+
+        if (b == 7) {
+            b += 1;
+        }
+    }
+
+    // BEGIN-CHECK: Testing::Testing::function::switch_no_default__uint256
+    function switch_no_default(uint a) public pure returns (uint b) {
+        assembly {
+            switch a
+            // CHECK: switch (arg #0):
+		    // CHECK: case uint256 1: goto block #2
+		    // CHECK: case uint256 2: goto block #3
+		    // CHECK: default: goto block #1
+
+            // CHECK: block1: # end_switch
+	        // CHECK: branchcond (%b == uint256 5), block4, block5
+
+            case 1 {
+            // CHECK: block2: # case_0
+            // CHECK: ty:uint256 %b = uint256 5
+            // CHECK: branch block1
+                b := 5
+            }
+            case 2 {
+            // CHECK: block3: # case_1
+            // CHECK: ty:uint256 %b = uint256 6
+	        // CHECK: branch block1
+                b := 6
+            }
+        }
+
+        if (b == 5) {
+            b += 1;
+        }
+    }
+}

+ 0 - 2
tests/contract_testcases/solana/yul/yul_switch.dot

@@ -41,7 +41,6 @@ strict digraph "tests/contract_testcases/solana/yul/yul_switch.sol" {
 	return [label="return\ntests/contract_testcases/solana/yul/yul_switch.sol:16:9-17"]
 	variable [label="variable: y\nuint256\ntests/contract_testcases/solana/yul/yul_switch.sol:16:16-17"]
 	diagnostic [label="found contract 'testTypes'\nlevel Debug\ntests/contract_testcases/solana/yul/yul_switch.sol:1:1-18:2"]
-	diagnostic_44 [label="switch statements have no implementation in code generation yet. Please, file a GitHub issue if there is urgent need for such a feature\nlevel Error\ntests/contract_testcases/solana/yul/yul_switch.sol:6:13-11:14"]
 	contracts -> contract
 	contract -> var [label="variable"]
 	contract -> testAsm [label="function"]
@@ -84,5 +83,4 @@ strict digraph "tests/contract_testcases/solana/yul/yul_switch.sol" {
 	inline_assembly -> return [label="next"]
 	return -> variable [label="expr"]
 	diagnostics -> diagnostic [label="Debug"]
-	diagnostics -> diagnostic_44 [label="Error"]
 }

+ 100 - 1
tests/solana_tests/yul.rs

@@ -384,6 +384,105 @@ fn addmod_mulmod() {
     let returns = vm.function("testMod", &[], &[], None);
     assert_eq!(
         returns,
-        vec![Token::Uint(Uint::from(0)), Token::Uint(Uint::from(7)),]
+        vec![Token::Uint(Uint::from(0)), Token::Uint(Uint::from(7))]
     );
 }
+
+#[test]
+fn switch_statement() {
+    let mut vm = build_solidity(
+        r#"
+
+contract Testing {
+    function switch_default(uint a) public pure returns (uint b) {
+        b = 4;
+        assembly {
+            switch a
+            case 1 {
+                b := 5
+            }
+            case 2 {
+                b := 6
+            }
+            default {
+                b := 7
+            }
+        }
+
+        if (b == 7) {
+            b += 2;
+        }
+    }
+
+    function switch_no_default(uint a) public pure returns (uint b) {
+        b = 4;
+        assembly {
+            switch a
+            case 1 {
+                b := 5
+            }
+            case 2 {
+                b := 6
+            }
+        }
+
+        if (b == 5) {
+            b -= 2;
+        }
+    }
+
+    function switch_no_case(uint a) public pure returns (uint b) {
+        b = 7;
+        assembly {
+            switch a
+            default {
+                b := 5
+            }
+        }
+
+        if (b == 5) {
+            b -= 1;
+        }
+    }
+}
+        "#,
+    );
+
+    vm.constructor("Testing", &[]);
+
+    let returns = vm.function("switch_default", &[Token::Uint(Uint::from(1))], &[], None);
+    assert_eq!(returns[0], Token::Uint(Uint::from(5)));
+
+    let returns = vm.function("switch_default", &[Token::Uint(Uint::from(2))], &[], None);
+    assert_eq!(returns[0], Token::Uint(Uint::from(6)));
+
+    let returns = vm.function("switch_default", &[Token::Uint(Uint::from(6))], &[], None);
+    assert_eq!(returns[0], Token::Uint(Uint::from(9)));
+
+    let returns = vm.function(
+        "switch_no_default",
+        &[Token::Uint(Uint::from(1))],
+        &[],
+        None,
+    );
+    assert_eq!(returns[0], Token::Uint(Uint::from(3)));
+
+    let returns = vm.function(
+        "switch_no_default",
+        &[Token::Uint(Uint::from(2))],
+        &[],
+        None,
+    );
+    assert_eq!(returns[0], Token::Uint(Uint::from(6)));
+
+    let returns = vm.function(
+        "switch_no_default",
+        &[Token::Uint(Uint::from(6))],
+        &[],
+        None,
+    );
+    assert_eq!(returns[0], Token::Uint(Uint::from(4)));
+
+    let returns = vm.function("switch_no_case", &[Token::Uint(Uint::from(3))], &[], None);
+    assert_eq!(returns[0], Token::Uint(Uint::from(4)));
+}

+ 83 - 0
tests/substrate_tests/yul.rs

@@ -192,3 +192,86 @@ contract testing  {
 
     assert_eq!(runtime.vm.output, expected);
 }
+
+#[test]
+fn switch_statement() {
+    let mut runtime = build_solidity(
+        r#"
+contract Testing {
+    function switch_default(uint a) public pure returns (uint b) {
+        b = 4;
+        assembly {
+            switch a
+            case 1 {
+                b := 5
+            }
+            case 2 {
+                b := 6
+            }
+            default {
+                b := 7
+            }
+        }
+
+        if (b == 7) {
+            b += 2;
+        }
+    }
+
+    function switch_no_default(uint a) public pure returns (uint b) {
+        b = 4;
+        assembly {
+            switch a
+            case 1 {
+                b := 5
+            }
+            case 2 {
+                b := 6
+            }
+        }
+
+        if (b == 5) {
+            b -= 2;
+        }
+    }
+
+    function switch_no_case(uint a) public pure returns (uint b) {
+        b = 7;
+        assembly {
+            switch a
+            default {
+                b := 5
+            }
+        }
+
+        if (b == 5) {
+            b -= 1;
+        }
+    }
+}
+        "#,
+    );
+
+    runtime.constructor(0, Vec::new());
+
+    runtime.function("switch_default", Val256(U256::from(1)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(5)).encode());
+
+    runtime.function("switch_default", Val256(U256::from(2)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(6)).encode());
+
+    runtime.function("switch_default", Val256(U256::from(6)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(9)).encode());
+
+    runtime.function("switch_no_default", Val256(U256::from(1)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(3)).encode());
+
+    runtime.function("switch_no_default", Val256(U256::from(2)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(6)).encode());
+
+    runtime.function("switch_no_default", Val256(U256::from(6)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(4)).encode());
+
+    runtime.function("switch_no_case", Val256(U256::from(3)).encode());
+    assert_eq!(runtime.vm.output, Val256(U256::from(4)).encode());
+}