mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-12 20:57:26 +08:00
fix smoltcp not wakeup closed socket
This commit is contained in:
@@ -8,14 +8,17 @@ use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
|
||||
use pnet::packet::MutablePacket;
|
||||
use pnet::packet::Packet;
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU16};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::timeout;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::common::error::Result;
|
||||
@@ -59,18 +62,31 @@ pub struct NatDstTcpConnector;
|
||||
#[async_trait::async_trait]
|
||||
impl NatDstConnector for NatDstTcpConnector {
|
||||
type DstStream = TcpStream;
|
||||
|
||||
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
|
||||
let socket = TcpSocket::new_v4().unwrap();
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
|
||||
Ok(
|
||||
tokio::time::timeout(Duration::from_secs(10), socket.connect(nat_dst))
|
||||
.await?
|
||||
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?,
|
||||
)
|
||||
const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
|
||||
const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
|
||||
const TCP_KEEPALIVE_RETRIES: u32 = 2;
|
||||
|
||||
let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst))
|
||||
.await?
|
||||
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?;
|
||||
|
||||
let ka = TcpKeepalive::new()
|
||||
.with_time(TCP_KEEPALIVE_TIME)
|
||||
.with_interval(TCP_KEEPALIVE_INTERVAL);
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
let ka = ka.with_retries(TCP_KEEPALIVE_RETRIES);
|
||||
|
||||
let sf = SockRef::from(&stream);
|
||||
sf.set_tcp_keepalive(&ka)?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool {
|
||||
@@ -214,11 +230,29 @@ impl SmolTcpListener {
|
||||
.unwrap();
|
||||
let tx = tx.clone();
|
||||
tasks.spawn(async move {
|
||||
let mut not_listening_count = 0;
|
||||
loop {
|
||||
tx.send(tcp.accept().await.map_err(|e| {
|
||||
anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into()
|
||||
}))
|
||||
.unwrap();
|
||||
select! {
|
||||
_ = tokio::time::sleep(Duration::from_secs(2)) => {
|
||||
if tcp.is_listening() {
|
||||
not_listening_count = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
not_listening_count += 1;
|
||||
if not_listening_count >= 2 {
|
||||
tracing::error!("smol tcp listener not listening");
|
||||
tcp.relisten();
|
||||
}
|
||||
}
|
||||
accept_ret = tcp.accept() => {
|
||||
tx.send(accept_ret.map_err(|e| {
|
||||
anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into()
|
||||
}))
|
||||
.unwrap();
|
||||
not_listening_count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -705,17 +739,19 @@ impl<C: NatDstConnector> TcpProxy<C> {
|
||||
nat_entry.tasks.lock().await.spawn(async move {
|
||||
let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await;
|
||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||
|
||||
let ret = src_tcp_stream.shutdown().await;
|
||||
nat_entry_clone.state.store(NatDstEntryState::ClosingSrc);
|
||||
let ret = timeout(Duration::from_secs(10), src_tcp_stream.shutdown()).await;
|
||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown");
|
||||
|
||||
let ret = dst_tcp_stream.shutdown().await;
|
||||
nat_entry_clone.state.store(NatDstEntryState::ClosingDst);
|
||||
let ret = timeout(Duration::from_secs(10), dst_tcp_stream.shutdown()).await;
|
||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown");
|
||||
|
||||
drop(src_tcp_stream);
|
||||
drop(dst_tcp_stream);
|
||||
|
||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||
// sleep later so the fin packet can be processed
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
|
||||
|
||||
@@ -91,6 +91,19 @@ async fn run(
|
||||
&mut device,
|
||||
&mut socket_allocator.sockets().lock(),
|
||||
);
|
||||
|
||||
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
|
||||
for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
|
||||
match socket {
|
||||
Socket::Tcp(tcp) => {
|
||||
if tcp.state() == smoltcp::socket::tcp::State::Closed {
|
||||
tcp.abort();
|
||||
}
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -67,6 +67,19 @@ impl TcpListener {
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
Ok(self.local_addr)
|
||||
}
|
||||
|
||||
pub fn relisten(&mut self) {
|
||||
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||
let local_endpoint = socket.local_endpoint().unwrap();
|
||||
socket.abort();
|
||||
socket.listen(local_endpoint).unwrap();
|
||||
self.reactor.notify();
|
||||
}
|
||||
|
||||
pub fn is_listening(&self) -> bool {
|
||||
let socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||
socket.is_listening()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Incoming(TcpListener);
|
||||
@@ -210,6 +223,9 @@ impl AsyncWrite for TcpStream {
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||
if !socket.may_send() {
|
||||
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
|
||||
}
|
||||
if socket.send_queue() == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
@@ -193,6 +193,10 @@ enum TcpProxyEntryState {
|
||||
Connected = 3;
|
||||
// connection closed
|
||||
Closed = 4;
|
||||
// closing src
|
||||
ClosingSrc = 5;
|
||||
// closing dst
|
||||
ClosingDst = 6;
|
||||
}
|
||||
|
||||
message TcpProxyEntry {
|
||||
|
||||
Reference in New Issue
Block a user