Эх сурвалжийг харах

Load value before shifting (#1133)

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 2 жил өмнө
parent
commit
0d529dd676

+ 4 - 4
src/sema/expression.rs

@@ -3065,12 +3065,12 @@ fn shift_left(
     let _ = get_int_length(&left.ty(), &l.loc(), true, ns, diagnostics)?;
     let (right_length, _) = get_int_length(&right.ty(), &r.loc(), false, ns, diagnostics)?;
 
-    let left_type = left.ty();
+    let left_type = left.ty().deref_any().clone();
 
     Ok(Expression::ShiftLeft {
         loc: *loc,
         ty: left_type.clone(),
-        left: Box::new(left),
+        left: Box::new(left.cast(loc, &left_type, true, ns, diagnostics)?),
         right: Box::new(cast_shift_arg(loc, right, right_length, &left_type, ns)),
     })
 }
@@ -3090,7 +3090,7 @@ fn shift_right(
 
     check_var_usage_expression(ns, &left, &right, symtable);
 
-    let left_type = left.ty();
+    let left_type = left.ty().deref_any().clone();
     // left hand side may be bytes/int/uint
     // right hand size may be int/uint
     let _ = get_int_length(&left_type, &l.loc(), true, ns, diagnostics)?;
@@ -3099,7 +3099,7 @@ fn shift_right(
     Ok(Expression::ShiftRight {
         loc: *loc,
         ty: left_type.clone(),
-        left: Box::new(left),
+        left: Box::new(left.cast(loc, &left_type, true, ns, diagnostics)?),
         right: Box::new(cast_shift_arg(loc, right, right_length, &left_type, ns)),
         sign: left_type.is_signed_int(),
     })

+ 10 - 10
tests/contract_testcases/solana/shift_struct_member.dot

@@ -7,8 +7,8 @@ strict digraph "tests/contract_testcases/solana/shift_struct_member.sol" {
 	returns [label="returns\nuint112 "]
 	return [label="return\ntests/contract_testcases/solana/shift_struct_member.sol:14:9-39"]
 	trunc [label="truncate uint112\ntests/contract_testcases/solana/shift_struct_member.sol:14:16-39"]
-	load [label="load uint224\ntests/contract_testcases/solana/shift_struct_member.sol:14:16-39"]
 	shift_right [label="shift right\nuint224\ntests/contract_testcases/solana/shift_struct_member.sol:14:24-38"]
+	load [label="load uint224\ntests/contract_testcases/solana/shift_struct_member.sol:14:24-38"]
 	structmember [label="struct member #0 uint224\ntests/contract_testcases/solana/shift_struct_member.sol:14:29-31"]
 	variable [label="variable: self\nstruct FixedPoint.uq112x112\ntests/contract_testcases/solana/shift_struct_member.sol:14:24-28"]
 	zero_ext [label="zero extend uint224\ntests/contract_testcases/solana/shift_struct_member.sol:14:24-38"]
@@ -18,8 +18,8 @@ strict digraph "tests/contract_testcases/solana/shift_struct_member.sol" {
 	returns_18 [label="returns\nuint144 "]
 	return_19 [label="return\ntests/contract_testcases/solana/shift_struct_member.sol:19:9-39"]
 	trunc_20 [label="truncate uint144\ntests/contract_testcases/solana/shift_struct_member.sol:19:16-39"]
-	load_21 [label="load uint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:16-39"]
-	shift_right_22 [label="shift right\nuint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:24-38"]
+	shift_right_21 [label="shift right\nuint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:24-38"]
+	load_22 [label="load uint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:24-38"]
 	structmember_23 [label="struct member #0 uint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:29-31"]
 	variable_24 [label="variable: self\nstruct FixedPoint.uq144x112\ntests/contract_testcases/solana/shift_struct_member.sol:19:24-28"]
 	zero_ext_25 [label="zero extend uint256\ntests/contract_testcases/solana/shift_struct_member.sol:19:24-38"]
@@ -33,9 +33,9 @@ strict digraph "tests/contract_testcases/solana/shift_struct_member.sol" {
 	decode -> returns [label="returns"]
 	decode -> return [label="body"]
 	return -> trunc [label="expr"]
-	trunc -> load [label="expr"]
-	load -> shift_right [label="expr"]
-	shift_right -> structmember [label="left"]
+	trunc -> shift_right [label="expr"]
+	shift_right -> load [label="left"]
+	load -> structmember [label="expr"]
 	structmember -> variable [label="var"]
 	shift_right -> zero_ext [label="right"]
 	zero_ext -> number_literal [label="expr"]
@@ -44,11 +44,11 @@ strict digraph "tests/contract_testcases/solana/shift_struct_member.sol" {
 	decode144 -> returns_18 [label="returns"]
 	decode144 -> return_19 [label="body"]
 	return_19 -> trunc_20 [label="expr"]
-	trunc_20 -> load_21 [label="expr"]
-	load_21 -> shift_right_22 [label="expr"]
-	shift_right_22 -> structmember_23 [label="left"]
+	trunc_20 -> shift_right_21 [label="expr"]
+	shift_right_21 -> load_22 [label="left"]
+	load_22 -> structmember_23 [label="expr"]
 	structmember_23 -> variable_24 [label="var"]
-	shift_right_22 -> zero_ext_25 [label="right"]
+	shift_right_21 -> zero_ext_25 [label="right"]
 	zero_ext_25 -> number_literal_26 [label="expr"]
 	diagnostics -> diagnostic [label="Debug"]
 }

+ 45 - 1
tests/solana_tests/primitives.rs

@@ -3,7 +3,7 @@
 use crate::build_solidity_with_overflow_check;
 use crate::{build_solidity, BorshToken};
 use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt};
-use num_traits::{ToPrimitive, Zero};
+use num_traits::{One, ToPrimitive, Zero};
 use rand::seq::SliceRandom;
 use rand::Rng;
 use std::ops::BitAnd;
@@ -1578,3 +1578,47 @@ fn bytes_cast() {
 
     assert_eq!(returns, BorshToken::uint8_fixed_array(b"abcde".to_vec()));
 }
+
+#[test]
+fn shift_after_load() {
+    let mut vm = build_solidity(
+        r#"
+    contract OneSwapToken {
+        function testIt(uint256[] calldata mixedAddrVal) public pure returns (uint256, uint256) {
+            uint256 a = mixedAddrVal[0]<<2;
+            uint256 b = mixedAddrVal[1]>>2;
+            return (a, b);
+        }
+    }
+        "#,
+    );
+
+    vm.constructor(&[]);
+    let args = BorshToken::Array(vec![
+        BorshToken::Uint {
+            width: 256,
+            value: BigInt::one(),
+        },
+        BorshToken::Uint {
+            width: 256,
+            value: BigInt::from(4u8),
+        },
+    ]);
+    let returns = vm.function("testIt", &[args]).unwrap().unwrap_tuple();
+
+    assert_eq!(returns.len(), 2);
+    assert_eq!(
+        returns[0],
+        BorshToken::Uint {
+            width: 256,
+            value: BigInt::from(4u8)
+        }
+    );
+    assert_eq!(
+        returns[1],
+        BorshToken::Uint {
+            width: 256,
+            value: BigInt::one(),
+        }
+    );
+}