diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index 8ad0080..ed5e273 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -97,10 +97,55 @@ impl ProxyTcpStream { } } +#[cfg(feature = "smoltcp")] +struct SmolTcpListener { + listener_task: JoinSet<()>, + listen_count: usize, + + stream_rx: mpsc::UnboundedReceiver>, +} + +#[cfg(feature = "smoltcp")] +impl SmolTcpListener { + pub async fn new(net: Arc>>, listen_count: usize) -> Self { + let mut tasks = JoinSet::new(); + + let (tx, rx) = mpsc::unbounded_channel(); + let locked_net = net.lock().await; + for _ in 0..listen_count { + let mut tcp = locked_net + .as_ref() + .unwrap() + .tcp_bind("0.0.0.0:8899".parse().unwrap()) + .await + .unwrap(); + let tx = tx.clone(); + tasks.spawn(async move { + loop { + tx.send(tcp.accept().await.map_err(|e| { + anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into() + })) + .unwrap(); + } + }); + } + + Self { + listener_task: tasks, + listen_count, + stream_rx: rx, + } + } + + pub async fn accept(&mut self) -> Result<(tokio_smoltcp::TcpStream, SocketAddr)> { + self.stream_rx.recv().await.unwrap() + } +} + enum ProxyTcpListener { KernelTcpListener(TcpListener), #[cfg(feature = "smoltcp")] - SmolTcpListener(tokio_smoltcp::TcpListener), + SmolTcpListener(SmolTcpListener), } impl ProxyTcpListener { @@ -375,8 +420,8 @@ impl TcpProxy { ), ); net.set_any_ip(true); - let tcp = net.tcp_bind("0.0.0.0:8899".parse().unwrap()).await?; self.smoltcp_net.lock().await.replace(net); + let tcp = SmolTcpListener::new(self.smoltcp_net.clone(), 64).await; self.enable_smoltcp .store(true, std::sync::atomic::Ordering::Relaxed);