mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-13 05:07:23 +08:00
fix smoltcp not wakeup closed socket
This commit is contained in:
@@ -8,14 +8,17 @@ use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
|||||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
|
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
|
||||||
use pnet::packet::MutablePacket;
|
use pnet::packet::MutablePacket;
|
||||||
use pnet::packet::Packet;
|
use pnet::packet::Packet;
|
||||||
|
use socket2::{SockRef, TcpKeepalive};
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU16};
|
use std::sync::atomic::{AtomicBool, AtomicU16};
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||||
|
use tokio::select;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
|
use tokio::time::timeout;
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::common::error::Result;
|
use crate::common::error::Result;
|
||||||
@@ -59,18 +62,31 @@ pub struct NatDstTcpConnector;
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl NatDstConnector for NatDstTcpConnector {
|
impl NatDstConnector for NatDstTcpConnector {
|
||||||
type DstStream = TcpStream;
|
type DstStream = TcpStream;
|
||||||
|
|
||||||
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
|
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
|
||||||
let socket = TcpSocket::new_v4().unwrap();
|
let socket = TcpSocket::new_v4().unwrap();
|
||||||
if let Err(e) = socket.set_nodelay(true) {
|
if let Err(e) = socket.set_nodelay(true) {
|
||||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(
|
const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
|
||||||
tokio::time::timeout(Duration::from_secs(10), socket.connect(nat_dst))
|
const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
|
||||||
|
const TCP_KEEPALIVE_RETRIES: u32 = 2;
|
||||||
|
|
||||||
|
let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst))
|
||||||
.await?
|
.await?
|
||||||
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?,
|
.with_context(|| format!("connect to nat dst failed: {:?}", nat_dst))?;
|
||||||
)
|
|
||||||
|
let ka = TcpKeepalive::new()
|
||||||
|
.with_time(TCP_KEEPALIVE_TIME)
|
||||||
|
.with_interval(TCP_KEEPALIVE_INTERVAL);
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
let ka = ka.with_retries(TCP_KEEPALIVE_RETRIES);
|
||||||
|
|
||||||
|
let sf = SockRef::from(&stream);
|
||||||
|
sf.set_tcp_keepalive(&ka)?;
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool {
|
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool {
|
||||||
@@ -214,11 +230,29 @@ impl SmolTcpListener {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let tx = tx.clone();
|
let tx = tx.clone();
|
||||||
tasks.spawn(async move {
|
tasks.spawn(async move {
|
||||||
|
let mut not_listening_count = 0;
|
||||||
loop {
|
loop {
|
||||||
tx.send(tcp.accept().await.map_err(|e| {
|
select! {
|
||||||
|
_ = tokio::time::sleep(Duration::from_secs(2)) => {
|
||||||
|
if tcp.is_listening() {
|
||||||
|
not_listening_count = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
not_listening_count += 1;
|
||||||
|
if not_listening_count >= 2 {
|
||||||
|
tracing::error!("smol tcp listener not listening");
|
||||||
|
tcp.relisten();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accept_ret = tcp.accept() => {
|
||||||
|
tx.send(accept_ret.map_err(|e| {
|
||||||
anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into()
|
anyhow::anyhow!("smol tcp listener accept failed: {:?}", e).into()
|
||||||
}))
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
not_listening_count = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -705,17 +739,19 @@ impl<C: NatDstConnector> TcpProxy<C> {
|
|||||||
nat_entry.tasks.lock().await.spawn(async move {
|
nat_entry.tasks.lock().await.spawn(async move {
|
||||||
let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await;
|
let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await;
|
||||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
||||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
|
||||||
|
|
||||||
let ret = src_tcp_stream.shutdown().await;
|
nat_entry_clone.state.store(NatDstEntryState::ClosingSrc);
|
||||||
|
let ret = timeout(Duration::from_secs(10), src_tcp_stream.shutdown()).await;
|
||||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown");
|
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown");
|
||||||
|
|
||||||
let ret = dst_tcp_stream.shutdown().await;
|
nat_entry_clone.state.store(NatDstEntryState::ClosingDst);
|
||||||
|
let ret = timeout(Duration::from_secs(10), dst_tcp_stream.shutdown()).await;
|
||||||
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown");
|
tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown");
|
||||||
|
|
||||||
drop(src_tcp_stream);
|
drop(src_tcp_stream);
|
||||||
drop(dst_tcp_stream);
|
drop(dst_tcp_stream);
|
||||||
|
|
||||||
|
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||||
// sleep later so the fin packet can be processed
|
// sleep later so the fin packet can be processed
|
||||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||||
|
|
||||||
|
|||||||
@@ -91,6 +91,19 @@ async fn run(
|
|||||||
&mut device,
|
&mut device,
|
||||||
&mut socket_allocator.sockets().lock(),
|
&mut socket_allocator.sockets().lock(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
|
||||||
|
for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
|
||||||
|
match socket {
|
||||||
|
Socket::Tcp(tcp) => {
|
||||||
|
if tcp.state() == smoltcp::socket::tcp::State::Closed {
|
||||||
|
tcp.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[allow(unreachable_patterns)]
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -67,6 +67,19 @@ impl TcpListener {
|
|||||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||||
Ok(self.local_addr)
|
Ok(self.local_addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn relisten(&mut self) {
|
||||||
|
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||||
|
let local_endpoint = socket.local_endpoint().unwrap();
|
||||||
|
socket.abort();
|
||||||
|
socket.listen(local_endpoint).unwrap();
|
||||||
|
self.reactor.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_listening(&self) -> bool {
|
||||||
|
let socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||||
|
socket.is_listening()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Incoming(TcpListener);
|
pub struct Incoming(TcpListener);
|
||||||
@@ -210,6 +223,9 @@ impl AsyncWrite for TcpStream {
|
|||||||
}
|
}
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
|
||||||
|
if !socket.may_send() {
|
||||||
|
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
|
||||||
|
}
|
||||||
if socket.send_queue() == 0 {
|
if socket.send_queue() == 0 {
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,6 +193,10 @@ enum TcpProxyEntryState {
|
|||||||
Connected = 3;
|
Connected = 3;
|
||||||
// connection closed
|
// connection closed
|
||||||
Closed = 4;
|
Closed = 4;
|
||||||
|
// closing src
|
||||||
|
ClosingSrc = 5;
|
||||||
|
// closing dst
|
||||||
|
ClosingDst = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TcpProxyEntry {
|
message TcpProxyEntry {
|
||||||
|
|||||||
Reference in New Issue
Block a user