From 7fc4aecdb951e20ab47d6cea7e59fdc9e6c064b1 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Thu, 8 Feb 2024 16:27:18 +0800 Subject: [PATCH] Fix udp and win route (#16) * robust udp tunnel * fix windows route add * use pnet to get index * windows disable udp reset --- easytier-core/Cargo.toml | 3 + easytier-core/src/arch/mod.rs | 2 + easytier-core/src/arch/windows.rs | 142 ++++++++++++++++++ easytier-core/src/common/ifcfg.rs | 23 ++- easytier-core/src/connector/udp_hole_punch.rs | 12 +- easytier-core/src/main.rs | 1 + easytier-core/src/tunnels/common.rs | 6 + easytier-core/src/tunnels/udp_tunnel.rs | 107 ++++++++++++- 8 files changed, 285 insertions(+), 11 deletions(-) create mode 100644 easytier-core/src/arch/mod.rs create mode 100644 easytier-core/src/arch/windows.rs diff --git a/easytier-core/Cargo.toml b/easytier-core/Cargo.toml index 8827ebf..11da408 100644 --- a/easytier-core/Cargo.toml +++ b/easytier-core/Cargo.toml @@ -87,6 +87,9 @@ public-ip = { version = "0.2", features = ["default"] } clap = { version = "4.4", features = ["derive"] } +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.52", features = ["Win32_Networking_WinSock", "Win32_NetworkManagement_IpHelper", "Win32_Foundation", "Win32_System_IO"] } + [build-dependencies] tonic-build = "0.10" diff --git a/easytier-core/src/arch/mod.rs b/easytier-core/src/arch/mod.rs new file mode 100644 index 0000000..581af77 --- /dev/null +++ b/easytier-core/src/arch/mod.rs @@ -0,0 +1,2 @@ +#[cfg(target_os = "windows")] +pub mod windows; diff --git a/easytier-core/src/arch/windows.rs b/easytier-core/src/arch/windows.rs new file mode 100644 index 0000000..d3040e6 --- /dev/null +++ b/easytier-core/src/arch/windows.rs @@ -0,0 +1,142 @@ +use std::{ + ffi::c_void, + io::{self, ErrorKind}, + mem, + net::SocketAddr, + os::windows::io::AsRawSocket, + ptr, +}; + +use windows_sys::{ + core::PCSTR, + Win32::{ + Foundation::{BOOL, FALSE}, + Networking::WinSock::{ + htonl, setsockopt, WSAGetLastError, WSAIoctl, IPPROTO_IP, IPPROTO_IPV6, + IPV6_UNICAST_IF, IP_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET, SOCKET_ERROR, + }, + }, +}; + +use crate::tunnels::common::get_interface_name_by_ip; + +pub fn disable_connection_reset(socket: &S) -> io::Result<()> { + let handle = socket.as_raw_socket() as SOCKET; + + unsafe { + // Ignoring UdpSocket's WSAECONNRESET error + // https://github.com/shadowsocks/shadowsocks-rust/issues/179 + // https://stackoverflow.com/questions/30749423/is-winsock-error-10054-wsaeconnreset-normal-with-udp-to-from-localhost + // + // This is because `UdpSocket::recv_from` may return WSAECONNRESET + // if you called `UdpSocket::send_to` a destination that is not existed (may be closed). + // + // It is not an error. Could be ignored completely. + // We have to ignore it here because it will crash the server. + + let mut bytes_returned: u32 = 0; + let enable: BOOL = FALSE; + + let ret = WSAIoctl( + handle, + SIO_UDP_CONNRESET, + &enable as *const _ as *const c_void, + mem::size_of_val(&enable) as u32, + ptr::null_mut(), + 0, + &mut bytes_returned as *mut _, + ptr::null_mut(), + None, + ); + + if ret == SOCKET_ERROR { + use std::io::Error; + + // Error occurs + let err_code = WSAGetLastError(); + return Err(Error::from_raw_os_error(err_code)); + } + } + + Ok(()) +} + +pub fn find_interface_index_cached(iface_name: &str) -> io::Result { + let ifaces = pnet::datalink::interfaces(); + for iface in ifaces { + if iface.name == iface_name { + return Ok(iface.index); + } + } + let err = io::Error::new( + ErrorKind::NotFound, + format!("Failed to find interface index for {}", iface_name), + ); + Err(err) +} + +pub fn set_ip_unicast_if( + socket: &S, + addr: &SocketAddr, + iface: &str, +) -> io::Result<()> { + let handle = socket.as_raw_socket() as SOCKET; + + let if_index = find_interface_index_cached(iface)?; + + unsafe { + // https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + let ret = match addr { + SocketAddr::V4(..) => { + // Interface index is in network byte order for IPPROTO_IP. + let if_index = htonl(if_index); + setsockopt( + handle, + IPPROTO_IP as i32, + IP_UNICAST_IF as i32, + &if_index as *const _ as PCSTR, + mem::size_of_val(&if_index) as i32, + ) + } + SocketAddr::V6(..) => { + // Interface index is in host byte order for IPPROTO_IPV6. + setsockopt( + handle, + IPPROTO_IPV6 as i32, + IPV6_UNICAST_IF as i32, + &if_index as *const _ as PCSTR, + mem::size_of_val(&if_index) as i32, + ) + } + }; + + if ret == SOCKET_ERROR { + let err = io::Error::from_raw_os_error(WSAGetLastError()); + tracing::error!( + "set IP_UNICAST_IF / IPV6_UNICAST_IF interface: {}, index: {}, error: {}", + iface, + if_index, + err + ); + return Err(err); + } + } + + Ok(()) +} + +pub fn setup_socket_for_win( + socket: &S, + bind_addr: &SocketAddr, + is_udp: bool, +) -> io::Result<()> { + if is_udp { + disable_connection_reset(socket)?; + } + + if let Some(iface) = get_interface_name_by_ip(&bind_addr.ip()) { + set_ip_unicast_if(socket, bind_addr, iface.as_str())?; + } + + Ok(()) +} diff --git a/easytier-core/src/common/ifcfg.rs b/easytier-core/src/common/ifcfg.rs index a50324d..38e5f57 100644 --- a/easytier-core/src/common/ifcfg.rs +++ b/easytier-core/src/common/ifcfg.rs @@ -196,8 +196,17 @@ impl IfConfiguerTrait for LinuxIfConfiger { } } +#[cfg(target_os = "windows")] pub struct WindowsIfConfiger {} +#[cfg(target_os = "windows")] +impl WindowsIfConfiger { + pub fn get_interface_index(name: &str) -> Option { + crate::arch::windows::find_interface_index_cached(name).ok() + } +} + +#[cfg(target_os = "windows")] #[async_trait] impl IfConfiguerTrait for WindowsIfConfiger { async fn add_ipv4_route( @@ -206,12 +215,15 @@ impl IfConfiguerTrait for WindowsIfConfiger { address: Ipv4Addr, cidr_prefix: u8, ) -> Result<(), Error> { + let Some(idx) = Self::get_interface_index(name) else { + return Err(Error::NotFound); + }; run_shell_cmd( format!( - "route add {} mask {} {}", + "route ADD {} MASK {} 10.1.1.1 IF {} METRIC 255", address, cidr_to_subnet_mask(cidr_prefix), - name + idx ) .as_str(), ) @@ -224,12 +236,15 @@ impl IfConfiguerTrait for WindowsIfConfiger { address: Ipv4Addr, cidr_prefix: u8, ) -> Result<(), Error> { + let Some(idx) = Self::get_interface_index(name) else { + return Err(Error::NotFound); + }; run_shell_cmd( format!( - "route delete {} mask {} {}", + "route DELETE {} MASK {} IF {}", address, cidr_to_subnet_mask(cidr_prefix), - name + idx ) .as_str(), ) diff --git a/easytier-core/src/connector/udp_hole_punch.rs b/easytier-core/src/connector/udp_hole_punch.rs index ffe9a7f..a7b2a51 100644 --- a/easytier-core/src/connector/udp_hole_punch.rs +++ b/easytier-core/src/connector/udp_hole_punch.rs @@ -14,6 +14,7 @@ use crate::{ }, peers::{peer_manager::PeerManager, PeerId}, tunnels::{ + common::setup_sokcet2, udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener, }, @@ -387,9 +388,14 @@ impl UdpHolePunchConnector { .unwrap(), ); - let socket = UdpSocket::bind(local_socket_addr) - .await - .with_context(|| "")?; + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(local_socket_addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + setup_sokcet2(&socket2_socket, &local_socket_addr)?; + let socket = UdpSocket::from_std(socket2_socket.into())?; + Ok(connector .try_connect_with_socket(socket) .await diff --git a/easytier-core/src/main.rs b/easytier-core/src/main.rs index 928674b..d673370 100644 --- a/easytier-core/src/main.rs +++ b/easytier-core/src/main.rs @@ -5,6 +5,7 @@ mod tests; use clap::Parser; +mod arch; mod common; mod connector; mod gateway; diff --git a/easytier-core/src/tunnels/common.rs b/easytier-core/src/tunnels/common.rs index fc3ab5b..cae44d5 100644 --- a/easytier-core/src/tunnels/common.rs +++ b/easytier-core/src/tunnels/common.rs @@ -273,6 +273,12 @@ pub(crate) fn setup_sokcet2( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, ) -> Result<(), TunnelError> { + #[cfg(target_os = "windows")] + { + let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM); + crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, is_udp)?; + } + socket2_socket.set_nonblocking(true)?; socket2_socket.set_reuse_address(true)?; socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?; diff --git a/easytier-core/src/tunnels/udp_tunnel.rs b/easytier-core/src/tunnels/udp_tunnel.rs index a4934c9..37d477a 100644 --- a/easytier-core/src/tunnels/udp_tunnel.rs +++ b/easytier-core/src/tunnels/udp_tunnel.rs @@ -121,7 +121,11 @@ fn get_tunnel_from_socket( } let (buf, addr) = v.unwrap(); - assert_eq!(addr, recv_addr.clone()); + if recv_addr != addr { + tracing::warn!(?addr, ?recv_addr, "udp recv addr not match"); + return None; + } + Some(Ok(try_get_data_payload(buf, conn_id.clone())?)) }); let stream = Box::pin(stream); @@ -304,7 +308,7 @@ impl TunnelListener for UdpTunnelListener { }; if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) { - let conn = Self::handle_connect( + let Ok(conn) = Self::handle_connect( socket.clone(), addr, forward_tasks.clone(), @@ -313,7 +317,10 @@ impl TunnelListener for UdpTunnelListener { udp_packet.conn_id.into(), ) .await - .unwrap(); + else { + tracing::error!(?addr, "udp handle connect error"); + continue; + }; if let Err(e) = conn_send.send(conn).await { tracing::warn!(?e, "udp send conn to accept channel error"); } @@ -465,6 +472,9 @@ impl UdpTunnelConnector { let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "udp")?; log::warn!("udp connect: {:?}", self.addr); + #[cfg(target_os = "windows")] + crate::arch::windows::disable_connection_reset(&socket)?; + // send syn let conn_id = rand::random(); let udp_packet = UdpPacket::new_syn_packet(conn_id); @@ -544,7 +554,12 @@ impl super::TunnelConnector for UdpTunnelConnector { #[cfg(test)] mod tests { - use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; + use std::time::Duration; + + use rand::Rng; + use tokio::time::timeout; + + use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}; use super::*; @@ -578,4 +593,88 @@ mod tests { connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); _tunnel_pingpong(listener, connector).await } + + async fn send_random_data_to_socket(remote_url: url::Url) { + let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); + socket + .connect(format!( + "{}:{}", + remote_url.host().unwrap(), + remote_url.port().unwrap() + )) + .await + .unwrap(); + + // get a random 100-len buf + loop { + let mut buf = vec![0u8; 100]; + rand::thread_rng().fill(&mut buf[..]); + socket.send(&buf).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + } + + #[tokio::test] + async fn udp_multiple_conns() { + let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); + listener.listen().await.unwrap(); + + let _lis = tokio::spawn(async move { + loop { + let ret = listener.accept().await.unwrap(); + assert_eq!( + ret.info().unwrap().local_addr, + listener.local_url().to_string() + ); + tokio::spawn(async move { _tunnel_echo_server(ret, false).await }); + } + }); + + let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); + let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); + + let t1 = connector1.connect().await.unwrap(); + let t2 = connector2.connect().await.unwrap(); + + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()), + )); + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()), + )); + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()), + )); + + let sender1 = tokio::spawn(async move { + let mut sink = t1.pin_sink(); + let mut stream = t1.pin_stream(); + + for i in 0..10 { + sink.send(Bytes::from("hello1")).await.unwrap(); + let recv = stream.next().await.unwrap().unwrap(); + println!("t1 recv: {:?}, {:?}", recv, i); + assert_eq!(recv, Bytes::from("hello1")); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + }); + + let sender2 = tokio::spawn(async move { + let mut sink = t2.pin_sink(); + let mut stream = t2.pin_stream(); + + for i in 0..10 { + sink.send(Bytes::from("hello2")).await.unwrap(); + let recv = stream.next().await.unwrap().unwrap(); + println!("t2 recv: {:?}, {:?}", recv, i); + assert_eq!(recv, Bytes::from("hello2")); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + }); + + let _ = tokio::join!(sender1, sender2); + } }