fix socks5 access local virtual ip

This commit is contained in:
sijie.sun
2024-08-17 22:51:44 +08:00
committed by Sijie.Sun
parent db660ee3b1
commit ad4cbbea6d
2 changed files with 113 additions and 19 deletions

View File

@@ -1,13 +1,16 @@
use std::{ use std::{
net::{Ipv4Addr, SocketAddr}, net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
}; };
use crate::{ use crate::{
gateway::{ gateway::{
fast_socks5::server::{ fast_socks5::{
AcceptAuthentication, AsyncTcpConnector, Config, SimpleUserPassword, Socks5Socket, server::{
AcceptAuthentication, AsyncTcpConnector, Config, SimpleUserPassword, Socks5Socket,
},
util::stream::tcp_connect_with_timeout,
}, },
tokio_smoltcp::TcpStream, tokio_smoltcp::TcpStream,
}, },
@@ -16,7 +19,10 @@ use crate::{
use anyhow::Context; use anyhow::Context;
use dashmap::DashSet; use dashmap::DashSet;
use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, Packet}; use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, Packet};
use tokio::select; use tokio::{
io::{AsyncRead, AsyncWrite},
select,
};
use tokio::{ use tokio::{
net::TcpListener, net::TcpListener,
sync::{mpsc, Mutex}, sync::{mpsc, Mutex},
@@ -31,6 +37,71 @@ use crate::{
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
}; };
enum SocksTcpStream {
TcpStream(tokio::net::TcpStream),
SmolTcpStream(TcpStream),
}
impl AsyncRead for SocksTcpStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
}
}
}
impl AsyncWrite for SocksTcpStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_flush(cx)
}
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
}
}
}
#[derive(Debug, Eq, PartialEq, Hash, Clone)] #[derive(Debug, Eq, PartialEq, Hash, Clone)]
struct Socks5Entry { struct Socks5Entry {
src: SocketAddr, src: SocketAddr,
@@ -132,30 +203,42 @@ impl Socks5ServerNet {
#[async_trait::async_trait] #[async_trait::async_trait]
impl AsyncTcpConnector for SmolTcpConnector { impl AsyncTcpConnector for SmolTcpConnector {
type S = TcpStream; type S = SocksTcpStream;
async fn tcp_connect( async fn tcp_connect(
&self, &self,
addr: SocketAddr, addr: SocketAddr,
timeout_s: u64, timeout_s: u64,
) -> crate::gateway::fast_socks5::Result<TcpStream> { ) -> crate::gateway::fast_socks5::Result<SocksTcpStream> {
let local_addr = self.0.get_address();
let port = self.0.get_port(); let port = self.0.get_port();
let entry = Socks5Entry { let entry = Socks5Entry {
src: SocketAddr::new(self.0.get_address(), port), src: SocketAddr::new(local_addr, port),
dst: addr, dst: addr,
}; };
*self.2.lock().unwrap() = Some(entry.clone()); *self.2.lock().unwrap() = Some(entry.clone());
self.1.insert(entry); self.1.insert(entry);
let remote_socket = timeout( if addr.ip() == local_addr {
Duration::from_secs(timeout_s), let modified_addr =
self.0.tcp_connect(addr, port), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port());
)
.await
.with_context(|| "connect to remote timeout")?;
remote_socket.map_err(|e| super::fast_socks5::SocksError::Other(e.into())) Ok(SocksTcpStream::TcpStream(
tcp_connect_with_timeout(modified_addr, timeout_s).await?,
))
} else {
let remote_socket = timeout(
Duration::from_secs(timeout_s),
self.0.tcp_connect(addr, port),
)
.await
.with_context(|| "connect to remote timeout")?;
Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err(
|e| super::fast_socks5::SocksError::Other(e.into()),
)?))
}
} }
} }

View File

@@ -633,9 +633,10 @@ pub async fn wireguard_vpn_portal() {
} }
#[cfg(feature = "wireguard")] #[cfg(feature = "wireguard")]
#[rstest::rstest]
#[tokio::test] #[tokio::test]
#[serial_test::serial] #[serial_test::serial]
pub async fn socks5_vpn_portal() { pub async fn socks5_vpn_portal(#[values("10.144.144.1", "10.144.144.3")] dst_addr: &str) {
use rand::Rng as _; use rand::Rng as _;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
@@ -649,13 +650,23 @@ pub async fn socks5_vpn_portal() {
rand::thread_rng().fill(&mut buf[..]); rand::thread_rng().fill(&mut buf[..]);
let buf_clone = buf.clone(); let buf_clone = buf.clone();
let dst_addr_clone = dst_addr.to_owned();
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let net_ns = NetNS::new(Some("net_c".into())); let net_ns = if dst_addr_clone == "10.144.144.1" {
NetNS::new(Some("net_a".into()))
} else {
NetNS::new(Some("net_c".into()))
};
let _g = net_ns.guard(); let _g = net_ns.guard();
let socket = TcpListener::bind("10.144.144.3:22222").await.unwrap(); let socket = TcpListener::bind("0.0.0.0:22222").await.unwrap();
let (mut st, addr) = socket.accept().await.unwrap(); let (mut st, addr) = socket.accept().await.unwrap();
assert_eq!(addr.ip().to_string(), "10.144.144.1".to_string());
if dst_addr_clone == "10.144.144.3" {
assert_eq!(addr.ip().to_string(), "10.144.144.1".to_string());
} else {
assert_eq!(addr.ip().to_string(), "127.0.0.1".to_string());
}
let rbuf = &mut [0u8; 1024]; let rbuf = &mut [0u8; 1024];
st.read_exact(rbuf).await.unwrap(); st.read_exact(rbuf).await.unwrap();
@@ -670,7 +681,7 @@ pub async fn socks5_vpn_portal() {
println!("connect to socks5 portal done"); println!("connect to socks5 portal done");
stream.set_nodelay(true).unwrap(); stream.set_nodelay(true).unwrap();
let mut conn = Socks5Stream::connect_with_socket(stream, "10.144.144.3:22222") let mut conn = Socks5Stream::connect_with_socket(stream, format!("{}:22222", dst_addr))
.await .await
.unwrap(); .unwrap();