Kaynağa Gözat

fix: use proper serialization for versioned messages in get_fee_for_m… (#7719)

* fix: use proper serialization for versioned messages in get_fee_for_message

* Clippy satiated

Co-authored-by: kirill lykov <lykov.kirill@gmail.com>

* Try harder to satiate clippy

---------

Co-authored-by: Steven Luscher <steveluscher@users.noreply.github.com>
Co-authored-by: kirill lykov <lykov.kirill@gmail.com>
Swarna 2 ay önce
ebeveyn
işleme
3f25767626

+ 1 - 0
Cargo.lock

@@ -10137,6 +10137,7 @@ dependencies = [
  "solana-version",
  "solana-vote-interface",
  "static_assertions",
+ "test-case",
  "tokio",
 ]
 

+ 1 - 0
rpc-client/Cargo.toml

@@ -64,3 +64,4 @@ solana-pubkey = { workspace = true, features = ["rand"] }
 solana-signer = { workspace = true }
 solana-system-transaction = { workspace = true }
 static_assertions = { workspace = true }
+test-case = { workspace = true }

+ 2 - 1
rpc-client/src/nonblocking/rpc_client.rs

@@ -4671,7 +4671,8 @@ impl RpcClient {
         &self,
         message: &impl SerializableMessage,
     ) -> ClientResult<u64> {
-        let serialized_encoded = serialize_and_encode(message, UiTransactionEncoding::Base64)?;
+        let serialized = message.serialize();
+        let serialized_encoded = BASE64_STANDARD.encode(serialized);
         let result = self
             .send::<Response<Option<u64>>>(
                 RpcRequest::GetFeeForMessage,

+ 98 - 3
rpc-client/src/rpc_client.rs

@@ -62,9 +62,19 @@ impl RpcClientConfig {
 
 /// Trait used to add support for versioned messages to RPC APIs while
 /// retaining backwards compatibility
-pub trait SerializableMessage: Serialize {}
-impl SerializableMessage for LegacyMessage {}
-impl SerializableMessage for v0::Message {}
+pub trait SerializableMessage {
+    fn serialize(&self) -> Vec<u8>;
+}
+impl SerializableMessage for LegacyMessage {
+    fn serialize(&self) -> Vec<u8> {
+        self.serialize()
+    }
+}
+impl SerializableMessage for v0::Message {
+    fn serialize(&self) -> Vec<u8> {
+        self.serialize()
+    }
+}
 
 /// Trait used to add support for versioned transactions to RPC APIs while
 /// retaining backwards compatibility
@@ -3797,19 +3807,23 @@ mod tests {
         super::*,
         crate::mock_sender::PUBKEY,
         assert_matches::assert_matches,
+        base64::{prelude::BASE64_STANDARD, Engine},
         crossbeam_channel::unbounded,
         jsonrpc_core::{futures::prelude::*, Error, IoHandler, Params},
         jsonrpc_http_server::{AccessControlAllowOrigin, DomainsValidation, ServerBuilder},
         serde_json::{json, Number},
         solana_account_decoder::encode_ui_account,
         solana_account_decoder_client_types::UiAccountEncoding,
+        solana_hash::Hash,
         solana_instruction::error::InstructionError,
         solana_keypair::Keypair,
+        solana_message::{compiled_instruction::CompiledInstruction, MessageHeader},
         solana_rpc_client_api::client_error::ErrorKind,
         solana_signer::Signer,
         solana_system_transaction as system_transaction,
         solana_transaction_error::TransactionError,
         std::{io, thread},
+        test_case::test_case,
     };
 
     #[test]
@@ -4254,4 +4268,85 @@ mod tests {
             assert_eq!(expected_result, result1);
         }
     }
+
+    #[test_case(LegacyMessage {
+        header: MessageHeader {
+            num_required_signatures: 1,
+            num_readonly_signed_accounts: 0,
+            num_readonly_unsigned_accounts: 1,
+        },
+        account_keys: vec![Pubkey::new_unique()],
+        recent_blockhash: Hash::new_unique(),
+        instructions: vec![CompiledInstruction {
+            program_id_index: 1,
+            accounts: vec![0],
+            data: vec![],
+        }],
+    }; "legacy message")]
+    #[test_case(v0::Message {
+            header: MessageHeader {
+                num_required_signatures: 1,
+                num_readonly_signed_accounts: 0,
+                num_readonly_unsigned_accounts: 0,
+            },
+            account_keys: vec![Pubkey::new_unique()],
+            recent_blockhash: Hash::new_unique(),
+            instructions: vec![CompiledInstruction {
+                program_id_index: 0,
+                accounts: vec![],
+                data: vec![],
+            }],
+            address_table_lookups: vec![],
+        }; "v0 message")]
+    fn test_get_fee_for_message_sends_properly_serialized_v0_transaction<M>(message: M)
+    where
+        M: SerializableMessage,
+    {
+        let serialized_message = message.serialize();
+        let serialized_message_base64 = BASE64_STANDARD.encode(serialized_message);
+
+        let (sender, receiver) = unbounded();
+        thread::spawn(move || {
+            let rpc_addr = "0.0.0.0:0".parse().unwrap();
+            let mut io = IoHandler::default();
+            // Successful request
+            io.add_method("getFeeForMessage", move |params: Params| match params {
+                Params::Array(p) => {
+                    let first_element = p.first().unwrap();
+                    if let Value::String(actual_serialized_message) = first_element {
+                        assert_eq!(actual_serialized_message, &serialized_message_base64);
+                        return future::ok(json!(Response {
+                            context: RpcResponseContext {
+                                api_version: None,
+                                slot: 1,
+                            },
+                            value: json!(42),
+                        }));
+                    }
+                    future::err(Error::invalid_params(
+                        "Expected the serialized message to be the first element of the params",
+                    ))
+                }
+                _ => {
+                    panic!("Expected an array of params to be forwarded to `getFeeForMessage");
+                }
+            });
+
+            let server = ServerBuilder::new(io)
+                .threads(1)
+                .cors(DomainsValidation::AllowOnly(vec![
+                    AccessControlAllowOrigin::Any,
+                ]))
+                .start_http(&rpc_addr)
+                .expect("Unable to start RPC server");
+            sender.send(*server.address()).unwrap();
+            server.wait();
+        });
+
+        let rpc_addr = receiver.recv().unwrap();
+        let rpc_client = RpcClient::new_socket(rpc_addr);
+
+        let fee: u64 = rpc_client.get_fee_for_message(&message).unwrap();
+        assert_eq!(fee, 42);
+    }
 }