Эх сурвалжийг харах

encode bytes/string for substrate

Signed-off-by: Sean Young <sean@mess.org>
Sean Young 5 жил өмнө
parent
commit
2a4e5047df

+ 10 - 4
src/emit/ethabiencoder.rs

@@ -43,8 +43,11 @@ impl EthAbiEncoder {
                 if let Some(d) = &dim[0] {
                     contract.emit_static_loop_with_pointer(
                         function,
-                        0,
-                        d.to_u64().unwrap(),
+                        contract.context.i64_type().const_zero(),
+                        contract
+                            .context
+                            .i64_type()
+                            .const_int(d.to_u64().unwrap(), false),
                         data,
                         |index, data| {
                             let mut elem = unsafe {
@@ -613,8 +616,11 @@ impl EthAbiEncoder {
                 if let Some(d) = &dim[0] {
                     contract.emit_static_loop_with_pointer(
                         function,
-                        0,
-                        d.to_u64().unwrap(),
+                        contract.context.i64_type().const_zero(),
+                        contract
+                            .context
+                            .i64_type()
+                            .const_int(d.to_u64().unwrap(), false),
                         data,
                         |index: IntValue<'b>, data: &mut PointerValue<'b>| {
                             let elem = unsafe {

+ 7 - 10
src/emit/mod.rs

@@ -222,8 +222,8 @@ impl<'a> Contract<'a> {
     pub fn emit_static_loop_with_pointer<'b, F>(
         &'b self,
         function: FunctionValue,
-        from: u64,
-        to: u64,
+        from: IntValue<'b>,
+        to: IntValue<'b>,
         data_ref: &mut PointerValue<'b>,
         mut insert_body: F,
     ) where
@@ -236,7 +236,7 @@ impl<'a> Contract<'a> {
         self.builder.build_unconditional_branch(body);
         self.builder.position_at_end(body);
 
-        let loop_ty = self.context.i64_type();
+        let loop_ty = from.get_type();
         let loop_phi = self.builder.build_phi(loop_ty, "index");
         let data_phi = self.builder.build_phi(data_ref.get_type(), "data");
         let mut data = data_phi.as_basic_value().into_pointer_value();
@@ -250,15 +250,12 @@ impl<'a> Contract<'a> {
             .builder
             .build_int_add(loop_var, loop_ty.const_int(1, false), "next_index");
 
-        let comp = self.builder.build_int_compare(
-            IntPredicate::ULT,
-            next,
-            loop_ty.const_int(to, false),
-            "loop_cond",
-        );
+        let comp = self
+            .builder
+            .build_int_compare(IntPredicate::ULT, next, to, "loop_cond");
         self.builder.build_conditional_branch(comp, body, done);
 
-        loop_phi.add_incoming(&[(&loop_ty.const_int(from, false), entry), (&next, body)]);
+        loop_phi.add_incoming(&[(&from, entry), (&next, body)]);
         data_phi.add_incoming(&[(&*data_ref, entry), (&data, body)]);
 
         self.builder.position_at_end(done);

+ 224 - 64
src/emit/substrate.rs

@@ -441,8 +441,11 @@ impl SubstrateTarget {
                 if let Some(d) = &dim[0] {
                     contract.emit_static_loop_with_pointer(
                         function,
-                        0,
-                        d.to_u64().unwrap(),
+                        contract.context.i64_type().const_zero(),
+                        contract
+                            .context
+                            .i64_type()
+                            .const_int(d.to_u64().unwrap(), false),
                         data,
                         |index: IntValue<'b>, data: &mut PointerValue<'b>| {
                             let elem = unsafe {
@@ -598,7 +601,7 @@ impl SubstrateTarget {
         contract: &'a Contract,
         function: FunctionValue,
         ty: &resolver::Type,
-        arg: BasicValueEnum,
+        arg: BasicValueEnum<'a>,
         data: &mut PointerValue<'a>,
     ) {
         match &ty {
@@ -624,8 +627,11 @@ impl SubstrateTarget {
                 if let Some(d) = &dim[0] {
                     contract.emit_static_loop_with_pointer(
                         function,
-                        0,
-                        d.to_u64().unwrap(),
+                        contract.context.i64_type().const_zero(),
+                        contract
+                            .context
+                            .i64_type()
+                            .const_int(d.to_u64().unwrap(), false),
                         data,
                         |index, data| {
                             let mut elem = unsafe {
@@ -674,37 +680,199 @@ impl SubstrateTarget {
             resolver::Type::Ref(ty) => {
                 self.encode_ty(contract, function, ty, arg, data);
             }
-            resolver::Type::String | resolver::Type::DynamicBytes => unimplemented!(),
+            resolver::Type::String | resolver::Type::DynamicBytes => {
+                *data = contract
+                    .builder
+                    .build_call(
+                        contract.module.get_function("scale_encode_string").unwrap(),
+                        &[(*data).into(), arg],
+                        "",
+                    )
+                    .try_as_basic_value()
+                    .left()
+                    .unwrap()
+                    .into_pointer_value();
+            }
         };
     }
 
-    /// Return the encoded length of the given type
-    pub fn encoded_length(&self, ty: &resolver::Type, contract: &resolver::Contract) -> u64 {
+    /// Calculate the maximum space a type will need when encoded. This is used for
+    /// allocating enough space to do abi encoding. The length for vectors is always
+    /// assumed to be five, even when it can be encoded in less bytes. The overhead
+    /// of calculating the exact size is not worth reducing the malloc by a few bytes.
+    pub fn encoded_length<'a>(
+        &self,
+        arg: BasicValueEnum<'a>,
+        ty: &resolver::Type,
+        function: FunctionValue,
+        contract: &'a Contract,
+    ) -> IntValue<'a> {
         match ty {
-            resolver::Type::Bool => 1,
-            resolver::Type::Uint(n) | resolver::Type::Int(n) => *n as u64 / 8,
-            resolver::Type::Bytes(n) => *n as u64,
-            resolver::Type::Address => ADDRESS_LENGTH,
-            resolver::Type::Enum(n) => self.encoded_length(&contract.enums[*n].ty, contract),
-            resolver::Type::Struct(n) => contract.structs[*n]
-                .fields
-                .iter()
-                .map(|f| self.encoded_length(&f.ty, contract))
-                .sum(),
-            resolver::Type::Array(ty, dims) => {
-                self.encoded_length(ty, contract)
-                    * dims
-                        .iter()
-                        .map(|d| match d {
-                            Some(d) => d.to_u64().unwrap(),
-                            None => 1,
-                        })
-                        .product::<u64>()
+            resolver::Type::Bool => contract.context.i32_type().const_int(1, false),
+            resolver::Type::Uint(n) | resolver::Type::Int(n) => {
+                contract.context.i32_type().const_int(*n as u64 / 8, false)
+            }
+            resolver::Type::Bytes(n) => contract.context.i32_type().const_int(*n as u64, false),
+            resolver::Type::Address => contract.context.i32_type().const_int(ADDRESS_LENGTH, false),
+            resolver::Type::Enum(n) => {
+                self.encoded_length(arg, &contract.ns.enums[*n].ty, function, contract)
+            }
+            resolver::Type::Struct(n) => {
+                let mut sum = contract.context.i32_type().const_zero();
+
+                for (i, field) in contract.ns.structs[*n].fields.iter().enumerate() {
+                    let mut elem = unsafe {
+                        contract.builder.build_gep(
+                            arg.into_pointer_value(),
+                            &[
+                                contract.context.i32_type().const_zero(),
+                                contract.context.i32_type().const_int(i as u64, false),
+                            ],
+                            &field.name,
+                        )
+                    };
+
+                    if field.ty.is_reference_type() {
+                        elem = contract.builder.build_load(elem, "").into_pointer_value()
+                    }
+
+                    sum = contract.builder.build_int_add(
+                        sum,
+                        self.encoded_length(elem.into(), &field.ty, function, contract),
+                        "",
+                    );
+                }
+
+                sum
+            }
+            resolver::Type::Array(_, dims) => {
+                let len = match dims.last().unwrap() {
+                    None => {
+                        let len = unsafe {
+                            contract.builder.build_gep(
+                                arg.into_pointer_value(),
+                                &[
+                                    contract.context.i32_type().const_zero(),
+                                    contract.context.i32_type().const_zero(),
+                                ],
+                                "array.len",
+                            )
+                        };
+
+                        contract
+                            .builder
+                            .build_load(len, "array.len")
+                            .into_int_value()
+                    }
+                    Some(d) => contract
+                        .context
+                        .i32_type()
+                        .const_int(d.to_u64().unwrap(), false),
+                };
+
+                let elem_ty = ty.array_elem();
+                let llvm_elem_ty = contract.llvm_var(&elem_ty);
+
+                if elem_ty.is_dynamic(contract.ns) {
+                    let mut sum = contract.context.i32_type().const_zero();
+
+                    contract.emit_static_loop_with_int(
+                        function,
+                        contract.context.i32_type().const_zero(),
+                        len,
+                        &mut sum,
+                        |index, sum| {
+                            let index = contract.builder.build_int_mul(
+                                index,
+                                llvm_elem_ty
+                                    .into_pointer_type()
+                                    .get_element_type()
+                                    .size_of()
+                                    .unwrap()
+                                    .const_cast(contract.context.i32_type(), false),
+                                "",
+                            );
+
+                            let mut elem = unsafe {
+                                contract.builder.build_gep(
+                                    arg.into_pointer_value(),
+                                    &[
+                                        contract.context.i32_type().const_zero(),
+                                        contract.context.i32_type().const_int(2, false),
+                                        index,
+                                    ],
+                                    "index_access",
+                                )
+                            };
+
+                            if ty.is_reference_type() {
+                                elem = contract.builder.build_load(elem, "").into_pointer_value()
+                            }
+
+                            *sum = contract.builder.build_int_add(
+                                self.encoded_length(elem.into(), &elem_ty, function, contract),
+                                *sum,
+                                "",
+                            );
+                        },
+                    );
+
+                    sum
+                } else {
+                    // arg
+                    let elem_ty = ty.array_deref();
+
+                    let elem = unsafe {
+                        contract.builder.build_gep(
+                            arg.into_pointer_value(),
+                            &[
+                                contract.context.i32_type().const_zero(),
+                                contract.context.i32_type().const_zero(),
+                            ],
+                            "index_access",
+                        )
+                    };
+
+                    let arg = if elem_ty.is_reference_type() {
+                        contract.builder.build_load(elem, "")
+                    } else {
+                        elem.into()
+                    };
+
+                    contract.builder.build_int_mul(
+                        self.encoded_length(arg, &elem_ty, function, contract),
+                        len,
+                        "",
+                    )
+                }
             }
             resolver::Type::Undef => unreachable!(),
             resolver::Type::StorageRef(_) => unreachable!(),
-            resolver::Type::Ref(r) => self.encoded_length(r, contract),
-            resolver::Type::String | resolver::Type::DynamicBytes => unimplemented!(),
+            resolver::Type::Ref(r) => self.encoded_length(arg, r, function, contract),
+            resolver::Type::String | resolver::Type::DynamicBytes => {
+                // A string or bytes type has to be encoded by: one compact integer for
+                // the length, followed by the bytes themselves. Here we assume that the
+                // length requires 5 bytes.
+                let len = unsafe {
+                    contract.builder.build_gep(
+                        arg.into_pointer_value(),
+                        &[
+                            contract.context.i32_type().const_zero(),
+                            contract.context.i32_type().const_zero(),
+                        ],
+                        "string.len",
+                    )
+                };
+
+                contract.builder.build_int_add(
+                    contract
+                        .builder
+                        .build_load(len, "string.len")
+                        .into_int_value(),
+                    contract.context.i32_type().const_int(5, false),
+                    "",
+                )
+            }
         }
     }
 }
@@ -908,36 +1076,9 @@ impl TargetRuntime for SubstrateTarget {
         function: FunctionValue,
         args: &mut Vec<BasicValueEnum<'b>>,
         data: PointerValue<'b>,
-        datalength: IntValue,
+        _datalength: IntValue,
         spec: &resolver::FunctionDecl,
     ) {
-        let length = spec
-            .params
-            .iter()
-            .map(|arg| self.encoded_length(&arg.ty, contract.ns))
-            .sum();
-
-        let decode_block = contract.context.append_basic_block(function, "abi_decode");
-        let wrong_length_block = contract
-            .context
-            .append_basic_block(function, "wrong_abi_length");
-
-        let is_ok = contract.builder.build_int_compare(
-            IntPredicate::EQ,
-            datalength,
-            contract.context.i32_type().const_int(length, false),
-            "correct_length",
-        );
-
-        contract
-            .builder
-            .build_conditional_branch(is_ok, decode_block, wrong_length_block);
-
-        contract.builder.position_at_end(wrong_length_block);
-        contract.builder.build_unreachable();
-
-        contract.builder.position_at_end(decode_block);
-
         let mut argsdata = contract.builder.build_pointer_cast(
             data,
             contract.context.i8_type().ptr_type(AddressSpace::Generic),
@@ -957,13 +1098,23 @@ impl TargetRuntime for SubstrateTarget {
         args: &[BasicValueEnum<'b>],
         spec: &resolver::FunctionDecl,
     ) -> (PointerValue<'b>, IntValue<'b>) {
-        let length = spec
-            .returns
-            .iter()
-            .map(|arg| self.encoded_length(&arg.ty, contract.ns))
-            .sum();
+        let mut length = contract.context.i32_type().const_zero();
 
-        let length = contract.context.i32_type().const_int(length, false);
+        for (i, field) in spec.returns.iter().enumerate() {
+            let val = if field.ty.is_reference_type() {
+                contract
+                    .builder
+                    .build_load(args[i].into_pointer_value(), "")
+            } else {
+                args[i]
+            };
+
+            length = contract.builder.build_int_add(
+                length,
+                self.encoded_length(val, &field.ty, function, contract),
+                "",
+            );
+        }
 
         let data = contract
             .builder
@@ -991,6 +1142,15 @@ impl TargetRuntime for SubstrateTarget {
             self.encode_ty(contract, function, &arg.ty, val, &mut argsdata);
         }
 
-        (data, length)
+        let length = contract
+            .builder
+            .build_ptr_diff(argsdata, data, "datalength");
+
+        (
+            data,
+            contract
+                .builder
+                .build_int_cast(length, contract.context.i32_type(), "datalength"),
+        )
     }
 }

+ 29 - 0
src/resolver/mod.rs

@@ -126,6 +126,17 @@ impl Type {
         }
     }
 
+    /// Given an array, return the type of its elements
+    pub fn array_elem(&self) -> Self {
+        match self {
+            Type::Array(ty, dim) if dim.len() > 1 => {
+                Type::Array(ty.clone(), dim[..dim.len() - 1].to_vec())
+            }
+            Type::Array(ty, dim) if dim.len() == 1 => *ty.clone(),
+            _ => panic!("not an array"),
+        }
+    }
+
     /// Give the type of an storage array after dereference. This can only be used on
     /// array types and will cause a panic otherwise.
     pub fn storage_deref(&self) -> Self {
@@ -241,6 +252,24 @@ impl Type {
         }
     }
 
+    /// Does this type contain any types which are variable-length
+    pub fn is_dynamic(&self, ns: &Contract) -> bool {
+        match self {
+            Type::String | Type::DynamicBytes => true,
+            Type::Ref(r) => r.is_dynamic(ns),
+            Type::Array(ty, dim) => {
+                if dim.iter().any(|d| d.is_none()) {
+                    return true;
+                }
+
+                ty.is_dynamic(ns)
+            }
+            Type::Struct(n) => ns.structs[*n].fields.iter().any(|f| f.ty.is_dynamic(ns)),
+            Type::StorageRef(r) => r.is_dynamic(ns),
+            _ => false,
+        }
+    }
+
     /// Can this type have a calldata, memory, or storage location. This is to be
     /// compatible with ethereum solidity. Opinions on whether other types should be
     /// allowed be storage are welcome.

BIN
stdlib/stdlib.bc


+ 42 - 0
stdlib/stdlib.c

@@ -456,4 +456,46 @@ __attribute__((visibility("hidden"))) struct vector *concat(uint8_t *left, uint3
 	}
 
 	return v;
+}
+
+// Encode an 32 bit integer as as scale compact integer
+// https://substrate.dev/docs/en/conceptual/core/codec#vectors-lists-series-sets
+uint8_t *compact_encode_u32(uint8_t *dest, uint32_t val)
+{
+	if (val < 64)
+	{
+		*dest++ = val << 2;
+	}
+	else if (val < 4000)
+	{
+		*((uint16_t *)dest) = (val << 2) | 1;
+		dest += 2;
+	}
+	else if (val < 0x40000000)
+	{
+		*((uint32_t *)dest) = (val << 2) | 2;
+		dest += 4;
+	}
+	else
+	{
+		*dest++ = 3;
+		*((uint32_t *)dest) = val;
+		dest += 4;
+	}
+
+	return dest;
+}
+
+uint8_t *scale_encode_string(uint8_t *dest, struct vector *s)
+{
+	uint32_t len = s->len;
+	uint8_t *data_dst = compact_encode_u32(dest, len);
+	uint8_t *data = s->data;
+
+	while (len--)
+	{
+		*data_dst++ = *data++;
+	}
+
+	return data_dst;
 }

+ 22 - 0
tests/substrate_strings/mod.rs

@@ -1,3 +1,6 @@
+use parity_scale_codec::{Decode, Encode};
+use parity_scale_codec_derive::{Decode, Encode};
+
 use super::{build_solidity, first_error, no_errors};
 use solang::{parse_and_resolve, Target};
 
@@ -238,3 +241,22 @@ fn string_concat() {
 
     runtime.function(&mut store, "test", Vec::new());
 }
+
+#[test]
+fn string_abi_encode() {
+    #[derive(Debug, PartialEq, Encode, Decode)]
+    struct Val(String);
+
+    let (runtime, mut store) = build_solidity(
+        r##"
+        contract foo {
+            function test() public returns (string) {
+                return "foobar";
+            }
+        }"##,
+    );
+
+    runtime.function(&mut store, "test", Vec::new());
+
+    assert_eq!(store.scratch, Val("foobar".to_string()).encode());
+}