diff --git a/easytier/src/gateway/kcp_proxy.rs b/easytier/src/gateway/kcp_proxy.rs index 5a9c82a..163b636 100644 --- a/easytier/src/gateway/kcp_proxy.rs +++ b/easytier/src/gateway/kcp_proxy.rs @@ -20,7 +20,7 @@ use pnet::packet::{ Packet as _, }; use prost::Message; -use tokio::{io::copy_bidirectional, task::JoinSet}; +use tokio::{io::copy_bidirectional, select, task::JoinSet}; use super::{ tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy}, @@ -134,21 +134,54 @@ impl NatDstConnector for NatDstKcpConnector { return Err(anyhow::anyhow!("no dst peer found for nat dst: {}", nat_dst).into()); } - let ret = self - .kcp_endpoint - .connect( - Duration::from_secs(10), - self.peer_mgr.my_peer_id(), - dst_peers[0], - Bytes::from(conn_data.encode_to_vec()), - ) - .await - .with_context(|| format!("failed to connect to nat dst: {}", nat_dst.to_string()))?; + let mut connect_tasks: JoinSet> = JoinSet::new(); + let mut retry_remain = 5; + loop { + select! { + Some(Ok(Ok(ret))) = connect_tasks.join_next() => { + // just wait for the previous connection to finish + let stream = KcpStream::new(&self.kcp_endpoint, ret) + .ok_or(anyhow::anyhow!("failed to create kcp stream"))?; + return Ok(stream); + } + _ = tokio::time::sleep(Duration::from_millis(200)), if !connect_tasks.is_empty() && retry_remain > 0 => { + // no successful connection yet, trigger another connection attempt + } + else => { + // got error in connect_tasks, continue to retry + if retry_remain == 0 && connect_tasks.is_empty() { + break; + } + } + } - let stream = KcpStream::new(&self.kcp_endpoint, ret) - .ok_or(anyhow::anyhow!("failed to create kcp stream"))?; + // create a new connection task + if retry_remain == 0 { + continue; + } + retry_remain -= 1; - Ok(stream) + let kcp_endpoint = self.kcp_endpoint.clone(); + let peer_mgr = self.peer_mgr.clone(); + let dst_peer = dst_peers[0]; + let conn_data_clone = conn_data.clone(); + + connect_tasks.spawn(async move { + kcp_endpoint + .connect( + Duration::from_secs(10), + peer_mgr.my_peer_id(), + dst_peer, + Bytes::from(conn_data_clone.encode_to_vec()), + ) + .await + .with_context(|| { + format!("failed to connect to nat dst: {}", nat_dst.to_string()) + }) + }); + } + + Err(anyhow::anyhow!("failed to connect to nat dst: {}", nat_dst).into()) } fn check_packet_from_peer_fast(&self, _cidr_set: &CidrSet, _global_ctx: &GlobalCtx) -> bool {