lib.rs 14 KB


  1. //! The `net_utils` module assists with networking
  2. use log::*;
  3. use rand::{thread_rng, Rng};
  4. use socket2::{Domain, SockAddr, Socket, Type};
  5. use std::io::{self, Read, Write};
  6. use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
  7. use std::sync::mpsc::channel;
  8. use std::time::Duration;
  9. mod ip_echo_server;
  10. use ip_echo_server::IpEchoServerMessage;
  11. pub use ip_echo_server::{ip_echo_server, IpEchoServer};
  12. /// A data type representing a public Udp socket
  13. pub struct UdpSocketPair {
  14. pub addr: SocketAddr, // Public address of the socket
  15. pub receiver: UdpSocket, // Locally bound socket that can receive from the public address
  16. pub sender: UdpSocket, // Locally bound socket to send via public address
  17. }
  18. pub type PortRange = (u16, u16);
  19. fn ip_echo_server_request(
  20. ip_echo_server_addr: &SocketAddr,
  21. msg: IpEchoServerMessage,
  22. ) -> Result<IpAddr, String> {
  23. let mut data = Vec::new();
  24. let timeout = Duration::new(5, 0);
  25. TcpStream::connect_timeout(ip_echo_server_addr, timeout)
  26. .and_then(|mut stream| {
  27. let msg = bincode::serialize(&msg).expect("serialize IpEchoServerMessage");
  28. stream.write_all(&msg)?;
  29. stream.shutdown(std::net::Shutdown::Write)?;
  30. stream
  31. .set_read_timeout(Some(Duration::new(10, 0)))
  32. .expect("set_read_timeout");
  33. stream.read_to_end(&mut data)
  34. })
  35. .and_then(|_| {
  36. bincode::deserialize(&data).map_err(|err| {
  37. io::Error::new(
  38. io::ErrorKind::Other,
  39. format!("Failed to deserialize: {:?}", err),
  40. )
  41. })
  42. })
  43. .map_err(|err| err.to_string())
  44. }
  45. /// Determine the public IP address of this machine by asking an ip_echo_server at the given
  46. /// address
  47. pub fn get_public_ip_addr(ip_echo_server_addr: &SocketAddr) -> Result<IpAddr, String> {
  48. ip_echo_server_request(ip_echo_server_addr, IpEchoServerMessage::default())
  49. }
  50. // Aborts the process if any of the provided TCP/UDP ports are not reachable by the machine at
  51. // `ip_echo_server_addr`
  52. pub fn verify_reachable_ports(
  53. ip_echo_server_addr: &SocketAddr,
  54. tcp_listeners: Vec<(u16, TcpListener)>,
  55. udp_sockets: &[&UdpSocket],
  56. ) {
  57. let udp: Vec<(_, _)> = udp_sockets
  58. .iter()
  59. .map(|udp_socket| {
  60. (
  61. udp_socket.local_addr().unwrap().port(),
  62. udp_socket.try_clone().expect("Unable to clone udp socket"),
  63. )
  64. })
  65. .collect();
  66. let udp_ports: Vec<_> = udp.iter().map(|x| x.0).collect();
  67. info!(
  68. "Checking that tcp ports {:?} and udp ports {:?} are reachable from {:?}",
  69. tcp_listeners, udp_ports, ip_echo_server_addr
  70. );
  71. let tcp_ports: Vec<_> = tcp_listeners.iter().map(|(port, _)| *port).collect();
  72. let _ = ip_echo_server_request(
  73. ip_echo_server_addr,
  74. IpEchoServerMessage::new(&tcp_ports, &udp_ports),
  75. )
  76. .map_err(|err| warn!("ip_echo_server request failed: {}", err));
  77. // Wait for a connection to open on each TCP port
  78. for (port, tcp_listener) in tcp_listeners {
  79. let (sender, receiver) = channel();
  80. std::thread::spawn(move || {
  81. debug!("Waiting for incoming connection on tcp/{}", port);
  82. let _ = tcp_listener.incoming().next().expect("tcp incoming failed");
  83. sender.send(()).expect("send failure");
  84. });
  85. receiver
  86. .recv_timeout(Duration::from_secs(5))
  87. .unwrap_or_else(|err| {
  88. error!(
  89. "Received no response at tcp/{}, check your port configuration: {}",
  90. port, err
  91. );
  92. std::process::exit(1);
  93. });
  94. info!("tdp/{} is reachable", port);
  95. }
  96. // Wait for a datagram to arrive at each UDP port
  97. for (port, udp_socket) in udp {
  98. let (sender, receiver) = channel();
  99. std::thread::spawn(move || {
  100. let mut buf = [0; 1];
  101. debug!("Waiting for incoming datagram on udp/{}", port);
  102. let _ = udp_socket.recv(&mut buf).expect("udp recv failure");
  103. sender.send(()).expect("send failure");
  104. });
  105. receiver
  106. .recv_timeout(Duration::from_secs(5))
  107. .unwrap_or_else(|err| {
  108. error!(
  109. "Received no response at udp/{}, check your port configuration: {}",
  110. port, err
  111. );
  112. std::process::exit(1);
  113. });
  114. info!("udp/{} is reachable", port);
  115. }
  116. }
  117. pub fn parse_port_or_addr(optstr: Option<&str>, default_addr: SocketAddr) -> SocketAddr {
  118. if let Some(addrstr) = optstr {
  119. if let Ok(port) = addrstr.parse() {
  120. let mut addr = default_addr;
  121. addr.set_port(port);
  122. addr
  123. } else if let Ok(addr) = addrstr.parse() {
  124. addr
  125. } else {
  126. default_addr
  127. }
  128. } else {
  129. default_addr
  130. }
  131. }
  132. pub fn parse_port_range(port_range: &str) -> Option<PortRange> {
  133. let ports: Vec<&str> = port_range.split('-').collect();
  134. if ports.len() != 2 {
  135. return None;
  136. }
  137. let start_port = ports[0].parse();
  138. let end_port = ports[1].parse();
  139. if start_port.is_err() || end_port.is_err() {
  140. return None;
  141. }
  142. let start_port = start_port.unwrap();
  143. let end_port = end_port.unwrap();
  144. if end_port < start_port {
  145. return None;
  146. }
  147. Some((start_port, end_port))
  148. }
  149. pub fn parse_host(host: &str) -> Result<IpAddr, String> {
  150. let ips: Vec<_> = (host, 0)
  151. .to_socket_addrs()
  152. .map_err(|err| err.to_string())?
  153. .map(|socket_address| socket_address.ip())
  154. .collect();
  155. if ips.is_empty() {
  156. Err(format!("Unable to resolve host: {}", host))
  157. } else {
  158. Ok(ips[0])
  159. }
  160. }
  161. pub fn parse_host_port(host_port: &str) -> Result<SocketAddr, String> {
  162. let addrs: Vec<_> = host_port
  163. .to_socket_addrs()
  164. .map_err(|err| err.to_string())?
  165. .collect();
  166. if addrs.is_empty() {
  167. Err(format!("Unable to resolve host: {}", host_port))
  168. } else {
  169. Ok(addrs[0])
  170. }
  171. }
  172. pub fn is_host_port(string: String) -> Result<(), String> {
  173. parse_host_port(&string)?;
  174. Ok(())
  175. }
  176. #[cfg(windows)]
  177. fn udp_socket(_reuseaddr: bool) -> io::Result<Socket> {
  178. let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
  179. Ok(sock)
  180. }
  181. #[cfg(not(windows))]
  182. fn udp_socket(reuseaddr: bool) -> io::Result<Socket> {
  183. use nix::sys::socket::setsockopt;
  184. use nix::sys::socket::sockopt::{ReuseAddr, ReusePort};
  185. use std::os::unix::io::AsRawFd;
  186. let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
  187. let sock_fd = sock.as_raw_fd();
  188. if reuseaddr {
  189. // best effort, i.e. ignore errors here, we'll get the failure in caller
  190. setsockopt(sock_fd, ReusePort, &true).ok();
  191. setsockopt(sock_fd, ReuseAddr, &true).ok();
  192. }
  193. Ok(sock)
  194. }
  195. // Find a port in the given range that is available for both TCP and UDP
  196. pub fn bind_common_in_range(range: PortRange) -> io::Result<(u16, (UdpSocket, TcpListener))> {
  197. let (start, end) = range;
  198. let mut tries_left = end - start;
  199. let mut rand_port = thread_rng().gen_range(start, end);
  200. loop {
  201. match bind_common(rand_port, false) {
  202. Ok((sock, listener)) => {
  203. break Result::Ok((sock.local_addr().unwrap().port(), (sock, listener)));
  204. }
  205. Err(err) => {
  206. if tries_left == 0 {
  207. return Err(err);
  208. }
  209. }
  210. }
  211. rand_port += 1;
  212. if rand_port == end {
  213. rand_port = start;
  214. }
  215. tries_left -= 1;
  216. }
  217. }
  218. pub fn bind_in_range(range: PortRange) -> io::Result<(u16, UdpSocket)> {
  219. let sock = udp_socket(false)?;
  220. let (start, end) = range;
  221. let mut tries_left = end - start;
  222. let mut rand_port = thread_rng().gen_range(start, end);
  223. loop {
  224. let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), rand_port);
  225. match sock.bind(&SockAddr::from(addr)) {
  226. Ok(_) => {
  227. let sock = sock.into_udp_socket();
  228. break Result::Ok((sock.local_addr().unwrap().port(), sock));
  229. }
  230. Err(err) => {
  231. if tries_left == 0 {
  232. return Err(err);
  233. }
  234. }
  235. }
  236. rand_port += 1;
  237. if rand_port == end {
  238. rand_port = start;
  239. }
  240. tries_left -= 1;
  241. }
  242. }
  243. // binds many sockets to the same port in a range
  244. pub fn multi_bind_in_range(range: PortRange, mut num: usize) -> io::Result<(u16, Vec<UdpSocket>)> {
  245. if cfg!(windows) && num != 1 {
  246. // See https://github.com/solana-labs/solana/issues/4607
  247. warn!(
  248. "multi_bind_in_range() only supports 1 socket in windows ({} requested)",
  249. num
  250. );
  251. num = 1;
  252. }
  253. let mut sockets = Vec::with_capacity(num);
  254. let port = {
  255. let (port, _) = bind_in_range(range)?;
  256. port
  257. }; // drop the probe, port should be available... briefly.
  258. for _ in 0..num {
  259. sockets.push(bind_to(port, true)?);
  260. }
  261. Ok((port, sockets))
  262. }
  263. pub fn bind_to(port: u16, reuseaddr: bool) -> io::Result<UdpSocket> {
  264. let sock = udp_socket(reuseaddr)?;
  265. let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
  266. match sock.bind(&SockAddr::from(addr)) {
  267. Ok(_) => Result::Ok(sock.into_udp_socket()),
  268. Err(err) => Err(err),
  269. }
  270. }
  271. // binds both a UdpSocket and a TcpListener
  272. pub fn bind_common(port: u16, reuseaddr: bool) -> io::Result<(UdpSocket, TcpListener)> {
  273. let sock = udp_socket(reuseaddr)?;
  274. let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
  275. let sock_addr = SockAddr::from(addr);
  276. match sock.bind(&sock_addr) {
  277. Ok(_) => match TcpListener::bind(&addr) {
  278. Ok(listener) => Result::Ok((sock.into_udp_socket(), listener)),
  279. Err(err) => Err(err),
  280. },
  281. Err(err) => Err(err),
  282. }
  283. }
  284. pub fn find_available_port_in_range(range: PortRange) -> io::Result<u16> {
  285. let (start, end) = range;
  286. let mut tries_left = end - start;
  287. let mut rand_port = thread_rng().gen_range(start, end);
  288. loop {
  289. match bind_common(rand_port, false) {
  290. Ok(_) => {
  291. break Ok(rand_port);
  292. }
  293. Err(err) => {
  294. if tries_left == 0 {
  295. return Err(err);
  296. }
  297. }
  298. }
  299. rand_port += 1;
  300. if rand_port == end {
  301. rand_port = start;
  302. }
  303. tries_left -= 1;
  304. }
  305. }
  306. #[cfg(test)]
  307. mod tests {
  308. use super::*;
  309. #[test]
  310. fn test_parse_port_or_addr() {
  311. let p1 = parse_port_or_addr(Some("9000"), SocketAddr::from(([1, 2, 3, 4], 1)));
  312. assert_eq!(p1.port(), 9000);
  313. let p2 = parse_port_or_addr(Some("127.0.0.1:7000"), SocketAddr::from(([1, 2, 3, 4], 1)));
  314. assert_eq!(p2.port(), 7000);
  315. let p2 = parse_port_or_addr(Some("hi there"), SocketAddr::from(([1, 2, 3, 4], 1)));
  316. assert_eq!(p2.port(), 1);
  317. let p3 = parse_port_or_addr(None, SocketAddr::from(([1, 2, 3, 4], 1)));
  318. assert_eq!(p3.port(), 1);
  319. }
  320. #[test]
  321. fn test_parse_port_range() {
  322. assert_eq!(parse_port_range("garbage"), None);
  323. assert_eq!(parse_port_range("1-"), None);
  324. assert_eq!(parse_port_range("1-2"), Some((1, 2)));
  325. assert_eq!(parse_port_range("1-2-3"), None);
  326. assert_eq!(parse_port_range("2-1"), None);
  327. }
  328. #[test]
  329. fn test_parse_host() {
  330. parse_host("localhost:1234").unwrap_err();
  331. parse_host("localhost").unwrap();
  332. parse_host("127.0.0.0:1234").unwrap_err();
  333. parse_host("127.0.0.0").unwrap();
  334. }
  335. #[test]
  336. fn test_parse_host_port() {
  337. parse_host_port("localhost:1234").unwrap();
  338. parse_host_port("localhost").unwrap_err();
  339. parse_host_port("127.0.0.0:1234").unwrap();
  340. parse_host_port("127.0.0.0").unwrap_err();
  341. }
  342. #[test]
  343. fn test_is_host_port() {
  344. assert!(is_host_port("localhost:1234".to_string()).is_ok());
  345. assert!(is_host_port("localhost".to_string()).is_err());
  346. }
  347. #[test]
  348. fn test_bind() {
  349. assert_eq!(bind_in_range((2000, 2001)).unwrap().0, 2000);
  350. let x = bind_to(2002, true).unwrap();
  351. let y = bind_to(2002, true).unwrap();
  352. assert_eq!(
  353. x.local_addr().unwrap().port(),
  354. y.local_addr().unwrap().port()
  355. );
  356. bind_to(2002, false).unwrap_err();
  357. bind_in_range((2002, 2003)).unwrap_err();
  358. let (port, v) = multi_bind_in_range((2010, 2110), 10).unwrap();
  359. for sock in &v {
  360. assert_eq!(port, sock.local_addr().unwrap().port());
  361. }
  362. }
  363. #[test]
  364. #[should_panic]
  365. fn test_bind_in_range_nil() {
  366. let _ = bind_in_range((2000, 2000));
  367. }
  368. #[test]
  369. fn test_find_available_port_in_range() {
  370. assert_eq!(find_available_port_in_range((3000, 3001)).unwrap(), 3000);
  371. let port = find_available_port_in_range((3000, 3050)).unwrap();
  372. assert!(3000 <= port && port < 3050);
  373. let _socket = bind_to(port, false).unwrap();
  374. find_available_port_in_range((port, port + 1)).unwrap_err();
  375. }
  376. #[test]
  377. fn test_bind_common_in_range() {
  378. let (port, _sockets) = bind_common_in_range((3100, 3150)).unwrap();
  379. assert!(3100 <= port && port < 3150);
  380. bind_common_in_range((port, port + 1)).unwrap_err();
  381. }
  382. #[test]
  383. fn test_get_public_ip_addr() {
  384. let (_server_port, (server_udp_socket, server_tcp_listener)) =
  385. bind_common_in_range((3200, 3250)).unwrap();
  386. let (client_port, (client_udp_socket, client_tcp_listener)) =
  387. bind_common_in_range((3200, 3250)).unwrap();
  388. let _runtime = ip_echo_server(server_tcp_listener);
  389. let ip_echo_server_addr = server_udp_socket.local_addr().unwrap();
  390. get_public_ip_addr(&ip_echo_server_addr).unwrap();
  391. verify_reachable_ports(
  392. &ip_echo_server_addr,
  393. vec![(client_port, client_tcp_listener)],
  394. &[&client_udp_socket],
  395. );
  396. }
  397. }