diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index ce4c7dc..5d7c2c1 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -1,13 +1,16 @@ use std::{ - net::{Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, time::Duration, }; use crate::{ gateway::{ - fast_socks5::server::{ - AcceptAuthentication, AsyncTcpConnector, Config, SimpleUserPassword, Socks5Socket, + fast_socks5::{ + server::{ + AcceptAuthentication, AsyncTcpConnector, Config, SimpleUserPassword, Socks5Socket, + }, + util::stream::tcp_connect_with_timeout, }, tokio_smoltcp::TcpStream, }, @@ -16,7 +19,10 @@ use crate::{ use anyhow::Context; use dashmap::DashSet; use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, Packet}; -use tokio::select; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + select, +}; use tokio::{ net::TcpListener, sync::{mpsc, Mutex}, @@ -31,6 +37,71 @@ use crate::{ tunnel::packet_def::ZCPacket, }; +enum SocksTcpStream { + TcpStream(tokio::net::TcpStream), + SmolTcpStream(TcpStream), +} + +impl AsyncRead for SocksTcpStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + SocksTcpStream::TcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_read(cx, buf) + } + SocksTcpStream::SmolTcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_read(cx, buf) + } + } + } +} + +impl AsyncWrite for SocksTcpStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + SocksTcpStream::TcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_write(cx, buf) + } + SocksTcpStream::SmolTcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_write(cx, buf) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + SocksTcpStream::TcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx), + SocksTcpStream::SmolTcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_flush(cx) + } + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + SocksTcpStream::TcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_shutdown(cx) + } + SocksTcpStream::SmolTcpStream(ref mut stream) => { + std::pin::Pin::new(stream).poll_shutdown(cx) + } + } + } +} + #[derive(Debug, Eq, PartialEq, Hash, Clone)] struct Socks5Entry { src: SocketAddr, @@ -132,30 +203,42 @@ impl Socks5ServerNet { #[async_trait::async_trait] impl AsyncTcpConnector for SmolTcpConnector { - type S = TcpStream; + type S = SocksTcpStream; async fn tcp_connect( &self, addr: SocketAddr, timeout_s: u64, - ) -> crate::gateway::fast_socks5::Result { + ) -> crate::gateway::fast_socks5::Result { + let local_addr = self.0.get_address(); let port = self.0.get_port(); let entry = Socks5Entry { - src: SocketAddr::new(self.0.get_address(), port), + src: SocketAddr::new(local_addr, port), dst: addr, }; *self.2.lock().unwrap() = Some(entry.clone()); self.1.insert(entry); - let remote_socket = timeout( - Duration::from_secs(timeout_s), - self.0.tcp_connect(addr, port), - ) - .await - .with_context(|| "connect to remote timeout")?; + if addr.ip() == local_addr { + let modified_addr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port()); - remote_socket.map_err(|e| super::fast_socks5::SocksError::Other(e.into())) + Ok(SocksTcpStream::TcpStream( + tcp_connect_with_timeout(modified_addr, timeout_s).await?, + )) + } else { + let remote_socket = timeout( + Duration::from_secs(timeout_s), + self.0.tcp_connect(addr, port), + ) + .await + .with_context(|| "connect to remote timeout")?; + + Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err( + |e| super::fast_socks5::SocksError::Other(e.into()), + )?)) + } } } diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 46d24b1..2d470e6 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -633,9 +633,10 @@ pub async fn wireguard_vpn_portal() { } #[cfg(feature = "wireguard")] +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn socks5_vpn_portal() { +pub async fn socks5_vpn_portal(#[values("10.144.144.1", "10.144.144.3")] dst_addr: &str) { use rand::Rng as _; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -649,13 +650,23 @@ pub async fn socks5_vpn_portal() { rand::thread_rng().fill(&mut buf[..]); let buf_clone = buf.clone(); + let dst_addr_clone = dst_addr.to_owned(); let task = tokio::spawn(async move { - let net_ns = NetNS::new(Some("net_c".into())); + let net_ns = if dst_addr_clone == "10.144.144.1" { + NetNS::new(Some("net_a".into())) + } else { + NetNS::new(Some("net_c".into())) + }; let _g = net_ns.guard(); - let socket = TcpListener::bind("10.144.144.3:22222").await.unwrap(); + let socket = TcpListener::bind("0.0.0.0:22222").await.unwrap(); let (mut st, addr) = socket.accept().await.unwrap(); - assert_eq!(addr.ip().to_string(), "10.144.144.1".to_string()); + + if dst_addr_clone == "10.144.144.3" { + assert_eq!(addr.ip().to_string(), "10.144.144.1".to_string()); + } else { + assert_eq!(addr.ip().to_string(), "127.0.0.1".to_string()); + } let rbuf = &mut [0u8; 1024]; st.read_exact(rbuf).await.unwrap(); @@ -670,7 +681,7 @@ pub async fn socks5_vpn_portal() { println!("connect to socks5 portal done"); stream.set_nodelay(true).unwrap(); - let mut conn = Socks5Stream::connect_with_socket(stream, "10.144.144.3:22222") + let mut conn = Socks5Stream::connect_with_socket(stream, format!("{}:22222", dst_addr)) .await .unwrap();