diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 6da3a7f..baf8e93 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -61,11 +61,11 @@ pub struct GlobalCtx { cached_ipv4: AtomicCell>, cached_proxy_cidrs: AtomicCell>>, - ip_collector: Arc, + ip_collector: Mutex>>, hostname: Mutex, - stun_info_collection: Box, + stun_info_collection: Mutex>, running_listeners: Mutex>, @@ -120,11 +120,14 @@ impl GlobalCtx { cached_ipv4: AtomicCell::new(None), cached_proxy_cidrs: AtomicCell::new(None), - ip_collector: Arc::new(IPCollector::new(net_ns, stun_info_collection.clone())), + ip_collector: Mutex::new(Some(Arc::new(IPCollector::new( + net_ns, + stun_info_collection.clone(), + )))), hostname: Mutex::new(hostname), - stun_info_collection: Box::new(stun_info_collection), + stun_info_collection: Mutex::new(stun_info_collection), running_listeners: Mutex::new(Vec::new()), @@ -215,7 +218,7 @@ impl GlobalCtx { } pub fn get_ip_collector(&self) -> Arc { - self.ip_collector.clone() + self.ip_collector.lock().unwrap().as_ref().unwrap().clone() } pub fn get_hostname(&self) -> String { @@ -226,19 +229,19 @@ impl GlobalCtx { *self.hostname.lock().unwrap() = hostname; } - pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ { - self.stun_info_collection.as_ref() + pub fn get_stun_info_collector(&self) -> Arc { + self.stun_info_collection.lock().unwrap().clone() } pub fn replace_stun_info_collector(&self, collector: Box) { - // force replace the stun_info_collection without mut and drop the old one - let ptr = &self.stun_info_collection as *const Box; - let ptr = ptr as *mut Box; - unsafe { - std::ptr::drop_in_place(ptr); - #[allow(invalid_reference_casting)] - std::ptr::write(ptr, collector); - } + let arc_collector: Arc = Arc::new(collector); + *self.stun_info_collection.lock().unwrap() = arc_collector.clone(); + + // rebuild the ip collector + *self.ip_collector.lock().unwrap() = Some(Arc::new(IPCollector::new( + self.net_ns.clone(), + arc_collector, + ))); } pub fn get_running_listeners(&self) -> Vec { diff --git a/easytier/src/common/network.rs b/easytier/src/common/network.rs index 948a98f..8396a30 100644 --- a/easytier/src/common/network.rs +++ b/easytier/src/common/network.rs @@ -179,18 +179,16 @@ impl IPCollector { Self::do_collect_local_ip_addrs(self.net_ns.clone()).await; let net_ns = self.net_ns.clone(); let stun_info_collector = self.stun_info_collector.clone(); - task.spawn(async move { - loop { - let ip_addrs = Self::do_collect_local_ip_addrs(net_ns.clone()).await; - *cached_ip_list.write().await = ip_addrs; - tokio::time::sleep(std::time::Duration::from_secs(CACHED_IP_LIST_TIMEOUT_SEC)) - .await; - } - }); - let cached_ip_list = self.cached_ip_list.clone(); task.spawn(async move { + let mut last_fetch_iface_time = std::time::Instant::now(); loop { + if last_fetch_iface_time.elapsed().as_secs() > CACHED_IP_LIST_TIMEOUT_SEC { + let ifaces = Self::do_collect_local_ip_addrs(net_ns.clone()).await; + *cached_ip_list.write().await = ifaces; + last_fetch_iface_time = std::time::Instant::now(); + } + let stun_info = stun_info_collector.get_stun_info(); for ip in stun_info.public_ip.iter() { let Ok(ip_addr) = ip.parse::() else { @@ -199,14 +197,20 @@ impl IPCollector { match ip_addr { IpAddr::V4(v) => { - cached_ip_list.write().await.public_ipv4 = Some(v.into()) + cached_ip_list.write().await.public_ipv4.replace(v.into()); } IpAddr::V6(v) => { - cached_ip_list.write().await.public_ipv6 = Some(v.into()) + cached_ip_list.write().await.public_ipv6.replace(v.into()); } } } + tracing::debug!( + "got public ip: {:?}, {:?}", + cached_ip_list.read().await.public_ipv4, + cached_ip_list.read().await.public_ipv6 + ); + let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_none() { CACHED_IP_LIST_TIMEOUT_SEC } else { @@ -217,7 +221,7 @@ impl IPCollector { }); } - return self.cached_ip_list.read().await.deref().clone(); + self.cached_ip_list.read().await.deref().clone() } pub async fn collect_interfaces(net_ns: NetNS, filter: bool) -> Vec { diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index 18dfca9..a250ae0 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -890,7 +890,7 @@ impl StunInfoCollectorTrait for MockStunInfoCollector { last_update_time: std::time::Instant::now().elapsed().as_secs() as i64, min_port: 100, max_port: 200, - public_ip: vec!["127.0.0.1".to_string()], + public_ip: vec!["127.0.0.1".to_string(), "::1".to_string()], } } diff --git a/easytier/src/connector/direct.rs b/easytier/src/connector/direct.rs index 06408fe..e8896f8 100644 --- a/easytier/src/connector/direct.rs +++ b/easytier/src/connector/direct.rs @@ -12,29 +12,31 @@ use std::{ }; use crate::{ - common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, + common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, peers::{ - peer_manager::PeerManager, peer_rpc::PeerRpcManager, + peer_conn::PeerConnId, + peer_manager::PeerManager, + peer_rpc::PeerRpcManager, peer_rpc_service::DirectConnectorManagerRpcServer, + peer_task::{PeerTaskLauncher, PeerTaskManager}, }, proto::{ peer_rpc::{ DirectConnectorRpc, DirectConnectorRpcClientFactory, DirectConnectorRpcServer, - GetIpListRequest, GetIpListResponse, + GetIpListRequest, GetIpListResponse, SendV6HolePunchPacketRequest, }, rpc_types::controller::BaseController, }, - tunnel::IpVersion, + tunnel::{udp::UdpTunnelConnector, IpVersion}, }; use crate::proto::cli::PeerConnInfo; use anyhow::Context; use rand::Rng; -use tokio::{task::JoinSet, time::timeout}; -use tracing::Instrument; +use tokio::{net::UdpSocket, task::JoinSet, time::timeout}; use url::Host; -use super::create_connector_by_url; +use super::{create_connector_by_url, udp_hole_punch}; pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1; pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300; @@ -77,7 +79,7 @@ impl PeerManagerForDirectConnector for PeerManager { struct DstBlackListItem(PeerId, String); #[derive(Hash, Eq, PartialEq, Clone)] -struct DstListenerUrlBlackListItem(PeerId, url::Url); +struct DstListenerUrlBlackListItem(PeerId, String); struct DirectConnectorManagerData { global_ctx: ArcGlobalCtx, @@ -93,95 +95,114 @@ impl DirectConnectorManagerData { dst_listener_blacklist: timedmap::TimedMap::new(), } } -} -impl std::fmt::Debug for DirectConnectorManagerData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DirectConnectorManagerData") - .field("peer_manager", &self.peer_manager) - .finish() - } -} + async fn remote_send_v6_hole_punch_packet( + &self, + dst_peer_id: PeerId, + local_socket: &UdpSocket, + remote_url: &url::Url, + ) -> Result<(), Error> { + let global_ctx = self.peer_manager.get_global_ctx(); + let listener_port = remote_url.port().ok_or(anyhow::anyhow!( + "failed to parse port from remote url: {}", + remote_url + ))?; + let connector_ip = global_ctx + .get_stun_info_collector() + .get_stun_info() + .public_ip + .iter() + .find(|x| x.contains(":")) + .ok_or(anyhow::anyhow!( + "failed to get public ipv6 address from stun info" + ))? + .parse::() + .with_context(|| { + format!( + "failed to parse public ipv6 address from stun info: {:?}", + global_ctx.get_stun_info_collector().get_stun_info() + ) + })?; + let connector_addr = SocketAddr::new( + std::net::IpAddr::V6(connector_ip), + local_socket.local_addr()?.port(), + ); -pub struct DirectConnectorManager { - global_ctx: ArcGlobalCtx, - data: Arc, - - tasks: JoinSet<()>, -} - -impl DirectConnectorManager { - pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc) -> Self { - Self { - global_ctx: global_ctx.clone(), - data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)), - tasks: JoinSet::new(), - } - } - - pub fn run(&mut self) { - if self.global_ctx.get_flags().disable_p2p { - return; - } - - self.run_as_server(); - self.run_as_client(); - } - - pub fn run_as_server(&mut self) { - self.data + let rpc_stub = self .peer_manager .get_peer_rpc_mgr() - .rpc_server() - .registry() - .register( - DirectConnectorRpcServer::new(DirectConnectorManagerRpcServer::new( - self.global_ctx.clone(), - )), - &self.data.global_ctx.get_network_name(), - ); - } - - pub fn run_as_client(&mut self) { - let data = self.data.clone(); - let my_peer_id = self.data.peer_manager.my_peer_id(); - self.tasks.spawn( - async move { - loop { - let peers = data.peer_manager.list_peers().await; - let mut tasks = JoinSet::new(); - for peer_id in peers { - if peer_id == my_peer_id - || data.peer_manager.has_directly_connected_conn(peer_id) - { - continue; - } - tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id)); - } - - while let Some(task_ret) = tasks.join_next().await { - tracing::debug!(?task_ret, ?my_peer_id, "direct connect task ret"); - } - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - } - } - .instrument( - tracing::info_span!("direct_connector_client", my_id = ?self.global_ctx.id), - ), + .rpc_client() + .scoped_client::>( + self.peer_manager.my_peer_id(), + dst_peer_id, + global_ctx.get_network_name(), ); + + rpc_stub + .send_v6_hole_punch_packet( + BaseController::default(), + SendV6HolePunchPacketRequest { + listener_port: listener_port as u32, + connector_addr: Some(connector_addr.into()), + }, + ) + .await + .with_context(|| { + format!( + "do rpc, send v6 hole punch packet to peer {} at {}", + dst_peer_id, remote_url + ) + })?; + + Ok(()) } - async fn do_try_connect_to_ip( - data: Arc, + async fn connect_to_public_ipv6( + &self, dst_peer_id: PeerId, - addr: String, - ) -> Result<(), Error> { - let connector = create_connector_by_url(&addr, &data.global_ctx, IpVersion::Both).await?; - let (peer_id, conn_id) = timeout( - std::time::Duration::from_secs(3), - data.peer_manager.try_direct_connect(connector), + remote_url: &url::Url, + ) -> Result<(PeerId, PeerConnId), Error> { + let local_socket = Arc::new( + UdpSocket::bind("[::]:0") + .await + .with_context(|| format!("failed to bind local socket for {}", remote_url))?, + ); + + // ask remote to send v6 hole punch packet + // and no matter what the result is, continue to connect + let _ = self + .remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, &remote_url) + .await; + + let udp_connector = UdpTunnelConnector::new(remote_url.clone()); + let remote_addr = super::check_scheme_and_get_socket_addr::( + &remote_url, + "udp", + IpVersion::V6, ) - .await??; + .await?; + let ret = udp_connector + .try_connect_with_socket(local_socket, remote_addr) + .await?; + + // NOTICE: must add as directly connected tunnel + self.peer_manager.add_direct_tunnel(ret).await + } + + async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> { + let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?; + let remote_url = connector.remote_url(); + let (peer_id, conn_id) = + if remote_url.scheme() == "udp" && matches!(remote_url.host(), Some(Host::Ipv6(_))) { + self.connect_to_public_ipv6(dst_peer_id, &remote_url) + .await? + } else { + timeout( + std::time::Duration::from_secs(3), + self.peer_manager.try_direct_connect(connector), + ) + .await?? + }; if peer_id != dst_peer_id && !TESTING.load(Ordering::Relaxed) { tracing::info!( @@ -190,7 +211,7 @@ impl DirectConnectorManager { dst_peer_id, peer_id ); - data.peer_manager + self.peer_manager .get_peer_map() .close_peer_conn(peer_id, &conn_id) .await?; @@ -202,7 +223,7 @@ impl DirectConnectorManager { #[tracing::instrument] async fn try_connect_to_ip( - data: Arc, + self: Arc, dst_peer_id: PeerId, addr: String, ) -> Result<(), Error> { @@ -210,11 +231,23 @@ impl DirectConnectorManager { let backoff_ms = vec![1000, 2000]; let mut backoff_idx = 0; + self.dst_listener_blacklist.cleanup(); + + if self + .dst_listener_blacklist + .contains(&DstListenerUrlBlackListItem( + dst_peer_id.clone(), + addr.clone(), + )) + { + return Err(Error::UrlInBlacklist); + } + loop { - let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await; + let ret = self.do_try_connect_to_ip(dst_peer_id, addr.clone()).await; tracing::debug!(?ret, ?dst_peer_id, ?addr, "try_connect_to_ip return"); - if matches!(ret, Err(Error::UrlInBlacklist) | Ok(_)) { - return ret; + if ret.is_ok() { + return Ok(()); } if backoff_idx < backoff_ms.len() { @@ -230,6 +263,11 @@ impl DirectConnectorManager { backoff_idx += 1; continue; } else { + self.dst_listener_blacklist.insert( + DstListenerUrlBlackListItem(dst_peer_id.clone(), addr), + (), + std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC), + ); return ret; } } @@ -237,24 +275,17 @@ impl DirectConnectorManager { #[tracing::instrument] async fn do_try_direct_connect_internal( - data: Arc, + self: &Arc, dst_peer_id: PeerId, ip_list: GetIpListResponse, ) -> Result<(), Error> { - data.dst_listener_blacklist.cleanup(); - - let enable_ipv6 = data.global_ctx.get_flags().enable_ipv6; + let enable_ipv6 = self.global_ctx.get_flags().enable_ipv6; let available_listeners = ip_list .listeners .into_iter() .map(Into::::into) .filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None }) .filter(|l| l.port().is_some() && l.host().is_some()) - .filter(|l| { - !data - .dst_listener_blacklist - .contains(&DstListenerUrlBlackListItem(dst_peer_id.clone(), l.clone())) - }) .filter(|l| enable_ipv6 || !matches!(l.host().unwrap().to_owned(), Host::Ipv6(_))) .collect::>(); @@ -267,7 +298,7 @@ impl DirectConnectorManager { // if have default listener, use it first let listener = available_listeners .iter() - .find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol) + .find(|l| l.scheme() == self.global_ctx.get_flags().default_protocol) .unwrap_or(available_listeners.get(0).unwrap()); let mut tasks = bounded_join_set::JoinSet::new(2); @@ -284,7 +315,7 @@ impl DirectConnectorManager { let mut addr = (*listener).clone(); if addr.set_host(Some(ip.to_string().as_str())).is_ok() { tasks.spawn(Self::try_connect_to_ip( - data.clone(), + self.clone(), dst_peer_id.clone(), addr.to_string(), )); @@ -299,7 +330,7 @@ impl DirectConnectorManager { }); } else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) { tasks.spawn(Self::try_connect_to_ip( - data.clone(), + self.clone(), dst_peer_id.clone(), listener.to_string(), )); @@ -330,7 +361,7 @@ impl DirectConnectorManager { .is_ok() { tasks.spawn(Self::try_connect_to_ip( - data.clone(), + self.clone(), dst_peer_id.clone(), addr.to_string(), )); @@ -345,7 +376,7 @@ impl DirectConnectorManager { }); } else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) { tasks.spawn(Self::try_connect_to_ip( - data.clone(), + self.clone(), dst_peer_id.clone(), listener.to_string(), )); @@ -356,11 +387,9 @@ impl DirectConnectorManager { } } - let mut has_succ = false; while let Some(ret) = tasks.join_next().await { match ret { Ok(Ok(_)) => { - has_succ = true; tracing::info!( ?dst_peer_id, ?listener, @@ -377,42 +406,150 @@ impl DirectConnectorManager { } } - if !has_succ { - data.dst_listener_blacklist.insert( - DstListenerUrlBlackListItem(dst_peer_id.clone(), listener.clone()), - (), - std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC), - ); - } - Ok(()) } #[tracing::instrument] async fn do_try_direct_connect( - data: Arc, + self: Arc, dst_peer_id: PeerId, ) -> Result<(), Error> { - let peer_manager = data.peer_manager.clone(); - tracing::debug!("try direct connect to peer: {}", dst_peer_id); + let mut backoff = + udp_hole_punch::BackOff::new(vec![1000, 2000, 2000, 5000, 5000, 10000, 30000, 60000]); + loop { + let peer_manager = self.peer_manager.clone(); + tracing::debug!("try direct connect to peer: {}", dst_peer_id); - let rpc_stub = peer_manager - .get_peer_rpc_mgr() - .rpc_client() - .scoped_client::>( + let rpc_stub = peer_manager + .get_peer_rpc_mgr() + .rpc_client() + .scoped_client::>( peer_manager.my_peer_id(), dst_peer_id, - data.global_ctx.get_network_name(), + self.global_ctx.get_network_name(), ); - let ip_list = rpc_stub - .get_ip_list(BaseController::default(), GetIpListRequest {}) + let ip_list = rpc_stub + .get_ip_list(BaseController::default(), GetIpListRequest {}) + .await + .with_context(|| format!("get ip list from peer {}", dst_peer_id))?; + + tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list"); + + let ret = self + .do_try_direct_connect_internal(dst_peer_id, ip_list) + .await; + tracing::info!(?ret, ?dst_peer_id, "do_try_direct_connect return"); + + if peer_manager.has_directly_connected_conn(dst_peer_id) { + tracing::info!( + "direct connect to peer {} success, has direct conn", + dst_peer_id + ); + return Ok(()); + } + + tokio::time::sleep(Duration::from_millis(backoff.next_backoff())).await; + } + } +} + +impl std::fmt::Debug for DirectConnectorManagerData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DirectConnectorManagerData") + .field("peer_manager", &self.peer_manager) + .finish() + } +} + +pub struct DirectConnectorManager { + global_ctx: ArcGlobalCtx, + data: Arc, + client: PeerTaskManager, + tasks: JoinSet<()>, +} + +#[derive(Clone)] +struct DirectConnectorLauncher(Arc); + +#[async_trait::async_trait] +impl PeerTaskLauncher for DirectConnectorLauncher { + type Data = Arc; + type CollectPeerItem = PeerId; + type TaskRet = (); + + fn new_data(&self, _peer_mgr: Arc) -> Self::Data { + self.0.clone() + } + + async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec { + let my_peer_id = data.peer_manager.my_peer_id(); + data.peer_manager + .list_peers() .await - .with_context(|| format!("get ip list from peer {}", dst_peer_id))?; + .into_iter() + .filter(|peer_id| { + *peer_id != my_peer_id && !data.peer_manager.has_directly_connected_conn(*peer_id) + }) + .collect() + } - tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list"); + async fn launch_task( + &self, + data: &Self::Data, + item: Self::CollectPeerItem, + ) -> tokio::task::JoinHandle> { + let data = data.clone(); + tokio::spawn(async move { data.do_try_direct_connect(item).await.map_err(Into::into) }) + } - Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await + async fn all_task_done(&self, _data: &Self::Data) {} + + fn loop_interval_ms(&self) -> u64 { + 5000 + } +} + +impl DirectConnectorManager { + pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc) -> Self { + let data = Arc::new(DirectConnectorManagerData::new( + global_ctx.clone(), + peer_manager.clone(), + )); + let client = PeerTaskManager::new(DirectConnectorLauncher(data.clone()), peer_manager); + Self { + global_ctx, + data, + client, + tasks: JoinSet::new(), + } + } + + pub fn run(&mut self) { + if self.global_ctx.get_flags().disable_p2p { + return; + } + + self.run_as_server(); + self.run_as_client(); + } + + pub fn run_as_server(&mut self) { + self.data + .peer_manager + .get_peer_rpc_mgr() + .rpc_server() + .registry() + .register( + DirectConnectorRpcServer::new(DirectConnectorManagerRpcServer::new( + self.global_ctx.clone(), + )), + &self.data.global_ctx.get_network_name(), + ); + } + + pub fn run_as_client(&mut self) { + self.client.start(); } } @@ -491,6 +628,13 @@ mod tests { wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + p_c.get_global_ctx() + .get_ip_collector() + .collect_ip_addrs() + .await; + + tokio::time::sleep(std::time::Duration::from_secs(4)).await; + let mut dm_a = DirectConnectorManager::new(p_a.get_global_ctx(), p_a.clone()); let mut dm_c = DirectConnectorManager::new(p_c.get_global_ctx(), p_c.clone()); @@ -525,6 +669,7 @@ mod tests { #[tokio::test] async fn direct_connector_scheme_blacklist() { + TESTING.store(true, std::sync::atomic::Ordering::Relaxed); let p_a = create_mock_peer_manager().await; let data = Arc::new(DirectConnectorManagerData::new( p_a.get_global_ctx(), @@ -539,7 +684,7 @@ mod tests { .interface_ipv4s .push("127.0.0.1".parse::().unwrap().into()); - DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone()) + data.do_try_direct_connect_internal(1, ip_list.clone()) .await .unwrap(); diff --git a/easytier/src/connector/udp_hole_punch/common.rs b/easytier/src/connector/udp_hole_punch/common.rs index ec45b4f..ce94ba7 100644 --- a/easytier/src/connector/udp_hole_punch/common.rs +++ b/easytier/src/connector/udp_hole_punch/common.rs @@ -495,6 +495,7 @@ impl PunchHoleServerCommon { .udp_nat_type } + #[async_recursion::async_recursion] pub(crate) async fn select_listener( &self, use_new_listener: bool, @@ -515,24 +516,28 @@ impl PunchHoleServerCommon { let mut locked = all_listener_sockets.lock().await; let listener = if use_last { - locked.last_mut()? + Some(locked.last_mut()?) } else { // use the listener that is active most recently locked .iter_mut() - .max_by_key(|listener| listener.last_active_time.load())? + .filter(|l| !l.mapped_addr.ip().is_unspecified()) + .max_by_key(|listener| listener.last_active_time.load()) }; - if listener.mapped_addr.ip().is_unspecified() { - tracing::info!("listener mapped addr is unspecified, trying to get mapped addr"); - listener.mapped_addr = self - .get_global_ctx() - .get_stun_info_collector() - .get_udp_port_mapping(listener.mapped_addr.port()) - .await - .ok()?; + if listener.is_none() || listener.as_ref().unwrap().mapped_addr.ip().is_unspecified() { + tracing::warn!( + ?use_new_listener, + "no available udp hole punching listener with mapped address" + ); + if !use_new_listener { + return self.select_listener(true).await; + } else { + return None; + } } + let listener = listener.unwrap(); Some((listener.get_socket().await, listener.mapped_addr)) } diff --git a/easytier/src/connector/udp_hole_punch/mod.rs b/easytier/src/connector/udp_hole_punch/mod.rs index aabb0f6..010e9ed 100644 --- a/easytier/src/connector/udp_hole_punch/mod.rs +++ b/easytier/src/connector/udp_hole_punch/mod.rs @@ -143,7 +143,7 @@ impl UdpHolePunchRpc for UdpHolePunchServer { } #[derive(Debug)] -struct BackOff { +pub struct BackOff { backoffs_ms: Vec, current_idx: usize, } diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 6abad6f..2adcd8a 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -389,6 +389,15 @@ impl PeerManager { }); } + pub async fn add_direct_tunnel( + &self, + t: Box, + ) -> Result<(PeerId, PeerConnId), Error> { + let (peer_id, conn_id) = self.add_client_tunnel(t).await?; + self.add_directly_connected_conn(peer_id, conn_id); + Ok((peer_id, conn_id)) + } + #[tracing::instrument] pub async fn try_direct_connect( &self, @@ -401,9 +410,7 @@ impl PeerManager { let t = ns .run_async(|| async move { connector.connect().await }) .await?; - let (peer_id, conn_id) = self.add_client_tunnel(t).await?; - self.add_directly_connected_conn(peer_id, conn_id); - Ok((peer_id, conn_id)) + self.add_direct_tunnel(t).await } #[tracing::instrument] diff --git a/easytier/src/peers/peer_rpc_service.rs b/easytier/src/peers/peer_rpc_service.rs index 35c4fc7..8e3b5e7 100644 --- a/easytier/src/peers/peer_rpc_service.rs +++ b/easytier/src/peers/peer_rpc_service.rs @@ -1,9 +1,15 @@ +use std::net::SocketAddr; + use crate::{ common::global_ctx::ArcGlobalCtx, proto::{ - peer_rpc::{DirectConnectorRpc, GetIpListRequest, GetIpListResponse}, + common::Void, + peer_rpc::{ + DirectConnectorRpc, GetIpListRequest, GetIpListResponse, SendV6HolePunchPacketRequest, + }, rpc_types::{self, controller::BaseController}, }, + tunnel::udp, }; #[derive(Clone)] @@ -30,8 +36,42 @@ impl DirectConnectorRpc for DirectConnectorManagerRpcServer { .chain(self.global_ctx.get_running_listeners().into_iter()) .map(Into::into) .collect(); + tracing::trace!( + "get_ip_list: public_ipv4: {:?}, public_ipv6: {:?}, listeners: {:?}", + ret.public_ipv4, + ret.public_ipv6, + ret.listeners + ); Ok(ret) } + + async fn send_v6_hole_punch_packet( + &self, + _: BaseController, + req: SendV6HolePunchPacketRequest, + ) -> rpc_types::error::Result { + let listener_port = req.listener_port as u16; + let SocketAddr::V6(connector_addr) = req + .connector_addr + .ok_or(anyhow::anyhow!("connector_addr is required"))? + .into() + else { + return Err(anyhow::anyhow!("connector_addr is not a v6 address").into()); + }; + + tracing::info!( + "Sending v6 hole punch packet to {} from listener port {}", + connector_addr, + listener_port + ); + + // send 3 packets to the connector + for _ in 0..3 { + udp::send_v6_hole_punch_packet(listener_port, connector_addr).await?; + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + } + Ok(Default::default()) + } } impl DirectConnectorManagerRpcServer { diff --git a/easytier/src/proto/peer_rpc.proto b/easytier/src/proto/peer_rpc.proto index 7f3e2a2..56a0354 100644 --- a/easytier/src/proto/peer_rpc.proto +++ b/easytier/src/proto/peer_rpc.proto @@ -91,8 +91,14 @@ message GetIpListResponse { repeated common.Url listeners = 5; } +message SendV6HolePunchPacketRequest { + common.SocketAddr connector_addr = 1; + uint32 listener_port = 2; +} + service DirectConnectorRpc { rpc GetIpList(GetIpListRequest) returns (GetIpListResponse); + rpc SendV6HolePunchPacket(SendV6HolePunchPacketRequest) returns (common.Void); } message SelectPunchListenerRequest { diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index da018ef..cc39a85 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -177,21 +177,6 @@ pub(crate) trait FromUrl { Self: Sized; } -pub(crate) async fn check_scheme_and_get_socket_addr_ext( - url: &url::Url, - scheme: &str, - ip_version: IpVersion, -) -> Result -where - T: FromUrl, -{ - if url.scheme() != scheme { - return Err(TunnelError::InvalidProtocol(url.scheme().to_string())); - } - - Ok(T::from_url(url.clone(), ip_version).await?) -} - pub(crate) async fn check_scheme_and_get_socket_addr( url: &url::Url, scheme: &str, diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 43956e8..873e70a 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -28,6 +28,15 @@ pub enum UdpPacketType { Data = 3, Fin = 4, HolePunch = 5, + V6HolePunch = 6, // when receiving v6 hole punch packet, the packet contains a socket addr of other peer, we + // will send a hole punch packet to that peer. we only accept this packet from lookback interface. +} + +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct V6HolePunchPacket { + pub dst_ipv6: [u8; 16], + pub dst_port: U16, } #[repr(C, packed)] diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index fbf5b83..20d6344 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -5,7 +5,6 @@ use std::{error::Error, net::SocketAddr, sync::Arc}; use crate::tunnel::{ - check_scheme_and_get_socket_addr_ext, common::{FramedReader, FramedWriter, TunnelWrapper}, TunnelInfo, }; @@ -151,7 +150,7 @@ impl QUICTunnelConnector { impl TunnelConnector for QUICTunnelConnector { async fn connect(&mut self) -> Result, super::TunnelError> { let addr = - check_scheme_and_get_socket_addr_ext::(&self.addr, "quic", self.ip_version) + check_scheme_and_get_socket_addr::(&self.addr, "quic", self.ip_version) .await?; let local_addr = if addr.is_ipv4() { "0.0.0.0:0" diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs index 98e5881..c18e223 100644 --- a/easytier/src/tunnel/tcp.rs +++ b/easytier/src/tunnel/tcp.rs @@ -8,7 +8,7 @@ use super::TunnelInfo; use crate::tunnel::common::setup_sokcet2; use super::{ - check_scheme_and_get_socket_addr, check_scheme_and_get_socket_addr_ext, + check_scheme_and_get_socket_addr, common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper}, IpVersion, Tunnel, TunnelError, TunnelListener, }; @@ -191,7 +191,7 @@ impl TcpTunnelConnector { impl super::TunnelConnector for TcpTunnelConnector { async fn connect(&mut self) -> Result, super::TunnelError> { let addr = - check_scheme_and_get_socket_addr_ext::(&self.addr, "tcp", self.ip_version) + check_scheme_and_get_socket_addr::(&self.addr, "tcp", self.ip_version) .await?; if self.bind_addrs.is_empty() { self.connect_with_default_bind(addr).await diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 73eacd6..8496e6e 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -1,5 +1,6 @@ use std::{ fmt::Debug, + net::{Ipv6Addr, SocketAddrV6}, sync::{Arc, Weak}, }; @@ -9,7 +10,7 @@ use bytes::BytesMut; use dashmap::DashMap; use futures::{stream::FuturesUnordered, StreamExt}; use rand::{Rng, SeedableRng}; -use zerocopy::AsBytes; +use zerocopy::{AsBytes, FromBytes}; use std::net::SocketAddr; use tokio::{ @@ -20,7 +21,7 @@ use tokio::{ use tracing::{instrument, Instrument}; -use super::TunnelInfo; +use super::{packet_def::V6HolePunchPacket, TunnelInfo}; use crate::{ common::{join_joinset_background, scoped_task::ScopedTask}, tunnel::{ @@ -43,7 +44,7 @@ pub const UDP_DATA_MTU: usize = 2000; type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option)>; type UdpCloseEventReceiver = UnboundedReceiver<(SocketAddr, Option)>; -fn new_udp_packet(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket +fn new_udp_packet(f: F, udp_body: Option<&[u8]>) -> ZCPacket where F: FnOnce(&mut UDPTunnelHeader), { @@ -97,6 +98,29 @@ pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket { ) } +pub fn new_v6_hole_punch_packet(dst: &SocketAddrV6) -> ZCPacket { + // generate a 128 bytes vec with random data + let mut body = V6HolePunchPacket::default(); + body.dst_ipv6.copy_from_slice(&dst.ip().octets()); + body.dst_port.set(dst.port()); + new_udp_packet( + |header| { + header.msg_type = UdpPacketType::V6HolePunch as u8; + header.conn_id.set(dst.port() as u32); + header + .len + .set(std::mem::size_of::() as u16); + }, + Some(body.as_bytes()), + ) +} + +fn extrace_dst_addr_from_hole_punch_packet(buf: &[u8]) -> Option { + let body = V6HolePunchPacket::ref_from_prefix(&buf[..])?; + let ip = Ipv6Addr::from(body.dst_ipv6); + Some(SocketAddrV6::new(ip, body.dst_port.get(), 0, 0)) +} + fn is_stun_packet(b: &[u8]) -> bool { // stun has following pattern: // 1. first two bits are 0b00 @@ -104,6 +128,21 @@ fn is_stun_packet(b: &[u8]) -> bool { b[4..8] == [0x21, 0x12, 0xA4, 0x42] && b[0] & 0xC0 == 0 } +pub async fn send_v6_hole_punch_packet( + listener_port: u16, + dst_addr: SocketAddrV6, +) -> Result<(), TunnelError> { + let local_socket = UdpSocket::bind("[::1]:0").await?; + let udp_packet = new_v6_hole_punch_packet(&dst_addr); + let remote_addr = format!("[::1]:{}", listener_port) + .parse::() + .unwrap(); + local_socket + .send_to(&udp_packet.into_bytes(), remote_addr) + .await?; + Ok(()) +} + async fn respond_stun_packet( socket: Arc, addr: SocketAddr, @@ -421,6 +460,27 @@ impl UdpTunnelListenerData { tracing::error!(?e, "udp respond stun packet error"); } }); + } else if header.msg_type == UdpPacketType::V6HolePunch as u8 { + if !addr.ip().is_loopback() { + tracing::warn!(?addr, "v6 hole punch packet should be from loopback"); + return; + } + if !addr.ip().is_ipv6() { + tracing::warn!(?addr, "v6 hole punch packet should be sent from ipv6"); + return; + } + let Some(dst_addr) = extrace_dst_addr_from_hole_punch_packet(zc_packet.udp_payload()) + else { + tracing::warn!("invalid v6 hole punch packet"); + return; + }; + let socket = self.socket.as_ref().unwrap().clone(); + let udp_packet = new_hole_punch_packet(1, 32); + if let Err(e) = socket.try_send_to(&udp_packet.into_bytes(), SocketAddr::V6(dst_addr)) { + tracing::error!(?e, "udp send hole punch packet error"); + } + tracing::debug!(?dst_addr, "udp forward packet send hole punch packet"); + return; } else if header.msg_type != UdpPacketType::HolePunch as u8 { let Some(mut conn) = self.sock_map.get_mut(&addr) else { tracing::trace!(?header, "udp forward packet error, connection not found"); @@ -429,6 +489,8 @@ impl UdpTunnelListenerData { if let Err(e) = conn.handle_packet_from_remote(zc_packet) { tracing::trace!(?e, "udp forward packet error"); } + } else { + tracing::trace!(?header, "udp forward packet ignore hole punch packet"); } } @@ -778,7 +840,7 @@ impl UdpTunnelConnector { #[async_trait] impl super::TunnelConnector for UdpTunnelConnector { async fn connect(&mut self) -> Result, super::TunnelError> { - let addr = super::check_scheme_and_get_socket_addr_ext::( + let addr = super::check_scheme_and_get_socket_addr::( &self.addr, "udp", self.ip_version, @@ -1055,4 +1117,40 @@ mod tests { ) .await; } + + #[tokio::test] + async fn test_v6_hole_punch_packet() { + let mut lis = UdpTunnelListener::new("udp://[::]:0".parse().unwrap()); + lis.listen().await.unwrap(); + + // a socket to receive forwarded hole punch packets + let socket = Arc::new(UdpSocket::bind("[::]:0").await.unwrap()); + let socket_clone = socket.clone(); + let t = tokio::spawn(async move { + let mut buf = BytesMut::new(); + buf.resize(128, 0); + socket_clone.recv_from(&mut buf).await.unwrap(); + }); + + tracing::info!("lis local addr: {:?}", lis.local_url()); + tracing::info!("socket local addr: {:?}", socket.local_addr().unwrap()); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // a socket to send v6 hole punch packets + send_v6_hole_punch_packet( + lis.local_url().port().unwrap(), + match socket.local_addr().unwrap() { + std::net::SocketAddr::V6(addr_v6) => addr_v6, + _ => panic!("Expected an IPv6 address"), + }, + ) + .await + .unwrap(); + + tokio::time::timeout(tokio::time::Duration::from_secs(2), t) + .await + .expect("Timeout waiting for v6 hole punch packet") + .unwrap(); + } } diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index bb27de6..3275b42 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -702,7 +702,7 @@ impl WgTunnelConnector { impl super::TunnelConnector for WgTunnelConnector { #[tracing::instrument] async fn connect(&mut self) -> Result, super::TunnelError> { - let addr = super::check_scheme_and_get_socket_addr_ext::( + let addr = super::check_scheme_and_get_socket_addr::( &self.addr, "wg", self.ip_version,