瀏覽代碼

Introduce string compare using == or != operators

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 5 年之前
父節點
當前提交
176f244d37
共有 6 個文件被更改,包括 245 次插入33 次删除
  1. 71 1
      src/emit/mod.rs
  2. 16 1
      src/resolver/cfg.rs
  3. 94 31
      src/resolver/expression.rs
  4. 二進制
      stdlib/stdlib.bc
  5. 14 0
      stdlib/stdlib.c
  6. 50 0
      tests/substrate_strings/mod.rs

+ 71 - 1
src/emit/mod.rs

@@ -2,7 +2,7 @@ use hex;
 use parser::ast;
 use resolver;
 use resolver::cfg;
-use resolver::expression::Expression;
+use resolver::expression::{Expression, StringLocation};
 use std::path::Path;
 use std::str;
 
@@ -1141,10 +1141,80 @@ impl<'a> Contract<'a> {
 
                 self.builder.build_load(dst, "keccak256_hash")
             }
+            Expression::StringCompare(_, l, r) => {
+                let (left, left_len) = self.string_location(l, vartab, function, runtime);
+                let (right, right_len) = self.string_location(r, vartab, function, runtime);
+
+                self.builder
+                    .build_call(
+                        self.module.get_function("memcmp").unwrap(),
+                        &[left.into(), left_len.into(), right.into(), right_len.into()],
+                        "",
+                    )
+                    .try_as_basic_value()
+                    .left()
+                    .unwrap()
+            }
             Expression::Poison => unreachable!(),
         }
     }
 
+    /// Load a string from expression or create global
+    fn string_location(
+        &self,
+        location: &StringLocation,
+        vartab: &[Variable<'a>],
+        function: FunctionValue<'a>,
+        runtime: &dyn TargetRuntime,
+    ) -> (PointerValue<'a>, IntValue<'a>) {
+        match location {
+            StringLocation::CompileTime(literal) => (
+                self.emit_global_string("const_string", literal, false),
+                self.context
+                    .i32_type()
+                    .const_int(literal.len() as u64, false),
+            ),
+            StringLocation::RunTime(e) => {
+                let v = self
+                    .expression(e, vartab, function, runtime)
+                    .into_pointer_value();
+
+                let data = unsafe {
+                    self.builder.build_gep(
+                        v,
+                        &[
+                            self.context.i32_type().const_zero(),
+                            self.context.i32_type().const_int(2, false),
+                        ],
+                        "data",
+                    )
+                };
+
+                let data_len = unsafe {
+                    self.builder.build_gep(
+                        v,
+                        &[
+                            self.context.i32_type().const_zero(),
+                            self.context.i32_type().const_zero(),
+                        ],
+                        "data_len",
+                    )
+                };
+
+                (
+                    self.builder.build_pointer_cast(
+                        data,
+                        self.context.i8_type().ptr_type(AddressSpace::Generic),
+                        "data",
+                    ),
+                    self.builder
+                        .build_load(data_len, "data_len")
+                        .into_int_value(),
+                )
+            }
+        }
+    }
+
     /// Convert a BigInt number to llvm const value
     fn number_literal(&self, bits: u32, n: &BigInt) -> IntValue<'a> {
         let ty = self.context.custom_width_int_type(bits);

+ 16 - 1
src/resolver/cfg.rs

@@ -9,7 +9,7 @@ use output;
 use output::Output;
 use parser::ast;
 use resolver;
-use resolver::expression::{cast, expression, Expression};
+use resolver::expression::{cast, expression, Expression, StringLocation};
 
 pub enum Instr {
     FuncArg {
@@ -336,10 +336,25 @@ impl ControlFlowGraph {
             Expression::DynamicArrayLength(_, a) => {
                 format!("(array {} len)", self.expr_to_string(ns, a))
             }
+            Expression::StringCompare(_, l, r) => format!(
+                "(strcmp ({}) ({}))",
+                self.location_to_string(ns, l),
+                self.location_to_string(ns, r)
+            ),
             Expression::Keccak256(_, e) => format!("(keccak256 {})", self.expr_to_string(ns, e)),
         }
     }
 
+    fn location_to_string(&self, ns: &resolver::Contract, l: &StringLocation) -> String {
+        match l {
+            StringLocation::RunTime(e) => self.expr_to_string(ns, e),
+            StringLocation::CompileTime(literal) => match str::from_utf8(literal) {
+                Ok(s) => format!("\"{}\"", s.to_owned()),
+                Err(_) => format!("hex\"{}\"", hex::encode(literal)),
+            },
+        }
+    }
+
     pub fn instr_to_string(&self, ns: &resolver::Contract, instr: &Instr) -> String {
         match instr {
             Instr::Return { value } => {

+ 94 - 31
src/resolver/expression.rs

@@ -73,6 +73,7 @@ pub enum Expression {
     AllocDynamicArray(Loc, resolver::Type, Box<Expression>, Option<Vec<u8>>),
     DynamicArrayLength(Loc, Box<Expression>),
     DynamicArraySubscript(Loc, Box<Expression>, resolver::Type, Box<Expression>),
+    StringCompare(Loc, StringLocation, StringLocation),
 
     Or(Loc, Box<Expression>, Box<Expression>),
     And(Loc, Box<Expression>, Box<Expression>),
@@ -82,6 +83,12 @@ pub enum Expression {
     Poison,
 }
 
+#[derive(PartialEq, Clone, Debug)]
+pub enum StringLocation {
+    CompileTime(Vec<u8>),
+    RunTime(Box<Expression>),
+}
+
 impl Expression {
     /// Return the location for this expression
     pub fn loc(&self) -> Loc {
@@ -131,6 +138,7 @@ impl Expression {
             | Expression::AllocDynamicArray(loc, _, _, _)
             | Expression::DynamicArrayLength(loc, _)
             | Expression::DynamicArraySubscript(loc, _, _, _)
+            | Expression::StringCompare(loc, _, _)
             | Expression::Keccak256(loc, _)
             | Expression::And(loc, _, _) => *loc,
             Expression::Poison => unreachable!(),
@@ -231,6 +239,17 @@ impl Expression {
             Expression::Keccak256(_, e) => e.reads_contract_storage(),
             Expression::And(_, l, r) => l.reads_contract_storage() || r.reads_contract_storage(),
             Expression::Or(_, l, r) => l.reads_contract_storage() || r.reads_contract_storage(),
+            Expression::StringCompare(_, l, r) => {
+                if let StringLocation::RunTime(e) = l {
+                    if !e.reads_contract_storage() {
+                        return false;
+                    }
+                }
+                if let StringLocation::RunTime(e) = r {
+                    return e.reads_contract_storage();
+                }
+                false
+            }
             Expression::Poison => false,
         }
     }
@@ -1392,37 +1411,14 @@ pub fn expression(
                 ))
             }
         }
-        ast::Expression::Equal(loc, l, r) => {
-            let (left, left_type) = expression(l, cfg, ns, vartab, errors)?;
-            let (right, right_type) = expression(r, cfg, ns, vartab, errors)?;
-
-            let ty = coerce(&left_type, &l.loc(), &right_type, &r.loc(), ns, errors)?;
-
-            Ok((
-                Expression::Equal(
-                    *loc,
-                    Box::new(cast(&l.loc(), left, &left_type, &ty, true, ns, errors)?),
-                    Box::new(cast(&r.loc(), right, &right_type, &ty, true, ns, errors)?),
-                ),
-                resolver::Type::Bool,
-            ))
-        }
-        ast::Expression::NotEqual(loc, l, r) => {
-            let (left, left_type) = expression(l, cfg, ns, vartab, errors)?;
-            let (right, right_type) = expression(r, cfg, ns, vartab, errors)?;
-
-            let ty = coerce(&left_type, &l.loc(), &right_type, &r.loc(), ns, errors)?;
-
-            Ok((
-                Expression::NotEqual(
-                    *loc,
-                    Box::new(cast(&l.loc(), left, &left_type, &ty, true, ns, errors)?),
-                    Box::new(cast(&r.loc(), right, &right_type, &ty, true, ns, errors)?),
-                ),
-                resolver::Type::Bool,
-            ))
-        }
-
+        ast::Expression::Equal(loc, l, r) => Ok((
+            equal(loc, l, r, cfg, ns, vartab, errors)?,
+            resolver::Type::Bool,
+        )),
+        ast::Expression::NotEqual(loc, l, r) => Ok((
+            Expression::Not(*loc, Box::new(equal(loc, l, r, cfg, ns, vartab, errors)?)),
+            resolver::Type::Bool,
+        )),
         // unary expressions
         ast::Expression::Not(loc, e) => {
             let (expr, expr_type) = expression(e, cfg, ns, vartab, errors)?;
@@ -2571,6 +2567,73 @@ fn new(
     ))
 }
 
+/// Resolve an array subscript expression
+fn equal(
+    loc: &ast::Loc,
+    l: &ast::Expression,
+    r: &ast::Expression,
+    cfg: &mut ControlFlowGraph,
+    ns: &resolver::Contract,
+    vartab: &mut Option<&mut Vartable>,
+    errors: &mut Vec<output::Output>,
+) -> Result<Expression, ()> {
+    let (left, left_type) = expression(l, cfg, ns, vartab, errors)?;
+    let (right, right_type) = expression(r, cfg, ns, vartab, errors)?;
+
+    // Comparing stringliteral against stringliteral
+    if let (Expression::BytesLiteral(_, l), Expression::BytesLiteral(_, r)) = (&left, &right) {
+        return Ok(Expression::BoolLiteral(*loc, l == r));
+    }
+
+    // compare string against literal
+    match (&left, &right_type) {
+        (Expression::BytesLiteral(_, l), resolver::Type::String)
+        | (Expression::BytesLiteral(_, l), resolver::Type::DynamicBytes) => {
+            return Ok(Expression::StringCompare(
+                *loc,
+                StringLocation::RunTime(Box::new(right)),
+                StringLocation::CompileTime(l.clone()),
+            ));
+        }
+        _ => {}
+    }
+
+    match (&right, &left_type) {
+        (Expression::BytesLiteral(_, l), resolver::Type::String)
+        | (Expression::BytesLiteral(_, l), resolver::Type::DynamicBytes) => {
+            return Ok(Expression::StringCompare(
+                *loc,
+                StringLocation::RunTime(Box::new(left)),
+                StringLocation::CompileTime(l.clone()),
+            ));
+        }
+        _ => {}
+    }
+
+    // compare string
+    match (&left_type, &right_type) {
+        (resolver::Type::String, resolver::Type::String)
+        | (resolver::Type::DynamicBytes, resolver::Type::DynamicBytes)
+        | (resolver::Type::String, resolver::Type::DynamicBytes)
+        | (resolver::Type::DynamicBytes, resolver::Type::String) => {
+            return Ok(Expression::StringCompare(
+                *loc,
+                StringLocation::RunTime(Box::new(left)),
+                StringLocation::RunTime(Box::new(right)),
+            ));
+        }
+        _ => {}
+    }
+
+    let ty = coerce(&left_type, &l.loc(), &right_type, &r.loc(), ns, errors)?;
+
+    Ok(Expression::Equal(
+        *loc,
+        Box::new(cast(&l.loc(), left, &left_type, &ty, true, ns, errors)?),
+        Box::new(cast(&r.loc(), right, &right_type, &ty, true, ns, errors)?),
+    ))
+}
+
 /// Resolve an array subscript expression
 fn array_subscript(
     loc: &ast::Loc,

二進制
stdlib/stdlib.bc


+ 14 - 0
stdlib/stdlib.c

@@ -421,3 +421,17 @@ __attribute__((visibility("hidden"))) struct vector *vector_new(uint32_t members
 
 	return v;
 }
+
+__attribute__((visibility("hidden"))) bool memcmp(uint8_t *left, uint32_t left_len, uint8_t *right, uint32_t right_len)
+{
+	if (left_len != right_len)
+		return false;
+
+	while (left_len--)
+	{
+		if (*left++ != *right++)
+			return false;
+	}
+
+	return true;
+}

+ 50 - 0
tests/substrate_strings/mod.rs

@@ -156,3 +156,53 @@ fn more_tests() {
 
     runtime.function(&mut store, "test", Vec::new());
 }
+
+#[test]
+fn string_compare() {
+    // compare literal to literal. This should be compile-time thing
+    let (runtime, mut store) = build_solidity(
+        r##"
+        contract foo {
+            function test() public {
+                assert(hex"414243" == "ABC");
+
+                assert(hex"414243" != "ABD");
+            }
+        }"##,
+    );
+
+    runtime.function(&mut store, "test", Vec::new());
+
+    let (runtime, mut store) = build_solidity(
+        r##"
+        contract foo {
+            function lets_compare1(string s) private returns (bool) {
+                return s == "the quick brown fox jumps over the lazy dog";
+            }
+
+            function lets_compare2(string s) private returns (bool) {
+                return "the quick brown fox jumps over the lazy dog" == s;
+            }
+
+            function test() public {
+                string s1 = "the quick brown fox jumps over the lazy dog";
+
+                assert(lets_compare1(s1));
+                assert(lets_compare2(s1));
+
+                string s2 = "the quick brown dog jumps over the lazy fox";
+
+                assert(!lets_compare1(s2));
+                assert(!lets_compare2(s2));
+
+                assert(s1 != s2);
+
+                s1 = "the quick brown dog jumps over the lazy fox";
+
+                assert(s1 == s2);
+            }
+        }"##,
+    );
+
+    runtime.function(&mut store, "test", Vec::new());
+}