Ver código fonte

Allow using {func} for type to use library functions (#1528)

Fixes https://github.com/hyperledger/solang/issues/1525

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 2 anos atrás
pai
commit
a74ab4db8a

+ 3 - 2
src/sema/contracts.rs

@@ -63,8 +63,6 @@ impl ast::Contract {
 
 /// Resolve the following contract
 pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::Namespace) {
-    resolve_using(contracts, file_no, ns);
-
     // we need to resolve declarations first, so we call functions/constructors of
     // contracts before they are declared
     let mut delayed: ResolveLater = Default::default();
@@ -73,6 +71,9 @@ pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::N
         resolve_declarations(def, file_no, ns, &mut delayed);
     }
 
+    // using may use functions declared in contracts
+    resolve_using(contracts, file_no, ns);
+
     // Resolve base contract constructor arguments on contract definition (not constructor definitions)
     resolve_base_args(contracts, file_no, ns);
 

+ 4 - 3
src/sema/expression/function_call.rs

@@ -1324,7 +1324,6 @@ pub(super) fn method_call_pos_args(
         resolve_to,
     )? {
         return Ok(resolved_call);
-    } else {
     }
 
     if let Some(resolved_call) = try_user_type(
@@ -1352,8 +1351,9 @@ pub(super) fn method_call_pos_args(
     if let Some(mut path) = ns.expr_to_identifier_path(var) {
         path.identifiers.push(func.clone());
 
-        if let Ok(list) = ns.resolve_free_function_with_namespace(
+        if let Ok(list) = ns.resolve_function_with_namespace(
             context.file_no,
+            None,
             &path,
             &mut Diagnostics::default(),
         ) {
@@ -1640,8 +1640,9 @@ pub(super) fn method_call_named_args(
     if let Some(mut path) = ns.expr_to_identifier_path(var) {
         path.identifiers.push(func_name.clone());
 
-        if let Ok(list) = ns.resolve_free_function_with_namespace(
+        if let Ok(list) = ns.resolve_function_with_namespace(
             context.file_no,
+            None,
             &path,
             &mut Diagnostics::default(),
         ) {

+ 10 - 4
src/sema/namespace.rs

@@ -364,9 +364,10 @@ impl Namespace {
     }
 
     /// Resolve a free function name with namespace
-    pub(super) fn resolve_free_function_with_namespace(
+    pub(super) fn resolve_function_with_namespace(
         &mut self,
         file_no: usize,
+        contract_no: Option<usize>,
         name: &pt::IdentifierPath,
         diagnostics: &mut Diagnostics,
     ) -> Result<Vec<(pt::Loc, usize)>, ()> {
@@ -376,12 +377,12 @@ impl Namespace {
             .map(|(id, namespace)| (id, namespace.iter().collect()))
             .unwrap();
 
-        let s = self.resolve_namespace(namespace, file_no, None, id, diagnostics)?;
+        let symbol = self.resolve_namespace(namespace, file_no, contract_no, id, diagnostics)?;
 
-        if let Some(Symbol::Function(list)) = s {
+        if let Some(Symbol::Function(list)) = symbol {
             Ok(list.clone())
         } else {
-            let error = Namespace::wrong_symbol(s, id);
+            let error = Namespace::wrong_symbol(symbol, id);
 
             diagnostics.push(error);
 
@@ -1335,6 +1336,7 @@ impl Namespace {
                         ));
                         return Err(());
                     };
+                    namespace.clear();
                     Some(*n)
                 }
                 Some(Symbol::Function(_)) => {
@@ -1390,6 +1392,10 @@ impl Namespace {
             };
         }
 
+        if !namespace.is_empty() {
+            return Ok(None);
+        }
+
         let mut s = self
             .variable_symbols
             .get(&(import_file_no, contract_no, id.name.to_owned()))

+ 30 - 2
src/sema/using.rs

@@ -94,8 +94,9 @@ pub(crate) fn using_decl(
 
             for using_function in functions {
                 let function_name = &using_function.path;
-                if let Ok(list) = ns.resolve_free_function_with_namespace(
+                if let Ok(list) = ns.resolve_function_with_namespace(
                     file_no,
+                    contract_no,
                     &using_function.path,
                     &mut diagnostics,
                 ) {
@@ -120,6 +121,18 @@ pub(crate) fn using_decl(
 
                     let func = &ns.functions[func_no];
 
+                    if let Some(contract_no) = func.contract_no {
+                        if !ns.contracts[contract_no].is_library() {
+                            diagnostics.push(Diagnostic::error_with_note(
+                                function_name.loc,
+                                format!("'{function_name}' is not a library function"),
+                                func.loc,
+                                format!("definition of {}", using_function.path),
+                            ));
+                            continue;
+                        }
+                    }
+
                     if func.params.is_empty() {
                         diagnostics.push(Diagnostic::error_with_note(
                             function_name.loc,
@@ -251,7 +264,22 @@ pub(crate) fn using_decl(
                         Some(oper)
                     } else {
                         if let Some(ty) = &ty {
-                            if *ty != func.params[0].ty {
+                            let dummy = Expression::Variable {
+                                loc,
+                                ty: ty.clone(),
+                                var_no: 0,
+                            };
+
+                            if dummy
+                                .cast(
+                                    &loc,
+                                    &func.params[0].ty,
+                                    true,
+                                    ns,
+                                    &mut Diagnostics::default(),
+                                )
+                                .is_err()
+                            {
                                 diagnostics.push(Diagnostic::error_with_note(
                                     function_name.loc,
                                     format!("function cannot be used since first argument is '{}' rather than the required '{}'", func.params[0].ty.to_string(ns), ty.to_string(ns)),

+ 31 - 0
tests/contract_testcases/solana/using_functions.sol

@@ -0,0 +1,31 @@
+contract C {
+	function foo(int256 a) internal pure returns (int256) {
+		return a;
+	}
+}
+
+library L {
+	function bar(int256 a) internal pure returns (int256) {
+		return a;
+	}
+}
+
+library Lib {
+	function baz(int256 a, bool b) internal pure returns (int256) {
+		if (b) {
+			return 1;
+		} else {
+			return a;
+		}
+	}
+	using {L.bar, baz} for int256;
+}
+
+library Lib2 {
+	using {L.foo.bar, C.foo} for int256;
+}
+
+// ---- Expect: diagnostics ----
+// error: 25:15-18: 'foo' not found
+// error: 25:20-25: 'C.foo' is not a library function
+// 	note 2:2-55: definition of C.foo

+ 1 - 1
tests/evm.rs

@@ -249,7 +249,7 @@ fn ethereum_solidity_tests() {
         })
         .sum();
 
-    assert_eq!(errors, 1024);
+    assert_eq!(errors, 1018);
 }
 
 fn set_file_contents(source: &str, path: &Path) -> (FileResolver, Vec<String>) {

+ 1 - 1
tests/polkadot_tests/libraries.rs

@@ -74,7 +74,7 @@ fn using() {
     let mut runtime = build_solidity(
         r##"
         contract test {
-            using ints for uint32;
+            using {ints.max} for uint32;
             function foo(uint32 x) public pure returns (uint64) {
                 // x is 32 bit but the max function takes 64 bit uint
                 return x.max(65536);

+ 127 - 0
tests/solana_tests/using.rs

@@ -255,3 +255,130 @@ contract C {
 
     assert_eq!(res, BorshToken::Bool(true));
 }
+
+#[test]
+fn using_function_for_struct() {
+    let mut vm = build_solidity(
+        r#"
+struct Pet {
+    string name;
+    uint8 age;
+}
+
+library Info {
+    function isCat(Pet memory myPet) public pure returns (bool) {
+        return myPet.name == "cat";
+    }
+
+    function setAge(Pet memory myPet, uint8 age) pure public {
+        myPet.age = age;
+    }
+}
+
+contract C {
+    using {Info.isCat, Info.setAge} for Pet;
+
+    function testPet(string memory name, uint8 age) pure public returns (bool) {
+        Pet memory my_pet = Pet(name, age);
+        return my_pet.isCat();
+    }
+
+    function changeAge(Pet memory myPet) public pure returns (Pet memory) {
+        myPet.setAge(5);
+        return myPet;
+    }
+
+}
+        "#,
+    );
+
+    let data_account = vm.initialize_data_account();
+
+    vm.function("new")
+        .accounts(vec![("dataAccount", data_account)])
+        .call();
+
+    let res = vm
+        .function("testPet")
+        .arguments(&[
+            BorshToken::String("cat".to_string()),
+            BorshToken::Uint {
+                width: 8,
+                value: BigInt::from(2u8),
+            },
+        ])
+        .call()
+        .unwrap();
+
+    assert_eq!(res, BorshToken::Bool(true));
+
+    let res = vm
+        .function("changeAge")
+        .arguments(&[BorshToken::Tuple(vec![
+            BorshToken::String("cat".to_string()),
+            BorshToken::Uint {
+                width: 8,
+                value: BigInt::from(2u8),
+            },
+        ])])
+        .call()
+        .unwrap();
+
+    assert_eq!(
+        res,
+        BorshToken::Tuple(vec![
+            BorshToken::String("cat".to_string()),
+            BorshToken::Uint {
+                width: 8,
+                value: BigInt::from(5u8),
+            }
+        ])
+    );
+}
+
+#[test]
+fn using_function_overload() {
+    let mut vm = build_solidity(
+        r#"
+        library LibInLib {
+            function get0(bytes x) public pure returns (bytes1) {
+                return x[0];
+            }
+
+            function get1(bytes x) public pure returns (bytes1) {
+                return x[1];
+            }
+        }
+
+        library MyBytes {
+            using {LibInLib.get0, LibInLib.get1} for bytes;
+
+            function push(bytes memory b, uint8[] memory a) pure public returns (bool) {
+                return b.get0() == a[0] && b.get1()== a[1];
+            }
+        }
+
+        contract C {
+            using {MyBytes.push} for bytes;
+
+            function check() public pure returns (bool) {
+                bytes memory b;
+                b.push(1);
+                b.push(2);
+                uint8[] memory vec = new uint8[](2);
+                vec[0] = 1;
+                vec[1] = 2;
+                return b.push(vec);
+            }
+        }"#,
+    );
+
+    let data_account = vm.initialize_data_account();
+    vm.function("new")
+        .accounts(vec![("dataAccount", data_account)])
+        .call();
+
+    let res = vm.function("check").call().unwrap();
+
+    assert_eq!(res, BorshToken::Bool(true));
+}