소스 검색

Implement assignment to selector and address in Yul (#945)

* Implement assignment to selector and address in Yul

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>

* Remove FunctionSelector and ExternalFunctionAddress builtins

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 3 년 전
부모
커밋
225b367208

+ 2 - 12
src/codegen/encoding/borsh_encoding.rs

@@ -273,19 +273,9 @@ impl BorshEncoding {
             }
 
             Type::ExternalFunction { .. } => {
-                let selector = Expression::Builtin(
-                    Loc::Codegen,
-                    vec![Type::Uint(32)],
-                    Builtin::FunctionSelector,
-                    vec![expr.clone()],
-                );
+                let selector = expr.external_function_selector();
 
-                let address = Expression::Builtin(
-                    Loc::Codegen,
-                    vec![Type::Address(false)],
-                    Builtin::ExternalFunctionAddress,
-                    vec![expr.clone()],
-                );
+                let address = expr.external_function_address();
 
                 cfg.add(
                     vartab,

+ 6 - 31
src/codegen/expression.rs

@@ -407,26 +407,16 @@ pub fn expression(
                 _ => unreachable!(),
             }
         }
-        ast::Expression::Builtin(
-            loc,
-            returns,
-            ast::Builtin::ExternalFunctionAddress,
-            func_expr,
-        ) => {
+        ast::Expression::Builtin(_, _, ast::Builtin::ExternalFunctionAddress, func_expr) => {
             if let ast::Expression::ExternalFunction { address, .. } = &func_expr[0] {
                 expression(address, cfg, contract_no, func, ns, vartab, opt)
             } else {
                 let func_expr = expression(&func_expr[0], cfg, contract_no, func, ns, vartab, opt);
 
-                Expression::Builtin(
-                    *loc,
-                    returns.clone(),
-                    Builtin::ExternalFunctionAddress,
-                    vec![func_expr],
-                )
+                func_expr.external_function_address()
             }
         }
-        ast::Expression::Builtin(loc, returns, ast::Builtin::FunctionSelector, func_expr) => {
+        ast::Expression::Builtin(loc, _, ast::Builtin::FunctionSelector, func_expr) => {
             match &func_expr[0] {
                 ast::Expression::ExternalFunction { function_no, .. }
                 | ast::Expression::InternalFunction { function_no, .. } => {
@@ -437,12 +427,7 @@ pub fn expression(
                     let func_expr =
                         expression(&func_expr[0], cfg, contract_no, func, ns, vartab, opt);
 
-                    Expression::Builtin(
-                        *loc,
-                        returns.clone(),
-                        Builtin::FunctionSelector,
-                        vec![func_expr],
-                    )
+                    func_expr.external_function_selector()
                 }
             }
         }
@@ -2382,18 +2367,8 @@ pub fn emit_function_call(
                     Expression::NumberLiteral(pt::Loc::Codegen, Type::Value, BigInt::zero())
                 };
 
-                let selector = Expression::Builtin(
-                    *loc,
-                    vec![Type::Bytes(4)],
-                    Builtin::FunctionSelector,
-                    vec![function.clone()],
-                );
-                let address = Expression::Builtin(
-                    *loc,
-                    vec![Type::Address(false)],
-                    Builtin::ExternalFunctionAddress,
-                    vec![function],
-                );
+                let selector = function.external_function_selector();
+                let address = function.external_function_address();
 
                 let (payload, address) = if ns.target == Target::Solana {
                     tys.insert(0, Type::Address(false));

+ 46 - 7
src/codegen/mod.rs

@@ -907,9 +907,22 @@ impl Expression {
                 Box::new(self.clone()),
             ),
 
-            (Type::Bytes(_), Type::Uint(_))
-            | (Type::Bytes(_), Type::Int(_))
-            | (Type::Uint(_), Type::Bytes(_))
+            (Type::Bytes(n), Type::Uint(bits) | Type::Int(bits)) => {
+                let num_bytes = (bits / 8) as u8;
+                match n.cmp(&num_bytes) {
+                    Ordering::Greater => {
+                        Expression::Trunc(self.loc(), to.clone(), Box::new(self.clone()))
+                    }
+                    Ordering::Less => {
+                        Expression::ZeroExt(self.loc(), to.clone(), Box::new(self.clone()))
+                    }
+                    Ordering::Equal => {
+                        Expression::Cast(self.loc(), to.clone(), Box::new(self.clone()))
+                    }
+                }
+            }
+
+            (Type::Uint(_), Type::Bytes(_))
             | (Type::Int(_), Type::Bytes(_))
             | (Type::Bytes(_), Type::Address(_))
             | (Type::Address(false), Type::Address(true))
@@ -1197,6 +1210,36 @@ impl Expression {
             ctx,
         )
     }
+
+    fn external_function_address(&self) -> Expression {
+        debug_assert!(
+            matches!(self.ty(), Type::ExternalFunction { .. }),
+            "This is not an external function"
+        );
+        let loc = self.loc();
+        let struct_member = Expression::StructMember(
+            loc,
+            Type::Ref(Box::new(Type::Address(false))),
+            Box::new(self.clone()),
+            0,
+        );
+        Expression::Load(loc, Type::Address(false), Box::new(struct_member))
+    }
+
+    fn external_function_selector(&self) -> Expression {
+        debug_assert!(
+            matches!(self.ty(), Type::ExternalFunction { .. }),
+            "This is not an external function"
+        );
+        let loc = self.loc();
+        let struct_member = Expression::StructMember(
+            loc,
+            Type::Ref(Box::new(Type::Bytes(4))),
+            Box::new(self.clone()),
+            1,
+        );
+        Expression::Load(loc, Type::Bytes(4), Box::new(struct_member))
+    }
 }
 
 #[derive(PartialEq, Debug, Clone, Copy)]
@@ -1212,8 +1255,6 @@ pub enum Builtin {
     BlockHash,
     BlockNumber,
     Calldata,
-    ExternalFunctionAddress,
-    FunctionSelector,
     Gasleft,
     GasLimit,
     Gasprice,
@@ -1263,8 +1304,6 @@ impl From<&ast::Builtin> for Builtin {
             ast::Builtin::BlockHash => Builtin::BlockHash,
             ast::Builtin::BlockNumber => Builtin::BlockNumber,
             ast::Builtin::Calldata => Builtin::Calldata,
-            ast::Builtin::ExternalFunctionAddress => Builtin::ExternalFunctionAddress,
-            ast::Builtin::FunctionSelector => Builtin::FunctionSelector,
             ast::Builtin::Gasleft => Builtin::Gasleft,
             ast::Builtin::GasLimit => Builtin::GasLimit,
             ast::Builtin::Gasprice => Builtin::Gasprice,

+ 3 - 13
src/codegen/statements.rs

@@ -12,7 +12,7 @@ use crate::codegen::unused_variable::{
     should_remove_assignment, should_remove_variable, SideEffectsCheckParameters,
 };
 use crate::codegen::yul::inline_assembly_cfg;
-use crate::codegen::{Builtin, Expression};
+use crate::codegen::Expression;
 use crate::sema::ast::RetrieveType;
 use crate::sema::ast::{
     ArrayLength, CallTy, DestructureField, Function, Namespace, Parameter, Statement, TryCatch,
@@ -1124,19 +1124,9 @@ fn try_catch(
                     .map(|a| expression(a, cfg, callee_contract_no, Some(func), ns, vartab, opt))
                     .collect();
 
-                let selector = Expression::Builtin(
-                    *loc,
-                    vec![Type::Bytes(4)],
-                    Builtin::FunctionSelector,
-                    vec![function.clone()],
-                );
+                let selector = function.external_function_selector();
 
-                let address = Expression::Builtin(
-                    *loc,
-                    vec![Type::Address(false)],
-                    Builtin::ExternalFunctionAddress,
-                    vec![function],
-                );
+                let address = function.external_function_address();
 
                 let payload = Expression::AbiEncode {
                     loc: *loc,

+ 0 - 4
src/codegen/tests.rs

@@ -15,8 +15,6 @@ fn test_builtin_conversion() {
         ast::Builtin::BlockHash,
         ast::Builtin::BlockNumber,
         ast::Builtin::Calldata,
-        ast::Builtin::ExternalFunctionAddress,
-        ast::Builtin::FunctionSelector,
         ast::Builtin::Gasleft,
         ast::Builtin::GasLimit,
         ast::Builtin::Gasprice,
@@ -73,8 +71,6 @@ fn test_builtin_conversion() {
         codegen::Builtin::BlockHash,
         codegen::Builtin::BlockNumber,
         codegen::Builtin::Calldata,
-        codegen::Builtin::ExternalFunctionAddress,
-        codegen::Builtin::FunctionSelector,
         codegen::Builtin::Gasleft,
         codegen::Builtin::GasLimit,
         codegen::Builtin::Gasprice,

+ 4 - 12
src/codegen/yul/expression.rs

@@ -173,12 +173,8 @@ fn process_suffix_access(
             if let ast::YulExpression::SolidityLocalVariable(_, Type::ExternalFunction { .. }, ..) =
                 expr
             {
-                return Expression::Builtin(
-                    *loc,
-                    vec![Type::Address(false)],
-                    Builtin::ExternalFunctionAddress,
-                    vec![expression(expr, contract_no, ns, vartab, cfg, opt)],
-                );
+                let func_expr = expression(expr, contract_no, ns, vartab, cfg, opt);
+                return func_expr.external_function_address();
             }
         }
 
@@ -186,12 +182,8 @@ fn process_suffix_access(
             if let ast::YulExpression::SolidityLocalVariable(_, Type::ExternalFunction { .. }, ..) =
                 expr
             {
-                return Expression::Builtin(
-                    *loc,
-                    vec![Type::Uint(32)],
-                    Builtin::FunctionSelector,
-                    vec![expression(expr, contract_no, ns, vartab, cfg, opt)],
-                );
+                let func_expr = expression(expr, contract_no, ns, vartab, cfg, opt);
+                return func_expr.external_function_selector();
             }
         }
     }

+ 28 - 6
src/codegen/yul/statements.rs

@@ -254,12 +254,34 @@ fn cfg_single_assigment(
 
                     _ => unreachable!(),
                 },
-                ast::YulExpression::SolidityLocalVariable(_, Type::ExternalFunction { .. }, ..) => {
-                    if matches!(suffix, YulSuffix::Address | YulSuffix::Selector) {
-                        unimplemented!(
-                            "Assignment to a function's address/selector is no implemented."
-                        )
-                    }
+                ast::YulExpression::SolidityLocalVariable(
+                    _,
+                    ty @ Type::ExternalFunction { .. },
+                    _,
+                    var_no,
+                ) => {
+                    let (member_no, casted_expr, member_ty) = match suffix {
+                        YulSuffix::Address => {
+                            (0, rhs.cast(&Type::Address(false), ns), Type::Address(false))
+                        }
+                        YulSuffix::Selector => (1, rhs.cast(&Type::Uint(32), ns), Type::Uint(32)),
+                        _ => unreachable!(),
+                    };
+
+                    let ptr = Expression::StructMember(
+                        *loc,
+                        Type::Ref(Box::new(member_ty)),
+                        Box::new(Expression::Variable(*loc, ty.clone(), *var_no)),
+                        member_no,
+                    );
+
+                    cfg.add(
+                        vartab,
+                        Instr::Store {
+                            dest: ptr,
+                            data: casted_expr,
+                        },
+                    );
                 }
 
                 ast::YulExpression::SolidityLocalVariable(

+ 32 - 24
src/codegen/yul/tests/expression.rs

@@ -542,20 +542,24 @@ fn selector_suffix() {
 
     assert_eq!(
         res,
-        Expression::Builtin(
+        Expression::Load(
             loc,
-            vec![Type::Uint(32)],
-            Builtin::FunctionSelector,
-            vec![Expression::Variable(
+            Type::Bytes(4),
+            Box::new(Expression::StructMember(
                 loc,
-                Type::ExternalFunction {
-                    mutability: Mutability::Pure(loc),
-                    params: vec![],
-                    returns: vec![]
-                },
-                4
-            )],
-        ),
+                Type::Ref(Box::new(Type::Bytes(4))),
+                Box::new(Expression::Variable(
+                    loc,
+                    Type::ExternalFunction {
+                        mutability: Mutability::Pure(loc),
+                        params: vec![],
+                        returns: vec![],
+                    },
+                    4
+                )),
+                1
+            ))
+        )
     );
 }
 
@@ -607,20 +611,24 @@ fn address_suffix() {
 
     assert_eq!(
         res,
-        Expression::Builtin(
+        Expression::Load(
             loc,
-            vec![Type::Address(false)],
-            Builtin::ExternalFunctionAddress,
-            vec![Expression::Variable(
+            Type::Address(false),
+            Box::new(Expression::StructMember(
                 loc,
-                Type::ExternalFunction {
-                    mutability: Mutability::Pure(loc),
-                    params: vec![],
-                    returns: vec![]
-                },
-                4
-            )],
-        ),
+                Type::Ref(Box::new(Type::Address(false))),
+                Box::new(Expression::Variable(
+                    loc,
+                    Type::ExternalFunction {
+                        mutability: Mutability::Pure(loc),
+                        params: vec![],
+                        returns: vec![]
+                    },
+                    4
+                )),
+                0
+            ))
+        )
     );
 }
 

+ 0 - 36
src/emit/mod.rs

@@ -2961,42 +2961,6 @@ pub trait TargetRuntime<'a> {
 
                 ef.into()
             }
-            Expression::Builtin(_, _, Builtin::FunctionSelector, args) => {
-                let ef = self
-                    .expression(bin, &args[0], vartab, function, ns)
-                    .into_pointer_value();
-
-                let selector_member = unsafe {
-                    bin.builder.build_gep(
-                        ef,
-                        &[
-                            bin.context.i32_type().const_zero(),
-                            bin.context.i32_type().const_int(1, false),
-                        ],
-                        "selector",
-                    )
-                };
-
-                bin.builder.build_load(selector_member, "selector")
-            }
-            Expression::Builtin(_, _, Builtin::ExternalFunctionAddress, args) => {
-                let ef = self
-                    .expression(bin, &args[0], vartab, function, ns)
-                    .into_pointer_value();
-
-                let selector_member = unsafe {
-                    bin.builder.build_gep(
-                        ef,
-                        &[
-                            bin.context.i32_type().const_zero(),
-                            bin.context.i32_type().const_zero(),
-                        ],
-                        "address",
-                    )
-                };
-
-                bin.builder.build_load(selector_member, "address")
-            }
             Expression::Builtin(_, _, hash @ Builtin::Ripemd160, args)
             | Expression::Builtin(_, _, hash @ Builtin::Keccak256, args)
             | Expression::Builtin(_, _, hash @ Builtin::Blake2_128, args)

+ 1 - 1
src/sema/ast.rs

@@ -61,7 +61,7 @@ pub enum Type {
     Unresolved,
     /// When we advance a pointer, it cannot be any of the previous types.
     /// e.g. Type::Bytes is a pointer to struct.vector. When we advance it, it is a pointer
-    /// to latter's data region
+    /// to latter's data region.
     BufferPointer,
 }
 

+ 0 - 17
src/sema/yul/expression.rs

@@ -732,23 +732,6 @@ pub(crate) fn check_type(
                     ));
                 }
             }
-
-            YulExpression::SuffixAccess(_, exp, YulSuffix::Address)
-            | YulExpression::SuffixAccess(_, exp, YulSuffix::Selector) => {
-                if matches!(
-                    **exp,
-                    YulExpression::SolidityLocalVariable(_, Type::ExternalFunction { .. }, _, _)
-                ) {
-                    return Some(Diagnostic::error(
-                        expr.loc(),
-                        "assignment to selector and address is not implemented. \
-                        If there is need for these features, please file a GitHub issue at \
-                        https://github.com/hyperledger-labs/solang/issues"
-                            .to_string(),
-                    ));
-                }
-            }
-
             _ => (),
         }
 

+ 1 - 5
src/sema/yul/tests/expression.rs

@@ -1413,11 +1413,7 @@ contract C {
     "#;
 
     let ns = parse(file);
-    assert_eq!(ns.diagnostics.len(), 2);
-    assert!(ns
-        .diagnostics
-        .contains_message("assignment to selector and address is not implemented. If there is need for these features, please file a GitHub issue at https://github.com/hyperledger-labs/solang/issues"));
-
+    assert_eq!(ns.diagnostics.len(), 1);
     assert!(ns.diagnostics.contains_message("found contract 'C'"));
 }
 

+ 2 - 2
tests/codegen_testcases/solidity/borsh_encoding_simple_types.sol

@@ -240,8 +240,8 @@ contract EncodingTest {
         uint64 pr = 9234;
 
         // CHECK: ty:bytes %abi_encoded.temp.106 = (alloc bytes len (uint32 36 + uint32 8))
-        // CHECK: writebuffer buffer:%abi_encoded.temp.106 offset:uint32 0 value:(builtin FunctionSelector (%fPtr))
-        // CHECK: writebuffer buffer:%abi_encoded.temp.106 offset:(uint32 0 + uint32 4) value:(builtin ExternalFunctionAddress (%fPtr))
+        // CHECK: writebuffer buffer:%abi_encoded.temp.106 offset:uint32 0 value:(load (struct %fPtr field 1))
+        // CHECK: writebuffer buffer:%abi_encoded.temp.106 offset:(uint32 0 + uint32 4) value:(load (struct %fPtr field 0))
         // CHECK: writebuffer buffer:%abi_encoded.temp.106 offset:(uint32 0 + uint32 36) value:%pr
 
         bytes memory b = abi.encode(fPtr, pr);

+ 2 - 2
tests/codegen_testcases/yul/expression.sol

@@ -144,10 +144,10 @@ contract testing {
             // CHECK: ty:uint256 %p = (zext uint256 (builtin ArrayLength ((arg #0))))
             let p := vl.length
 
-            // CHECK: ty:uint256 %q = uint256((builtin ExternalFunctionAddress (%fPtr)))
+            // CHECK: ty:uint256 %q = uint256((load (struct %fPtr field 0)))
             let q := fPtr.address
 
-            // CHECK: ty:uint256 %r = (zext uint256 (builtin FunctionSelector (%fPtr)))
+            // CHECK: ty:uint256 %r = (zext uint256 (load (struct %fPtr field 1)))
             let r := fPtr.selector
 
             // CHECK: ty:uint256 %s = (arg #1)

+ 16 - 0
tests/codegen_testcases/yul/external_function.sol

@@ -0,0 +1,16 @@
+// RUN: --target solana --emit cfg
+
+contract C {
+    // BEGIN-CHECK: C::C::function::combineToFunctionPointer__address_bytes4
+    function combineToFunctionPointer(address newAddress, bytes4 newSelector) public pure returns (bytes4, address) {
+        function() external fun;
+        assembly {
+            // CHECK: store (struct %fun field 1), uint32((arg #1))
+            fun.selector := newSelector
+            // CHECK: store (struct %fun field 0), (arg #0)
+            fun.address  := newAddress
+        }
+
+        return (fun.selector, fun.address);
+    }
+}

+ 44 - 1
tests/solana_tests/yul.rs

@@ -1,5 +1,5 @@
 use crate::build_solidity;
-use ethabi::{ethereum_types::U256, Token, Uint};
+use ethabi::{ethereum_types::U256, FixedBytes, Token, Uint};
 
 #[test]
 fn suffixes_access() {
@@ -248,3 +248,46 @@ contract c {
         vec![Token::Uint(U256::from(0)), Token::Uint(U256::from(0))]
     );
 }
+
+#[test]
+fn external_function() {
+    let mut vm = build_solidity(
+        r#"
+    contract C {
+
+        function myFun() public {
+
+        }
+
+        function test(uint256 newAddress, bytes4 newSelector) public view returns (bytes4, address) {
+            function() external fun = this.myFun;
+            address myAddr = address(newAddress);
+            assembly {
+                fun.selector := newSelector
+                fun.address  := myAddr
+            }
+
+            return (fun.selector, fun.address);
+        }
+    }
+        "#,
+    );
+
+    vm.constructor("C", &[]);
+    let mut addr: Vec<u8> = vec![0; 32];
+    addr[5] = 90;
+    let returns = vm.function(
+        "test",
+        &[
+            Token::Uint(U256::from_little_endian(&addr[..])),
+            Token::FixedBytes(FixedBytes::from([1, 2, 3, 4])),
+        ],
+        &[],
+        None,
+    );
+
+    let selector = returns[0].clone().into_fixed_bytes().unwrap();
+    assert_eq!(selector, vec![1, 2, 3, 4]);
+    let addr = returns[1].clone().into_fixed_bytes().unwrap();
+    assert_eq!(addr[26], 90);
+}