allow tcp port forward use kcp (#838)

This commit is contained in:
Sijie.Sun
2025-05-11 00:48:34 +08:00
committed by GitHub
parent c5580feb64
commit 72be46e8fa
6 changed files with 107 additions and 14 deletions

View File

@@ -1,10 +1,11 @@
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
sync::{Arc, Weak},
time::{Duration, Instant},
};
use crossbeam::atomic::AtomicCell;
use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream};
use crate::{
common::{
@@ -19,6 +20,7 @@ use crate::{
util::stream::tcp_connect_with_timeout,
},
ip_reassembler::IpReassembler,
kcp_proxy::NatDstKcpConnector,
tokio_smoltcp::{channel_device, Net, NetConfig},
},
tunnel::packet_def::{PacketType, ZCPacket},
@@ -43,6 +45,8 @@ use crate::{
peers::{peer_manager::PeerManager, PeerPacketFilter},
};
use super::tcp_proxy::NatDstConnector as _;
enum SocksUdpSocket {
UdpSocket(Arc<tokio::net::UdpSocket>),
SmolUdpSocket(super::tokio_smoltcp::UdpSocket),
@@ -67,6 +71,7 @@ impl SocksUdpSocket {
enum SocksTcpStream {
TcpStream(tokio::net::TcpStream),
SmolTcpStream(super::tokio_smoltcp::TcpStream),
KcpStream(KcpStream),
}
impl AsyncRead for SocksTcpStream {
@@ -82,6 +87,9 @@ impl AsyncRead for SocksTcpStream {
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
}
}
}
@@ -99,6 +107,9 @@ impl AsyncWrite for SocksTcpStream {
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
}
}
@@ -111,6 +122,7 @@ impl AsyncWrite for SocksTcpStream {
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_flush(cx)
}
SocksTcpStream::KcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
}
}
@@ -125,6 +137,9 @@ impl AsyncWrite for SocksTcpStream {
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
}
}
}
@@ -204,6 +219,40 @@ impl Drop for SmolTcpConnector {
}
}
struct Socks5KcpConnector {
kcp_endpoint: Weak<KcpEndpoint>,
peer_mgr: Weak<PeerManager>,
src_addr: SocketAddr,
}
#[async_trait::async_trait]
impl AsyncTcpConnector for Socks5KcpConnector {
type S = SocksTcpStream;
async fn tcp_connect(
&self,
addr: SocketAddr,
_timeout_s: u64,
) -> crate::gateway::fast_socks5::Result<SocksTcpStream> {
let Some(kcp_endpoint) = self.kcp_endpoint.upgrade() else {
return Err(anyhow::anyhow!("kcp endpoint is not ready").into());
};
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer mgr is not ready").into());
};
let c = NatDstKcpConnector {
kcp_endpoint,
peer_mgr,
};
println!("connect to kcp endpoint, addr = {:?}", addr);
let ret = c
.connect(self.src_addr, addr)
.await
.map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?;
Ok(SocksTcpStream::KcpStream(ret))
}
}
struct Socks5ServerNet {
ipv4_addr: cidr::Ipv4Inet,
auth: Option<SimpleUserPassword>,
@@ -345,6 +394,8 @@ pub struct Socks5Server {
tcp_forward_task: Arc<std::sync::Mutex<JoinSet<()>>>,
udp_client_map: Arc<DashMap<UdpClientKey, Arc<UdpClientInfo>>>,
udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>,
kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>,
}
#[async_trait::async_trait]
@@ -442,6 +493,8 @@ impl Socks5Server {
tcp_forward_task: Arc::new(std::sync::Mutex::new(JoinSet::new())),
udp_client_map: Arc::new(DashMap::new()),
udp_forward_task: Arc::new(DashMap::new()),
kcp_endpoint: Mutex::new(None),
})
}
@@ -487,7 +540,11 @@ impl Socks5Server {
});
}
pub async fn run(self: &Arc<Self>) -> Result<(), Error> {
pub async fn run(
self: &Arc<Self>,
kcp_endpoint: Option<Weak<KcpEndpoint>>,
) -> Result<(), Error> {
*self.kcp_endpoint.lock().await = kcp_endpoint;
let mut need_start = false;
if let Some(proxy_url) = self.global_ctx.config.get_socks5_portal() {
let bind_addr = format!(
@@ -539,7 +596,7 @@ impl Socks5Server {
async fn handle_port_forward_connection(
mut incoming_socket: tokio::net::TcpStream,
connector: SmolTcpConnector,
connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send>,
dst_addr: SocketAddr,
) {
let outgoing_socket = match connector.tcp_connect(dst_addr, 10).await {
@@ -601,10 +658,12 @@ impl Socks5Server {
let entries = self.entries.clone();
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
let forward_tasks = tasks.clone();
let kcp_endpoint = self.kcp_endpoint.lock().await.clone();
let peer_mgr = Arc::downgrade(&self.peer_manager.clone());
self.tasks.lock().unwrap().spawn(async move {
loop {
let (incoming_socket, _addr) = match listener.accept().await {
let (incoming_socket, addr) = match listener.accept().await {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
@@ -624,11 +683,21 @@ impl Socks5Server {
continue;
};
let connector = SmolTcpConnector {
net: net.smoltcp_net.clone(),
entries: entries.clone(),
current_entry: std::sync::Mutex::new(None),
};
let connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send> =
if kcp_endpoint.is_none() {
Box::new(SmolTcpConnector {
net: net.smoltcp_net.clone(),
entries: entries.clone(),
current_entry: std::sync::Mutex::new(None),
})
} else {
let kcp_endpoint = kcp_endpoint.as_ref().unwrap().clone();
Box::new(Socks5KcpConnector {
kcp_endpoint,
peer_mgr: peer_mgr.clone(),
src_addr: addr,
})
};
forward_tasks
.lock()