瀏覽代碼

Refactor accounts collection to maintain proper ordering (#1232)

Signed-off-by: Lucas Steuernagel <lucas.tnagel@gmail.com>
Lucas Steuernagel 2 年之前
父節點
當前提交
3cfca59c80
共有 7 個文件被更改,包括 201 次插入157 次删除
  1. 13 71
      src/abi/anchor.rs
  2. 0 1
      src/abi/mod.rs
  3. 1 1
      src/abi/tests.rs
  4. 2 2
      src/bin/solang.rs
  5. 10 0
      src/codegen/mod.rs
  6. 162 82
      src/codegen/solana_accounts.rs
  7. 13 0
      src/sema/ast.rs

+ 13 - 71
src/abi/anchor.rs

@@ -1,8 +1,7 @@
 // SPDX-License-Identifier: Apache-2.0
 
 use crate::sema::ast::{
-    ArrayLength, Contract, Function, Mutability, Namespace, Parameter, StructDecl, StructType, Tag,
-    Type,
+    ArrayLength, Contract, Function, Namespace, Parameter, StructDecl, StructType, Tag, Type,
 };
 use anchor_syn::idl::{
     Idl, IdlAccount, IdlAccountItem, IdlEnumVariant, IdlEvent, IdlEventField, IdlField,
@@ -13,9 +12,7 @@ use num_traits::ToPrimitive;
 use semver::Version;
 use std::collections::{HashMap, HashSet};
 
-use crate::abi::solana_accounts::{collect_accounts_from_contract, SolanaAccount};
 use convert_case::{Boundary, Case, Casing};
-use indexmap::IndexSet;
 use serde_json::json;
 use sha2::{Digest, Sha256};
 use solang_parser::pt::FunctionTy;
@@ -126,7 +123,6 @@ fn idl_instructions(
         })
     }
 
-    let mut remaining_accounts = collect_accounts_from_contract(contract_no, ns);
     for func_no in contract.all_functions.keys() {
         if !ns.functions[*func_no].is_public()
             || matches!(
@@ -140,41 +136,6 @@ fn idl_instructions(
         let func = &ns.functions[*func_no];
         let tags = idl_docs(&func.tags);
 
-        let mut accounts = match &func.mutability {
-            Mutability::Pure(_) => {
-                vec![]
-            }
-            Mutability::View(_) => {
-                vec![IdlAccountItem::IdlAccount(IdlAccount {
-                    name: "dataAccount".to_string(),
-                    is_mut: false,
-                    is_signer: false,
-                    is_optional: Some(false),
-                    docs: None,
-                    pda: None,
-                    relations: vec![],
-                })]
-            }
-            _ => {
-                vec![IdlAccountItem::IdlAccount(IdlAccount {
-                    name: "dataAccount".to_string(),
-                    is_mut: true,
-                    /// With a @payer annotation, the account is created on-chain and needs a signer. The client
-                    /// provides an address that does not exist yet, so SystemProgram.CreateAccount is called
-                    /// on-chain.
-                    ///
-                    /// However, if a @seed is also provided, the program can sign for the account
-                    /// with the seed using program derived address (pda) when SystemProgram.CreateAccount is called,
-                    /// so no signer is required from the client.
-                    is_signer: func.has_payer_annotation() && !func.has_seed_annotation(),
-                    is_optional: Some(false),
-                    docs: None,
-                    pda: None,
-                    relations: vec![],
-                })]
-            }
-        };
-
         let mut args: Vec<IdlField> = Vec::with_capacity(func.params.len());
         let mut dedup = Deduplicate::new("arg".to_owned());
         for item in &*func.params {
@@ -191,29 +152,7 @@ fn idl_instructions(
             });
         }
 
-        let cfg_no = contract.all_functions[func_no];
-
         let name = if func.is_constructor() {
-            if func.has_payer_annotation() {
-                accounts.push(IdlAccountItem::IdlAccount(IdlAccount {
-                    name: "wallet".to_string(),
-                    is_mut: false,
-                    is_signer: true,
-                    is_optional: Some(false),
-                    docs: None,
-                    pda: None,
-                    relations: vec![],
-                }));
-
-                // Constructors with the payer annotation need the system account
-                if let Some(set) = remaining_accounts.get_mut(&cfg_no) {
-                    set.insert(SolanaAccount::SystemAccount);
-                } else {
-                    remaining_accounts
-                        .insert(cfg_no, IndexSet::from([SolanaAccount::SystemAccount]));
-                }
-            }
-
             "new".to_string()
         } else if func.mangled_name_contracts.contains(&contract_no) {
             func.mangled_name.clone()
@@ -221,19 +160,22 @@ fn idl_instructions(
             func.name.clone()
         };
 
-        if let Some(other_accounts) = remaining_accounts.get(&cfg_no) {
-            for account in other_accounts {
-                accounts.push(IdlAccountItem::IdlAccount(IdlAccount {
-                    name: account.name().to_string(),
-                    is_mut: false,
-                    is_signer: false,
+        let accounts = func
+            .solana_accounts
+            .borrow()
+            .iter()
+            .map(|(account_name, account)| {
+                IdlAccountItem::IdlAccount(IdlAccount {
+                    name: account_name.clone(),
+                    is_mut: account.is_writer,
+                    is_signer: account.is_signer,
                     is_optional: Some(false),
                     docs: None,
                     pda: None,
                     relations: vec![],
-                }));
-            }
-        }
+                })
+            })
+            .collect::<Vec<IdlAccountItem>>();
 
         let returns = if func.returns.is_empty() {
             None

+ 0 - 1
src/abi/mod.rs

@@ -5,7 +5,6 @@ use crate::Target;
 
 pub mod anchor;
 pub mod ethereum;
-mod solana_accounts;
 pub mod substrate;
 mod tests;
 

+ 1 - 1
src/abi/tests.rs

@@ -157,7 +157,7 @@ fn instructions_and_types() {
     }
 
     function notInIdl(uint256 c, MetaData dd) private pure returns (int256) {
-        if (dd.a && dd.b) {
+        if (dd.c && dd.b) {
             return 0;
         }
         return int256(c);

+ 2 - 2
src/bin/solang.rs

@@ -499,7 +499,7 @@ fn compile(matches: &ArgMatches) {
     }
 
     if !errors {
-        for ns in &namespaces {
+        for ns in namespaces.iter_mut() {
             for contract_no in 0..ns.contracts.len() {
                 contract_results(contract_no, matches, ns, &mut json_contracts, &opt);
             }
@@ -576,7 +576,7 @@ fn process_file(
 fn contract_results(
     contract_no: usize,
     matches: &ArgMatches,
-    ns: &Namespace,
+    ns: &mut Namespace,
     json_contracts: &mut HashMap<String, JsonContract>,
     opt: &Options,
 ) {

+ 10 - 0
src/codegen/mod.rs

@@ -11,6 +11,7 @@ mod events;
 mod expression;
 mod external_functions;
 mod reaching_definitions;
+mod solana_accounts;
 mod solana_deploy;
 mod statements;
 mod storage;
@@ -36,6 +37,7 @@ use std::cmp::Ordering;
 
 use crate::codegen::cfg::ASTFunction;
 use crate::codegen::dispatch::function_dispatch;
+use crate::codegen::solana_accounts::collect_accounts_from_contract;
 use crate::codegen::yul::generate_yul_function_cfg;
 use crate::sema::Recurse;
 use num_bigint::{BigInt, Sign};
@@ -150,6 +152,14 @@ pub fn codegen(ns: &mut Namespace, opt: &Options) {
             contracts_done[contract_no] = true;
         }
     }
+
+    if ns.target == Target::Solana {
+        for contract_no in 0..ns.contracts.len() {
+            if ns.contracts[contract_no].instantiable {
+                collect_accounts_from_contract(contract_no, ns);
+            }
+        }
+    }
 }
 
 fn contract(contract_no: usize, ns: &mut Namespace, opt: &Options) {

+ 162 - 82
src/abi/solana_accounts.rs → src/codegen/solana_accounts.rs

@@ -1,35 +1,37 @@
 // SPDX-License-Identifier: Apache-2.0
 
-use crate::codegen::cfg::{ControlFlowGraph, Instr, InternalCallTy};
+use crate::codegen::cfg::{ASTFunction, ControlFlowGraph, Instr, InternalCallTy};
 use crate::codegen::{Builtin, Expression};
-use crate::sema::ast::Namespace;
+use crate::sema::ast::{Contract, Function, Mutability, Namespace, SolanaAccount};
 use crate::sema::Recurse;
 use base58::FromBase58;
 use indexmap::IndexSet;
 use num_bigint::{BigInt, Sign};
 use num_traits::Zero;
 use once_cell::sync::Lazy;
-use std::collections::hash_map::Entry;
+use solang_parser::pt::FunctionTy;
 use std::collections::{HashMap, HashSet, VecDeque};
 
 /// These are the accounts that we can collect from a contract and that Anchor will populate
-/// automatically if their name matches the source code description:
+/// automatically if their names match the source code description:
 /// https://github.com/coral-xyz/anchor/blob/06c42327d4241e5f79c35bc5588ec0a6ad2fedeb/ts/packages/anchor/src/program/accounts-resolver.ts#L54-L60
-#[derive(Hash, Eq, PartialEq, Copy, Clone)]
-pub(super) enum SolanaAccount {
-    ClockAccount,
-    InstructionAccount,
-    SystemAccount,
-    AssociatedProgramId,
-    Rent,
-    TokenProgramId,
-}
+static CLOCK_ACCOUNT: &str = "clock";
+static SYSTEM_ACCOUNT: &str = "systemProgram";
+static ASSOCIATED_TOKEN_PROGRAM: &str = "associatedTokenProgram";
+static RENT_ACCOUNT: &str = "rent";
+static TOKEN_PROGRAM_ID: &str = "tokenProgram";
+
+/// We automatically include the following accounts in the IDL, but these are not
+/// automatically populated
+static DATA_ACCOUNT: &str = "dataAccount";
+static WALLET_ACCOUNT: &str = "wallet";
+static INSTRUCTION_ACCOUNT: &str = "SysvarInstruction";
 
 /// If the public keys available in AVAILABLE_ACCOUNTS are hardcoded in a Solidity contract
 /// for external calls, we can detect them and leverage Anchor's public key auto populate feature.
-static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, SolanaAccount>> = Lazy::new(|| {
+static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, &'static str>> = Lazy::new(|| {
     HashMap::from([
-        (BigInt::zero(), SolanaAccount::SystemAccount),
+        (BigInt::zero(), SYSTEM_ACCOUNT),
         (
             BigInt::from_bytes_be(
                 Sign::Plus,
@@ -37,7 +39,7 @@ static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, SolanaAccount>> = Lazy::new(|| {
                     .from_base58()
                     .unwrap(),
             ),
-            SolanaAccount::AssociatedProgramId,
+            ASSOCIATED_TOKEN_PROGRAM,
         ),
         (
             BigInt::from_bytes_be(
@@ -46,7 +48,7 @@ static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, SolanaAccount>> = Lazy::new(|| {
                     .from_base58()
                     .unwrap(),
             ),
-            SolanaAccount::Rent,
+            RENT_ACCOUNT,
         ),
         (
             BigInt::from_bytes_be(
@@ -55,7 +57,7 @@ static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, SolanaAccount>> = Lazy::new(|| {
                     .from_base58()
                     .unwrap(),
             ),
-            SolanaAccount::TokenProgramId,
+            TOKEN_PROGRAM_ID,
         ),
         (
             BigInt::from_bytes_be(
@@ -64,35 +66,19 @@ static AVAILABLE_ACCOUNTS: Lazy<HashMap<BigInt, SolanaAccount>> = Lazy::new(|| {
                     .from_base58()
                     .unwrap(),
             ),
-            SolanaAccount::ClockAccount,
+            CLOCK_ACCOUNT,
         ),
     ])
 });
 
-impl SolanaAccount {
-    /// Retrieve a name from an account, according to Anchor's constant accounts map
-    /// https://github.com/coral-xyz/anchor/blob/06c42327d4241e5f79c35bc5588ec0a6ad2fedeb/ts/packages/anchor/src/program/accounts-resolver.ts#L54-L60
-    pub(super) fn name(&self) -> &'static str {
-        match self {
-            SolanaAccount::AssociatedProgramId => "associatedTokenProgram",
-            SolanaAccount::Rent => "rent",
-            SolanaAccount::SystemAccount => "systemProgram",
-            SolanaAccount::TokenProgramId => "tokenProgram",
-            SolanaAccount::ClockAccount => "clock",
-            SolanaAccount::InstructionAccount => "SysvarInstruction",
-        }
-    }
-
-    fn from_number(num: &BigInt) -> Option<SolanaAccount> {
-        AVAILABLE_ACCOUNTS.get(num).cloned()
-    }
+/// Retrieve a name from an account, according to Anchor's constant accounts map
+/// https://github.com/coral-xyz/anchor/blob/06c42327d4241e5f79c35bc5588ec0a6ad2fedeb/ts/packages/anchor/src/program/accounts-resolver.ts#L54-L60
+fn account_from_number(num: &BigInt) -> Option<&'static str> {
+    AVAILABLE_ACCOUNTS.get(num).cloned()
 }
 
 /// Struct to save the recursion data when traversing all the CFG instructions
 struct RecurseData<'a> {
-    /// Here we collect all accounts. It is a map between the cfg number of a functions and the
-    /// accounts it requires
-    accounts: HashMap<usize, IndexSet<SolanaAccount>>,
     /// next_queue saves the set of functions we must check in the next iteration
     next_queue: IndexSet<(usize, usize)>,
     /// The number of the function we are currently traversing
@@ -101,45 +87,110 @@ struct RecurseData<'a> {
     contract_no: usize,
     /// The quantity of accounts we have added the the hashmap 'accounts'
     accounts_added: usize,
-    ns: &'a Namespace,
+    /// The number of the AST function we are currently traversing
+    ast_no: usize,
+    /// The namespace contracts
+    contracts: &'a [Contract],
+    /// The vector of functions from the contract
+    functions: &'a [Function],
 }
 
 impl RecurseData<'_> {
-    /// Add an account to the hashmap
-    fn add_account(&mut self, account: SolanaAccount) {
-        let inserted = match self.accounts.entry(self.cfg_func_no) {
-            Entry::Occupied(mut val) => val.get_mut().insert(account),
-            Entry::Vacant(val) => {
-                val.insert(IndexSet::from([account]));
-                true
-            }
-        };
-
-        if inserted {
+    /// Add an account to the function's indexmap
+    fn add_account(&mut self, account_name: String, account: SolanaAccount) {
+        if self.functions[self.ast_no]
+            .solana_accounts
+            .borrow_mut()
+            .insert(account_name, account)
+            .is_none()
+        {
             self.accounts_added += 1;
         }
     }
+
+    /// Add the system account to the function's indexmap
+    fn add_system_account(&mut self) {
+        self.add_account(
+            SYSTEM_ACCOUNT.to_string(),
+            SolanaAccount {
+                is_writer: false,
+                is_signer: false,
+            },
+        );
+    }
 }
 
 /// Collect the accounts this contract needs
-pub(super) fn collect_accounts_from_contract(
-    contract_no: usize,
-    ns: &Namespace,
-) -> HashMap<usize, IndexSet<SolanaAccount>> {
-    let mut visiting_queue: IndexSet<(usize, usize)> = IndexSet::from_iter(
-        ns.contracts[contract_no]
-            .functions
-            .iter()
-            .map(|ast_no| (contract_no, ns.contracts[contract_no].all_functions[ast_no])),
-    );
+pub(super) fn collect_accounts_from_contract(contract_no: usize, ns: &Namespace) {
+    let mut visiting_queue: IndexSet<(usize, usize)> = IndexSet::new();
+
+    for func_no in ns.contracts[contract_no].all_functions.keys() {
+        if ns.functions[*func_no].is_public()
+            && !matches!(
+                ns.functions[*func_no].ty,
+                FunctionTy::Fallback | FunctionTy::Receive | FunctionTy::Modifier
+            )
+        {
+            let func = &ns.functions[*func_no];
+            match &func.mutability {
+                Mutability::Pure(_) => (),
+                Mutability::View(_) => {
+                    func.solana_accounts.borrow_mut().insert(
+                        DATA_ACCOUNT.to_string(),
+                        SolanaAccount {
+                            is_writer: false,
+                            is_signer: false,
+                        },
+                    );
+                }
+                _ => {
+                    func.solana_accounts.borrow_mut().insert(
+                        DATA_ACCOUNT.to_string(),
+                        SolanaAccount {
+                            is_writer: true,
+                            /// With a @payer annotation, the account is created on-chain and needs a signer. The client
+                            /// provides an address that does not exist yet, so SystemProgram.CreateAccount is called
+                            /// on-chain.
+                            ///
+                            /// However, if a @seed is also provided, the program can sign for the account
+                            /// with the seed using program derived address (pda) when SystemProgram.CreateAccount is called,
+                            /// so no signer is required from the client.
+                            is_signer: func.has_payer_annotation() && !func.has_seed_annotation(),
+                        },
+                    );
+                }
+            }
+            if func.is_constructor() && func.has_payer_annotation() {
+                func.solana_accounts.borrow_mut().insert(
+                    WALLET_ACCOUNT.to_string(),
+                    SolanaAccount {
+                        is_signer: true,
+                        is_writer: false,
+                    },
+                );
+                func.solana_accounts.borrow_mut().insert(
+                    SYSTEM_ACCOUNT.to_string(),
+                    SolanaAccount {
+                        is_signer: false,
+                        is_writer: false,
+                    },
+                );
+            }
+        }
+        visiting_queue.insert((
+            contract_no,
+            ns.contracts[contract_no].all_functions[func_no],
+        ));
+    }
 
     let mut recurse_data = RecurseData {
-        accounts: HashMap::new(),
         next_queue: IndexSet::new(),
         cfg_func_no: 0,
+        ast_no: 0,
         accounts_added: 0,
         contract_no,
-        ns,
+        functions: &ns.functions,
+        contracts: &ns.contracts,
     };
 
     let mut old_size: usize = 0;
@@ -151,6 +202,12 @@ pub(super) fn collect_accounts_from_contract(
 
             recurse_data.contract_no = *contract_no;
             recurse_data.cfg_func_no = *func_no;
+            match &ns.contracts[*contract_no].cfg[*func_no].function_no {
+                ASTFunction::SolidityFunction(ast_no) | ASTFunction::YulFunction(ast_no) => {
+                    recurse_data.ast_no = *ast_no;
+                }
+                _ => (),
+            }
             check_function(&ns.contracts[*contract_no].cfg[*func_no], &mut recurse_data);
         }
 
@@ -166,8 +223,6 @@ pub(super) fn collect_accounts_from_contract(
         std::mem::swap(&mut visiting_queue, &mut recurse_data.next_queue);
         recurse_data.next_queue.clear();
     }
-
-    recurse_data.accounts
 }
 
 /// Collect the accounts in a function
@@ -215,15 +270,20 @@ fn check_instruction(instr: &Instr, data: &mut RecurseData) {
                 // recursive function calls
                 data.next_queue.insert((data.contract_no, *cfg_no));
                 data.next_queue.insert((data.contract_no, data.cfg_func_no));
-                if let Some(callee_accounts) = data.accounts.get(cfg_no).cloned() {
-                    for item in callee_accounts {
-                        data.add_account(item);
+                match &data.contracts[data.contract_no].cfg[*cfg_no].function_no {
+                    ASTFunction::SolidityFunction(ast_no) | ASTFunction::YulFunction(ast_no) => {
+                        let accounts_to_add =
+                            data.functions[*ast_no].solana_accounts.borrow().clone();
+                        for (account_name, account) in accounts_to_add {
+                            data.add_account(account_name, account);
+                        }
                     }
+                    _ => (),
                 }
             } else if let InternalCallTy::Builtin { ast_func_no } = call {
-                let name = &data.ns.functions[*ast_func_no].name;
+                let name = &data.functions[*ast_func_no].name;
                 if name == "create_program_address" {
-                    data.add_account(SolanaAccount::SystemAccount);
+                    data.add_system_account();
                 }
             }
 
@@ -337,7 +397,7 @@ fn check_instruction(instr: &Instr, data: &mut RecurseData) {
                 seeds.recurse(data, check_expression);
             }
 
-            data.add_account(SolanaAccount::SystemAccount);
+            data.add_system_account();
         }
         Instr::ExternalCall {
             address,
@@ -353,8 +413,14 @@ fn check_instruction(instr: &Instr, data: &mut RecurseData) {
                 address.recurse(data, check_expression);
                 if let Expression::NumberLiteral(_, _, num) = address {
                     // Check if we can auto populate this account
-                    if let Some(account) = SolanaAccount::from_number(num) {
-                        data.add_account(account);
+                    if let Some(account) = account_from_number(num) {
+                        data.add_account(
+                            account.to_string(),
+                            SolanaAccount {
+                                is_signer: false,
+                                is_writer: false,
+                            },
+                        );
                     }
                 }
             }
@@ -368,13 +434,15 @@ fn check_instruction(instr: &Instr, data: &mut RecurseData) {
             value.recurse(data, check_expression);
             gas.recurse(data, check_expression);
             // External calls always need the system account
-            data.add_account(SolanaAccount::SystemAccount);
+            data.add_system_account();
             if let Some((contract_no, function_no)) = contract_function_no {
-                let cfg_no = data.ns.contracts[*contract_no].all_functions[function_no];
-                if let Some(callee_accounts) = data.accounts.get(&cfg_no).cloned() {
-                    for item in callee_accounts {
-                        data.add_account(item);
-                    }
+                let cfg_no = data.contracts[*contract_no].all_functions[function_no];
+                let accounts_to_add = data.functions[*function_no]
+                    .solana_accounts
+                    .borrow()
+                    .clone();
+                for (account_name, account) in accounts_to_add {
+                    data.add_account(account_name, account);
                 }
                 data.next_queue.insert((*contract_no, cfg_no));
                 data.next_queue.insert((data.contract_no, data.cfg_func_no));
@@ -409,10 +477,22 @@ fn check_expression(expr: &Expression, data: &mut RecurseData) -> bool {
             Builtin::Timestamp | Builtin::BlockNumber | Builtin::Slot,
             ..,
         ) => {
-            data.add_account(SolanaAccount::ClockAccount);
+            data.add_account(
+                CLOCK_ACCOUNT.to_string(),
+                SolanaAccount {
+                    is_signer: false,
+                    is_writer: false,
+                },
+            );
         }
         Expression::Builtin(_, _, Builtin::SignatureVerify, ..) => {
-            data.add_account(SolanaAccount::InstructionAccount);
+            data.add_account(
+                INSTRUCTION_ACCOUNT.to_string(),
+                SolanaAccount {
+                    is_writer: false,
+                    is_signer: false,
+                },
+            );
         }
         Expression::Builtin(
             _,
@@ -420,7 +500,7 @@ fn check_expression(expr: &Expression, data: &mut RecurseData) -> bool {
             Builtin::Ripemd160 | Builtin::Keccak256 | Builtin::Sha256,
             ..,
         ) => {
-            data.add_account(SolanaAccount::SystemAccount);
+            data.add_system_account();
         }
 
         _ => (),

+ 13 - 0
src/sema/ast.rs

@@ -14,6 +14,7 @@ use once_cell::unsync::OnceCell;
 pub use solang_parser::diagnostics::*;
 use solang_parser::pt;
 use solang_parser::pt::{CodeLocation, FunctionTy, OptionalCodeLocation};
+use std::cell::RefCell;
 use std::{
     collections::HashSet,
     collections::{BTreeMap, HashMap},
@@ -305,6 +306,17 @@ pub struct Function {
     pub annotations: Vec<ConstructorAnnotation>,
     /// Which contracts should we use the mangled name in?
     pub mangled_name_contracts: HashSet<usize>,
+    /// This indexmap stores the accounts this functions needs to be called on Solana
+    /// The string is the account's name
+    pub solana_accounts: RefCell<IndexMap<String, SolanaAccount>>,
+}
+
+/// This struct represents a Solana account. There is no name field, because
+/// it is stored in a IndexMap<String, SolanaAccount> (see above)
+#[derive(Clone, Copy, Debug)]
+pub struct SolanaAccount {
+    pub is_signer: bool,
+    pub is_writer: bool,
 }
 
 pub enum ConstructorAnnotation {
@@ -406,6 +418,7 @@ impl Function {
             mangled_name,
             annotations: Vec::new(),
             mangled_name_contracts: HashSet::new(),
+            solana_accounts: IndexMap::new().into(),
         }
     }