From 389ea709ce1ec3d894871e5906743ce222698aaa Mon Sep 17 00:00:00 2001 From: "sijie.sun" Date: Wed, 12 Mar 2025 23:14:36 +0800 Subject: [PATCH] fix smoltcp not wakeup closed socket --- easytier/src/gateway/tcp_proxy.rs | 62 +++++++++++++++---- easytier/src/gateway/tokio_smoltcp/reactor.rs | 13 ++++ easytier/src/gateway/tokio_smoltcp/socket.rs | 16 +++++ easytier/src/proto/cli.proto | 4 ++ 4 files changed, 82 insertions(+), 13 deletions(-) diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index 4630dcf..05d617b 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -8,14 +8,17 @@ use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet}; use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket}; use pnet::packet::MutablePacket; use pnet::packet::Packet; +use socket2::{SockRef, TcpKeepalive}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::select; use tokio::sync::{mpsc, Mutex}; use tokio::task::JoinSet; +use tokio::time::timeout; use tracing::Instrument; use crate::common::error::Result; @@ -59,18 +62,31 @@ pub struct NatDstTcpConnector; #[async_trait::async_trait] impl NatDstConnector for NatDstTcpConnector { type DstStream = TcpStream; - async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result { let socket = TcpSocket::new_v4().unwrap(); if let Err(e) = socket.set_nodelay(true) { tracing::warn!("set_nodelay failed, ignore it: {:?}", e); } - Ok( - tokio::time::timeout(Duration::from_secs(10), socket.connect(nat_dst)) - .await? - .with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?, - ) + const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(5); + const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2); + const TCP_KEEPALIVE_RETRIES: u32 = 2; + + let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst)) + .await? + .with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?; + + let ka = TcpKeepalive::new() + .with_time(TCP_KEEPALIVE_TIME) + .with_interval(TCP_KEEPALIVE_INTERVAL); + + #[cfg(not(target_os = "windows"))] + let ka = ka.with_retries(TCP_KEEPALIVE_RETRIES); + + let sf = SockRef::from(&stream); + sf.set_tcp_keepalive(&ka)?; + + Ok(stream) } fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool { @@ -214,11 +230,29 @@ impl SmolTcpListener { .unwrap(); let tx = tx.clone(); tasks.spawn(async move { + let mut not_listening_count = 0; loop { - tx.send(tcp.accept().await.map_err(|e| { - anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into() - })) - .unwrap(); + select! { + _ = tokio::time::sleep(Duration::from_secs(2)) => { + if tcp.is_listening() { + not_listening_count = 0; + continue; + } + + not_listening_count += 1; + if not_listening_count >= 2 { + tracing::error!("smol tcp listener not listening"); + tcp.relisten(); + } + } + accept_ret = tcp.accept() => { + tx.send(accept_ret.map_err(|e| { + anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into() + })) + .unwrap(); + not_listening_count = 0; + } + } } }); } @@ -705,17 +739,19 @@ impl TcpProxy { nat_entry.tasks.lock().await.spawn(async move { let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await; tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed"); - nat_entry_clone.state.store(NatDstEntryState::Closed); - let ret = src_tcp_stream.shutdown().await; + nat_entry_clone.state.store(NatDstEntryState::ClosingSrc); + let ret = timeout(Duration::from_secs(10), src_tcp_stream.shutdown()).await; tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown"); - let ret = dst_tcp_stream.shutdown().await; + nat_entry_clone.state.store(NatDstEntryState::ClosingDst); + let ret = timeout(Duration::from_secs(10), dst_tcp_stream.shutdown()).await; tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown"); drop(src_tcp_stream); drop(dst_tcp_stream); + nat_entry_clone.state.store(NatDstEntryState::Closed); // sleep later so the fin packet can be processed tokio::time::sleep(Duration::from_secs(10)).await; diff --git a/easytier/src/gateway/tokio_smoltcp/reactor.rs b/easytier/src/gateway/tokio_smoltcp/reactor.rs index c6ed005..b41c492 100644 --- a/easytier/src/gateway/tokio_smoltcp/reactor.rs +++ b/easytier/src/gateway/tokio_smoltcp/reactor.rs @@ -91,6 +91,19 @@ async fn run( &mut device, &mut socket_allocator.sockets().lock(), ); + + // wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets) + for (_, socket) in socket_allocator.sockets().lock().iter_mut() { + match socket { + Socket::Tcp(tcp) => { + if tcp.state() == smoltcp::socket::tcp::State::Closed { + tcp.abort(); + } + } + #[allow(unreachable_patterns)] + _ => {} + } + } } Ok(()) diff --git a/easytier/src/gateway/tokio_smoltcp/socket.rs b/easytier/src/gateway/tokio_smoltcp/socket.rs index ae4e579..f46f80f 100644 --- a/easytier/src/gateway/tokio_smoltcp/socket.rs +++ b/easytier/src/gateway/tokio_smoltcp/socket.rs @@ -67,6 +67,19 @@ impl TcpListener { pub fn local_addr(&self) -> io::Result { Ok(self.local_addr) } + + pub fn relisten(&mut self) { + let mut socket = self.reactor.get_socket::(*self.handle); + let local_endpoint = socket.local_endpoint().unwrap(); + socket.abort(); + socket.listen(local_endpoint).unwrap(); + self.reactor.notify(); + } + + pub fn is_listening(&self) -> bool { + let socket = self.reactor.get_socket::(*self.handle); + socket.is_listening() + } } pub struct Incoming(TcpListener); @@ -210,6 +223,9 @@ impl AsyncWrite for TcpStream { } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut socket = self.reactor.get_socket::(*self.handle); + if !socket.may_send() { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())); + } if socket.send_queue() == 0 { return Poll::Ready(Ok(())); } diff --git a/easytier/src/proto/cli.proto b/easytier/src/proto/cli.proto index f53e102..22096d2 100644 --- a/easytier/src/proto/cli.proto +++ b/easytier/src/proto/cli.proto @@ -193,6 +193,10 @@ enum TcpProxyEntryState { Connected = 3; // connection closed Closed = 4; + // closing src + ClosingSrc = 5; + // closing dst + ClosingDst = 6; } message TcpProxyEntry {