fix smoltcp not wakeup closed socket

This commit is contained in:
sijie.sun
2025-03-12 23:14:36 +08:00
committed by Sijie.Sun
parent c2f535ead4
commit 389ea709ce
4 changed files with 82 additions and 13 deletions

View File

@@ -8,14 +8,17 @@ use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
use pnet::packet::MutablePacket;
use pnet::packet::Packet;
use socket2::{SockRef, TcpKeepalive};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::select;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinSet;
use tokio::time::timeout;
use tracing::Instrument;
use crate::common::error::Result;
@@ -59,18 +62,31 @@ pub struct NatDstTcpConnector;
#[async_trait::async_trait]
impl NatDstConnector for NatDstTcpConnector {
type DstStream = TcpStream;
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
let socket = TcpSocket::new_v4().unwrap();
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
Ok(
tokio::time::timeout(Duration::from_secs(10), socket.connect(nat_dst))
const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
const TCP_KEEPALIVE_RETRIES: u32 = 2;
let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst))
.await?
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?,
)
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?;
let ka = TcpKeepalive::new()
.with_time(TCP_KEEPALIVE_TIME)
.with_interval(TCP_KEEPALIVE_INTERVAL);
#[cfg(not(target_os = "windows"))]
let ka = ka.with_retries(TCP_KEEPALIVE_RETRIES);
let sf = SockRef::from(&stream);
sf.set_tcp_keepalive(&ka)?;
Ok(stream)
}
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool {
@@ -214,11 +230,29 @@ impl SmolTcpListener {
.unwrap();
let tx = tx.clone();
tasks.spawn(async move {
let mut not_listening_count = 0;
loop {
tx.send(tcp.accept().await.map_err(|e| {
select! {
_ = tokio::time::sleep(Duration::from_secs(2)) => {
if tcp.is_listening() {
not_listening_count = 0;
continue;
}
not_listening_count += 1;
if not_listening_count >= 2 {
tracing::error!("smol tcp listener not listening");
tcp.relisten();
}
}
accept_ret = tcp.accept() => {
tx.send(accept_ret.map_err(|e| {
anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into()
}))
.unwrap();
not_listening_count = 0;
}
}
}
});
}
@@ -705,17 +739,19 @@ impl<C: NatDstConnector> TcpProxy<C> {
nat_entry.tasks.lock().await.spawn(async move {
let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await;
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
nat_entry_clone.state.store(NatDstEntryState::Closed);
let ret = src_tcp_stream.shutdown().await;
nat_entry_clone.state.store(NatDstEntryState::ClosingSrc);
let ret = timeout(Duration::from_secs(10), src_tcp_stream.shutdown()).await;
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown");
let ret = dst_tcp_stream.shutdown().await;
nat_entry_clone.state.store(NatDstEntryState::ClosingDst);
let ret = timeout(Duration::from_secs(10), dst_tcp_stream.shutdown()).await;
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown");
drop(src_tcp_stream);
drop(dst_tcp_stream);
nat_entry_clone.state.store(NatDstEntryState::Closed);
// sleep later so the fin packet can be processed
tokio::time::sleep(Duration::from_secs(10)).await;

View File

@@ -91,6 +91,19 @@ async fn run(
&mut device,
&mut socket_allocator.sockets().lock(),
);
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
match socket {
Socket::Tcp(tcp) => {
if tcp.state() == smoltcp::socket::tcp::State::Closed {
tcp.abort();
}
}
#[allow(unreachable_patterns)]
_ => {}
}
}
}
Ok(())

View File

@@ -67,6 +67,19 @@ impl TcpListener {
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
pub fn relisten(&mut self) {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
let local_endpoint = socket.local_endpoint().unwrap();
socket.abort();
socket.listen(local_endpoint).unwrap();
self.reactor.notify();
}
pub fn is_listening(&self) -> bool {
let socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
socket.is_listening()
}
}
pub struct Incoming(TcpListener);
@@ -210,6 +223,9 @@ impl AsyncWrite for TcpStream {
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if !socket.may_send() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
if socket.send_queue() == 0 {
return Poll::Ready(Ok(()));
}

View File

@@ -193,6 +193,10 @@ enum TcpProxyEntryState {
Connected = 3;
// connection closed
Closed = 4;
// closing src
ClosingSrc = 5;
// closing dst
ClosingDst = 6;
}
message TcpProxyEntry {