Преглед изворни кода

svm: `AccountLoader::load_account()` refactor (#4680)

hana пре 6 месеци
родитељ
комит
cce41f97a0

+ 192 - 48
svm/src/account_loader.rs

@@ -147,56 +147,54 @@ pub struct FeesOnlyTransaction {
     pub fee_details: FeeDetails,
 }
 
-#[cfg_attr(feature = "dev-context-only-utils", derive(Clone))]
+// This is an internal SVM type that tracks account changes throughout a
+// transaction batch and obviates the need to load accounts from accounts-db
+// more than once. It effectively wraps an `impl TransactionProcessingCallback`
+// type, and itself implements `TransactionProcessingCallback`, behaving
+// exactly like the implementor of the trait, but also returning up-to-date
+// account states mid-batch.
 pub(crate) struct AccountLoader<'a, CB: TransactionProcessingCallback> {
-    account_cache: AHashMap<Pubkey, AccountSharedData>,
+    loaded_accounts: AHashMap<Pubkey, AccountSharedData>,
     callbacks: &'a CB,
     pub(crate) feature_set: &'a SVMFeatureSet,
 }
+
 impl<'a, CB: TransactionProcessingCallback> AccountLoader<'a, CB> {
-    pub(crate) fn new_with_account_cache_capacity(
+    // create a new AccountLoader for the transaction batch
+    pub(crate) fn new_with_loaded_accounts_capacity(
         account_overrides: Option<&'a AccountOverrides>,
         callbacks: &'a CB,
         feature_set: &'a SVMFeatureSet,
         capacity: usize,
     ) -> AccountLoader<'a, CB> {
-        let mut account_cache = AHashMap::with_capacity(capacity);
+        let mut loaded_accounts = AHashMap::with_capacity(capacity);
 
         // SlotHistory may be overridden for simulation.
         // No other uses of AccountOverrides are expected.
         if let Some(slot_history) =
             account_overrides.and_then(|overrides| overrides.get(&slot_history::id()))
         {
-            account_cache.insert(slot_history::id(), slot_history.clone());
+            loaded_accounts.insert(slot_history::id(), slot_history.clone());
         }
 
         Self {
-            account_cache,
+            loaded_accounts,
             callbacks,
             feature_set,
         }
     }
 
-    pub(crate) fn load_account(
+    // Load an account either from our own store or accounts-db and inspect it on behalf of Bank.
+    // Inspection is required prior to any modifications to the account. This function is used
+    // by load_transaction() and validate_transaction_fee_payer() for that purpose. It returns
+    // a different type than other AccountLoader load functions, which should prevent accidental
+    // mix and match of them.
+    pub(crate) fn load_transaction_account(
         &mut self,
         account_key: &Pubkey,
         is_writable: bool,
     ) -> Option<LoadedTransactionAccount> {
-        let account = if let Some(account) = self.account_cache.get(account_key) {
-            // If lamports is 0, a previous transaction deallocated this account.
-            // We return None instead of the account we found so it can be created fresh.
-            // We never evict from the cache, or else we would fetch stale state from accounts-db.
-            if account.lamports() == 0 {
-                None
-            } else {
-                Some(account.clone())
-            }
-        } else if let Some(account) = self.callbacks.get_account_shared_data(account_key) {
-            self.account_cache.insert(*account_key, account.clone());
-            Some(account)
-        } else {
-            None
-        };
+        let account = self.load_account(account_key);
 
         // Inspect prior to collecting rent, since rent collection can modify the account.
         self.callbacks.inspect_account(
@@ -216,6 +214,41 @@ impl<'a, CB: TransactionProcessingCallback> AccountLoader<'a, CB> {
         })
     }
 
+    // Load an account as above, with no inspection and no LoadedTransactionAccount wrapper.
+    // This is a general purpose function suitable for usage outside initial transaction loading.
+    pub(crate) fn load_account(&mut self, account_key: &Pubkey) -> Option<AccountSharedData> {
+        match self.do_load(account_key) {
+            (Some(account), true) => {
+                self.loaded_accounts.insert(*account_key, account.clone());
+                Some(account)
+            }
+            (account, false) => account,
+            (None, true) => unreachable!(),
+        }
+    }
+
+    // Internal helper for core loading logic to prevent code duplication. Returns a bool
+    // indicating whether the account came from accounts-db, which allows wrappers with
+    // &mut self to insert the account. Wrappers with &self ignore it.
+    fn do_load(&self, account_key: &Pubkey) -> (Option<AccountSharedData>, bool) {
+        if let Some(account) = self.loaded_accounts.get(account_key) {
+            // If lamports is 0, a previous transaction deallocated this account.
+            // We return None instead of the account we found so it can be created fresh.
+            // We *never* remove accounts, or else we would fetch stale state from accounts-db.
+            let option_account = if account.lamports() == 0 {
+                None
+            } else {
+                Some(account.clone())
+            };
+
+            (option_account, false)
+        } else if let Some(account) = self.callbacks.get_account_shared_data(account_key) {
+            (Some(account), true)
+        } else {
+            (None, false)
+        }
+    }
+
     pub(crate) fn update_accounts_for_executed_tx(
         &mut self,
         message: &impl SVMMessage,
@@ -242,20 +275,20 @@ impl<'a, CB: TransactionProcessingCallback> AccountLoader<'a, CB> {
         let fee_payer_address = message.fee_payer();
         match rollback_accounts {
             RollbackAccounts::FeePayerOnly { fee_payer_account } => {
-                self.account_cache
+                self.loaded_accounts
                     .insert(*fee_payer_address, fee_payer_account.clone());
             }
             RollbackAccounts::SameNonceAndFeePayer { nonce } => {
-                self.account_cache
+                self.loaded_accounts
                     .insert(*nonce.address(), nonce.account().clone());
             }
             RollbackAccounts::SeparateNonceAndFeePayer {
                 nonce,
                 fee_payer_account,
             } => {
-                self.account_cache
+                self.loaded_accounts
                     .insert(*nonce.address(), nonce.account().clone());
-                self.account_cache
+                self.loaded_accounts
                     .insert(*fee_payer_address, fee_payer_account.clone());
             }
         }
@@ -279,11 +312,36 @@ impl<'a, CB: TransactionProcessingCallback> AccountLoader<'a, CB> {
                 continue;
             }
 
-            self.account_cache.insert(*address, account.clone());
+            self.loaded_accounts.insert(*address, account.clone());
         }
     }
 }
 
+// Program loaders and parsers require a type that impls TransactionProcessingCallback,
+// because they are used in both SVM and by Bank. We impl it, with the consequence
+// that if we fall back to accounts-db, we cannot store the state for future loads.
+// In general, most accounts we load this way should already be in our accounts store.
+// Once SIMD-0186 is implemented, 100% of accounts will be.
+impl<CB: TransactionProcessingCallback> TransactionProcessingCallback for AccountLoader<'_, CB> {
+    fn get_account_shared_data(&self, pubkey: &Pubkey) -> Option<AccountSharedData> {
+        self.do_load(pubkey).0
+    }
+
+    fn account_matches_owners(&self, pubkey: &Pubkey, owners: &[Pubkey]) -> Option<usize> {
+        self.do_load(pubkey)
+            .0
+            .and_then(|account| owners.iter().position(|entry| entry == account.owner()))
+    }
+}
+
+// NOTE this is a required subtrait of TransactionProcessingCallback.
+// It may make sense to break out a second subtrait just for the above two functions,
+// but this would be a nontrivial breaking change and require careful consideration.
+impl<CB: TransactionProcessingCallback> solana_svm_callback::InvokeContextCallback
+    for AccountLoader<'_, CB>
+{
+}
+
 /// Collect rent from an account if rent is still enabled and regardless of
 /// whether rent is enabled, set the rent epoch to u64::MAX if the account is
 /// rent exempt.
@@ -481,11 +539,7 @@ fn load_transaction_accounts<CB: TransactionProcessingCallback>(
 
             let program_index = instruction.program_id_index as usize;
 
-            let Some(LoadedTransactionAccount {
-                account: program_account,
-                ..
-            }) = account_loader.load_account(program_id, false)
-            else {
+            let Some(program_account) = account_loader.load_account(program_id) else {
                 error_metrics.account_not_found += 1;
                 return Err(TransactionError::ProgramAccountNotFound);
             };
@@ -522,12 +576,7 @@ fn load_transaction_accounts<CB: TransactionProcessingCallback>(
                 // and SIMD-186 are active, we do not need to load loaders at all to comply with consensus rules
                 // we may verify program ids are owned by `PROGRAM_OWNERS` purely as an optimization
                 // this could even be done before loading the rest of the accounts for a transaction
-                if let Some(LoadedTransactionAccount {
-                    account: owner_account,
-                    loaded_size: owner_size,
-                    ..
-                }) = account_loader.load_account(owner_id, false)
-                {
+                if let Some(owner_account) = account_loader.load_account(owner_id) {
                     if !native_loader::check_id(owner_account.owner())
                         || (!account_loader
                             .feature_set
@@ -539,7 +588,7 @@ fn load_transaction_accounts<CB: TransactionProcessingCallback>(
                     }
                     accumulate_and_check_loaded_account_data_size(
                         &mut accumulated_accounts_data_size,
-                        owner_size,
+                        owner_account.data().len(),
                         loaded_accounts_bytes_limit,
                         error_metrics,
                     )?;
@@ -578,7 +627,9 @@ fn load_transaction_account<CB: TransactionProcessingCallback>(
             account: construct_instructions_account(message),
             rent_collected: 0,
         }
-    } else if let Some(mut loaded_account) = account_loader.load_account(account_key, is_writable) {
+    } else if let Some(mut loaded_account) =
+        account_loader.load_transaction_account(account_key, is_writable)
+    {
         loaded_account.rent_collected = if is_writable {
             collect_rent_from_account(
                 account_loader.feature_set,
@@ -741,7 +792,7 @@ mod tests {
 
     impl<'a> From<&'a TestCallbacks> for AccountLoader<'a, TestCallbacks> {
         fn from(callbacks: &'a TestCallbacks) -> AccountLoader<'a, TestCallbacks> {
-            AccountLoader::new_with_account_cache_capacity(
+            AccountLoader::new_with_loaded_accounts_capacity(
                 None,
                 callbacks,
                 &callbacks.feature_set,
@@ -1048,7 +1099,7 @@ mod tests {
             ..Default::default()
         };
         let feature_set = SVMFeatureSet::all_enabled();
-        let mut account_loader = AccountLoader::new_with_account_cache_capacity(
+        let mut account_loader = AccountLoader::new_with_loaded_accounts_capacity(
             account_overrides,
             &callbacks,
             &feature_set,
@@ -2226,11 +2277,7 @@ mod tests {
             // *not* key0, since it is loaded during fee payer validation
             (address1, vec![(Some(account1), true)]),
             (address2, vec![(None, true)]),
-            (
-                address3,
-                vec![(Some(account3.clone()), false), (Some(account3), false)],
-            ),
-            (bpf_loader::id(), vec![(None, false)]),
+            (address3, vec![(Some(account3), false)]),
         ];
         expected_inspected_accounts.sort_unstable_by(|a, b| a.0.cmp(&b.0));
 
@@ -2324,7 +2371,7 @@ mod tests {
         let feature_set = SVMFeatureSet::default();
         let test_transaction_data_size = |transaction, expected_size| {
             let mut account_loader =
-                AccountLoader::new_with_account_cache_capacity(None, &mock_bank, &feature_set, 0);
+                AccountLoader::new_with_loaded_accounts_capacity(None, &mock_bank, &feature_set, 0);
 
             let loaded_transaction_accounts = load_transaction_accounts(
                 &mut account_loader,
@@ -2500,4 +2547,101 @@ mod tests {
             );
         }
     }
+
+    #[test]
+    fn test_account_loader_wrappers() {
+        let fee_payer = Pubkey::new_unique();
+        let message = Message {
+            account_keys: vec![fee_payer],
+            header: MessageHeader::default(),
+            instructions: vec![],
+            recent_blockhash: Hash::default(),
+        };
+        let sanitized_message = new_unchecked_sanitized_message(message);
+
+        let mut fee_payer_account = AccountSharedData::default();
+        fee_payer_account.set_rent_epoch(u64::MAX);
+        fee_payer_account.set_lamports(5000);
+
+        let mut mock_bank = TestCallbacks::default();
+        mock_bank
+            .accounts_map
+            .insert(fee_payer, fee_payer_account.clone());
+
+        // test without stored account
+        let mut account_loader: AccountLoader<_> = (&mock_bank).into();
+        assert_eq!(
+            account_loader
+                .load_transaction_account(&fee_payer, false)
+                .unwrap()
+                .account,
+            fee_payer_account
+        );
+
+        let mut account_loader: AccountLoader<_> = (&mock_bank).into();
+        assert_eq!(
+            account_loader
+                .load_transaction_account(&fee_payer, true)
+                .unwrap()
+                .account,
+            fee_payer_account
+        );
+
+        let mut account_loader: AccountLoader<_> = (&mock_bank).into();
+        assert_eq!(
+            account_loader.load_account(&fee_payer).unwrap(),
+            fee_payer_account
+        );
+
+        let account_loader: AccountLoader<_> = (&mock_bank).into();
+        assert_eq!(
+            account_loader.get_account_shared_data(&fee_payer).unwrap(),
+            fee_payer_account
+        );
+
+        // test with stored account
+        let mut account_loader: AccountLoader<_> = (&mock_bank).into();
+        account_loader.load_account(&fee_payer).unwrap();
+
+        assert_eq!(
+            account_loader
+                .load_transaction_account(&fee_payer, false)
+                .unwrap()
+                .account,
+            fee_payer_account
+        );
+        assert_eq!(
+            account_loader
+                .load_transaction_account(&fee_payer, true)
+                .unwrap()
+                .account,
+            fee_payer_account
+        );
+        assert_eq!(
+            account_loader.load_account(&fee_payer).unwrap(),
+            fee_payer_account
+        );
+        assert_eq!(
+            account_loader.get_account_shared_data(&fee_payer).unwrap(),
+            fee_payer_account
+        );
+
+        // drop the account and ensure all deliver the updated state
+        fee_payer_account.set_lamports(0);
+        account_loader.update_accounts_for_failed_tx(
+            &sanitized_message,
+            &RollbackAccounts::FeePayerOnly { fee_payer_account },
+        );
+
+        assert_eq!(
+            account_loader.load_transaction_account(&fee_payer, false),
+            None
+        );
+        assert_eq!(
+            account_loader.load_transaction_account(&fee_payer, true),
+            None
+        );
+        assert_eq!(account_loader.load_account(&fee_payer), None);
+        assert_eq!(account_loader.get_account_shared_data(&fee_payer), None);
+    }
 }

+ 2 - 7
svm/src/transaction_balances.rs

@@ -86,11 +86,7 @@ impl BalanceCollector {
         let has_token_program = transaction.account_keys().iter().any(is_known_spl_token_id);
 
         for (index, key) in transaction.account_keys().iter().enumerate() {
-            // we load as read-only to avoid triggering a bad account inspection
-            let Some(account) = account_loader
-                .load_account(key, false)
-                .map(|loaded| loaded.account)
-            else {
+            let Some(account) = account_loader.load_account(key) else {
                 native_balances.push(0);
                 continue;
             };
@@ -190,8 +186,7 @@ impl SvmTokenInfo {
             amount,
         } = generic_token::Account::unpack(account.data(), &program_id)?;
 
-        // we load as read-only to avoid triggering a bad account inspection
-        let mint_account = account_loader.load_account(&mint, false)?.account;
+        let mint_account = account_loader.load_account(&mint)?;
         if *mint_account.owner() != program_id {
             return None;
         }

+ 13 - 6
svm/src/transaction_processor.rs

@@ -383,7 +383,7 @@ impl<FG: ForkGraph> TransactionBatchProcessor<FG> {
         let account_keys_in_batch = sanitized_txs.iter().map(|tx| tx.account_keys().len()).sum();
 
         // Create the account loader, which wraps all external account fetching.
-        let mut account_loader = AccountLoader::new_with_account_cache_capacity(
+        let mut account_loader = AccountLoader::new_with_loaded_accounts_capacity(
             config.account_overrides,
             callbacks,
             &environment.feature_set,
@@ -571,7 +571,11 @@ impl<FG: ForkGraph> TransactionBatchProcessor<FG> {
 
         let fee_payer_address = message.fee_payer();
 
-        let Some(mut loaded_fee_payer) = account_loader.load_account(fee_payer_address, true)
+        // We *must* use load_transaction_account() here because *this* is when the fee-payer
+        // is loaded for the transaction. Transaction loading skips the first account and
+        // loads (and thus inspects) all others normally.
+        let Some(mut loaded_fee_payer) =
+            account_loader.load_transaction_account(fee_payer_address, true)
         else {
             error_counters.account_not_found += 1;
             return Err(TransactionError::AccountNotFound);
@@ -628,11 +632,14 @@ impl<FG: ForkGraph> TransactionBatchProcessor<FG> {
         // We must validate the account in case it was reopened, either as a normal system account,
         // or a fake nonce account. We must also check the signer in case the authority was changed.
         //
+        // We do not need to inspect the nonce account here, because by definition it is either the
+        // first account, inspected in `validate_transaction_fee_payer()`, or the second through nth
+        // account, inspected in `load_transaction()`.
+        //
         // Note these checks are *not* obviated by fee-only transactions.
         let nonce_is_valid = account_loader
-            .load_account(nonce_info.address(), true)
-            .and_then(|loaded_nonce| {
-                let current_nonce_account = &loaded_nonce.account;
+            .load_account(nonce_info.address())
+            .and_then(|ref current_nonce_account| {
                 system_program::check_id(current_nonce_account.owner()).then_some(())?;
                 StateMut::<NonceVersions>::state(current_nonce_account).ok()
             })
@@ -1221,7 +1228,7 @@ mod tests {
 
     impl<'a> From<&'a MockBankCallback> for AccountLoader<'a, MockBankCallback> {
         fn from(callbacks: &'a MockBankCallback) -> AccountLoader<'a, MockBankCallback> {
-            AccountLoader::new_with_account_cache_capacity(
+            AccountLoader::new_with_loaded_accounts_capacity(
                 None,
                 callbacks,
                 &callbacks.feature_set,

+ 2 - 16
svm/tests/integration_test.rs

@@ -2446,12 +2446,6 @@ impl InspectedAccounts {
     fn inspect(&mut self, pubkey: Pubkey, inspect: Inspect) {
         self.0.entry(pubkey).or_default().push(inspect.into())
     }
-
-    fn inspect_n(&mut self, pubkey: Pubkey, inspect: Inspect, times: usize) {
-        for _ in 0..times {
-            self.inspect(pubkey, inspect.clone());
-        }
-    }
 }
 
 #[test]
@@ -2493,11 +2487,7 @@ fn svm_inspect_account() {
         true,
         0,
     );
-    expected_inspected_accounts.inspect_n(
-        system_program::id(),
-        Inspect::LiveRead(&system_account),
-        2,
-    );
+    expected_inspected_accounts.inspect(system_program::id(), Inspect::LiveRead(&system_account));
 
     let transfer_amount = 1_000_000;
     let transaction = Transaction::new_signed_with_payer(
@@ -2559,11 +2549,7 @@ fn svm_inspect_account() {
     );
 
     // system program
-    expected_inspected_accounts.inspect_n(
-        system_program::id(),
-        Inspect::LiveRead(&system_account),
-        2,
-    );
+    expected_inspected_accounts.inspect(system_program::id(), Inspect::LiveRead(&system_account));
 
     let mut final_test_entry = SvmTestEntry {
         initial_accounts: initial_test_entry.final_accounts.clone(),