cli for port forward and tcp whitelist (#1165)

This commit is contained in:
Sijie.Sun
2025-07-29 09:30:47 +08:00
committed by GitHub
parent 5514de1187
commit 2ec88da823
8 changed files with 828 additions and 171 deletions

View File

@@ -6,6 +6,7 @@ use std::{
use crossbeam::atomic::AtomicCell;
use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream};
use tokio_util::sync::{CancellationToken, DropGuard};
use crate::{
common::{
@@ -432,6 +433,8 @@ pub struct Socks5Server {
udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>,
kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>,
cancel_tokens: DashMap<PortForwardConfig, DropGuard>,
}
#[async_trait::async_trait]
@@ -531,6 +534,8 @@ impl Socks5Server {
udp_forward_task: Arc::new(DashMap::new()),
kcp_endpoint: Mutex::new(None),
cancel_tokens: DashMap::new(),
})
}
@@ -614,10 +619,9 @@ impl Socks5Server {
need_start = true;
};
for port_forward in self.global_ctx.config.get_port_forwards() {
self.add_port_forward(port_forward).await?;
need_start = true;
}
let cfgs = self.global_ctx.config.get_port_forwards();
self.reload_port_forwards(&cfgs).await?;
need_start = need_start || cfgs.len() > 0;
if need_start {
self.peer_manager
@@ -630,6 +634,26 @@ impl Socks5Server {
Ok(())
}
pub async fn reload_port_forwards(&self, cfgs: &Vec<PortForwardConfig>) -> Result<(), Error> {
// remove entries not in new cfg
self.cancel_tokens.retain(|k, _| {
cfgs.iter().any(|cfg| {
if cfg.dst_addr.ip().is_unspecified() {
k.bind_addr == cfg.bind_addr && k.proto == cfg.proto
} else {
k == cfg
}
})
});
// add new ones
for cfg in cfgs {
if !self.cancel_tokens.contains_key(cfg) {
self.add_port_forward(cfg.clone()).await?;
}
}
Ok(())
}
async fn handle_port_forward_connection(
mut incoming_socket: tokio::net::TcpStream,
connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send>,
@@ -660,12 +684,10 @@ impl Socks5Server {
pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> {
match cfg.proto.to_lowercase().as_str() {
"tcp" => {
self.add_tcp_port_forward(cfg.bind_addr, cfg.dst_addr)
.await?;
self.add_tcp_port_forward(&cfg).await?;
}
"udp" => {
self.add_udp_port_forward(cfg.bind_addr, cfg.dst_addr)
.await?;
self.add_udp_port_forward(&cfg).await?;
}
_ => {
return Err(anyhow::anyhow!(
@@ -680,11 +702,12 @@ impl Socks5Server {
Ok(())
}
pub async fn add_tcp_port_forward(
&self,
bind_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<(), Error> {
pub fn remove_port_forward(&self, cfg: PortForwardConfig) {
let _ = self.cancel_tokens.remove(&cfg);
}
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
let net = self.net.clone();
@@ -693,14 +716,26 @@ impl Socks5Server {
let forward_tasks = tasks.clone();
let kcp_endpoint = self.kcp_endpoint.lock().await.clone();
let peer_mgr = Arc::downgrade(&self.peer_manager.clone());
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move {
loop {
let (incoming_socket, addr) = match listener.accept().await {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
continue;
let (incoming_socket, addr) = select! {
biased;
_ = cancel_token.cancelled() => {
tracing::info!("port forward for {:?} cancelled", bind_addr);
break;
}
res = listener.accept() => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
continue;
}
}
}
};
@@ -747,11 +782,8 @@ impl Socks5Server {
}
#[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward(
&self,
bind_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<(), Error> {
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
let entries = self.entries.clone();
@@ -759,16 +791,28 @@ impl Socks5Server {
let net = self.net.clone();
let udp_client_map = self.udp_client_map.clone();
let udp_forward_task = self.udp_forward_task.clone();
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move {
loop {
// we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here.
let mut buf = vec![0u8; 8192];
let (len, addr) = match socket.recv_from(&mut buf).await {
Ok(result) => result,
Err(err) => {
tracing::error!("udp port forward recv error = {:?}", err);
continue;
let (len, addr) = select! {
biased;
_ = cancel_token.cancelled() => {
tracing::info!("udp port forward for {:?} cancelled", bind_addr);
break;
}
res = socket.recv_from(&mut buf) => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("udp port forward recv error = {:?}", err);
continue;
}
}
}
};