Explorar o código

Soroban: Add support for `bool`, `uint32`, and `int32` types (#1786)

Fixes #1677

---------

Signed-off-by: Tarek <tareknaser360@gmail.com>
Co-authored-by: salaheldinsoliman <49910731+salaheldinsoliman@users.noreply.github.com>
Tarek Elsayed hai 5 meses
pai
achega
601a8ed3eb

+ 120 - 0
src/codegen/encoding/soroban_encoding.rs

@@ -120,6 +120,15 @@ pub fn soroban_decode_arg(
     };
 
     match ty {
+        Type::Bool => Expression::NotEqual {
+            loc: Loc::Codegen,
+            left: arg.into(),
+            right: Box::new(Expression::NumberLiteral {
+                loc: Loc::Codegen,
+                ty: Type::Uint(64),
+                value: 0u64.into(),
+            }),
+        },
         Type::Uint(64) => Expression::ShiftRight {
             loc: Loc::Codegen,
             ty: Type::Uint(64),
@@ -135,6 +144,52 @@ pub fn soroban_decode_arg(
         Type::Address(_) => arg.clone(),
 
         Type::Int(128) | Type::Uint(128) => decode_i128(wrapper_cfg, vartab, arg),
+        Type::Uint(32) => {
+            // get payload out of major bits then truncate to 32‑bit
+            Expression::Trunc {
+                loc: Loc::Codegen,
+                ty: Type::Uint(32),
+                expr: Box::new(Expression::ShiftRight {
+                    loc: Loc::Codegen,
+                    ty: Type::Uint(64),
+                    left: arg.into(),
+                    right: Box::new(Expression::NumberLiteral {
+                        loc: Loc::Codegen,
+                        ty: Type::Uint(64),
+                        value: 32u64.into(),
+                    }),
+                    signed: false,
+                }),
+            }
+        }
+
+        Type::Int(32) => Expression::Trunc {
+            loc: Loc::Codegen,
+            ty: Type::Int(32),
+            expr: Box::new(Expression::ShiftRight {
+                loc: Loc::Codegen,
+                ty: Type::Int(64),
+                left: arg.into(),
+                right: Box::new(Expression::NumberLiteral {
+                    loc: Loc::Codegen,
+                    ty: Type::Uint(64),
+                    value: 32u64.into(),
+                }),
+                signed: true,
+            }),
+        },
+        Type::Int(64) => Expression::ShiftRight {
+            loc: Loc::Codegen,
+            ty: Type::Int(64),
+            left: arg.into(),
+            right: Box::new(Expression::NumberLiteral {
+                loc: Loc::Codegen,
+                ty: Type::Uint(64),
+                value: BigInt::from(8u64),
+            }),
+            signed: true,
+        },
+
         _ => unimplemented!(),
     }
 }
@@ -148,6 +203,22 @@ pub fn soroban_encode_arg(
     let obj = vartab.temp_name("obj_".to_string().as_str(), &Type::Uint(64));
 
     let ret = match item.ty() {
+        Type::Bool => {
+            let encoded = match item {
+                Expression::BoolLiteral { value, .. } => Expression::NumberLiteral {
+                    loc: item.loc(),
+                    ty: Type::Uint(64),
+                    value: if value { 1u64.into() } else { 0u64.into() },
+                },
+                _ => item.cast(&Type::Uint(64), ns),
+            };
+
+            Instr::Set {
+                loc: item.loc(),
+                res: obj,
+                expr: encoded,
+            }
+        }
         Type::String => {
             let inp = Expression::VectorData {
                 pointer: Box::new(item.clone()),
@@ -249,6 +320,55 @@ pub fn soroban_encode_arg(
                 args: vec![encoded, len],
             }
         }
+        Type::Uint(32) | Type::Int(32) => {
+            // widen to 64 bits so we can shift
+            let widened = match item.ty() {
+                Type::Uint(32) => Expression::ZeroExt {
+                    loc: item.loc(),
+                    ty: Type::Uint(64),
+                    expr: Box::new(item.clone()),
+                },
+                Type::Int(32) => Expression::SignExt {
+                    loc: item.loc(),
+                    ty: Type::Int(64),
+                    expr: Box::new(item.clone()),
+                },
+                _ => unreachable!(),
+            };
+
+            // the value goes into the major bits of the 64 bit value
+            let shifted = Expression::ShiftLeft {
+                loc: item.loc(),
+                ty: Type::Uint(64),
+                left: Box::new(widened.cast(&Type::Uint(64), ns)),
+                right: Box::new(Expression::NumberLiteral {
+                    loc: item.loc(),
+                    ty: Type::Uint(64),
+                    value: 32u64.into(), // 24 (minor) + 8 (tag)
+                }),
+            };
+
+            let tag = if matches!(item.ty(), Type::Uint(32)) {
+                4
+            } else {
+                5
+            };
+            Instr::Set {
+                loc: item.loc(),
+                res: obj,
+                expr: Expression::Add {
+                    loc: item.loc(),
+                    ty: Type::Uint(64),
+                    left: Box::new(shifted),
+                    right: Box::new(Expression::NumberLiteral {
+                        loc: item.loc(),
+                        ty: Type::Uint(64),
+                        value: tag.into(),
+                    }),
+                    overflowing: false,
+                },
+            }
+        }
         Type::Uint(64) | Type::Int(64) => {
             let shift_left = Expression::ShiftLeft {
                 loc: item.loc(),

+ 3 - 0
src/emit/soroban/mod.rs

@@ -258,7 +258,9 @@ impl SorobanTarget {
 
                             match ty {
                                 ast::Type::Uint(32) => ScSpecTypeDef::U32,
+                                &ast::Type::Int(32) => ScSpecTypeDef::I32,
                                 ast::Type::Uint(64) => ScSpecTypeDef::U64,
+                                &ast::Type::Int(64) => ScSpecTypeDef::I64,
                                 ast::Type::Int(128) => ScSpecTypeDef::I128,
                                 ast::Type::Uint(128) => ScSpecTypeDef::U128,
                                 ast::Type::Bool => ScSpecTypeDef::Bool,
@@ -285,6 +287,7 @@ impl SorobanTarget {
                         };
                         match ty {
                             ast::Type::Uint(32) => ScSpecTypeDef::U32,
+                            ast::Type::Int(32) => ScSpecTypeDef::I32,
                             ast::Type::Uint(64) => ScSpecTypeDef::U64,
                             ast::Type::Int(128) => ScSpecTypeDef::I128,
                             ast::Type::Uint(128) => ScSpecTypeDef::U128,

+ 218 - 0
tests/soroban_testcases/math.rs

@@ -189,3 +189,221 @@ fn u128_ops() {
     let expected: Val = 1_u128.into_val(&runtime.env);
     assert!(expected.shallow_eq(&res));
 }
+
+#[test]
+fn bool_roundtrip() {
+    let runtime = build_solidity(
+        r#"
+        contract test {
+            function flip(bool x) public returns (bool) {
+                return !x;
+            }
+        }"#,
+        |_| {},
+    );
+
+    let addr = runtime.contracts.last().unwrap();
+    let arg_true: Val = true.into_val(&runtime.env);
+    let res = runtime.invoke_contract(addr, "flip", vec![arg_true]);
+    let expected: Val = false.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn u32_roundtrip() {
+    let runtime = build_solidity(
+        r#"
+        contract test {
+            function id(uint32 x) public returns (uint32) {
+                return x;
+            }
+        }"#,
+        |_| {},
+    );
+
+    let addr = runtime.contracts.last().unwrap();
+    let arg: Val = (42_u32).into_val(&runtime.env);
+    let res = runtime.invoke_contract(addr, "id", vec![arg]);
+    let expected: Val = (42_u32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn u32_ops() {
+    let runtime = build_solidity(
+        r#"contract math {
+        function add(uint32 a, uint32 b) public returns (uint32) {
+            return a + b;
+        }
+
+        function sub(uint32 a, uint32 b) public returns (uint32) {
+            return a - b;
+        }
+
+        function mul(uint32 a, uint32 b) public returns (uint32) {
+            return a * b;
+        }
+
+        function div(uint32 a, uint32 b) public returns (uint32) {
+            return a / b;
+        }
+
+        function mod(uint32 a, uint32 b) public returns (uint32) {
+            return a % b;
+        }
+    }"#,
+        |_| {},
+    );
+
+    let arg: Val = 5_u32.into_val(&runtime.env);
+    let arg2: Val = 4_u32.into_val(&runtime.env);
+
+    let addr = runtime.contracts.last().unwrap();
+
+    let res = runtime.invoke_contract(addr, "add", vec![arg, arg2]);
+    let expected: Val = 9_u32.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "sub", vec![arg, arg2]);
+    let expected: Val = 1_u32.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mul", vec![arg, arg2]);
+    let expected: Val = 20_u32.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "div", vec![arg, arg2]);
+    let expected: Val = 1_u32.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mod", vec![arg, arg2]);
+    let expected: Val = 1_u32.into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn i32_roundtrip() {
+    let runtime = build_solidity(
+        r#"
+        contract test {
+            function id(int32 x) public returns (int32) {
+                return x;
+            }
+        }"#,
+        |_| {},
+    );
+
+    let addr = runtime.contracts.last().unwrap();
+    let arg: Val = (42_i32).into_val(&runtime.env);
+    let res = runtime.invoke_contract(addr, "id", vec![arg]);
+    let expected: Val = (42_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn i32_ops() {
+    let runtime = build_solidity(
+        r#"contract math {
+        function add(int32 a, int32 b) public returns (int32) {
+            return a + b;
+        }
+
+        function sub(int32 a, int32 b) public returns (int32) {
+            return a - b;
+        }
+
+        function mul(int32 a, int32 b) public returns (int32) {
+            return a * b;
+        }
+
+        function div(int32 a, int32 b) public returns (int32) {
+            return a / b;
+        }
+
+        function mod(int32 a, int32 b) public returns (int32) {
+            return a % b;
+        }
+    }"#,
+        |_| {},
+    );
+
+    let addr = runtime.contracts.last().unwrap();
+    let arg: Val = (5_i32).into_val(&runtime.env);
+    let arg2: Val = (4_i32).into_val(&runtime.env);
+
+    let res = runtime.invoke_contract(addr, "add", vec![arg, arg2]);
+    let expected: Val = (9_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "sub", vec![arg, arg2]);
+    let expected: Val = (1_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mul", vec![arg, arg2]);
+    let expected: Val = (20_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "div", vec![arg, arg2]);
+    let expected: Val = (1_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mod", vec![arg, arg2]);
+    let expected: Val = (1_i32).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn i64_roundtrip() {
+    let runtime = build_solidity(
+        r#"
+        contract test {
+            function id(int64 x) public returns (int64) {
+                return x;
+            }
+        }"#,
+        |_| {},
+    );
+    let addr = runtime.contracts.last().unwrap();
+    let arg: Val = (-42_i64).into_val(&runtime.env);
+    let res = runtime.invoke_contract(addr, "id", vec![arg]);
+    let expected: Val = (-42_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}
+
+#[test]
+fn i64_ops() {
+    let runtime = build_solidity(
+        r#"contract math {
+        function add(int64 a, int64 b) public returns (int64) { return a + b; }
+        function sub(int64 a, int64 b) public returns (int64) { return a - b; }
+        function mul(int64 a, int64 b) public returns (int64) { return a * b; }
+        function div(int64 a, int64 b) public returns (int64) { return a / b; }
+        function mod(int64 a, int64 b) public returns (int64) { return a % b; }
+    }"#,
+        |_| {},
+    );
+
+    let addr = runtime.contracts.last().unwrap();
+    let a: Val = (5_i64).into_val(&runtime.env);
+    let b: Val = (-4_i64).into_val(&runtime.env);
+
+    let res = runtime.invoke_contract(addr, "add", vec![a, b]);
+    let expected: Val = (1_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "sub", vec![a, b]);
+    let expected: Val = (9_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mul", vec![a, b]);
+    let expected: Val = (-20_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "div", vec![a, b]);
+    let expected: Val = (-1_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+
+    let res = runtime.invoke_contract(addr, "mod", vec![a, b]);
+    let expected: Val = (1_i64).into_val(&runtime.env);
+    assert!(expected.shallow_eq(&res));
+}