diff --git a/easytier/proto/cli.proto b/easytier/proto/cli.proto index 84f85fb..971002b 100644 --- a/easytier/proto/cli.proto +++ b/easytier/proto/cli.proto @@ -59,6 +59,9 @@ message StunInfo { NatType udp_nat_type = 1; NatType tcp_nat_type = 2; int64 last_update_time = 3; + repeated string public_ip = 4; + uint32 min_port = 5; + uint32 max_port = 6; } message Route { diff --git a/easytier/src/common/defer.rs b/easytier/src/common/defer.rs new file mode 100644 index 0000000..1132c9b --- /dev/null +++ b/easytier/src/common/defer.rs @@ -0,0 +1,24 @@ +#[doc(hidden)] +pub struct Defer { + // internal struct used by defer! macro + func: Option, +} + +impl Defer { + pub fn new(func: F) -> Self { + Self { func: Some(func) } + } +} + +impl Drop for Defer { + fn drop(&mut self) { + self.func.take().map(|f| f()); + } +} + +#[macro_export] +macro_rules! defer { + ( $($tt:tt)* ) => { + let _deferred = $crate::common::defer::Defer::new(|| { $($tt)* }); + }; +} diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index a7c7a46..1c6012f 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -8,6 +8,7 @@ use tracing::Instrument; pub mod config; pub mod constants; +pub mod defer; pub mod error; pub mod global_ctx; pub mod ifcfg; diff --git a/easytier/src/common/netns.rs b/easytier/src/common/netns.rs index 7ce384a..4433893 100644 --- a/easytier/src/common/netns.rs +++ b/easytier/src/common/netns.rs @@ -75,7 +75,7 @@ impl NetNSGuard { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct NetNS { name: Option, } diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index a296b6d..6a8d9e6 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -1,18 +1,20 @@ -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; +use std::collections::BTreeSet; +use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, Instant}; use crate::rpc::{NatType, StunInfo}; use anyhow::Context; +use chrono::Local; use crossbeam::atomic::AtomicCell; +use rand::seq::IteratorRandom; use tokio::net::{lookup_host, UdpSocket}; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, Mutex}; use tokio::task::JoinSet; -use tracing::Level; +use tracing::{Instrument, Level}; use bytecodec::{DecodeExt, EncodeExt}; use stun_codec::rfc5389::methods::BINDING; -use stun_codec::rfc5780::attributes::ChangeRequest; use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder}; use crate::common::error::Error; @@ -22,13 +24,15 @@ use super::stun_codec_ext::*; struct HostResolverIter { hostnames: Vec, ips: Vec, + max_ip_per_domain: u32, } impl HostResolverIter { - fn new(hostnames: Vec) -> Self { + fn new(hostnames: Vec, max_ip_per_domain: u32) -> Self { Self { hostnames, ips: vec![], + max_ip_per_domain, } } @@ -40,9 +44,17 @@ impl HostResolverIter { } let host = self.hostnames.remove(0); + let host = if host.contains(':') { + host + } else { + format!("{}:3478", host) + }; + match lookup_host(&host).await { Ok(ips) => { - self.ips = ips.collect(); + self.ips = ips + .filter(|x| x.is_ipv4()) + .choose_multiple(&mut rand::thread_rng(), self.max_ip_per_domain as usize); } Err(e) => { tracing::warn!(?host, ?e, "lookup host for stun failed"); @@ -55,19 +67,30 @@ impl HostResolverIter { } } +#[derive(Debug, Clone)] +struct StunPacket { + data: Vec, + addr: SocketAddr, +} + +type StunPacketReceiver = tokio::sync::broadcast::Receiver; + #[derive(Debug, Clone, Copy)] struct BindRequestResponse { - source_addr: SocketAddr, - send_to_addr: SocketAddr, + local_addr: SocketAddr, + stun_server_addr: SocketAddr, + recv_from_addr: SocketAddr, mapped_socket_addr: Option, changed_socket_addr: Option, - ip_changed: bool, - port_changed: bool, + change_ip: bool, + change_port: bool, real_ip_changed: bool, real_port_changed: bool, + + latency_us: u32, } impl BindRequestResponse { @@ -77,18 +100,26 @@ impl BindRequestResponse { } #[derive(Debug, Clone)] -struct Stun { +struct StunClient { stun_server: SocketAddr, - req_repeat: u8, resp_timeout: Duration, + req_repeat: u32, + socket: Arc, + stun_packet_receiver: Arc>, } -impl Stun { - pub fn new(stun_server: SocketAddr) -> Self { +impl StunClient { + pub fn new( + stun_server: SocketAddr, + socket: Arc, + stun_packet_receiver: StunPacketReceiver, + ) -> Self { Self { stun_server, - req_repeat: 2, resp_timeout: Duration::from_millis(3000), + req_repeat: 2, + socket, + stun_packet_receiver: Arc::new(Mutex::new(stun_packet_receiver)), } } @@ -96,7 +127,6 @@ impl Stun { async fn wait_stun_response<'a, const N: usize>( &self, buf: &'a mut [u8; N], - udp: &UdpSocket, tids: &Vec, expected_ip_changed: bool, expected_port_changed: bool, @@ -106,16 +136,20 @@ impl Stun { let deadline = now + self.resp_timeout; while now < deadline { - let mut udp_buf = [0u8; 1500]; - let (len, remote_addr) = - tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice())) - .await??; + let mut locked_receiver = self.stun_packet_receiver.lock().await; + let stun_packet_raw = tokio::time::timeout(deadline - now, locked_receiver.recv()) + .await? + .with_context(|| "recv stun packet from broadcast channel error")?; now = tokio::time::Instant::now(); + let (len, remote_addr) = (stun_packet_raw.data.len(), stun_packet_raw.addr); + if len < 20 { continue; } + let udp_buf = stun_packet_raw.data; + // TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) }; @@ -136,18 +170,6 @@ impl Stun { continue; } - // some stun server use changed socket even we don't ask for. - if expected_ip_changed && stun_host.ip() == remote_addr.ip() { - continue; - } - - if expected_port_changed - && stun_host.ip() == remote_addr.ip() - && stun_host.port() == remote_addr.port() - { - continue; - } - return Ok((msg, remote_addr)); } @@ -196,16 +218,14 @@ impl Stun { #[tracing::instrument(ret, err, level = Level::DEBUG)] pub async fn bind_request( - &self, - source_port: u16, + self, change_ip: bool, change_port: bool, ) -> Result { let stun_host = self.stun_server; - let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?; - // repeat req in case of packet loss let mut tids = vec![]; + for _ in 0..self.req_repeat { let tid = rand::random::(); // let tid = 1; @@ -222,16 +242,19 @@ impl Stun { let msg = encoder .encode_into_bytes(message.clone()) .with_context(|| "encode stun message")?; - tids.push(tid as u128); - tracing::trace!(?message, ?msg, tid, "send stun request"); - udp.send_to(msg.as_slice().into(), &stun_host).await?; + tracing::debug!(?message, ?msg, tid, "send stun request"); + self.socket + .send_to(msg.as_slice().into(), &stun_host) + .await?; } + let now = Instant::now(); + tracing::trace!("waiting stun response"); let mut buf = [0; 1620]; let (msg, recv_addr) = self - .wait_stun_response(&mut buf, &udp, &tids, change_ip, change_port, &stun_host) + .wait_stun_response(&mut buf, &tids, change_ip, change_port, &stun_host) .await?; let changed_socket_addr = Self::extract_changed_addr(&msg); @@ -239,16 +262,18 @@ impl Stun { let real_port_changed = stun_host.port() != recv_addr.port(); let resp = BindRequestResponse { - source_addr: udp.local_addr()?, - send_to_addr: stun_host, + local_addr: self.socket.local_addr()?, + stun_server_addr: stun_host, recv_from_addr: recv_addr, mapped_socket_addr: Self::extrace_mapped_addr(&msg), changed_socket_addr, - ip_changed: change_ip, - port_changed: change_port, + change_ip, + change_port, real_ip_changed, real_port_changed, + + latency_us: now.elapsed().as_micros() as u32, }; tracing::debug!( @@ -262,105 +287,256 @@ impl Stun { } } -pub struct UdpNatTypeDetector { - stun_servers: Vec, +struct StunClientBuilder { + udp: Arc, + task_set: JoinSet<()>, + stun_packet_sender: broadcast::Sender, } -impl UdpNatTypeDetector { - pub fn new(stun_servers: Vec) -> Self { - Self { stun_servers } - } +impl StunClientBuilder { + pub fn new(udp: Arc) -> Self { + let (stun_packet_sender, _) = broadcast::channel(1024); + let mut task_set = JoinSet::new(); - pub async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType { - // Like classic STUN (rfc3489). Detect NAT behavior for UDP. - // Modified from rfc3489. Requires at least two STUN servers. - let mut ret_test1_1 = None; - let mut ret_test1_2 = None; - let mut ret_test2 = None; - let mut ret_test3 = None; - - if source_port == 0 { - let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap(); - source_port = udp.local_addr().unwrap().port(); - } - - let mut succ = false; - let mut ips = HostResolverIter::new(self.stun_servers.clone()); - while let Some(server_ip) = ips.next().await { - let stun = Stun::new(server_ip.clone()); - let ret = stun.bind_request(source_port, false, false).await; - if ret.is_err() { - // Try another STUN server - continue; - } - if ret_test1_1.is_none() { - ret_test1_1 = ret.ok(); - continue; - } - ret_test1_2 = ret.ok(); - let ret = stun.bind_request(source_port, true, true).await; - if let Ok(resp) = ret { - if !resp.real_ip_changed || !resp.real_port_changed { - tracing::debug!( - ?server_ip, - ?ret, - "stun bind request return with unchanged ip and port" - ); - // Try another STUN server - continue; + let udp_clone = udp.clone(); + let stun_packet_sender_clone = stun_packet_sender.clone(); + task_set.spawn( + async move { + let mut buf = [0; 1620]; + tracing::info!("start stun packet listener"); + loop { + let Ok((len, addr)) = udp_clone.recv_from(&mut buf).await else { + tracing::error!("udp recv_from error"); + break; + }; + let data = buf[..len].to_vec(); + tracing::debug!(?addr, ?data, "recv udp stun packet"); + let _ = stun_packet_sender_clone.send(StunPacket { data, addr }); } } - ret_test2 = ret.ok(); - ret_test3 = stun.bind_request(source_port, false, true).await.ok(); - tracing::debug!(?ret_test3, "stun bind request with changed port"); - succ = true; - break; - } + .instrument(tracing::info_span!("stun_packet_listener")), + ); - if !succ { + Self { + udp, + task_set, + stun_packet_sender, + } + } + + pub fn new_stun_client(&self, stun_server: SocketAddr) -> StunClient { + StunClient::new( + stun_server, + self.udp.clone(), + self.stun_packet_sender.subscribe(), + ) + } + + pub async fn stop(&mut self) { + self.task_set.abort_all(); + while let Some(_) = self.task_set.join_next().await {} + } +} + +#[derive(Debug, Clone)] +pub struct UdpNatTypeDetectResult { + source_addr: SocketAddr, + stun_resps: Vec, +} + +impl UdpNatTypeDetectResult { + fn new(source_addr: SocketAddr, stun_resps: Vec) -> Self { + Self { + source_addr, + stun_resps, + } + } + + fn has_ip_changed_resp(&self) -> bool { + for resp in self.stun_resps.iter() { + if resp.real_ip_changed { + return true; + } + } + false + } + + fn has_port_changed_resp(&self) -> bool { + for resp in self.stun_resps.iter() { + if resp.real_port_changed { + return true; + } + } + false + } + + fn is_open_internet(&self) -> bool { + for resp in self.stun_resps.iter() { + if resp.mapped_socket_addr == Some(self.source_addr) { + return true; + } + } + return false; + } + + fn is_pat(&self) -> bool { + for resp in self.stun_resps.iter() { + if resp.mapped_socket_addr.map(|x| x.port()) == Some(self.source_addr.port()) { + return true; + } + } + false + } + + fn stun_server_count(&self) -> usize { + // find resp with distinct stun server + self.stun_resps + .iter() + .map(|x| x.stun_server_addr) + .collect::>() + .len() + } + + fn is_cone(&self) -> bool { + // if unique mapped addr count is less than stun server count, it is cone + let mapped_addr_count = self + .stun_resps + .iter() + .filter_map(|x| x.mapped_socket_addr) + .collect::>() + .len(); + mapped_addr_count < self.stun_server_count() + } + + pub fn nat_type(&self) -> NatType { + if self.stun_server_count() < 2 { return NatType::Unknown; } - tracing::debug!( - ?ret_test1_1, - ?ret_test1_2, - ?ret_test2, - ?ret_test3, - "finish stun test, try to detect nat type" - ); - - let ret_test1_1 = ret_test1_1.unwrap(); - let ret_test1_2 = ret_test1_2.unwrap(); - - if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr { - return NatType::Symmetric; - } - - if ret_test1_1.mapped_socket_addr.is_some() - && ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap() - { - if !ret_test2.is_none() { - return NatType::OpenInternet; - } else { - return NatType::SymUdpFirewall; - } - } else { - if let Some(ret_test2) = ret_test2 { - if source_port == ret_test2.get_mapped_addr_no_check().port() - && source_port == ret_test1_1.get_mapped_addr_no_check().port() - { + if self.is_cone() { + if self.has_ip_changed_resp() { + if self.is_open_internet() { + return NatType::OpenInternet; + } else if self.is_pat() { return NatType::NoPat; } else { return NatType::FullCone; } + } else if self.has_port_changed_resp() { + return NatType::Restricted; } else { - if !ret_test3.is_none() { - return NatType::Restricted; - } else { - return NatType::PortRestricted; - } + return NatType::PortRestricted; + } + } else if !self.stun_resps.is_empty() { + return NatType::Symmetric; + } else { + return NatType::Unknown; + } + } + + pub fn public_ips(&self) -> Vec { + self.stun_resps + .iter() + .filter_map(|x| x.mapped_socket_addr.map(|x| x.ip())) + .collect::>() + .into_iter() + .collect() + } + + pub fn collect_available_stun_server(&self) -> Vec { + let mut ret = vec![]; + for resp in self.stun_resps.iter() { + if !ret.contains(&resp.stun_server_addr) { + ret.push(resp.stun_server_addr); } } + ret + } + + pub fn local_addr(&self) -> SocketAddr { + self.source_addr + } + + pub fn extend_result(&mut self, other: UdpNatTypeDetectResult) { + self.stun_resps.extend(other.stun_resps); + } + + pub fn min_port(&self) -> u16 { + self.stun_resps + .iter() + .filter_map(|x| x.mapped_socket_addr.map(|x| x.port())) + .min() + .unwrap_or(0) + } + + pub fn max_port(&self) -> u16 { + self.stun_resps + .iter() + .filter_map(|x| x.mapped_socket_addr.map(|x| x.port())) + .max() + .unwrap_or(u16::MAX) + } +} + +pub struct UdpNatTypeDetector { + stun_server_hosts: Vec, + max_ip_per_domain: u32, +} + +impl UdpNatTypeDetector { + pub fn new(stun_server_hosts: Vec, max_ip_per_domain: u32) -> Self { + Self { + stun_server_hosts, + max_ip_per_domain, + } + } + + pub async fn detect_nat_type(&self, source_port: u16) -> Result { + let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?); + self.detect_nat_type_with_socket(udp).await + } + + #[tracing::instrument(skip(self))] + pub async fn detect_nat_type_with_socket( + &self, + udp: Arc, + ) -> Result { + let mut stun_servers = vec![]; + let mut host_resolver = + HostResolverIter::new(self.stun_server_hosts.clone(), self.max_ip_per_domain); + while let Some(addr) = host_resolver.next().await { + stun_servers.push(addr); + } + + let client_builder = StunClientBuilder::new(udp.clone()); + let mut stun_task_set = JoinSet::new(); + + for stun_server in stun_servers.iter() { + stun_task_set.spawn( + client_builder + .new_stun_client(*stun_server) + .bind_request(false, false), + ); + stun_task_set.spawn( + client_builder + .new_stun_client(*stun_server) + .bind_request(false, true), + ); + stun_task_set.spawn( + client_builder + .new_stun_client(*stun_server) + .bind_request(true, true), + ); + } + + let mut bind_resps = vec![]; + while let Some(resp) = stun_task_set.join_next().await { + if let Ok(Ok(resp)) = resp { + bind_resps.push(resp); + } + } + + Ok(UdpNatTypeDetectResult::new(udp.local_addr()?, bind_resps)) } } @@ -373,7 +549,8 @@ pub trait StunInfoCollectorTrait: Send + Sync { pub struct StunInfoCollector { stun_servers: Arc>>, - udp_nat_type: Arc>, + udp_nat_test_result: Arc>>, + nat_test_result_time: Arc>>, redetect_notify: Arc, tasks: JoinSet<()>, } @@ -381,27 +558,47 @@ pub struct StunInfoCollector { #[async_trait::async_trait] impl StunInfoCollectorTrait for StunInfoCollector { fn get_stun_info(&self) -> StunInfo { - let (typ, time) = self.udp_nat_type.load(); + let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else { + return Default::default(); + }; StunInfo { - udp_nat_type: typ as i32, + udp_nat_type: result.nat_type() as i32, tcp_nat_type: 0, - last_update_time: time.elapsed().as_secs() as i64, + last_update_time: self.nat_test_result_time.load().timestamp(), + public_ip: result.public_ips().iter().map(|x| x.to_string()).collect(), + min_port: result.min_port() as u32, + max_port: result.max_port() as u32, } } async fn get_udp_port_mapping(&self, local_port: u16) -> Result { - let stun_servers = self.stun_servers.read().await.clone(); - let mut ips = HostResolverIter::new(stun_servers.clone()); - while let Some(server) = ips.next().await { - let stun = Stun::new(server.clone()); - let Ok(ret) = stun.bind_request(local_port, false, false).await else { + let stun_servers = self + .udp_nat_test_result + .read() + .unwrap() + .clone() + .map(|x| x.collect_available_stun_server()) + .ok_or(Error::NotFound)?; + + let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", local_port)).await?); + let mut client_builder = StunClientBuilder::new(udp.clone()); + + for server in stun_servers.iter() { + let Ok(ret) = client_builder + .new_stun_client(*server) + .bind_request(false, false) + .await + else { tracing::warn!(?server, "stun bind request failed"); continue; }; if let Some(mapped_addr) = ret.mapped_socket_addr { + // make sure udp socket is available after return ok. + client_builder.stop().await; return Ok(mapped_addr); } } + Err(Error::NotFound) } } @@ -410,10 +607,8 @@ impl StunInfoCollector { pub fn new(stun_servers: Vec) -> Self { let mut ret = Self { stun_servers: Arc::new(RwLock::new(stun_servers)), - udp_nat_type: Arc::new(AtomicCell::new(( - NatType::Unknown, - std::time::Instant::now(), - ))), + udp_nat_test_result: Arc::new(RwLock::new(None)), + nat_test_result_time: Arc::new(AtomicCell::new(Local::now())), redetect_notify: Arc::new(tokio::sync::Notify::new()), tasks: JoinSet::new(), }; @@ -431,46 +626,78 @@ impl StunInfoCollector { // NOTICE: we may need to choose stun stun server based on geo location // stun server cross nation may return a external ip address with high latency and loss rate vec![ - "stun.miwifi.com:3478".to_string(), - "stun.chat.bilibili.com:3478".to_string(), // bilibili's stun server doesn't repond to change_ip and change_port - "stun.cloudflare.com:3478".to_string(), - "stun.syncthing.net:3478".to_string(), - "stun.isp.net.au:3478".to_string(), - "stun.nextcloud.com:3478".to_string(), - "stun.freeswitch.org:3478".to_string(), - "stun.voip.blackberry.com:3478".to_string(), - "stunserver.stunprotocol.org:3478".to_string(), - "stun.sipnet.com:3478".to_string(), - "stun.radiojar.com:3478".to_string(), - "stun.sonetel.com:3478".to_string(), - "stun.voipgate.com:3478".to_string(), + "stun.miwifi.com", + "stun.cdnbye.com", + "stun.hitv.com", + "stun.chat.bilibili.com", + "stun.douyucdn.cn:18000", + "fwa.lifesizecloud.com", + "global.turn.twilio.com", + "turn.cloudflare.com", + "stun.isp.net.au", + "stun.nextcloud.com", + "stun.freeswitch.org", + "stun.voip.blackberry.com", + "stunserver.stunprotocol.org", + "stun.sipnet.com", + "stun.radiojar.com", + "stun.sonetel.com", ] + .iter() + .map(|x| x.to_string()) + .collect() } fn start_stun_routine(&mut self) { let stun_servers = self.stun_servers.clone(); - let udp_nat_type = self.udp_nat_type.clone(); + let udp_nat_test_result = self.udp_nat_test_result.clone(); + let udp_test_time = self.nat_test_result_time.clone(); let redetect_notify = self.redetect_notify.clone(); self.tasks.spawn(async move { loop { - let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone()); - let old_nat_type = udp_nat_type.load().0; - let mut ret = NatType::Unknown; - for _ in 1..5 { - // if nat type degrade, sleep and retry. so result can be relatively stable. - ret = detector.get_udp_nat_type(0).await; - if ret == NatType::Unknown || ret <= old_nat_type { - break; + let servers = stun_servers.read().unwrap().clone(); + // use first three and random choose one from the rest + let servers = servers + .iter() + .take(2) + .chain(servers.iter().skip(2).choose(&mut rand::thread_rng())) + .map(|x| x.to_string()) + .collect(); + let detector = UdpNatTypeDetector::new(servers, 1); + let ret = detector.detect_nat_type(0).await; + tracing::debug!(?ret, "finish udp nat type detect"); + let mut nat_type = NatType::Unknown; + let sleep_sec = match &ret { + Ok(resp) => { + *udp_nat_test_result.write().unwrap() = Some(resp.clone()); + udp_test_time.store(Local::now()); + nat_type = resp.nat_type(); + if nat_type == NatType::Unknown { + 15 + } else { + 600 + } } - tokio::time::sleep(Duration::from_secs(5)).await; - } - udp_nat_type.store((ret, std::time::Instant::now())); - - let sleep_sec = match ret { - NatType::Unknown => 15, - _ => 60, + _ => 15, }; - tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect"); + + // if nat type is symmtric, detect with another port to gather more info + if nat_type == NatType::Symmetric { + let old_resp = ret.unwrap(); + let old_local_port = old_resp.local_addr().port(); + let new_port = if old_local_port >= 65535 { + old_local_port - 1 + } else { + old_local_port + 1 + }; + let ret = detector.detect_nat_type(new_port).await; + tracing::debug!(?ret, "finish udp nat type detect with another port"); + if let Ok(resp) = ret { + udp_nat_test_result.write().unwrap().as_mut().map(|x| { + x.extend_result(resp); + }); + } + } tokio::select! { _ = redetect_notify.notified() => {} @@ -483,53 +710,26 @@ impl StunInfoCollector { pub fn update_stun_info(&self) { self.redetect_notify.notify_one(); } - - pub async fn set_stun_servers(&self, stun_servers: Vec) { - *self.stun_servers.write().await = stun_servers; - self.update_stun_info(); - } } #[cfg(test)] mod tests { use super::*; - pub fn enable_log() { - let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) - .from_env() - .unwrap() - .add_directive("tarpc=error".parse().unwrap()); - tracing_subscriber::fmt::fmt() - .pretty() - .with_env_filter(filter) - .init(); - } - #[tokio::test] - async fn test_stun_bind_request() { - // miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port. - // let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]); - let mut ips_ = HostResolverIter::new(vec!["stun.canets.org:3478".to_string()]); - let mut ips = vec![]; - while let Some(ip) = ips_.next().await { - ips.push(ip); + async fn test_udp_nat_type_detector() { + let collector = StunInfoCollector::new_with_default_servers(); + collector.update_stun_info(); + loop { + let ret = collector.get_stun_info(); + if ret.udp_nat_type != NatType::Unknown as i32 { + println!("{:#?}", ret); + break; + } + tokio::time::sleep(Duration::from_secs(1)).await; } - println!("ip: {:?}", ips); - for ip in ips.iter() { - let stun = Stun::new(ip.clone()); - let _rs = stun.bind_request(12345, true, true).await; - } - } - #[tokio::test] - async fn test_udp_nat_type_detect() { - let detector = UdpNatTypeDetector::new(vec![ - "stun.counterpath.com:3478".to_string(), - "180.235.108.91:3478".to_string(), - ]); - let ret = detector.get_udp_nat_type(0).await; - - assert_ne!(ret, NatType::Unknown); + let port_mapping = collector.get_udp_port_mapping(3000).await; + println!("{:#?}", port_mapping); } } diff --git a/easytier/src/common/stun_codec_ext.rs b/easytier/src/common/stun_codec_ext.rs index ca7750b..3ea28b5 100644 --- a/easytier/src/common/stun_codec_ext.rs +++ b/easytier/src/common/stun_codec_ext.rs @@ -1,11 +1,12 @@ use std::net::SocketAddr; +use bytecodec::fixnum::{U32beDecoder, U32beEncoder}; use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder}; use stun_codec::rfc5389::attributes::{ MappedAddress, Software, XorMappedAddress, XorMappedAddress2, }; -use stun_codec::rfc5780::attributes::{ChangeRequest, OtherAddress, ResponseOrigin}; +use stun_codec::rfc5780::attributes::{OtherAddress, ResponseOrigin}; use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId}; use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode}; @@ -197,6 +198,75 @@ impl_encode!(SourceAddressEncoder, SourceAddress, |item: Self::Item| { item.0 }); +/// `CHANGE-REQUEST` attribute. +/// +/// See [RFC 5780 -- 7.2. CHANGE-REQUEST] about this attribute. +/// +/// [RFC 5780 -- 7.2. CHANGE-REQUEST]: https://tools.ietf.org/html/rfc5780#section-7.2 +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ChangeRequest(bool, bool); + +impl ChangeRequest { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x0003; + + /// Makes a new `ChangeRequest` instance. + pub fn new(ip: bool, port: bool) -> Self { + ChangeRequest(ip, port) + } + + /// Returns whether the client requested the server to send the Binding Response with a + /// different IP address than the one the Binding Request was received on + pub fn ip(&self) -> bool { + self.0 + } + + /// Returns whether the client requested the server to send the Binding Response with a + /// different port than the one the Binding Request was received on + pub fn port(&self) -> bool { + self.1 + } +} + +impl stun_codec::Attribute for ChangeRequest { + type Decoder = ChangeRequestDecoder; + type Encoder = ChangeRequestEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } +} + +/// [`ChangeRequest`] decoder. +#[derive(Debug, Default)] +pub struct ChangeRequestDecoder(U32beDecoder); + +impl ChangeRequestDecoder { + /// Makes a new `ChangeRequestDecoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_decode!(ChangeRequestDecoder, ChangeRequest, |item| { + Ok(ChangeRequest((item & 0x4) != 0, (item & 0x2) != 0)) +}); + +/// [`ChangeRequest`] encoder. +#[derive(Debug, Default)] +pub struct ChangeRequestEncoder(U32beEncoder); + +impl ChangeRequestEncoder { + /// Makes a new `ChangeRequestEncoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_encode!(ChangeRequestEncoder, ChangeRequest, |item: Self::Item| { + let ip = item.0 as u8; + let port = item.1 as u8; + ((ip << 1 | port) << 1) as u32 +}); + pub fn tid_to_u128(tid: &TransactionId) -> u128 { let mut tid_buf = [0u8; 16]; // copy bytes from msg_tid to tid_buf diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index 79eeec4..f410f83 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -1,20 +1,35 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; use anyhow::Context; use crossbeam::atomic::AtomicCell; -use rand::{seq::SliceRandom, SeedableRng}; -use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; -use tracing::Instrument; +use dashmap::{DashMap, DashSet}; +use rand::{seq::SliceRandom, Rng}; +use tokio::{ + net::UdpSocket, + sync::{Mutex, Notify}, + task::JoinSet, +}; +use tracing::{instrument, Instrument, Level}; +use zerocopy::FromBytes; use crate::{ common::{ - constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, + constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, stun::StunInfoCollectorTrait, PeerId, }, + defer, peers::peer_manager::PeerManager, rpc::NatType, tunnel::{ common::setup_sokcet2, + packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE}, udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener, }, @@ -22,9 +37,168 @@ use crate::{ use super::direct::PeerManagerForDirectConnector; +const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16; + +fn generate_shuffled_port_vec() -> Vec { + let mut rng = rand::thread_rng(); + let mut port_vec: Vec = (1..=65535).collect(); + port_vec.shuffle(&mut rng); + port_vec +} + +// used for symmetric hole punching, binding to multiple ports to increase the chance of success +struct UdpSocketArray { + sockets: Arc>>, + max_socket_count: usize, + net_ns: NetNS, + tasks: Arc>>, + + intreast_tids: Arc>, + tid_to_socket: Arc>>>, +} + +impl UdpSocketArray { + pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self { + let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); + join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned()); + + Self { + sockets: Arc::new(DashMap::new()), + max_socket_count, + net_ns, + tasks, + + intreast_tids: Arc::new(DashSet::new()), + tid_to_socket: Arc::new(DashMap::new()), + } + } + + pub fn started(&self) -> bool { + !self.sockets.is_empty() + } + + async fn add_new_socket(&self) -> Result<(), anyhow::Error> { + let socket = { + let _g = self.net_ns.guard(); + Arc::new(UdpSocket::bind("0.0.0.0:0").await?) + }; + let local_addr = socket.local_addr()?; + self.sockets.insert(local_addr, socket.clone()); + + let intreast_tids = self.intreast_tids.clone(); + let tid_to_socket = self.tid_to_socket.clone(); + self.tasks.lock().unwrap().spawn( + async move { + let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize]; + tracing::trace!(?local_addr, "udp socket added"); + loop { + let Ok((len, addr)) = socket.recv_from(&mut buf).await else { + break; + }; + + tracing::debug!(?len, ?addr, "got raw packet"); + + if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize { + continue; + } + + let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else { + continue; + }; + + tracing::debug!(?p, ?addr, "got udp hole punch packet"); + + if p.msg_type != UdpPacketType::HolePunch as u8 + || p.len.get() != HOLE_PUNCH_PACKET_BODY_LEN + { + continue; + } + + let tid = p.conn_id.get(); + if intreast_tids.contains(&tid) { + tracing::info!(?addr, "got hole punching packet with intreast tid"); + tid_to_socket + .entry(tid) + .or_insert_with(Vec::new) + .push(socket); + break; + } + } + tracing::debug!(?local_addr, "udp socket recv loop end"); + } + .instrument(tracing::info_span!("udp array socket recv loop")), + ); + Ok(()) + } + + #[instrument(err)] + pub async fn start(&self) -> Result<(), anyhow::Error> { + if self.started() { + return Ok(()); + } + + tracing::info!("starting udp socket array"); + + while self.sockets.len() < self.max_socket_count { + self.add_new_socket().await?; + } + + Ok(()) + } + + #[instrument(err)] + pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> { + tracing::info!(?addr, "sending hole punching packet"); + + for socket in self.sockets.iter() { + let socket = socket.value(); + socket.send_to(data, addr).await?; + } + + Ok(()) + } + + #[instrument(ret(level = Level::DEBUG))] + pub fn try_fetch_punched_socket(&self, tid: u32) -> Option> { + tracing::debug!(?tid, "try fetch punched socket"); + self.tid_to_socket.get_mut(&tid)?.value_mut().pop() + } + + pub fn add_intreast_tid(&self, tid: u32) { + self.intreast_tids.insert(tid); + } + + pub fn remove_intreast_tid(&self, tid: u32) { + self.intreast_tids.remove(&tid); + self.tid_to_socket.remove(&tid); + } +} + +impl std::fmt::Debug for UdpSocketArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UdpSocketArray") + .field("sockets", &self.sockets.len()) + .field("max_socket_count", &self.max_socket_count) + .field("started", &self.started()) + .field("intreast_tids", &self.intreast_tids.len()) + .field("tid_to_socket", &self.tid_to_socket.len()) + .finish() + } +} + #[tarpc::service] pub trait UdpHolePunchService { async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option; + async fn try_punch_symmetric( + listener_addr: SocketAddr, + port: u16, + public_ips: Vec, + min_port: u16, + max_port: u16, + transaction_id: u32, + round: u32, + last_port_index: usize, + ) -> Option; } #[derive(Debug)] @@ -37,7 +211,7 @@ struct UdpHolePunchListener { listen_time: std::time::Instant, last_select_time: AtomicCell, - last_connected_time: Arc>, + last_active_time: Arc>, } impl UdpHolePunchListener { @@ -46,6 +220,7 @@ impl UdpHolePunchListener { Ok(socket.local_addr()?.port()) } + #[instrument(err)] pub async fn new(peer_mgr: Arc) -> Result { let port = Self::get_avail_port().await?; let listen_url = format!("udp://0.0.0.0:{}", port); @@ -65,15 +240,11 @@ impl UdpHolePunchListener { let running = Arc::new(AtomicCell::new(true)); let running_clone = running.clone(); - let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now())); - let last_connected_time_clone = last_connected_time.clone(); - let conn_counter = listener.get_conn_counter(); let mut tasks = JoinSet::new(); tasks.spawn(async move { while let Ok(conn) = listener.accept().await { - last_connected_time_clone.store(std::time::Instant::now()); tracing::warn!(?conn, "udp hole punching listener got peer connection"); let peer_mgr = peer_mgr.clone(); tokio::spawn(async move { @@ -89,6 +260,18 @@ impl UdpHolePunchListener { running_clone.store(false); }); + let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now())); + let conn_counter_clone = conn_counter.clone(); + let last_active_time_clone = last_active_time.clone(); + tasks.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + if conn_counter_clone.get() != 0 { + last_active_time_clone.store(std::time::Instant::now()); + } + } + }); + tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started"); Ok(Self { @@ -100,7 +283,7 @@ impl UdpHolePunchListener { listen_time: std::time::Instant::now(), last_select_time: AtomicCell::new(std::time::Instant::now()), - last_connected_time, + last_active_time, }) } @@ -110,11 +293,34 @@ impl UdpHolePunchListener { } } -#[derive(Debug)] struct UdpHolePunchConnectorData { global_ctx: ArcGlobalCtx, peer_mgr: Arc, listeners: Arc>>, + shuffled_port_vec: Arc>, + + udp_array: Arc>>>, + try_direct_connect: AtomicBool, + punch_predicablely: AtomicBool, + punch_randomly: AtomicBool, + udp_array_size: AtomicUsize, +} + +impl std::fmt::Debug for UdpHolePunchConnectorData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // print peer id listener count + let peer_id = self.peer_mgr.my_peer_id(); + f.debug_struct("UdpHolePunchConnectorData") + .field("peer_id", &peer_id) + .finish() + } +} + +impl UdpHolePunchConnectorData { + fn my_nat_type(&self) -> NatType { + let stun_info = self.global_ctx.get_stun_info_collector().get_stun_info(); + NatType::try_from(stun_info.udp_nat_type).unwrap() + } } #[derive(Clone)] @@ -126,41 +332,146 @@ struct UdpHolePunchRpcServer { #[tarpc::server] impl UdpHolePunchService for UdpHolePunchRpcServer { + #[tracing::instrument(skip(self))] async fn try_punch_hole( self, _: tarpc::context::Context, local_mapped_addr: SocketAddr, ) -> Option { - let (socket, mapped_addr) = self.select_listener().await?; + // local mapped addr will be unspecified if peer is symmetric + let peer_is_symmetric = local_mapped_addr.ip().is_unspecified(); + let (socket, mapped_addr) = self.select_listener(peer_is_symmetric).await?; tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching"); - let my_udp_nat_type = self - .data - .global_ctx - .get_stun_info_collector() - .get_stun_info() - .udp_nat_type; + if !peer_is_symmetric { + let my_udp_nat_type = self + .data + .global_ctx + .get_stun_info_collector() + .get_stun_info() + .udp_nat_type; - // if we are restricted, we need to send hole punching resp to client - if my_udp_nat_type == NatType::PortRestricted as i32 - || my_udp_nat_type == NatType::Restricted as i32 - { - // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second - self.tasks.lock().unwrap().spawn(async move { - for _ in 0..10 { - tracing::info!(?local_mapped_addr, "sending hole punching packet"); + // if we are cone, we need to send hole punching resp to client + if my_udp_nat_type == NatType::PortRestricted as i32 + || my_udp_nat_type == NatType::Restricted as i32 + || my_udp_nat_type == NatType::FullCone as i32 + { + let notifier = Arc::new(Notify::new()); - let udp_packet = new_hole_punch_packet(); - let _ = socket - .send_to(&udp_packet.into_bytes(), local_mapped_addr) - .await; - tokio::time::sleep(std::time::Duration::from_millis(300)).await; - } - }); + let n = notifier.clone(); + // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second + self.tasks.lock().unwrap().spawn(async move { + for i in 0..10 { + tracing::info!(?local_mapped_addr, "sending hole punching packet"); + + let udp_packet = new_hole_punch_packet(100, HOLE_PUNCH_PACKET_BODY_LEN); + let _ = socket + .send_to(&udp_packet.into_bytes(), local_mapped_addr) + .await; + let sleep_ms = if i < 4 { 10 } else { 500 }; + tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await; + if i == 3 { + n.notify_one(); + } + } + }); + + notifier.notified().await; + } } Some(mapped_addr) } + + #[instrument(skip(self))] + async fn try_punch_symmetric( + self, + _: tarpc::context::Context, + listener_addr: SocketAddr, + port: u16, + public_ips: Vec, + mut min_port: u16, + mut max_port: u16, + transaction_id: u32, + round: u32, + last_port_index: usize, + ) -> Option { + tracing::info!("try_punch_symmetric start"); + + let punch_predictablely = self.data.punch_predicablely.load(Ordering::Relaxed); + let punch_randomly = self.data.punch_randomly.load(Ordering::Relaxed); + let total_port_count = self.data.shuffled_port_vec.len(); + let listener = self.find_listener(&listener_addr).await?; + let ip_count = public_ips.len(); + if ip_count == 0 { + tracing::warn!("try_punch_symmetric got zero len public ip"); + return None; + } + + min_port = std::cmp::max(1, min_port); + if max_port == 0 { + max_port = u16::MAX; + } + if max_port < min_port { + std::mem::swap(&mut min_port, &mut max_port); + } + + // send max k1 packets if we are predicting the dst port + let max_k1 = 180; + // send max k2 packets if we are sending to random port + let max_k2 = rand::thread_rng().gen_range(600..800); + + // this means the NAT is allocating port in a predictable way + if max_port.abs_diff(min_port) <= max_k1 && round <= 6 && punch_predictablely { + let (min_port, max_port) = { + // round begin from 0. if round is even, we guess port in increasing order + let port_delta = (max_k1 as u32) / ip_count as u32; + let port_diff_for_min = std::cmp::min((round / 2) * port_delta, u16::MAX as u32); + if round % 2 == 0 { + let lower = std::cmp::max(1, port.saturating_add(port_diff_for_min as u16)); + let upper = lower.saturating_add(port_delta as u16); + (lower, upper) + } else { + let upper = std::cmp::max(1, port.saturating_sub(port_diff_for_min as u16)); + let lower = std::cmp::max(1, upper.saturating_sub(port_delta as u16)); + (lower, upper) + } + }; + let mut ports = (min_port..=max_port).collect::>(); + ports.push(max_port); + ports.shuffle(&mut rand::thread_rng()); + self.send_symmetric_hole_punch_packet( + listener.clone(), + transaction_id, + &public_ips, + &ports, + ) + .await + .ok()?; + } + + if punch_randomly { + let start = last_port_index % total_port_count; + let diff = std::cmp::max(10, max_k2 / ip_count); + let end = std::cmp::min(start + diff, self.data.shuffled_port_vec.len()); + self.send_symmetric_hole_punch_packet( + listener.clone(), + transaction_id, + &public_ips, + &self.data.shuffled_port_vec[start..end], + ) + .await + .ok()?; + + return if end >= self.data.shuffled_port_vec.len() { + Some(1) + } else { + Some(end) + }; + } + + return Some(1); + } } impl UdpHolePunchRpcServer { @@ -170,17 +481,30 @@ impl UdpHolePunchRpcServer { Self { data, tasks } } - async fn select_listener(&self) -> Option<(Arc, SocketAddr)> { + async fn find_listener(&self, addr: &SocketAddr) -> Option> { + let all_listener_sockets = self.data.listeners.lock().await; + + let listener = all_listener_sockets + .iter() + .find(|listener| listener.mapped_addr == *addr && listener.running.load())?; + + Some(listener.get_socket().await) + } + + async fn select_listener( + &self, + use_new_listener: bool, + ) -> Option<(Arc, SocketAddr)> { let all_listener_sockets = &self.data.listeners; - // remove listener that not have connection in for 20 seconds + // remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds all_listener_sockets.lock().await.retain(|listener| { - listener.last_connected_time.load().elapsed().as_secs() < 20 - && listener.conn_counter.get() > 0 + listener.last_active_time.load().elapsed().as_secs() < 40 + || listener.last_select_time.load().elapsed().as_secs() < 30 }); let mut use_last = false; - if all_listener_sockets.lock().await.len() < 4 { + if all_listener_sockets.lock().await.len() < 4 || use_new_listener { tracing::warn!("creating new udp hole punching listener"); all_listener_sockets.lock().await.push( UdpHolePunchListener::new(self.data.peer_mgr.clone()) @@ -195,11 +519,38 @@ impl UdpHolePunchRpcServer { let listener = if use_last { locked.last()? } else { - locked.choose(&mut rand::rngs::StdRng::from_entropy())? + // use the listener that is active most recently + locked + .iter() + .max_by_key(|listener| listener.last_active_time.load())? }; Some((listener.get_socket().await, listener.mapped_addr)) } + + #[tracing::instrument(err, ret(level=Level::DEBUG), skip(self, ports))] + async fn send_symmetric_hole_punch_packet( + &self, + udp: Arc, + transaction_id: u32, + public_ips: &Vec, + ports: &[u16], + ) -> Result<(), Error> { + tracing::debug!( + ?public_ips, + "sending symmetric hole punching packet, ports len: {}", + ports.len(), + ); + for port in ports { + for pub_ip in public_ips { + let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, *port)); + let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN); + udp.send_to(&packet.into_bytes(), addr).await?; + tokio::time::sleep(Duration::from_millis(2)).await; + } + } + Ok(()) + } } pub struct UdpHolePunchConnector { @@ -221,6 +572,12 @@ impl UdpHolePunchConnector { global_ctx, peer_mgr, listeners: Arc::new(Mutex::new(Vec::new())), + shuffled_port_vec: Arc::new(generate_shuffled_port_vec()), + udp_array: Arc::new(Mutex::new(None)), + try_direct_connect: AtomicBool::new(true), + punch_predicablely: AtomicBool::new(true), + punch_randomly: AtomicBool::new(true), + udp_array_size: AtomicUsize::new(80), }), tasks: JoinSet::new(), } @@ -251,20 +608,15 @@ impl UdpHolePunchConnector { Ok(()) } - async fn collect_peer_to_connect(data: Arc) -> Vec { + async fn collect_peer_to_connect( + data: Arc, + ) -> Vec<(PeerId, NatType)> { let mut peers_to_connect = Vec::new(); // do not do anything if: // 1. our stun test has not finished // 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us - let my_nat_type = data - .global_ctx - .get_stun_info_collector() - .get_stun_info() - .udp_nat_type; - - let my_nat_type = NatType::try_from(my_nat_type).unwrap(); - + let my_nat_type = data.my_nat_type(); if my_nat_type == NatType::Unknown || my_nat_type == NatType::OpenInternet || my_nat_type == NatType::NoPat @@ -300,10 +652,9 @@ impl UdpHolePunchConnector { continue; } - // if we are symmetric, we can only connect to full cone - // TODO: can also connect to restricted full cone, with some extra work + // if we are symmetric, we can only connect to cone peer if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall) - && peer_nat_type != NatType::FullCone + && (peer_nat_type == NatType::Symmetric || peer_nat_type == NatType::SymUdpFirewall) { continue; } @@ -329,14 +680,34 @@ impl UdpHolePunchConnector { "found peer to do hole punching" ); - peers_to_connect.push(peer_id); + peers_to_connect.push((peer_id, peer_nat_type)); } peers_to_connect } - #[tracing::instrument] - async fn do_hole_punching( + async fn try_connect_with_socket( + socket: Arc, + remote_mapped_addr: SocketAddr, + ) -> Result, Error> { + let connector = UdpTunnelConnector::new( + format!( + "udp://{}:{}", + remote_mapped_addr.ip(), + remote_mapped_addr.port() + ) + .to_string() + .parse() + .unwrap(), + ); + connector + .try_connect_with_socket(socket, remote_mapped_addr) + .await + .map_err(|e| Error::from(e)) + } + + #[tracing::instrument(err)] + async fn do_hole_punching_cone( data: Arc, dst_peer_id: PeerId, ) -> Result, anyhow::Error> { @@ -382,18 +753,6 @@ impl UdpHolePunchConnector { // server: will send some punching resps, total 10 packets. // client: use the socket to create UdpTunnel with UdpTunnelConnector // NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote. - - let connector = UdpTunnelConnector::new( - format!( - "udp://{}:{}", - remote_mapped_addr.ip(), - remote_mapped_addr.port() - ) - .to_string() - .parse() - .unwrap(), - ); - let _g = data.global_ctx.net_ns.guard(); let socket2_socket = socket2::Socket::new( socket2::Domain::for_address(local_socket_addr), @@ -401,62 +760,240 @@ impl UdpHolePunchConnector { Some(socket2::Protocol::UDP), )?; setup_sokcet2(&socket2_socket, &local_socket_addr)?; - let socket = UdpSocket::from_std(socket2_socket.into())?; + let socket = Arc::new(UdpSocket::from_std(socket2_socket.into())?); - Ok(connector - .try_connect_with_socket(socket, remote_mapped_addr) + Ok(Self::try_connect_with_socket(socket, remote_mapped_addr) .await .with_context(|| "UdpTunnelConnector failed to connect remote")?) } - async fn main_loop(data: Arc) { + #[tracing::instrument(err(level = Level::ERROR))] + async fn do_hole_punching_symmetric( + data: Arc, + dst_peer_id: PeerId, + ) -> Result, anyhow::Error> { + let Some(udp_array) = data.udp_array.lock().await.clone() else { + return Err(anyhow::anyhow!("udp array not started")); + }; + + let Some(remote_mapped_addr) = data + .peer_mgr + .get_peer_rpc_mgr() + .do_client_rpc_scoped( + constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, + dst_peer_id, + |c| async { + let client = + UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn(); + let remote_mapped_addr = client + .try_punch_hole(tarpc::context::current(), "0.0.0.0:0".parse().unwrap()) + .await; + tracing::debug!( + ?remote_mapped_addr, + ?dst_peer_id, + "hole punching symmetric got remote mapped addr" + ); + remote_mapped_addr + }, + ) + .await? + else { + return Err(anyhow::anyhow!("failed to get remote mapped addr")); + }; + + // try direct connect first + if data.try_direct_connect.load(Ordering::Relaxed) { + if let Ok(tunnel) = Self::try_connect_with_socket( + Arc::new(UdpSocket::bind("0.0.0.0:0").await?), + remote_mapped_addr, + ) + .await + { + return Ok(tunnel); + } + } + + let tid = rand::thread_rng().gen(); + let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(); + udp_array.add_intreast_tid(tid); + defer! { udp_array.remove_intreast_tid(tid);} + udp_array.send_with_all(&packet, remote_mapped_addr).await?; + + // get latest port mapping + let local_mapped_addr = data + .global_ctx + .get_stun_info_collector() + .get_udp_port_mapping(0) + .await?; + let port = local_mapped_addr.port(); + let IpAddr::V4(ipv4) = local_mapped_addr.ip() else { + return Err(anyhow::anyhow!("failed to get local mapped addr")); + }; + let stun_info = data.global_ctx.get_stun_info_collector().get_stun_info(); + let mut public_ips: Vec = stun_info + .public_ip + .iter() + .map(|x| x.parse().unwrap()) + .collect(); + if !public_ips.contains(&ipv4) { + public_ips.push(ipv4); + } + if public_ips.is_empty() { + return Err(anyhow::anyhow!("failed to get public ips")); + } + + let mut last_port_idx = 0; + + for round in 0..30 { + let Some(next_last_port_idx) = data + .peer_mgr + .get_peer_rpc_mgr() + .do_client_rpc_scoped( + constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, + dst_peer_id, + |c| async { + let client = + UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c) + .spawn(); + let last_port_idx = client + .try_punch_symmetric( + tarpc::context::current(), + remote_mapped_addr, + port, + public_ips.clone(), + stun_info.min_port as u16, + stun_info.max_port as u16, + tid, + round, + last_port_idx, + ) + .await; + tracing::info!(?last_port_idx, ?dst_peer_id, "punch symmetric return"); + last_port_idx + }, + ) + .await? + else { + return Err(anyhow::anyhow!("failed to get remote mapped addr")); + }; + + while let Some(socket) = udp_array.try_fetch_punched_socket(tid) { + if let Ok(tunnel) = Self::try_connect_with_socket(socket, remote_mapped_addr).await + { + return Ok(tunnel); + } + } + + last_port_idx = next_last_port_idx; + } + + return Err(anyhow::anyhow!("udp array not started")); + } + + async fn peer_punching_task( + data: Arc, + peer_id: PeerId, + ) -> Result<(), anyhow::Error> { + const MAX_BACKOFF_TIME: u64 = 600; + let mut backoff_time = vec![15, 15, 30, 30, 60, 120, 300, MAX_BACKOFF_TIME]; + let my_nat_type = data.my_nat_type(); + loop { - let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await; - tracing::trace!(?peers_to_connect, "peers to connect"); - if peers_to_connect.len() == 0 { - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - continue; - } + let ret = if my_nat_type == NatType::FullCone + || my_nat_type == NatType::Restricted + || my_nat_type == NatType::PortRestricted + { + Self::do_hole_punching_cone(data.clone(), peer_id).await + } else { + Self::do_hole_punching_symmetric(data.clone(), peer_id).await + }; - let mut tasks: JoinSet> = JoinSet::new(); - for peer_id in peers_to_connect { - let data = data.clone(); - tasks.spawn( - async move { - let tunnel = Self::do_hole_punching(data.clone(), peer_id) - .await - .with_context(|| "failed to do hole punching")?; - - let _ = - data.peer_mgr - .add_client_tunnel(tunnel) - .await - .with_context(|| { - "failed to add tunnel as client in hole punch connector" - })?; - - Ok(()) - } - .instrument(tracing::info_span!("doing hole punching client", ?peer_id)), - ); - } - - while let Some(res) = tasks.join_next().await { - if let Err(e) = res { - tracing::error!(?e, "failed to join hole punching job"); + match ret { + Err(_) => { + tokio::time::sleep(Duration::from_secs( + backoff_time.pop().unwrap_or(MAX_BACKOFF_TIME), + )) + .await; continue; } - match res.unwrap() { - Err(e) => { - tracing::error!(?e, "failed to do hole punching job"); - } - Ok(_) => { - tracing::info!("hole punching job succeed"); + Ok(tunnel) => { + let _ = data + .peer_mgr + .add_client_tunnel(tunnel) + .await + .with_context(|| { + "failed to add tunnel as client in hole punch connector" + })?; + break; + } + } + } + + Ok(()) + } + + async fn main_loop(data: Arc) { + type JoinTaskRet = Result<(), anyhow::Error>; + type JoinTask = tokio::task::JoinHandle; + let punching_task = Arc::new(DashMap::<(PeerId, NatType), JoinTask>::new()); + let mut last_my_nat_type = NatType::Unknown; + + loop { + let my_nat_type = data.my_nat_type(); + let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await; + + // remove task not in peers_to_connect + let mut to_remove = vec![]; + for item in punching_task.iter() { + if !peers_to_connect.contains(item.key()) + || item.value().is_finished() + || my_nat_type != last_my_nat_type + { + to_remove.push(item.key().clone()); + } + } + for key in to_remove { + if let Some((_, task)) = punching_task.remove(&key) { + task.abort(); + match task.await { + Ok(Ok(_)) => {} + Ok(Err(task_ret)) => { + tracing::error!(?task_ret, "hole punching task failed"); + } + Err(e) => { + tracing::error!(?e, "hole punching task aborted"); + } } } } + last_my_nat_type = my_nat_type; + + if !peers_to_connect.is_empty() { + let my_nat_type = data.my_nat_type(); + if my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall { + let mut udp_array = data.udp_array.lock().await; + if udp_array.is_none() { + *udp_array = Some(Arc::new(UdpSocketArray::new( + data.udp_array_size.load(Ordering::Relaxed), + data.global_ctx.net_ns.clone(), + ))); + } + let udp_array = udp_array.as_ref().unwrap(); + udp_array.start().await.unwrap(); + } + + for item in peers_to_connect { + punching_task.insert( + item, + tokio::spawn(Self::peer_punching_task(data.clone(), item.0)), + ); + } + } else if punching_task.is_empty() { + data.udp_array.lock().await.take(); + } + tokio::time::sleep(std::time::Duration::from_secs(10)).await; } } @@ -464,8 +1001,14 @@ impl UdpHolePunchConnector { #[cfg(test)] pub mod tests { + use std::sync::atomic::AtomicU32; use std::sync::Arc; + use std::time::Duration; + use tokio::net::UdpSocket; + + use crate::connector::udp_hole_punch::UdpHolePunchListener; + use crate::peers::tests::wait_for_condition; use crate::rpc::{NatType, StunInfo}; use crate::{ @@ -491,10 +1034,16 @@ pub mod tests { udp_nat_type: self.udp_nat_type as i32, tcp_nat_type: NatType::Unknown as i32, last_update_time: std::time::Instant::now().elapsed().as_secs() as i64, + min_port: 100, + max_port: 200, + ..Default::default() } } - async fn get_udp_port_mapping(&self, port: u16) -> Result { + async fn get_udp_port_mapping(&self, mut port: u16) -> Result { + if port == 0 { + port = 40144; + } Ok(format!("127.0.0.1:{}", port).parse().unwrap()) } } @@ -515,10 +1064,10 @@ pub mod tests { } #[tokio::test] - async fn hole_punching() { - let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; - let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + async fn hole_punching_cone() { + let p_a = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; connect_peer_manager(p_a.clone(), p_b.clone()).await; connect_peer_manager(p_b.clone(), p_c.clone()).await; @@ -537,4 +1086,128 @@ pub mod tests { .unwrap(); println!("{:?}", p_a.list_routes().await); } + + #[tokio::test] + async fn hole_punching_symmetric_only_random() { + let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); + + hole_punching_a + .data + .try_direct_connect + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_c + .data + .punch_predicablely + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a.run().await.unwrap(); + hole_punching_c.run().await.unwrap(); + + wait_for_condition( + || async { hole_punching_a.data.udp_array.lock().await.is_some() }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || async { + wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) + .await + .is_ok() + }, + Duration::from_secs(30), + ) + .await; + println!("{:?}", p_a.list_routes().await); + + wait_for_condition( + || async { hole_punching_a.data.udp_array.lock().await.is_none() }, + Duration::from_secs(20), + ) + .await; + } + + #[tokio::test] + async fn hole_punching_symmetric_only_predict() { + let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); + + hole_punching_a + .data + .try_direct_connect + .store(false, std::sync::atomic::Ordering::Relaxed); + hole_punching_a + .data + .udp_array_size + .store(0, std::sync::atomic::Ordering::Relaxed); + + hole_punching_c + .data + .punch_randomly + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a.run().await.unwrap(); + hole_punching_c.run().await.unwrap(); + + let udp_self = Arc::new(UdpSocket::bind("0.0.0.0:40144").await.unwrap()); + let udp_inc = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap()); + let udp_inc2 = Arc::new(UdpSocket::bind("0.0.0.0:40400").await.unwrap()); + let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap()); + let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40350").await.unwrap()); + let udps = vec![udp_self, udp_inc, udp_inc2, udp_dec, udp_dec2]; + + let counter = Arc::new(AtomicU32::new(0)); + + // all these sockets should receive hole punching packet + for udp in udps.iter().map(Arc::clone) { + let counter = counter.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); + println!("{:?} {:?}", len, addr); + counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + }); + } + + let udp_len = udps.len(); + wait_for_condition( + || async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 }, + Duration::from_secs(30), + ) + .await; + } + + #[tokio::test] + async fn udp_listener() { + let p_a = create_mock_peer_manager().await; + wait_for_condition( + || async { + p_a.get_global_ctx() + .get_stun_info_collector() + .get_stun_info() + .udp_nat_type + != NatType::Unknown as i32 + }, + Duration::from_secs(20), + ) + .await; + let l = UdpHolePunchListener::new(p_a.clone()).await.unwrap(); + println!("{:#?}", l.mapped_addr); + } } diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 97541ba..f9f8365 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -1,9 +1,11 @@ #![allow(dead_code)] -use std::{net::SocketAddr, vec}; +use std::{net::SocketAddr, time::Duration, vec}; use clap::{command, Args, Parser, Subcommand}; +use common::stun::StunInfoCollectorTrait; use rpc::vpn_portal_rpc_client::VpnPortalRpcClient; +use tokio::time::timeout; use utils::{list_peer_route_pair, PeerRoutePair}; mod arch; @@ -13,7 +15,7 @@ mod tunnel; mod utils; use crate::{ - common::stun::{StunInfoCollector, UdpNatTypeDetector}, + common::stun::StunInfoCollector, rpc::{ connector_manage_rpc_client::ConnectorManageRpcClient, peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient, @@ -309,8 +311,19 @@ async fn main() -> Result<(), Error> { handler.handle_route_list().await?; } SubCommand::Stun => { - let stun = UdpNatTypeDetector::new(StunInfoCollector::get_default_servers()); - println!("udp type: {:?}", stun.get_udp_nat_type(0).await); + timeout(Duration::from_secs(5), async move { + let collector = StunInfoCollector::new_with_default_servers(); + loop { + let ret = collector.get_stun_info(); + if ret.udp_nat_type != NatType::Unknown as i32 { + println!("stun info: {:#?}", ret); + break; + } + tokio::time::sleep(Duration::from_millis(200)).await; + } + }) + .await + .unwrap(); } SubCommand::PeerCenter => { let mut peer_center_client = handler.get_peer_center_client().await?; diff --git a/easytier/src/peers/peer_conn_ping.rs b/easytier/src/peers/peer_conn_ping.rs index e06c2a2..dd450be 100644 --- a/easytier/src/peers/peer_conn_ping.rs +++ b/easytier/src/peers/peer_conn_ping.rs @@ -76,7 +76,7 @@ impl PeerConnPinger { let now = std::time::Instant::now(); // wait until we get a pong packet in ctrl_resp_receiver - let resp = timeout(Duration::from_secs(1), async { + let resp = timeout(Duration::from_secs(2), async { loop { match receiver.recv().await { Ok(p) => { diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 9e4709b..c92af75 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -77,16 +77,16 @@ fn new_sack_packet(conn_id: u32, magic: u64) -> ZCPacket { ) } -pub fn new_hole_punch_packet() -> ZCPacket { +pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket { // generate a 128 bytes vec with random data let mut rng = rand::rngs::StdRng::from_entropy(); - let mut buf = vec![0u8; 128]; + let mut buf = vec![0u8; buf_len as usize]; rng.fill(&mut buf[..]); new_udp_packet( |header| { header.msg_type = UdpPacketType::HolePunch as u8; - header.conn_id.set(0); - header.len.set(0); + header.conn_id.set(tid); + header.len.set(buf_len); }, Some(&mut buf), ) @@ -304,7 +304,7 @@ impl UdpTunnelListenerData { let header = zc_packet.udp_tunnel_header().unwrap(); if header.msg_type == UdpPacketType::Syn as u8 { tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet)); - } else { + } else if header.msg_type != UdpPacketType::HolePunch as u8 { if let Err(e) = self .try_forward_packet(addr, header.conn_id.get(), zc_packet) .await @@ -526,11 +526,10 @@ impl UdpTunnelConnector { async fn build_tunnel( &self, - socket: UdpSocket, + socket: Arc, dst_addr: SocketAddr, conn_id: u32, ) -> Result, super::TunnelError> { - let socket = Arc::new(socket); let ring_for_send_udp = Arc::new(RingTunnel::new(128)); let ring_for_recv_udp = Arc::new(RingTunnel::new(128)); tracing::debug!( @@ -610,13 +609,13 @@ impl UdpTunnelConnector { pub async fn try_connect_with_socket( &self, - socket: UdpSocket, + socket: Arc, addr: SocketAddr, ) -> Result, super::TunnelError> { log::warn!("udp connect: {:?}", self.addr); #[cfg(target_os = "windows")] - crate::arch::windows::disable_connection_reset(&socket)?; + crate::arch::windows::disable_connection_reset(socket.as_ref())?; // send syn let conn_id = rand::random(); @@ -649,7 +648,7 @@ impl UdpTunnelConnector { UdpSocket::bind("[::]:0").await? }; - return self.try_connect_with_socket(socket, addr).await; + return self.try_connect_with_socket(Arc::new(socket), addr).await; } async fn connect_with_custom_bind( @@ -666,7 +665,7 @@ impl UdpTunnelConnector { )?; setup_sokcet2(&socket2_socket, &bind_addr)?; let socket = UdpSocket::from_std(socket2_socket.into())?; - futures.push(self.try_connect_with_socket(socket, addr)); + futures.push(self.try_connect_with_socket(Arc::new(socket), addr)); } wait_for_connect_futures(futures).await }