Эх сурвалжийг харах

Simplify switch in constant folding (#1208)

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 2 жил өмнө
parent
commit
ecf8c3b2e0

+ 16 - 0
src/codegen/constant_folding.rs

@@ -329,6 +329,22 @@ pub fn constant_folding(cfg: &mut ControlFlowGraph, ns: &mut Namespace) {
                         .iter()
                         .iter()
                         .map(|(exp, goto)| (expression(exp, Some(&vars), cfg, ns).0, *goto))
                         .map(|(exp, goto)| (expression(exp, Some(&vars), cfg, ns).0, *goto))
                         .collect::<Vec<(Expression, usize)>>();
                         .collect::<Vec<(Expression, usize)>>();
+
+                    if let Expression::NumberLiteral(_, _, num) = &cond.0 {
+                        let mut simplified_branch = None;
+                        for (match_item, block) in &cases {
+                            if let Expression::NumberLiteral(_, _, match_num) = match_item {
+                                if match_num == num {
+                                    simplified_branch = Some(*block);
+                                }
+                            }
+                        }
+                        cfg.blocks[block_no].instr[instr_no] = Instr::Branch {
+                            block: simplified_branch.unwrap_or(*default),
+                        };
+                        continue;
+                    }
+
                     cfg.blocks[block_no].instr[instr_no] = Instr::Switch {
                     cfg.blocks[block_no].instr[instr_no] = Instr::Switch {
                         cond: cond.0,
                         cond: cond.0,
                         cases,
                         cases,

+ 2 - 4
tests/codegen_testcases/yul/cse_switch.sol

@@ -2,14 +2,12 @@
 
 
 contract foo {
 contract foo {
     // BEGIN-CHECK: foo::foo::function::test
     // BEGIN-CHECK: foo::foo::function::test
-    function test() public {
+    function test(uint x) public {
         uint256 yy=0;
         uint256 yy=0;
         assembly {
         assembly {
         // Ensure the CSE temp is not before the switch
         // Ensure the CSE temp is not before the switch
-        // CHECK: ty:uint256 %x = uint256 54
         // CHECK: ty:uint256 %y = uint256 5
         // CHECK: ty:uint256 %y = uint256 5
-	    // CHECK: switch uint256 2:
-            let x := 54
+	    // CHECK: switch ((arg #0) & uint256 3):
             let y := 5
             let y := 5
 
 
             switch and(x, 3)
             switch and(x, 3)

+ 76 - 0
tests/codegen_testcases/yul/switch_simplify.sol

@@ -0,0 +1,76 @@
+// RUN: --target solana --emit cfg
+
+contract test {
+    // BEGIN-CHECK: test::test::function::test_1
+    function test_1() public pure returns (int) {
+        int gg = 56;
+
+        int res = 0;
+        assembly {
+            // NOT-CHECK: switch
+            // CHECK: branch block3
+            switch add(gg, 4)
+            case 5 {
+                res := 90
+            }
+            case 60 {
+                // CHECK: block3: # case_1
+	            // CHECK: ty:int256 %res = int256 4
+                res := 4
+            }
+            default {
+                res := 7
+            }
+        }
+
+        return res;
+    }
+
+    // BEGIN-CHECK: test::test::function::test_2
+    function test_2() public pure returns (int) {
+        int gg = 56;
+
+        int res = 0;
+        assembly {
+            // NOT-CHECK: switch
+            // CHECK: branch block4
+            switch add(gg, 4)
+            case 5 {
+                res := 90
+            }
+            case 6 {
+                res := 4
+            }
+            default {
+                // CHECK: block4: # default
+	            // CHECK: ty:int256 %res = int256 7
+                res := 7
+            }
+        }
+
+        return res;
+    }
+
+    // BEGIN-CHECK: test::test::function::test_3
+    function test_3() public pure returns (int) {
+        int gg = 56;
+
+        int res = 0;
+        assembly {
+            // NOT-CHECK: switch
+            // CHECK: branch block1
+            switch add(gg, 4)
+            case 5 {
+                res := 90
+            }
+            case 6 {
+                res := 4
+            }
+        }
+
+        // CHECK: block1: # end_switch
+	    // CHECK: return %res
+
+        return res;
+    }
+}