diff --git a/easytier/src/common/error.rs b/easytier/src/common/error.rs index 17c7574..6ebfa4e 100644 --- a/easytier/src/common/error.rs +++ b/easytier/src/common/error.rs @@ -2,7 +2,7 @@ use std::{io, result}; use thiserror::Error; -use crate::{tunnel, tunnels}; +use crate::tunnel; use super::PeerId; @@ -13,7 +13,7 @@ pub enum Error { #[error("rust tun error {0}")] TunError(#[from] tun::Error), #[error("tunnel error {0}")] - TunnelError(#[from] tunnels::TunnelError), + TunnelError(#[from] tunnel::TunnelError), #[error("Peer has no conn, PeerId: {0}")] PeerNoConnectionError(PeerId), #[error("RouteError: {0:?}")] @@ -42,9 +42,6 @@ pub enum Error { #[error("wait resp error: {0}")] WaitRespError(String), - #[error("tunnel error")] - TunnelErr(#[from] tunnel::TunnelError), - #[error("message decode error: {0}")] MessageDecodeError(String), diff --git a/easytier/src/connector/manual.rs b/easytier/src/connector/manual.rs index 8c7342d..b189a99 100644 --- a/easytier/src/connector/manual.rs +++ b/easytier/src/connector/manual.rs @@ -8,7 +8,7 @@ use tokio::{ }; use crate::{ - common::PeerId, peers::zc_peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector, + common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector, }; use crate::{ diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 81863bb..af987ab 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, tunnel::{ + check_scheme_and_get_socket_addr, quic::QUICTunnelConnector, ring::RingTunnelConnector, tcp::TcpTunnelConnector, @@ -50,8 +51,7 @@ pub async fn create_connector_by_url( let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?; match url.scheme() { "tcp" => { - let dst_addr = - crate::tunnels::check_scheme_and_get_socket_addr::(&url, "tcp")?; + let dst_addr = check_scheme_and_get_socket_addr::(&url, "tcp")?; let mut connector = TcpTunnelConnector::new(url); set_bind_addr_for_peer_connector( &mut connector, @@ -62,8 +62,7 @@ pub async fn create_connector_by_url( return Ok(Box::new(connector)); } "udp" => { - let dst_addr = - crate::tunnels::check_scheme_and_get_socket_addr::(&url, "udp")?; + let dst_addr = check_scheme_and_get_socket_addr::(&url, "udp")?; let mut connector = UdpTunnelConnector::new(url); set_bind_addr_for_peer_connector( &mut connector, @@ -74,13 +73,12 @@ pub async fn create_connector_by_url( return Ok(Box::new(connector)); } "ring" => { - crate::tunnels::check_scheme_and_get_socket_addr::(&url, "ring")?; + check_scheme_and_get_socket_addr::(&url, "ring")?; let connector = RingTunnelConnector::new(url); return Ok(Box::new(connector)); } "quic" => { - let dst_addr = - crate::tunnels::check_scheme_and_get_socket_addr::(&url, "quic")?; + let dst_addr = check_scheme_and_get_socket_addr::(&url, "quic")?; let mut connector = QUICTunnelConnector::new(url); set_bind_addr_for_peer_connector( &mut connector, @@ -91,8 +89,7 @@ pub async fn create_connector_by_url( return Ok(Box::new(connector)); } "wg" => { - let dst_addr = - crate::tunnels::check_scheme_and_get_socket_addr::(&url, "wg")?; + let dst_addr = check_scheme_and_get_socket_addr::(&url, "wg")?; let nid = global_ctx.get_network_identity(); let wg_config = WgConfig::new_from_network_identity( &nid.network_name, diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 25ff5ae..33c53a8 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -10,7 +10,6 @@ mod arch; mod common; mod rpc; mod tunnel; -mod tunnels; mod utils; use crate::{ diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index b1b0c52..f280540 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -17,7 +17,6 @@ mod peer_center; mod peers; mod rpc; mod tunnel; -mod tunnels; mod vpn_portal; use common::{ diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index c611938..ba493c3 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -26,8 +26,10 @@ use tracing::Level; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{peer_manager::PeerManager, PeerPacketFilter}, - tunnel::packet_def::{PacketType, ZCPacket}, - tunnels::common::setup_sokcet2, + tunnel::{ + common::setup_sokcet2, + packet_def::{PacketType, ZCPacket}, + }, }; use super::CidrSet; diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index a736387..29fc243 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -22,9 +22,9 @@ use crate::gateway::icmp_proxy::IcmpProxy; use crate::gateway::tcp_proxy::TcpProxy; use crate::gateway::udp_proxy::UdpProxy; use crate::peer_center::instance::PeerCenterInstance; +use crate::peers::peer_conn::PeerConnId; use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::rpc_service::PeerManagerRpcService; -use crate::peers::zc_peer_conn::PeerConnId; use crate::peers::PacketRecvChanReceiver; use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc; use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; diff --git a/easytier/src/lib.rs b/easytier/src/lib.rs index bbba570..1653e9a 100644 --- a/easytier/src/lib.rs +++ b/easytier/src/lib.rs @@ -9,6 +9,5 @@ pub mod peer_center; pub mod peers; pub mod rpc; pub mod tunnel; -pub mod tunnels; pub mod utils; pub mod vpn_portal; diff --git a/easytier/src/peers/encrypt/ring_aes_gcm.rs b/easytier/src/peers/encrypt/ring_aes_gcm.rs index 059a716..8878e13 100644 --- a/easytier/src/peers/encrypt/ring_aes_gcm.rs +++ b/easytier/src/peers/encrypt/ring_aes_gcm.rs @@ -137,7 +137,7 @@ impl Encryptor for AesGcmCipher { mod tests { use crate::{ peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor}, - tunnel::packet_def::{ZCPacket, ZCPacketType, AES_GCM_ENCRYPTION_RESERVED}, + tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED}, }; #[test] diff --git a/easytier/src/peers/foreign_network_client.rs b/easytier/src/peers/foreign_network_client.rs index cff5008..6323c3d 100644 --- a/easytier/src/peers/foreign_network_client.rs +++ b/easytier/src/peers/foreign_network_client.rs @@ -17,9 +17,9 @@ use crate::{ use super::{ foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID}, + peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager, - zc_peer_conn::PeerConn, PacketRecvChan, }; diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index fbcc579..f0b179b 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -26,9 +26,9 @@ use crate::{ }; use super::{ + peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, - zc_peer_conn::PeerConn, PacketRecvChan, PacketRecvChanReceiver, }; diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index ff6d497..afc6054 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -1,6 +1,7 @@ pub mod packet; pub mod peer; // pub mod peer_conn; +pub mod peer_conn; pub mod peer_conn_ping; pub mod peer_manager; pub mod peer_map; @@ -9,7 +10,6 @@ pub mod peer_rip_route; pub mod peer_rpc; pub mod route_trait; pub mod rpc_service; -pub mod zc_peer_conn; pub mod foreign_network_client; pub mod foreign_network_manager; diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index ea19405..4efa6fa 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -8,7 +8,7 @@ use tokio::{select, sync::mpsc, task::JoinHandle}; use tracing::Instrument; use super::{ - zc_peer_conn::{PeerConn, PeerConnId}, + peer_conn::{PeerConn, PeerConnId}, PacketRecvChan, }; use crate::rpc::PeerConnInfo; @@ -175,7 +175,7 @@ mod tests { use crate::{ common::{global_ctx::tests::get_mock_global_ctx, new_peer_id}, - peers::zc_peer_conn::PeerConn, + peers::peer_conn::PeerConn, tunnel::ring::create_ring_tunnel_pair, }; diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index e8eca28..9043910 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -1,4 +1,5 @@ use std::{ + any::Any, fmt::Debug, pin::Pin, sync::{ @@ -7,8 +8,9 @@ use std::{ }, }; -use futures::{SinkExt, StreamExt}; -use pnet::datalink::NetworkInterface; +use futures::{SinkExt, StreamExt, TryFutureExt}; + +use prost::Message; use tokio::{ sync::{broadcast, mpsc, Mutex}, @@ -16,293 +18,34 @@ use tokio::{ time::{timeout, Duration}, }; -use tokio_util::{bytes::Bytes, sync::PollSender}; +use tokio_util::sync::PollSender; use tracing::Instrument; +use zerocopy::AsBytes; use crate::{ common::{ - global_ctx::{ArcGlobalCtx, NetworkIdentity}, + config::{NetworkIdentity, NetworkSecretDigest}, + error::Error, + global_ctx::ArcGlobalCtx, PeerId, }, - define_tunnel_filter_chain, - peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType}, - rpc::{PeerConnInfo, PeerConnStats}, - tunnels::{ + peers::packet::PacketType, + rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo}, + tunnel::{ + filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter}, + mpsc::{MpscTunnel, MpscTunnelSender}, + packet_def::ZCPacket, stats::{Throughput, WindowLatency}, - tunnel_filter::StatsRecorderTunnelFilter, - DatagramSink, Tunnel, TunnelError, + Tunnel, TunnelError, ZCPacketStream, }, }; -use super::packet::{self, HandShake, Packet}; - -pub type PacketRecvChan = mpsc::Sender; +use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan}; pub type PeerConnId = uuid::Uuid; -macro_rules! wait_response { - ($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => { - let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else { - return Err(TunnelError::WaitRespError( - "wait handshake response timeout".to_owned(), - )); - }; - let Some(rsp_vec) = rsp_vec else { - return Err(TunnelError::WaitRespError( - "wait handshake response get none".to_owned(), - )); - }; - let Ok(rsp_vec) = rsp_vec else { - return Err(TunnelError::WaitRespError(format!( - "wait handshake response get error {}", - rsp_vec.err().unwrap() - ))); - }; - - let $out_var; - let rsp_bytes = Packet::decode(&rsp_vec); - if rsp_bytes.packet_type != PacketType::HandShake { - tracing::error!("unexpected packet type: {:?}", rsp_bytes); - return Err(TunnelError::WaitRespError( - "unexpected packet type".to_owned(), - )); - } - let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes); - match &resp_payload { - $pattern => $out_var = $value, - _ => { - tracing::error!( - "unexpected packet: {:?}, pattern: {:?}", - rsp_bytes, - stringify!($pattern) - ); - return Err(TunnelError::WaitRespError("unexpected packet".to_owned())); - } - } - }; -} - -#[derive(Debug)] -pub struct PeerInfo { - magic: u32, - pub my_peer_id: PeerId, - version: u32, - pub features: Vec, - pub interfaces: Vec, - pub network_identity: NetworkIdentity, -} - -impl<'a> From<&HandShake> for PeerInfo { - fn from(hs: &HandShake) -> Self { - PeerInfo { - magic: hs.magic.into(), - my_peer_id: hs.my_peer_id.into(), - version: hs.version.into(), - features: hs.features.iter().map(|x| x.to_string()).collect(), - interfaces: Vec::new(), - network_identity: hs.network_identity.clone(), - } - } -} - -struct PeerConnPinger { - my_peer_id: PeerId, - peer_id: PeerId, - sink: Arc>>>, - ctrl_sender: broadcast::Sender, - latency_stats: Arc, - loss_rate_stats: Arc, - tasks: JoinSet>, -} - -impl Debug for PeerConnPinger { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PeerConnPinger") - .field("my_peer_id", &self.my_peer_id) - .field("peer_id", &self.peer_id) - .finish() - } -} - -impl PeerConnPinger { - pub fn new( - my_peer_id: PeerId, - peer_id: PeerId, - sink: Pin>, - ctrl_sender: broadcast::Sender, - latency_stats: Arc, - loss_rate_stats: Arc, - ) -> Self { - Self { - my_peer_id, - peer_id, - sink: Arc::new(Mutex::new(sink)), - tasks: JoinSet::new(), - latency_stats, - ctrl_sender, - loss_rate_stats, - } - } - - async fn do_pingpong_once( - my_node_id: PeerId, - peer_id: PeerId, - sink: Arc>>>, - receiver: &mut broadcast::Receiver, - seq: u32, - ) -> Result { - // should add seq here. so latency can be calculated more accurately - let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into(); - tracing::trace!("send ping packet: {:?}", req); - sink.lock().await.send(req).await.map_err(|e| { - tracing::warn!("send ping packet error: {:?}", e); - TunnelError::CommonError("send ping packet error".to_owned()) - })?; - - let now = std::time::Instant::now(); - - // wait until we get a pong packet in ctrl_resp_receiver - let resp = timeout(Duration::from_secs(1), async { - loop { - match receiver.recv().await { - Ok(p) => { - let ctrl_payload = - packet::CtrlPacketPayload::from_packet(Packet::decode(&p)); - if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload { - if resp_seq == seq { - break; - } - } - } - Err(e) => { - log::warn!("recv pong resp error: {:?}", e); - return Err(TunnelError::WaitRespError( - "recv pong resp error".to_owned(), - )); - } - } - } - Ok(()) - }) - .await; - - tracing::trace!(?resp, "wait ping response done"); - - if resp.is_err() { - return Err(TunnelError::WaitRespError( - "wait ping response timeout".to_owned(), - )); - } - - if resp.as_ref().unwrap().is_err() { - return Err(resp.unwrap().err().unwrap()); - } - - Ok(now.elapsed().as_micros()) - } - - async fn pingpong(&mut self) { - let sink = self.sink.clone(); - let my_node_id = self.my_peer_id; - let peer_id = self.peer_id; - let latency_stats = self.latency_stats.clone(); - - let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100); - - let stopped = Arc::new(AtomicU32::new(0)); - - // generate a pingpong task every 200ms - let mut pingpong_tasks = JoinSet::new(); - let ctrl_resp_sender = self.ctrl_sender.clone(); - let stopped_clone = stopped.clone(); - self.tasks.spawn(async move { - let mut req_seq = 0; - loop { - let receiver = ctrl_resp_sender.subscribe(); - let ping_res_sender = ping_res_sender.clone(); - let sink = sink.clone(); - - if stopped_clone.load(Ordering::Relaxed) != 0 { - return Ok(()); - } - - while pingpong_tasks.len() > 5 { - pingpong_tasks.join_next().await; - } - - pingpong_tasks.spawn(async move { - let mut receiver = receiver.resubscribe(); - let pingpong_once_ret = Self::do_pingpong_once( - my_node_id, - peer_id, - sink.clone(), - &mut receiver, - req_seq, - ) - .await; - - if let Err(e) = ping_res_sender.send(pingpong_once_ret).await { - tracing::info!(?e, "pingpong task send result error, exit.."); - }; - }); - - req_seq = req_seq.wrapping_add(1); - tokio::time::sleep(Duration::from_millis(1000)).await; - } - }); - - // one with 1% precision - let loss_rate_stats_1 = WindowLatency::new(100); - // one with 20% precision, so we can fast fail this conn. - let loss_rate_stats_20 = WindowLatency::new(5); - - let mut counter: u64 = 0; - - while let Some(ret) = ping_res_receiver.recv().await { - counter += 1; - - if let Ok(lat) = ret { - latency_stats.record_latency(lat as u32); - - loss_rate_stats_1.record_latency(0); - loss_rate_stats_20.record_latency(0); - } else { - loss_rate_stats_1.record_latency(1); - loss_rate_stats_20.record_latency(1); - } - - let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us(); - let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us(); - - tracing::trace!( - ?ret, - ?self, - ?loss_rate_1, - ?loss_rate_20, - "pingpong task recv pingpong_once result" - ); - - if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) { - tracing::warn!( - ?ret, - ?self, - ?loss_rate_1, - ?loss_rate_20, - "pingpong loss rate too high, closing" - ); - break; - } - - self.loss_rate_stats - .store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed); - } - - stopped.store(1, Ordering::Relaxed); - ping_res_receiver.close(); - } -} - -define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter); +const MAGIC: u32 = 0xd1e1a5e1; +const VERSION: u32 = 1; pub struct PeerConn { conn_id: PeerConnId, @@ -310,33 +53,45 @@ pub struct PeerConn { my_peer_id: PeerId, global_ctx: ArcGlobalCtx, - sink: Pin>, - tunnel: Box, + tunnel: Arc>>, + sink: MpscTunnelSender, + recv: Arc>>>>, + tunnel_info: Option, tasks: JoinSet>, - info: Option, + info: Option, close_event_sender: Option>, - ctrl_resp_sender: broadcast::Sender, + ctrl_resp_sender: broadcast::Sender, latency_stats: Arc, throughput: Arc, loss_rate_stats: Arc, } -enum PeerConnPacketType { - Data(Bytes), - CtrlReq(Bytes), - CtrlResp(Bytes), +impl Debug for PeerConn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PeerConn") + .field("conn_id", &self.conn_id) + .field("my_peer_id", &self.my_peer_id) + .field("info", &self.info) + .finish() + } } impl PeerConn { pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { + let tunnel_info = tunnel.info(); let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); - let peer_conn_tunnel = PeerConnTunnel::new(); - let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel); + + let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new(); + let throughput = peer_conn_tunnel_filter.filter_output(); + let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter); + let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel); + + let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink()); PeerConn { conn_id: PeerConnId::new_v4(), @@ -344,8 +99,10 @@ impl PeerConn { my_peer_id, global_ctx, - sink: tunnel.pin_sink(), - tunnel: Box::new(tunnel), + tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))), + sink, + recv: Arc::new(Mutex::new(Some(recv))), + tunnel_info, tasks: JoinSet::new(), @@ -355,7 +112,7 @@ impl PeerConn { ctrl_resp_sender: ctrl_sender, latency_stats: Arc::new(WindowLatency::new(15)), - throughput: peer_conn_tunnel.stats.get_throughput().clone(), + throughput, loss_rate_stats: Arc::new(AtomicU32::new(0)), } } @@ -364,41 +121,97 @@ impl PeerConn { self.conn_id } - #[tracing::instrument] - pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); + async fn wait_handshake(&mut self, need_retry: &mut bool) -> Result { + *need_retry = false; - tracing::info!("waiting for handshake request from client"); - wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x); - self.info = Some(PeerInfo::from(hs_req)); - tracing::info!("handshake request: {:?}", hs_req); + let mut locked = self.recv.lock().await; + let recv = locked.as_mut().unwrap(); + let Some(rsp) = recv.next().await else { + return Err(Error::WaitRespError( + "conn closed during wait handshake response".to_owned(), + )); + }; - let hs_req = self - .global_ctx - .net_ns - .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); - sink.send(hs_req.into()).await?; + *need_retry = true; + + let rsp = rsp?; + let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| { + Error::WaitRespError(format!("decode handshake response error: {:?}", e)) + })?; + + if rsp.network_secret_digrest.len() != std::mem::size_of::() { + return Err(Error::WaitRespError( + "invalid network secret digest".to_owned(), + )); + } + + return Ok(rsp); + } + + async fn wait_handshake_loop(&mut self) -> Result { + timeout(Duration::from_secs(5), async move { + loop { + let mut need_retry = true; + match self.wait_handshake(&mut need_retry).await { + Ok(rsp) => return Ok(rsp), + Err(e) => { + log::warn!("wait handshake error: {:?}", e); + if !need_retry { + return Err(e); + } + } + } + } + }) + .map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e))) + .await? + } + + async fn send_handshake(&mut self) -> Result<(), Error> { + let network = self.global_ctx.get_network_identity(); + let mut req = HandshakeRequest { + magic: MAGIC, + my_peer_id: self.my_peer_id, + version: VERSION, + features: Vec::new(), + network_name: network.network_name.clone(), + ..Default::default() + }; + req.network_secret_digrest + .extend_from_slice(&network.network_secret_digest.unwrap_or_default()); + + let hs_req = req.encode_to_vec(); + let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); + zc_packet.fill_peer_manager_hdr( + self.my_peer_id, + PeerId::default(), + PacketType::HandShake as u8, + ); + + self.sink.send(zc_packet).await.map_err(|e| { + tracing::warn!("send handshake request error: {:?}", e); + Error::WaitRespError("send handshake request error".to_owned()) + })?; Ok(()) } #[tracing::instrument] - pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); - - let hs_req = self - .global_ctx - .net_ns - .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); - sink.send(hs_req.into()).await?; + pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> { + let rsp = self.wait_handshake_loop().await?; + tracing::info!("handshake request: {:?}", rsp); + self.info = Some(rsp); + self.send_handshake().await?; + Ok(()) + } + #[tracing::instrument] + pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> { + self.send_handshake().await?; tracing::info!("waiting for handshake request from server"); - wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x); - self.info = Some(PeerInfo::from(hs_rsp)); - tracing::info!("handshake response: {:?}", hs_rsp); - + let rsp = self.wait_handshake_loop().await?; + tracing::info!("handshake response: {:?}", rsp); + self.info = Some(rsp); Ok(()) } @@ -406,11 +219,72 @@ impl PeerConn { self.info.is_some() } + pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { + let mut stream = self.recv.lock().await.take().unwrap(); + let sink = self.sink.clone(); + let mut sender = PollSender::new(packet_recv_chan.clone()); + let close_event_sender = self.close_event_sender.clone().unwrap(); + let conn_id = self.conn_id; + let ctrl_sender = self.ctrl_resp_sender.clone(); + let _conn_info = self.get_conn_info(); + let conn_info_for_instrument = self.get_conn_info(); + + self.tasks.spawn( + async move { + tracing::info!("start recving peer conn packet"); + let mut task_ret = Ok(()); + while let Some(ret) = stream.next().await { + if ret.is_err() { + tracing::error!(error = ?ret, "peer conn recv error"); + task_ret = Err(ret.err().unwrap()); + break; + } + + let mut zc_packet = ret.unwrap(); + let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else { + tracing::error!( + "unexpected packet: {:?}, cannot decode peer manager hdr", + zc_packet + ); + continue; + }; + + if peer_mgr_hdr.packet_type == PacketType::Ping as u8 { + peer_mgr_hdr.packet_type = PacketType::Pong as u8; + if let Err(e) = sink.send(zc_packet).await { + tracing::error!(?e, "peer conn send req error"); + } + } else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 { + if let Err(e) = ctrl_sender.send(zc_packet) { + tracing::error!(?e, "peer conn send ctrl resp error"); + } + } else { + if sender.send(zc_packet).await.is_err() { + break; + } + } + } + + tracing::info!("end recving peer conn packet"); + + drop(sink); + if let Err(e) = close_event_sender.send(conn_id).await { + tracing::error!(error = ?e, "peer conn close event send error"); + } + + task_ret + } + .instrument( + tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), + ), + ); + } + pub fn start_pingpong(&mut self) { let mut pingpong = PeerConnPinger::new( self.my_peer_id, self.get_peer_id(), - self.tunnel.pin_sink(), + self.sink.clone(), self.ctrl_resp_sender.clone(), self.latency_stats.clone(), self.loss_rate_stats.clone(), @@ -432,79 +306,8 @@ impl PeerConn { }); } - pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); - let mut sender = PollSender::new(packet_recv_chan.clone()); - let close_event_sender = self.close_event_sender.clone().unwrap(); - let conn_id = self.conn_id; - let ctrl_sender = self.ctrl_resp_sender.clone(); - let conn_info = self.get_conn_info(); - let conn_info_for_instrument = self.get_conn_info(); - - self.tasks.spawn( - async move { - tracing::info!("start recving peer conn packet"); - let mut task_ret = Ok(()); - while let Some(ret) = stream.next().await { - if ret.is_err() { - tracing::error!(error = ?ret, "peer conn recv error"); - task_ret = Err(ret.err().unwrap()); - break; - } - - let buf = ret.unwrap(); - let p = Packet::decode(&buf); - match p.packet_type { - ArchivedPacketType::Ping => { - let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p) - else { - log::error!("unexpected packet: {:?}", p); - continue; - }; - - let pong = packet::Packet::new_pong_packet( - conn_info.my_peer_id, - conn_info.peer_id, - seq.into(), - ); - - if let Err(e) = sink.send(pong.into()).await { - tracing::error!(?e, "peer conn send req error"); - } - } - ArchivedPacketType::Pong => { - if let Err(e) = ctrl_sender.send(buf.into()) { - tracing::error!(?e, "peer conn send ctrl resp error"); - } - } - _ => { - if sender.send(buf.into()).await.is_err() { - break; - } - } - } - } - - tracing::info!("end recving peer conn packet"); - - if let Err(close_ret) = sink.close().await { - tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it"); - } - if let Err(e) = close_event_sender.send(conn_id).await { - tracing::error!(error = ?e, "peer conn close event send error"); - } - - task_ret - } - .instrument( - tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), - ), - ); - } - - pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> { - self.sink.send(msg).await + pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { + Ok(self.sink.send(msg).await?) } pub fn get_peer_id(&self) -> PeerId { @@ -512,7 +315,17 @@ impl PeerConn { } pub fn get_network_identity(&self) -> NetworkIdentity { - self.info.as_ref().unwrap().network_identity.clone() + let info = self.info.as_ref().unwrap(); + let mut ret = NetworkIdentity { + network_name: info.network_name.clone(), + ..Default::default() + }; + ret.network_secret_digest = Some([0u8; 32]); + ret.network_secret_digest + .as_mut() + .unwrap() + .copy_from_slice(&info.network_secret_digrest); + ret } pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { @@ -537,34 +350,13 @@ impl PeerConn { my_peer_id: self.my_peer_id, peer_id: self.get_peer_id(), features: self.info.as_ref().unwrap().features.clone(), - tunnel: self.tunnel.info(), + tunnel: self.tunnel_info.clone(), stats: Some(self.get_stats()), loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32, } } } -impl Drop for PeerConn { - fn drop(&mut self) { - let mut sink = self.tunnel.pin_sink(); - tokio::spawn(async move { - let ret = sink.close().await; - tracing::info!(error = ?ret, "peer conn tunnel closed."); - }); - log::info!("peer conn {:?} drop", self.conn_id); - } -} - -impl Debug for PeerConn { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PeerConn") - .field("conn_id", &self.conn_id) - .field("my_peer_id", &self.my_peer_id) - .field("info", &self.info) - .finish() - } -} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -572,12 +364,12 @@ mod tests { use super::*; use crate::common::global_ctx::tests::get_mock_global_ctx; use crate::common::new_peer_id; - use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter; - use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter}; + use crate::tunnel::filter::tests::DropSendTunnelFilter; + use crate::tunnel::filter::PacketRecorderTunnelFilter; + use crate::tunnel::ring::create_ring_tunnel_pair; #[tokio::test] async fn peer_conn_handshake() { - use crate::tunnels::ring_tunnel::create_ring_tunnel_pair; let (c, s) = create_ring_tunnel_pair(); let c_recorder = Arc::new(PacketRecorderTunnelFilter::new()); @@ -614,7 +406,6 @@ mod tests { } async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) { - use crate::tunnels::ring_tunnel::create_ring_tunnel_pair; let (c, s) = create_ring_tunnel_pair(); // drop 1-3 packets should not affect pingpong @@ -633,7 +424,9 @@ mod tests { ); s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0); - s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + s_peer + .start_recv_loop(tokio::sync::mpsc::channel(200).0) + .await; assert!(c_ret.is_ok()); assert!(s_ret.is_ok()); @@ -641,7 +434,9 @@ mod tests { let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1); c_peer.set_close_event_sender(close_send); c_peer.start_pingpong(); - c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + c_peer + .start_recv_loop(tokio::sync::mpsc::channel(200).0) + .await; // wait 5s, conn should not be disconnected tokio::time::sleep(Duration::from_secs(15)).await; @@ -658,4 +453,19 @@ mod tests { peer_conn_pingpong_test_common(3, 5, false).await; peer_conn_pingpong_test_common(5, 12, true).await; } + + #[tokio::test] + async fn close_tunnel_during_handshake() { + let (c, s) = create_ring_tunnel_pair(); + let mut c_peer = PeerConn::new(new_peer_id(), get_mock_global_ctx(), Box::new(c)); + let j = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + drop(s); + }); + timeout(Duration::from_millis(1500), c_peer.do_handshake_as_client()) + .await + .unwrap() + .unwrap_err(); + let _ = tokio::join!(j); + } } diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 267a74d..cde4271 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -22,8 +22,8 @@ use tokio_util::bytes::Bytes; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{ - packet, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface, - zc_peer_conn::PeerConn, PeerPacketFilter, + packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, + route_trait::RouteInterface, PeerPacketFilter, }, tunnel::{ packet_def::{PacketType, ZCPacket}, @@ -35,12 +35,12 @@ use super::{ encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor, NullCipher}, foreign_network_client::ForeignNetworkClient, foreign_network_manager::ForeignNetworkManager, + peer_conn::PeerConnId, peer_map::PeerMap, peer_ospf_route::PeerRoute, peer_rip_route::BasicRoute, peer_rpc::PeerRpcManager, route_trait::{ArcRoute, Route}, - zc_peer_conn::PeerConnId, BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver, }; diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index e6572d8..821c903 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -12,13 +12,13 @@ use crate::{ }, rpc::PeerConnInfo, tunnel::packet_def::ZCPacket, - tunnels::TunnelError, + tunnel::TunnelError, }; use super::{ peer::Peer, + peer_conn::{PeerConn, PeerConnId}, route_trait::ArcRoute, - zc_peer_conn::{PeerConn, PeerConnId}, PacketRecvChan, }; diff --git a/easytier/src/peers/zc_peer_conn.rs b/easytier/src/peers/zc_peer_conn.rs deleted file mode 100644 index c79c1e3..0000000 --- a/easytier/src/peers/zc_peer_conn.rs +++ /dev/null @@ -1,769 +0,0 @@ -use std::{ - any::Any, - fmt::Debug, - pin::Pin, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, -}; - -use futures::{SinkExt, StreamExt, TryFutureExt}; - -use prost::Message; - -use tokio::{ - sync::{broadcast, mpsc, Mutex}, - task::JoinSet, - time::{timeout, Duration}, -}; - -use tokio_util::sync::PollSender; -use tracing::Instrument; -use zerocopy::AsBytes; - -use crate::{ - common::{ - config::{NetworkIdentity, NetworkSecretDigest}, - error::Error, - global_ctx::ArcGlobalCtx, - PeerId, - }, - peers::packet::PacketType, - rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo}, - tunnel::{ - filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter}, - mpsc::{MpscTunnel, MpscTunnelSender}, - packet_def::ZCPacket, - stats::{Throughput, WindowLatency}, - Tunnel, TunnelError, ZCPacketStream, - }, -}; - -use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan}; - -pub type PeerConnId = uuid::Uuid; - -const MAGIC: u32 = 0xd1e1a5e1; -const VERSION: u32 = 1; - -pub struct PeerConn { - conn_id: PeerConnId, - - my_peer_id: PeerId, - global_ctx: ArcGlobalCtx, - - tunnel: Arc>>, - sink: MpscTunnelSender, - recv: Arc>>>>, - tunnel_info: Option, - - tasks: JoinSet>, - - info: Option, - - close_event_sender: Option>, - - ctrl_resp_sender: broadcast::Sender, - - latency_stats: Arc, - throughput: Arc, - loss_rate_stats: Arc, -} - -impl Debug for PeerConn { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PeerConn") - .field("conn_id", &self.conn_id) - .field("my_peer_id", &self.my_peer_id) - .field("info", &self.info) - .finish() - } -} - -impl PeerConn { - pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { - let tunnel_info = tunnel.info(); - let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); - - let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new(); - let throughput = peer_conn_tunnel_filter.filter_output(); - let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter); - let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel); - - let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink()); - - PeerConn { - conn_id: PeerConnId::new_v4(), - - my_peer_id, - global_ctx, - - tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))), - sink, - recv: Arc::new(Mutex::new(Some(recv))), - tunnel_info, - - tasks: JoinSet::new(), - - info: None, - close_event_sender: None, - - ctrl_resp_sender: ctrl_sender, - - latency_stats: Arc::new(WindowLatency::new(15)), - throughput, - loss_rate_stats: Arc::new(AtomicU32::new(0)), - } - } - - pub fn get_conn_id(&self) -> PeerConnId { - self.conn_id - } - - async fn wait_handshake(&mut self) -> Result { - let mut locked = self.recv.lock().await; - let recv = locked.as_mut().unwrap(); - let Some(rsp) = recv.next().await else { - return Err(Error::WaitRespError( - "conn closed during wait handshake response".to_owned(), - )); - }; - let rsp = rsp?; - let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| { - Error::WaitRespError(format!("decode handshake response error: {:?}", e)) - })?; - - if rsp.network_secret_digrest.len() != std::mem::size_of::() { - return Err(Error::WaitRespError( - "invalid network secret digest".to_owned(), - )); - } - - return Ok(rsp); - } - - async fn wait_handshake_loop(&mut self) -> Result { - Ok(timeout(Duration::from_secs(5), async move { - loop { - match self.wait_handshake().await { - Ok(rsp) => return rsp, - Err(e) => { - log::warn!("wait handshake error: {:?}", e); - } - } - } - }) - .map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e))) - .await?) - } - - async fn send_handshake(&mut self) -> Result<(), Error> { - let network = self.global_ctx.get_network_identity(); - let mut req = HandshakeRequest { - magic: MAGIC, - my_peer_id: self.my_peer_id, - version: VERSION, - features: Vec::new(), - network_name: network.network_name.clone(), - ..Default::default() - }; - req.network_secret_digrest - .extend_from_slice(&network.network_secret_digest.unwrap_or_default()); - - let hs_req = req.encode_to_vec(); - let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); - zc_packet.fill_peer_manager_hdr( - self.my_peer_id, - PeerId::default(), - PacketType::HandShake as u8, - ); - - self.sink.send(zc_packet).await.map_err(|e| { - tracing::warn!("send handshake request error: {:?}", e); - Error::WaitRespError("send handshake request error".to_owned()) - })?; - - Ok(()) - } - - #[tracing::instrument] - pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> { - let rsp = self.wait_handshake_loop().await?; - tracing::info!("handshake request: {:?}", rsp); - self.info = Some(rsp); - self.send_handshake().await?; - Ok(()) - } - - #[tracing::instrument] - pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> { - self.send_handshake().await?; - tracing::info!("waiting for handshake request from server"); - let rsp = self.wait_handshake_loop().await?; - tracing::info!("handshake response: {:?}", rsp); - self.info = Some(rsp); - Ok(()) - } - - pub fn handshake_done(&self) -> bool { - self.info.is_some() - } - - pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { - let mut stream = self.recv.lock().await.take().unwrap(); - let sink = self.sink.clone(); - let mut sender = PollSender::new(packet_recv_chan.clone()); - let close_event_sender = self.close_event_sender.clone().unwrap(); - let conn_id = self.conn_id; - let ctrl_sender = self.ctrl_resp_sender.clone(); - let _conn_info = self.get_conn_info(); - let conn_info_for_instrument = self.get_conn_info(); - - self.tasks.spawn( - async move { - tracing::info!("start recving peer conn packet"); - let mut task_ret = Ok(()); - while let Some(ret) = stream.next().await { - if ret.is_err() { - tracing::error!(error = ?ret, "peer conn recv error"); - task_ret = Err(ret.err().unwrap()); - break; - } - - let mut zc_packet = ret.unwrap(); - let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else { - tracing::error!( - "unexpected packet: {:?}, cannot decode peer manager hdr", - zc_packet - ); - continue; - }; - - if peer_mgr_hdr.packet_type == PacketType::Ping as u8 { - peer_mgr_hdr.packet_type = PacketType::Pong as u8; - if let Err(e) = sink.send(zc_packet).await { - tracing::error!(?e, "peer conn send req error"); - } - } else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 { - if let Err(e) = ctrl_sender.send(zc_packet) { - tracing::error!(?e, "peer conn send ctrl resp error"); - } - } else { - if sender.send(zc_packet).await.is_err() { - break; - } - } - } - - tracing::info!("end recving peer conn packet"); - - drop(sink); - if let Err(e) = close_event_sender.send(conn_id).await { - tracing::error!(error = ?e, "peer conn close event send error"); - } - - task_ret - } - .instrument( - tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), - ), - ); - } - - pub fn start_pingpong(&mut self) { - let mut pingpong = PeerConnPinger::new( - self.my_peer_id, - self.get_peer_id(), - self.sink.clone(), - self.ctrl_resp_sender.clone(), - self.latency_stats.clone(), - self.loss_rate_stats.clone(), - ); - - let close_event_sender = self.close_event_sender.clone().unwrap(); - let conn_id = self.conn_id; - - self.tasks.spawn(async move { - pingpong.pingpong().await; - - tracing::warn!(?pingpong, "pingpong task exit"); - - if let Err(e) = close_event_sender.send(conn_id).await { - log::warn!("close event sender error: {:?}", e); - } - - Ok(()) - }); - } - - pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { - Ok(self.sink.send(msg).await?) - } - - pub fn get_peer_id(&self) -> PeerId { - self.info.as_ref().unwrap().my_peer_id - } - - pub fn get_network_identity(&self) -> NetworkIdentity { - let info = self.info.as_ref().unwrap(); - let mut ret = NetworkIdentity { - network_name: info.network_name.clone(), - ..Default::default() - }; - ret.network_secret_digest = Some([0u8; 32]); - ret.network_secret_digest - .as_mut() - .unwrap() - .copy_from_slice(&info.network_secret_digrest); - ret - } - - pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { - self.close_event_sender = Some(sender); - } - - pub fn get_stats(&self) -> PeerConnStats { - PeerConnStats { - latency_us: self.latency_stats.get_latency_us(), - - tx_bytes: self.throughput.tx_bytes(), - rx_bytes: self.throughput.rx_bytes(), - - tx_packets: self.throughput.tx_packets(), - rx_packets: self.throughput.rx_packets(), - } - } - - pub fn get_conn_info(&self) -> PeerConnInfo { - PeerConnInfo { - conn_id: self.conn_id.to_string(), - my_peer_id: self.my_peer_id, - peer_id: self.get_peer_id(), - features: self.info.as_ref().unwrap().features.clone(), - tunnel: self.tunnel_info.clone(), - stats: Some(self.get_stats()), - loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32, - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use crate::common::global_ctx::tests::get_mock_global_ctx; - use crate::common::new_peer_id; - use crate::tunnel::filter::tests::DropSendTunnelFilter; - use crate::tunnel::filter::PacketRecorderTunnelFilter; - use crate::tunnel::ring::create_ring_tunnel_pair; - - #[tokio::test] - async fn peer_conn_handshake() { - let (c, s) = create_ring_tunnel_pair(); - - let c_recorder = Arc::new(PacketRecorderTunnelFilter::new()); - let s_recorder = Arc::new(PacketRecorderTunnelFilter::new()); - - let c = TunnelWithFilter::new(c, c_recorder.clone()); - let s = TunnelWithFilter::new(s, s_recorder.clone()); - - let c_peer_id = new_peer_id(); - let s_peer_id = new_peer_id(); - - let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c)); - - let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s)); - - let (c_ret, s_ret) = tokio::join!( - c_peer.do_handshake_as_client(), - s_peer.do_handshake_as_server() - ); - - c_ret.unwrap(); - s_ret.unwrap(); - - assert_eq!(c_recorder.sent.lock().unwrap().len(), 1); - assert_eq!(c_recorder.received.lock().unwrap().len(), 1); - - assert_eq!(s_recorder.sent.lock().unwrap().len(), 1); - assert_eq!(s_recorder.received.lock().unwrap().len(), 1); - - assert_eq!(c_peer.get_peer_id(), s_peer_id); - assert_eq!(s_peer.get_peer_id(), c_peer_id); - assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity()); - assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default()); - } - - async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) { - let (c, s) = create_ring_tunnel_pair(); - - // drop 1-3 packets should not affect pingpong - let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end)); - let c = TunnelWithFilter::new(c, c_recorder.clone()); - - let c_peer_id = new_peer_id(); - let s_peer_id = new_peer_id(); - - let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c)); - let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s)); - - let (c_ret, s_ret) = tokio::join!( - c_peer.do_handshake_as_client(), - s_peer.do_handshake_as_server() - ); - - s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0); - s_peer - .start_recv_loop(tokio::sync::mpsc::channel(200).0) - .await; - - assert!(c_ret.is_ok()); - assert!(s_ret.is_ok()); - - let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1); - c_peer.set_close_event_sender(close_send); - c_peer.start_pingpong(); - c_peer - .start_recv_loop(tokio::sync::mpsc::channel(200).0) - .await; - - // wait 5s, conn should not be disconnected - tokio::time::sleep(Duration::from_secs(15)).await; - - if conn_closed { - assert!(close_recv.try_recv().is_ok()); - } else { - assert!(close_recv.try_recv().is_err()); - } - } - - #[tokio::test] - async fn peer_conn_pingpong_timeout() { - peer_conn_pingpong_test_common(3, 5, false).await; - peer_conn_pingpong_test_common(5, 12, true).await; - } -} - -/* -use std::{ - fmt::Debug, - pin::Pin, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, -}; - -use futures::{SinkExt, StreamExt}; -use pnet::datalink::NetworkInterface; - -use tokio::{ - sync::{broadcast, mpsc, Mutex}, - task::JoinSet, - time::{timeout, Duration}, -}; - -use tokio_util::{bytes::Bytes, sync::PollSender}; -use tracing::Instrument; - -use crate::{ - common::{ - error::Error, - global_ctx::{ArcGlobalCtx, NetworkIdentity}, - PeerId, - }, - define_tunnel_filter_chain, - peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType}, - rpc::{PeerConnInfo, PeerConnStats}, - tunnel::{mpsc::MpscTunnelSender, stats::WindowLatency, TunnelError}, -}; - -use super::packet::{self, HandShake, Packet}; - -pub type PacketRecvChan = mpsc::Sender; - -macro_rules! wait_response { - ($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => { - let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else { - return Err(Error::WaitRespError( - "wait handshake response timeout".to_owned(), - )); - }; - let Some(rsp_vec) = rsp_vec else { - return Err(Error::WaitRespError( - "wait handshake response get none".to_owned(), - )); - }; - let Ok(rsp_vec) = rsp_vec else { - return Err(Error::WaitRespError(format!( - "wait handshake response get error {}", - rsp_vec.err().unwrap() - ))); - }; - - let $out_var; - let rsp_bytes = Packet::decode(&rsp_vec); - if rsp_bytes.packet_type != PacketType::HandShake { - tracing::error!("unexpected packet type: {:?}", rsp_bytes); - return Err(Error::WaitRespError("unexpected packet type".to_owned())); - } - let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes); - match &resp_payload { - $pattern => $out_var = $value, - _ => { - tracing::error!( - "unexpected packet: {:?}, pattern: {:?}", - rsp_bytes, - stringify!($pattern) - ); - return Err(Error::WaitRespError("unexpected packet".to_owned())); - } - } - }; -} - -impl<'a> From<&HandShake> for PeerInfo { - fn from(hs: &HandShake) -> Self { - PeerInfo { - magic: hs.magic.into(), - my_peer_id: hs.my_peer_id.into(), - version: hs.version.into(), - features: hs.features.iter().map(|x| x.to_string()).collect(), - interfaces: Vec::new(), - network_identity: hs.network_identity.clone(), - } - } -} - - -define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter); - -pub struct PeerConn { - conn_id: PeerConnId, - - my_peer_id: PeerId, - global_ctx: ArcGlobalCtx, - - sink: Pin>, - tunnel: Box, - - tasks: JoinSet>, - - info: Option, - - close_event_sender: Option>, - - ctrl_resp_sender: broadcast::Sender, - - latency_stats: Arc, - throughput: Arc, - loss_rate_stats: Arc, -} - -enum PeerConnPacketType { - Data(Bytes), - CtrlReq(Bytes), - CtrlResp(Bytes), -} - -impl PeerConn { - pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { - let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); - let peer_conn_tunnel = PeerConnTunnel::new(); - let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel); - - PeerConn { - conn_id: PeerConnId::new_v4(), - - my_peer_id, - global_ctx, - - sink: tunnel.pin_sink(), - tunnel: Box::new(tunnel), - - tasks: JoinSet::new(), - - info: None, - close_event_sender: None, - - ctrl_resp_sender: ctrl_sender, - - latency_stats: Arc::new(WindowLatency::new(15)), - throughput: peer_conn_tunnel.stats.get_throughput().clone(), - loss_rate_stats: Arc::new(AtomicU32::new(0)), - } - } - - pub fn get_conn_id(&self) -> PeerConnId { - self.conn_id - } - - #[tracing::instrument] - pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); - - tracing::info!("waiting for handshake request from client"); - wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x); - self.info = Some(PeerInfo::from(hs_req)); - tracing::info!("handshake request: {:?}", hs_req); - - let hs_req = self - .global_ctx - .net_ns - .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); - sink.send(hs_req.into()).await?; - - Ok(()) - } - - #[tracing::instrument] - pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); - - let hs_req = self - .global_ctx - .net_ns - .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); - sink.send(hs_req.into()).await?; - - tracing::info!("waiting for handshake request from server"); - wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x); - self.info = Some(PeerInfo::from(hs_rsp)); - tracing::info!("handshake response: {:?}", hs_rsp); - - Ok(()) - } - - pub fn handshake_done(&self) -> bool { - self.info.is_some() - } - - pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { - let mut stream = self.tunnel.pin_stream(); - let mut sink = self.tunnel.pin_sink(); - let mut sender = PollSender::new(packet_recv_chan.clone()); - let close_event_sender = self.close_event_sender.clone().unwrap(); - let conn_id = self.conn_id; - let ctrl_sender = self.ctrl_resp_sender.clone(); - let conn_info = self.get_conn_info(); - let conn_info_for_instrument = self.get_conn_info(); - - self.tasks.spawn( - async move { - tracing::info!("start recving peer conn packet"); - let mut task_ret = Ok(()); - while let Some(ret) = stream.next().await { - if ret.is_err() { - tracing::error!(error = ?ret, "peer conn recv error"); - task_ret = Err(ret.err().unwrap()); - break; - } - - let buf = ret.unwrap(); - let p = Packet::decode(&buf); - match p.packet_type { - ArchivedPacketType::Ping => { - let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p) - else { - log::error!("unexpected packet: {:?}", p); - continue; - }; - - let pong = packet::Packet::new_pong_packet( - conn_info.my_peer_id, - conn_info.peer_id, - seq.into(), - ); - - if let Err(e) = sink.send(pong.into()).await { - tracing::error!(?e, "peer conn send req error"); - } - } - ArchivedPacketType::Pong => { - if let Err(e) = ctrl_sender.send(buf.into()) { - tracing::error!(?e, "peer conn send ctrl resp error"); - } - } - _ => { - if sender.send(buf.into()).await.is_err() { - break; - } - } - } - } - - tracing::info!("end recving peer conn packet"); - - if let Err(close_ret) = sink.close().await { - tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it"); - } - if let Err(e) = close_event_sender.send(conn_id).await { - tracing::error!(error = ?e, "peer conn close event send error"); - } - - task_ret - } - .instrument( - tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), - ), - ); - } - - pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), Error> { - self.sink.send(msg).await - } - - pub fn get_peer_id(&self) -> PeerId { - self.info.as_ref().unwrap().my_peer_id - } - - pub fn get_network_identity(&self) -> NetworkIdentity { - self.info.as_ref().unwrap().network_identity.clone() - } - - pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { - self.close_event_sender = Some(sender); - } - - pub fn get_stats(&self) -> PeerConnStats { - PeerConnStats { - latency_us: self.latency_stats.get_latency_us(), - - tx_bytes: self.throughput.tx_bytes(), - rx_bytes: self.throughput.rx_bytes(), - - tx_packets: self.throughput.tx_packets(), - rx_packets: self.throughput.rx_packets(), - } - } - - pub fn get_conn_info(&self) -> PeerConnInfo { - PeerConnInfo { - conn_id: self.conn_id.to_string(), - my_peer_id: self.my_peer_id, - peer_id: self.get_peer_id(), - features: self.info.as_ref().unwrap().features.clone(), - tunnel: self.tunnel.info(), - stats: Some(self.get_stats()), - loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32, - } - } -} - -impl Drop for PeerConn { - fn drop(&mut self) { - let mut sink = self.tunnel.pin_sink(); - tokio::spawn(async move { - let ret = sink.close().await; - tracing::info!(error = ?ret, "peer conn tunnel closed."); - }); - log::info!("peer conn {:?} drop", self.conn_id); - } -} - -} - */ diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index bede853..00bbf53 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -626,7 +626,10 @@ impl UdpTunnelConnector { ) .await??; - socket.connect(recv_addr).await?; + if recv_addr != addr { + tracing::debug!(?recv_addr, ?addr, "udp connect addr not match"); + } + self.build_tunnel(socket, addr, conn_id).await } diff --git a/easytier/src/tunnels/codec.rs b/easytier/src/tunnels/codec.rs deleted file mode 100644 index 4f91196..0000000 --- a/easytier/src/tunnels/codec.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::result::Result; -use tokio::io; -use tokio_util::{ - bytes::{BufMut, Bytes, BytesMut}, - codec::{Decoder, Encoder}, -}; - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] -pub struct BytesCodec { - capacity: usize, -} - -impl BytesCodec { - /// Creates a new `BytesCodec` for shipping around raw bytes. - pub fn new(capacity: usize) -> BytesCodec { - BytesCodec { capacity } - } -} - -impl Decoder for BytesCodec { - type Item = BytesMut; - type Error = io::Error; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { - if !buf.is_empty() { - let len = buf.len(); - let ret = Some(buf.split_to(len)); - buf.reserve(self.capacity); - Ok(ret) - } else { - Ok(None) - } - } -} - -impl Encoder for BytesCodec { - type Error = io::Error; - - fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { - buf.reserve(data.len()); - buf.put(data); - Ok(()) - } -} - -impl Encoder for BytesCodec { - type Error = io::Error; - - fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> { - buf.reserve(data.len()); - buf.put(data); - Ok(()) - } -} diff --git a/easytier/src/tunnels/common.rs b/easytier/src/tunnels/common.rs deleted file mode 100644 index 56f8908..0000000 --- a/easytier/src/tunnels/common.rs +++ /dev/null @@ -1,483 +0,0 @@ -use std::{ - collections::VecDeque, - net::{IpAddr, SocketAddr}, - sync::Arc, - task::{ready, Context, Poll}, -}; - -use async_stream::stream; -use futures::{stream::FuturesUnordered, Future, FutureExt, Sink, SinkExt, Stream, StreamExt}; -use network_interface::NetworkInterfaceConfig; -use tokio::{sync::Mutex, time::error::Elapsed}; - -use std::pin::Pin; - -use crate::tunnels::{SinkError, TunnelError}; - -use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo}; - -pub struct FramedTunnel { - read: Arc>, - write: Arc>, - - info: Option, -} - -impl FramedTunnel -where - R: Stream> + Send + Sync + Unpin + 'static, - W: Sink + Send + Sync + Unpin + 'static, - RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static, - WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From, -{ - pub fn new(read: R, write: W, info: Option) -> Self { - FramedTunnel { - read: Arc::new(Mutex::new(read)), - write: Arc::new(Mutex::new(write)), - info, - } - } - - pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box { - Box::new(FramedTunnel::new(read, write, Some(info))) - } - - pub fn recv_stream(&self) -> impl DatagramStream { - let read = self.read.clone(); - let info = self.info.clone(); - stream! { - loop { - let read_ret = read.lock().await.next().await; - if read_ret.is_none() { - tracing::info!(?info, "read_ret is none"); - yield Err(TunnelError::CommonError("recv stream closed".to_string())); - } else { - let read_ret = read_ret.unwrap(); - if read_ret.is_err() { - let err = read_ret.err().unwrap(); - tracing::info!(?info, "recv stream read error"); - yield Err(TunnelError::CommonError(err.to_string())); - } else { - yield Ok(read_ret.unwrap()); - } - } - } - } - } - - pub fn send_sink(&self) -> impl DatagramSink { - struct SendSink { - write: Arc>, - max_buffer_size: usize, - sending_buffers: Option>, - send_task: - Option> + Send + Sync + 'static>>>, - close_task: - Option> + Send + Sync + 'static>>>, - } - - impl SendSink - where - W: Sink + Send + Sync + Unpin + 'static, - WE: std::error::Error + std::fmt::Debug + Send + Sync + From, - { - fn try_send_buffser( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - if self.send_task.is_none() { - let mut buffers = self.sending_buffers.take().unwrap(); - let tun = self.write.clone(); - let send_task = async move { - if buffers.is_empty() { - return Ok(()); - } - - let mut locked_tun = tun.lock_owned().await; - while let Some(buf) = buffers.front() { - log::trace!( - "try_send buffer, len: {:?}, buf: {:?}", - buffers.len(), - &buf - ); - let timeout_task = tokio::time::timeout( - std::time::Duration::from_secs(1), - locked_tun.send(buf.clone()), - ); - let send_res = timeout_task.await; - let Ok(send_res) = send_res else { - // panic!("send timeout"); - let err = send_res.err().unwrap(); - return Err(err.into()); - }; - let Ok(_) = send_res else { - let err = send_res.err().unwrap(); - println!("send error: {:?}", err); - return Err(err); - }; - buffers.pop_front(); - } - return Ok(()); - }; - self.send_task = Some(Box::pin(send_task)); - } - - let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx)); - self.send_task = None; - self.sending_buffers = Some(VecDeque::new()); - return Poll::Ready(ret); - } - } - - impl Sink for SendSink - where - W: Sink + Send + Sync + Unpin + 'static, - WE: std::error::Error + std::fmt::Debug + Send + Sync + From, - { - type Error = SinkError; - - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let self_mut = self.get_mut(); - let sending_buf = self_mut.sending_buffers.as_ref(); - // if sending_buffers is None, must already be doing flush - if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size { - return self_mut.poll_flush_unpin(cx); - } else { - return Poll::Ready(Ok(())); - } - } - - fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { - assert!(self.send_task.is_none()); - let self_mut = self.get_mut(); - self_mut.sending_buffers.as_mut().unwrap().push_back(item); - Ok(()) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let self_mut = self.get_mut(); - let ret = self_mut.try_send_buffser(cx); - match ret { - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))), - Poll::Pending => { - return Poll::Pending; - } - } - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let self_mut = self.get_mut(); - if self_mut.close_task.is_none() { - let tun = self_mut.write.clone(); - let close_task = async move { - let mut locked_tun = tun.lock_owned().await; - return locked_tun.close().await; - }; - self_mut.close_task = Some(Box::pin(close_task)); - } - - let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx)); - self_mut.close_task = None; - - if ret.is_err() { - return Poll::Ready(Err(SinkError::CommonError( - ret.err().unwrap().to_string(), - ))); - } else { - return Poll::Ready(Ok(())); - } - } - } - - SendSink { - write: self.write.clone(), - max_buffer_size: 1000, - sending_buffers: Some(VecDeque::new()), - send_task: None, - close_task: None, - } - } -} - -impl Tunnel for FramedTunnel -where - R: Stream> + Send + Sync + Unpin + 'static, - W: Sink + Send + Sync + Unpin + 'static, - RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static, - WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From, -{ - fn stream(&self) -> Box { - Box::new(self.recv_stream()) - } - - fn sink(&self) -> Box { - Box::new(self.send_sink()) - } - - fn info(&self) -> Option { - if self.info.is_none() { - None - } else { - Some(self.info.clone().unwrap()) - } - } -} - -pub struct TunnelWithCustomInfo { - tunnel: Box, - info: TunnelInfo, -} - -impl TunnelWithCustomInfo { - pub fn new(tunnel: Box, info: TunnelInfo) -> Self { - TunnelWithCustomInfo { tunnel, info } - } -} - -impl Tunnel for TunnelWithCustomInfo { - fn stream(&self) -> Box { - self.tunnel.stream() - } - - fn sink(&self) -> Box { - self.tunnel.sink() - } - - fn info(&self) -> Option { - Some(self.info.clone()) - } -} - -pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option { - if local_ip.is_unspecified() || local_ip.is_multicast() { - return None; - } - let ifaces = network_interface::NetworkInterface::show().ok()?; - for iface in ifaces { - for addr in iface.addr { - if addr.ip() == *local_ip { - return Some(iface.name); - } - } - } - - tracing::error!(?local_ip, "can not find interface name by ip"); - None -} - -pub(crate) fn setup_sokcet2_ext( - socket2_socket: &socket2::Socket, - bind_addr: &SocketAddr, - bind_dev: Option, -) -> Result<(), TunnelError> { - #[cfg(target_os = "windows")] - { - let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM); - crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?; - } - - socket2_socket.set_nonblocking(true)?; - socket2_socket.set_reuse_address(true)?; - socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?; - - // #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] - // socket2_socket.set_reuse_port(true)?; - - if bind_addr.ip().is_unspecified() { - return Ok(()); - } - - // linux/mac does not use interface of bind_addr to send packet, so we need to bind device - // win can handle this with bind correctly - #[cfg(any(target_os = "ios", target_os = "macos"))] - if let Some(dev_name) = bind_dev { - // use IP_BOUND_IF to bind device - unsafe { - let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8); - tracing::warn!(?dev_idx, ?dev_name, "bind device"); - socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?; - tracing::warn!(?dev_idx, ?dev_name, "bind device doen"); - } - } - - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - if let Some(dev_name) = bind_dev { - tracing::trace!(dev_name = ?dev_name, "bind device"); - socket2_socket.bind_device(Some(dev_name.as_bytes()))?; - } - - Ok(()) -} - -pub(crate) async fn wait_for_connect_futures( - mut futures: FuturesUnordered, -) -> Result -where - Fut: Future> + Send + Sync, - E: std::error::Error + Into + Send + Sync + 'static, -{ - // return last error - let mut last_err = None; - - while let Some(ret) = futures.next().await { - if let Err(e) = ret { - last_err = Some(e.into()); - } else { - return ret.map_err(|e| e.into()); - } - } - - Err(last_err.unwrap_or(super::TunnelError::CommonError( - "no connect futures".to_string(), - ))) -} - -pub(crate) fn setup_sokcet2( - socket2_socket: &socket2::Socket, - bind_addr: &SocketAddr, -) -> Result<(), TunnelError> { - setup_sokcet2_ext( - socket2_socket, - bind_addr, - super::common::get_interface_name_by_ip(&bind_addr.ip()), - ) -} - -pub mod tests { - use std::time::Instant; - - use futures::SinkExt; - use tokio_stream::StreamExt; - use tokio_util::bytes::{BufMut, Bytes, BytesMut}; - - use crate::{ - common::netns::NetNS, - tunnels::{close_tunnel, TunnelConnector, TunnelListener}, - }; - - pub async fn _tunnel_echo_server(tunnel: Box, once: bool) { - let mut recv = Box::into_pin(tunnel.stream()); - let mut send = Box::into_pin(tunnel.sink()); - - while let Some(ret) = recv.next().await { - if ret.is_err() { - log::trace!("recv error: {:?}", ret.err().unwrap()); - break; - } - let res = ret.unwrap(); - log::trace!("recv a msg, try echo back: {:?}", res); - send.send(Bytes::from(res)).await.unwrap(); - if once { - break; - } - } - log::warn!("echo server exit..."); - } - - pub(crate) async fn _tunnel_pingpong(listener: L, connector: C) - where - L: TunnelListener + Send + Sync + 'static, - C: TunnelConnector + Send + Sync + 'static, - { - _tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await - } - - pub(crate) async fn _tunnel_pingpong_netns( - mut listener: L, - mut connector: C, - l_netns: NetNS, - c_netns: NetNS, - ) where - L: TunnelListener + Send + Sync + 'static, - C: TunnelConnector + Send + Sync + 'static, - { - l_netns - .run_async(|| async { - listener.listen().await.unwrap(); - }) - .await; - - let lis = tokio::spawn(async move { - let ret = listener.accept().await.unwrap(); - assert_eq!( - ret.info().unwrap().local_addr, - listener.local_url().to_string() - ); - _tunnel_echo_server(ret, false).await - }); - - let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap(); - - assert_eq!( - tunnel.info().unwrap().remote_addr, - connector.remote_url().to_string() - ); - - let mut send = tunnel.pin_sink(); - let mut recv = tunnel.pin_stream(); - let send_data = Bytes::from("12345678abcdefg"); - send.send(send_data).await.unwrap(); - let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next()) - .await - .unwrap() - .unwrap() - .unwrap(); - println!("echo back: {:?}", ret); - assert_eq!(ret, Bytes::from("12345678abcdefg")); - - close_tunnel(&tunnel).await.unwrap(); - - if ["udp", "wg"].contains(&connector.remote_url().scheme()) { - lis.abort(); - } else { - // lis should finish in 1 second - let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await; - assert!(ret.is_ok()); - } - } - - pub(crate) async fn _tunnel_bench(mut listener: L, mut connector: C) - where - L: TunnelListener + Send + Sync + 'static, - C: TunnelConnector + Send + Sync + 'static, - { - listener.listen().await.unwrap(); - - let lis = tokio::spawn(async move { - let ret = listener.accept().await.unwrap(); - _tunnel_echo_server(ret, false).await - }); - - let tunnel = connector.connect().await.unwrap(); - - let mut send = tunnel.pin_sink(); - let mut recv = tunnel.pin_stream(); - - // prepare a 4k buffer with random data - let mut send_buf = BytesMut::new(); - for _ in 0..64 { - send_buf.put_i128(rand::random::()); - } - - let now = Instant::now(); - let mut count = 0; - while now.elapsed().as_secs() < 3 { - send.send(send_buf.clone().freeze()).await.unwrap(); - let _ = recv.next().await.unwrap().unwrap(); - count += 1; - } - println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs()); - - lis.abort(); - } -} diff --git a/easytier/src/tunnels/mod.rs b/easytier/src/tunnels/mod.rs deleted file mode 100644 index ed51173..0000000 --- a/easytier/src/tunnels/mod.rs +++ /dev/null @@ -1,192 +0,0 @@ -pub mod codec; -pub mod common; -// pub mod ring_tunnel; -// pub mod stats; -// pub mod tcp_tunnel; -// pub mod tunnel_filter; -// pub mod udp_tunnel; -// pub mod wireguard; - -use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc}; - -use crate::rpc::TunnelInfo; -use async_trait::async_trait; -use futures::{Sink, SinkExt, Stream}; - -use thiserror::Error; -use tokio_util::bytes::{Bytes, BytesMut}; - -#[derive(Error, Debug)] -pub enum TunnelError { - #[error("Error: {0}")] - CommonError(String), - #[error("io error")] - IOError(#[from] std::io::Error), - #[error("wait resp error {0}")] - WaitRespError(String), - #[error("Connect Error: {0}")] - ConnectError(String), - #[error("Invalid Protocol: {0}")] - InvalidProtocol(String), - #[error("Invalid Addr: {0}")] - InvalidAddr(String), - #[error("Tun Error: {0}")] - TunError(String), - #[error("timeout")] - Timeout(#[from] tokio::time::error::Elapsed), -} - -pub type StreamT = BytesMut; -pub type StreamItem = Result; -pub type SinkItem = Bytes; -pub type SinkError = TunnelError; - -pub trait DatagramStream: Stream + Send + Sync {} -impl DatagramStream for T where T: Stream + Send + Sync {} -pub trait DatagramSink: Sink + Send + Sync {} -impl DatagramSink for T where T: Sink + Send + Sync {} - -#[auto_impl::auto_impl(Box, Arc)] -pub trait Tunnel: Send + Sync { - fn stream(&self) -> Box; - fn sink(&self) -> Box; - - fn pin_stream(&self) -> Pin> { - Box::into_pin(self.stream()) - } - - fn pin_sink(&self) -> Pin> { - Box::into_pin(self.sink()) - } - - fn info(&self) -> Option; -} - -pub async fn close_tunnel(t: &Box) -> Result<(), TunnelError> { - t.pin_sink().close().await -} - -#[auto_impl::auto_impl(Arc)] -pub trait TunnelConnCounter: 'static + Send + Sync + Debug { - fn get(&self) -> u32; -} - -#[async_trait] -#[auto_impl::auto_impl(Box)] -pub trait TunnelListener: Send + Sync { - async fn listen(&mut self) -> Result<(), TunnelError>; - async fn accept(&mut self) -> Result, TunnelError>; - fn local_url(&self) -> url::Url; - fn get_conn_counter(&self) -> Arc> { - #[derive(Debug)] - struct FakeTunnelConnCounter {} - impl TunnelConnCounter for FakeTunnelConnCounter { - fn get(&self) -> u32 { - 0 - } - } - Arc::new(Box::new(FakeTunnelConnCounter {})) - } -} - -#[async_trait] -#[auto_impl::auto_impl(Box)] -pub trait TunnelConnector { - async fn connect(&mut self) -> Result, TunnelError>; - fn remote_url(&self) -> url::Url; - fn set_bind_addrs(&mut self, _addrs: Vec) {} -} - -pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url { - url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap() -} - -impl std::fmt::Debug for dyn Tunnel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Tunnel") - .field("info", &self.info()) - .finish() - } -} - -impl std::fmt::Debug for dyn TunnelConnector + Sync + Send { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TunnelConnector") - .field("remote_url", &self.remote_url()) - .finish() - } -} - -impl std::fmt::Debug for dyn TunnelListener { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TunnelListener") - .field("local_url", &self.local_url()) - .finish() - } -} - -pub(crate) trait FromUrl { - fn from_url(url: url::Url) -> Result - where - Self: Sized; -} - -pub(crate) fn check_scheme_and_get_socket_addr( - url: &url::Url, - scheme: &str, -) -> Result -where - T: FromUrl, -{ - if url.scheme() != scheme { - return Err(TunnelError::InvalidProtocol(url.scheme().to_string())); - } - - Ok(T::from_url(url.clone())?) -} - -impl FromUrl for SocketAddr { - fn from_url(url: url::Url) -> Result { - Ok(url.socket_addrs(|| None)?.pop().unwrap()) - } -} - -impl FromUrl for uuid::Uuid { - fn from_url(url: url::Url) -> Result { - let o = url.host_str().unwrap(); - let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?; - Ok(o) - } -} - -pub struct TunnelUrl { - inner: url::Url, -} - -impl From for TunnelUrl { - fn from(url: url::Url) -> Self { - TunnelUrl { inner: url } - } -} - -impl From for url::Url { - fn from(url: TunnelUrl) -> Self { - url.into_inner() - } -} - -impl TunnelUrl { - pub fn into_inner(self) -> url::Url { - self.inner - } - - pub fn bind_dev(&self) -> Option { - self.inner.path().strip_prefix("/").and_then(|s| { - if s.is_empty() { - None - } else { - Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap()) - } - }) - } -} diff --git a/easytier/src/tunnels/ring_tunnel.rs b/easytier/src/tunnels/ring_tunnel.rs deleted file mode 100644 index 83f85fc..0000000 --- a/easytier/src/tunnels/ring_tunnel.rs +++ /dev/null @@ -1,436 +0,0 @@ -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, AtomicU32}, - Arc, - }, - task::Poll, -}; - -use async_stream::stream; -use crossbeam_queue::ArrayQueue; - -use async_trait::async_trait; -use futures::Sink; -use once_cell::sync::Lazy; -use tokio::sync::{ - mpsc::{UnboundedReceiver, UnboundedSender}, - Mutex, Notify, -}; - -use futures::FutureExt; -use tokio_util::bytes::BytesMut; -use uuid::Uuid; - -use crate::tunnels::{SinkError, SinkItem}; - -use super::{ - build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel, - DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, -}; - -static RING_TUNNEL_CAP: usize = 1000; - -struct Ring { - id: Uuid, - ring: ArrayQueue, - consume_notify: Notify, - produce_notify: Notify, - closed: AtomicBool, -} - -impl Ring { - fn new(cap: usize, id: uuid::Uuid) -> Self { - Self { - id, - ring: ArrayQueue::new(cap), - consume_notify: Notify::new(), - produce_notify: Notify::new(), - closed: AtomicBool::new(false), - } - } - - fn close(&self) { - self.closed - .store(true, std::sync::atomic::Ordering::Relaxed); - self.produce_notify.notify_one(); - } - - fn closed(&self) -> bool { - self.closed.load(std::sync::atomic::Ordering::Relaxed) - } -} - -pub struct RingTunnel { - id: Uuid, - ring: Arc, - sender_counter: Arc, -} - -impl RingTunnel { - pub fn new(cap: usize) -> Self { - let id = Uuid::new_v4(); - RingTunnel { - id: id.clone(), - ring: Arc::new(Ring::new(cap, id)), - sender_counter: Arc::new(AtomicU32::new(1)), - } - } - - pub fn new_with_id(id: Uuid, cap: usize) -> Self { - let mut ret = Self::new(cap); - ret.id = id; - ret - } - - fn recv_stream(&self) -> impl DatagramStream { - let ring = self.ring.clone(); - let id = self.id; - stream! { - loop { - match ring.ring.pop() { - Some(v) => { - let mut out = BytesMut::new(); - out.extend_from_slice(&v); - ring.consume_notify.notify_one(); - log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v); - yield Ok(out); - }, - None => { - if ring.closed() { - log::warn!("ring recv tunnel {:?} closed", id); - yield Err(TunnelError::CommonError("ring closed".to_owned())); - } - log::trace!("waiting recv buffer, id: {}", id); - ring.produce_notify.notified().await; - } - } - } - } - } - - fn send_sink(&self) -> impl DatagramSink { - let ring = self.ring.clone(); - let sender_counter = self.sender_counter.clone(); - use tokio::task::JoinHandle; - - struct T { - ring: Arc, - wait_consume_task: Option>, - sender_counter: Arc, - } - - impl T { - fn wait_ring_consume( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - expected_size: usize, - ) -> std::task::Poll<()> { - let self_mut = self.get_mut(); - if self_mut.ring.ring.len() <= expected_size { - return Poll::Ready(()); - } - if self_mut.wait_consume_task.is_none() { - let id = self_mut.ring.id; - let ring = self_mut.ring.clone(); - let task = async move { - log::trace!( - "waiting ring consume done, expected_size: {}, id: {}", - expected_size, - id - ); - while ring.ring.len() > expected_size { - ring.consume_notify.notified().await; - } - log::trace!( - "ring consume done, expected_size: {}, id: {}", - expected_size, - id - ); - }; - self_mut.wait_consume_task = Some(tokio::spawn(task)); - } - let task = self_mut.wait_consume_task.as_mut().unwrap(); - match task.poll_unpin(cx) { - Poll::Ready(_) => { - self_mut.wait_consume_task = None; - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - } - } - - impl Sink for T { - type Error = SinkError; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.ring.closed() { - return Poll::Ready(Err(TunnelError::CommonError( - "ring closed during ready".to_owned(), - ) - .into())); - } - let expected_size = self.ring.ring.capacity() - 1; - match self.wait_ring_consume(cx, expected_size) { - Poll::Ready(_) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, - } - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: SinkItem, - ) -> Result<(), Self::Error> { - if self.ring.closed() { - return Err( - TunnelError::CommonError("ring closed during send".to_owned()).into(), - ); - } - log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item); - self.ring.ring.push(item).unwrap(); - self.ring.produce_notify.notify_one(); - Ok(()) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.ring.closed() { - return Poll::Ready(Err(TunnelError::CommonError( - "ring closed during flush".to_owned(), - ) - .into())); - } - Poll::Ready(Ok(())) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.ring.close(); - Poll::Ready(Ok(())) - } - } - - impl Drop for T { - fn drop(&mut self) { - let rem = self - .sender_counter - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); - if rem == 1 { - self.ring.close() - } - } - } - - sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - T { - ring, - wait_consume_task: None, - sender_counter, - } - } -} - -impl Drop for RingTunnel { - fn drop(&mut self) { - let rem = self - .sender_counter - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); - if rem == 1 { - self.ring.close() - } - } -} - -struct Connection { - client: RingTunnel, - server: RingTunnel, -} - -impl Tunnel for RingTunnel { - fn stream(&self) -> Box { - Box::new(self.recv_stream()) - } - - fn sink(&self) -> Box { - Box::new(self.send_sink()) - } - - fn info(&self) -> Option { - None - } -} - -static CONNECTION_MAP: Lazy>>>>> = - Lazy::new(|| Arc::new(Mutex::new(HashMap::new()))); - -#[derive(Debug)] -pub struct RingTunnelListener { - listerner_addr: url::Url, - conn_sender: UnboundedSender>, - conn_receiver: UnboundedReceiver>, -} - -impl RingTunnelListener { - pub fn new(key: url::Url) -> Self { - let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel(); - RingTunnelListener { - listerner_addr: key, - conn_sender, - conn_receiver, - } - } -} - -fn get_tunnel_for_client(conn: Arc) -> Box { - FramedTunnel::new_tunnel_with_info( - Box::pin(conn.client.recv_stream()), - conn.server.send_sink(), - TunnelInfo { - tunnel_type: "ring".to_owned(), - local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), - remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), - }, - ) -} - -fn get_tunnel_for_server(conn: Arc) -> Box { - FramedTunnel::new_tunnel_with_info( - Box::pin(conn.server.recv_stream()), - conn.client.send_sink(), - TunnelInfo { - tunnel_type: "ring".to_owned(), - local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), - remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), - }, - ) -} - -impl RingTunnelListener { - fn get_addr(&self) -> Result { - check_scheme_and_get_socket_addr::(&self.listerner_addr, "ring") - } -} - -#[async_trait] -impl TunnelListener for RingTunnelListener { - async fn listen(&mut self) -> Result<(), TunnelError> { - log::info!("listen new conn of key: {}", self.listerner_addr); - CONNECTION_MAP - .lock() - .await - .insert(self.get_addr()?, self.conn_sender.clone()); - Ok(()) - } - - async fn accept(&mut self) -> Result, TunnelError> { - log::info!("waiting accept new conn of key: {}", self.listerner_addr); - let my_addr = self.get_addr()?; - if let Some(conn) = self.conn_receiver.recv().await { - if conn.server.id == my_addr { - log::info!("accept new conn of key: {}", self.listerner_addr); - return Ok(get_tunnel_for_server(conn)); - } else { - tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id"); - return Err(TunnelError::CommonError( - "accept got wrong ring server id".to_owned(), - )); - } - } - - return Err(TunnelError::CommonError("conn receiver stopped".to_owned())); - } - - fn local_url(&self) -> url::Url { - self.listerner_addr.clone() - } -} - -pub struct RingTunnelConnector { - remote_addr: url::Url, -} - -impl RingTunnelConnector { - pub fn new(remote_addr: url::Url) -> Self { - RingTunnelConnector { remote_addr } - } -} - -#[async_trait] -impl TunnelConnector for RingTunnelConnector { - async fn connect(&mut self) -> Result, super::TunnelError> { - let remote_addr = check_scheme_and_get_socket_addr::(&self.remote_addr, "ring")?; - let entry = CONNECTION_MAP - .lock() - .await - .get(&remote_addr) - .unwrap() - .clone(); - log::info!("connecting"); - let conn = Arc::new(Connection { - client: RingTunnel::new(RING_TUNNEL_CAP), - server: RingTunnel::new_with_id(remote_addr.clone(), RING_TUNNEL_CAP), - }); - entry - .send(conn.clone()) - .map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?; - Ok(get_tunnel_for_client(conn)) - } - - fn remote_url(&self) -> url::Url { - self.remote_addr.clone() - } -} - -pub fn create_ring_tunnel_pair() -> (Box, Box) { - let conn = Arc::new(Connection { - client: RingTunnel::new(RING_TUNNEL_CAP), - server: RingTunnel::new(RING_TUNNEL_CAP), - }); - ( - Box::new(get_tunnel_for_server(conn.clone())), - Box::new(get_tunnel_for_client(conn)), - ) -} - -#[cfg(test)] -mod tests { - use futures::StreamExt; - - use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; - - use super::*; - - #[tokio::test] - async fn ring_pingpong() { - let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap(); - let listener = RingTunnelListener::new(id.clone()); - let connector = RingTunnelConnector::new(id.clone()); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - async fn ring_bench() { - let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap(); - let listener = RingTunnelListener::new(id.clone()); - let connector = RingTunnelConnector::new(id); - _tunnel_bench(listener, connector).await - } - - #[tokio::test] - async fn ring_close() { - let (stunnel, ctunnel) = create_ring_tunnel_pair(); - drop(stunnel); - - let mut stream = ctunnel.pin_stream(); - let ret = stream.next().await; - assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret); - } -} diff --git a/easytier/src/tunnels/stats.rs b/easytier/src/tunnels/stats.rs deleted file mode 100644 index 8e8d7a4..0000000 --- a/easytier/src/tunnels/stats.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed}; - -pub struct WindowLatency { - latency_us_window: Vec, - latency_us_window_index: AtomicU32, - latency_us_window_size: u32, - - sum: AtomicU32, - count: AtomicU32, -} - -impl WindowLatency { - pub fn new(window_size: u32) -> Self { - Self { - latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(), - latency_us_window_index: AtomicU32::new(0), - latency_us_window_size: window_size, - - sum: AtomicU32::new(0), - count: AtomicU32::new(0), - } - } - - pub fn record_latency(&self, latency_us: u32) { - let index = self.latency_us_window_index.fetch_add(1, Relaxed); - if self.count.load(Relaxed) < self.latency_us_window_size { - self.count.fetch_add(1, Relaxed); - } - - let index = index % self.latency_us_window_size; - let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed); - - if old_lat < latency_us { - self.sum.fetch_add(latency_us - old_lat, Relaxed); - } else { - self.sum.fetch_sub(old_lat - latency_us, Relaxed); - } - } - - pub fn get_latency_us + std::ops::Div>(&self) -> T { - let count = self.count.load(Relaxed); - let sum = self.sum.load(Relaxed); - if count == 0 { - 0.into() - } else { - (T::from(sum)) / T::from(count) - } - } -} - -pub struct Throughput { - tx_bytes: AtomicU64, - rx_bytes: AtomicU64, - - tx_packets: AtomicU64, - rx_packets: AtomicU64, -} - -impl Throughput { - pub fn new() -> Self { - Self { - tx_bytes: AtomicU64::new(0), - rx_bytes: AtomicU64::new(0), - - tx_packets: AtomicU64::new(0), - rx_packets: AtomicU64::new(0), - } - } - - pub fn tx_bytes(&self) -> u64 { - self.tx_bytes.load(Relaxed) - } - - pub fn rx_bytes(&self) -> u64 { - self.rx_bytes.load(Relaxed) - } - - pub fn tx_packets(&self) -> u64 { - self.tx_packets.load(Relaxed) - } - - pub fn rx_packets(&self) -> u64 { - self.rx_packets.load(Relaxed) - } - - pub fn record_tx_bytes(&self, bytes: u64) { - self.tx_bytes.fetch_add(bytes, Relaxed); - self.tx_packets.fetch_add(1, Relaxed); - } - - pub fn record_rx_bytes(&self, bytes: u64) { - self.rx_bytes.fetch_add(bytes, Relaxed); - self.rx_packets.fetch_add(1, Relaxed); - } -} diff --git a/easytier/src/tunnels/tcp_tunnel.rs b/easytier/src/tunnels/tcp_tunnel.rs deleted file mode 100644 index 88e82b9..0000000 --- a/easytier/src/tunnels/tcp_tunnel.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::net::SocketAddr; - -use async_trait::async_trait; -use futures::stream::FuturesUnordered; -use tokio::net::{TcpListener, TcpSocket, TcpStream}; -use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; - -use crate::tunnels::common::setup_sokcet2; - -use super::{ - check_scheme_and_get_socket_addr, - common::{wait_for_connect_futures, FramedTunnel}, - Tunnel, TunnelInfo, TunnelListener, -}; - -#[derive(Debug)] -pub struct TcpTunnelListener { - addr: url::Url, - listener: Option, -} - -impl TcpTunnelListener { - pub fn new(addr: url::Url) -> Self { - TcpTunnelListener { - addr, - listener: None, - } - } -} - -#[async_trait] -impl TunnelListener for TcpTunnelListener { - async fn listen(&mut self) -> Result<(), super::TunnelError> { - let addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; - - let socket = if addr.is_ipv4() { - TcpSocket::new_v4()? - } else { - TcpSocket::new_v6()? - }; - - socket.set_reuseaddr(true)?; - // #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] - // socket.set_reuseport(true)?; - socket.bind(addr)?; - - self.listener = Some(socket.listen(1024)?); - Ok(()) - } - - async fn accept(&mut self) -> Result, super::TunnelError> { - let listener = self.listener.as_ref().unwrap(); - let (stream, _) = listener.accept().await?; - stream.set_nodelay(true).unwrap(); - let info = TunnelInfo { - tunnel_type: "tcp".to_owned(), - local_addr: self.local_url().into(), - remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp") - .into(), - }; - - let (r, w) = tokio::io::split(stream); - Ok(FramedTunnel::new_tunnel_with_info( - FramedRead::new(r, LengthDelimitedCodec::new()), - FramedWrite::new(w, LengthDelimitedCodec::new()), - info, - )) - } - - fn local_url(&self) -> url::Url { - self.addr.clone() - } -} - -fn get_tunnel_with_tcp_stream( - stream: TcpStream, - remote_url: url::Url, -) -> Result, super::TunnelError> { - stream.set_nodelay(true).unwrap(); - - let info = TunnelInfo { - tunnel_type: "tcp".to_owned(), - local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp") - .into(), - remote_addr: remote_url.into(), - }; - - let (r, w) = tokio::io::split(stream); - Ok(Box::new(FramedTunnel::new_tunnel_with_info( - FramedRead::new(r, LengthDelimitedCodec::new()), - FramedWrite::new(w, LengthDelimitedCodec::new()), - info, - ))) -} - -#[derive(Debug)] -pub struct TcpTunnelConnector { - addr: url::Url, - - bind_addrs: Vec, -} - -impl TcpTunnelConnector { - pub fn new(addr: url::Url) -> Self { - TcpTunnelConnector { - addr, - bind_addrs: vec![], - } - } - - async fn connect_with_default_bind(&mut self) -> Result, super::TunnelError> { - tracing::info!(addr = ?self.addr, "connect tcp start"); - let addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; - let stream = TcpStream::connect(addr).await?; - tracing::info!(addr = ?self.addr, "connect tcp succ"); - return get_tunnel_with_tcp_stream(stream, self.addr.clone().into()); - } - - async fn connect_with_custom_bind(&mut self) -> Result, super::TunnelError> { - let futures = FuturesUnordered::new(); - let dst_addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; - - for bind_addr in self.bind_addrs.iter() { - tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr"); - - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(dst_addr), - socket2::Type::STREAM, - Some(socket2::Protocol::TCP), - )?; - setup_sokcet2(&socket2_socket, bind_addr)?; - - let socket = TcpSocket::from_std_stream(socket2_socket.into()); - futures.push(socket.connect(dst_addr.clone())); - } - - let ret = wait_for_connect_futures(futures).await; - return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into()); - } -} - -#[async_trait] -impl super::TunnelConnector for TcpTunnelConnector { - async fn connect(&mut self) -> Result, super::TunnelError> { - if self.bind_addrs.is_empty() { - self.connect_with_default_bind().await - } else { - self.connect_with_custom_bind().await - } - } - - fn remote_url(&self) -> url::Url { - self.addr.clone() - } - fn set_bind_addrs(&mut self, addrs: Vec) { - self.bind_addrs = addrs; - } -} - -#[cfg(test)] -mod tests { - use futures::{SinkExt, StreamExt}; - - use crate::tunnels::{ - common::tests::{_tunnel_bench, _tunnel_pingpong}, - TunnelConnector, - }; - - use super::*; - - #[tokio::test] - async fn tcp_pingpong() { - let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap()); - let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap()); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - async fn tcp_bench() { - let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap()); - let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap()); - _tunnel_bench(listener, connector).await - } - - #[tokio::test] - async fn tcp_bench_with_bind() { - let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap()); - let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap()); - connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - #[should_panic] - async fn tcp_bench_with_bind_fail() { - let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); - let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); - connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - // test slow send lock in framed tunnel - #[tokio::test] - async fn tcp_multiple_sender_and_slow_receiver() { - // console_subscriber::init(); - let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); - let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); - - listener.listen().await.unwrap(); - let t1 = tokio::spawn(async move { - let t = listener.accept().await.unwrap(); - let mut stream = t.pin_stream(); - - let now = tokio::time::Instant::now(); - - while let Some(Ok(_)) = stream.next().await { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - if now.elapsed().as_secs() > 5 { - break; - } - } - - tracing::info!("t1 exit"); - }); - - let tunnel = connector.connect().await.unwrap(); - let mut sink1 = tunnel.pin_sink(); - let t2 = tokio::spawn(async move { - for i in 0..1000000 { - let a = sink1.send(b"hello".to_vec().into()).await; - if a.is_err() { - tracing::info!(?a, "t2 exit with err"); - break; - } - - if i % 5000 == 0 { - tracing::info!(i, "send2 1000"); - } - } - - tracing::info!("t2 exit"); - }); - - let mut sink2 = tunnel.pin_sink(); - let t3 = tokio::spawn(async move { - for i in 0..1000000 { - let a = sink2.send(b"hello".to_vec().into()).await; - if a.is_err() { - tracing::info!(?a, "t3 exit with err"); - break; - } - - if i % 5000 == 0 { - tracing::info!(i, "send2 1000"); - } - } - - tracing::info!("t3 exit"); - }); - - let t4 = tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - tracing::info!("closing"); - let close_ret = tunnel.pin_sink().close().await; - tracing::info!("closed {:?}", close_ret); - }); - - let _ = tokio::join!(t1, t2, t3, t4); - } -} diff --git a/easytier/src/tunnels/tunnel_filter.rs b/easytier/src/tunnels/tunnel_filter.rs deleted file mode 100644 index 8209360..0000000 --- a/easytier/src/tunnels/tunnel_filter.rs +++ /dev/null @@ -1,279 +0,0 @@ -use std::{ - sync::Arc, - task::{Context, Poll}, -}; - -use crate::rpc::TunnelInfo; -use futures::{Sink, SinkExt, Stream, StreamExt}; - -use self::stats::Throughput; - -use super::*; -use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel}; - -pub trait TunnelFilter { - fn before_send(&self, data: SinkItem) -> Option> { - Some(Ok(data)) - } - fn after_received(&self, data: StreamItem) -> Option> { - match data { - Ok(v) => Some(Ok(v)), - Err(e) => Some(Err(e)), - } - } -} - -pub struct TunnelWithFilter { - inner: T, - filter: Arc, -} - -impl Tunnel for TunnelWithFilter -where - T: Tunnel + Send + Sync + 'static, - F: TunnelFilter + Send + Sync + 'static, -{ - fn sink(&self) -> Box { - struct SinkWrapper { - sink: Pin>, - filter: Arc, - } - impl Sink for SinkWrapper - where - F: TunnelFilter + Send + Sync + 'static, - { - type Error = SinkError; - - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.get_mut().sink.poll_ready_unpin(cx) - } - - fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { - let Some(item) = self.filter.before_send(item) else { - return Ok(()); - }; - self.get_mut().sink.start_send_unpin(item?) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.get_mut().sink.poll_flush_unpin(cx) - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.get_mut().sink.poll_close_unpin(cx) - } - } - - Box::new(SinkWrapper { - sink: self.inner.pin_sink(), - filter: self.filter.clone(), - }) - } - - fn stream(&self) -> Box { - struct StreamWrapper { - stream: Pin>, - filter: Arc, - } - impl Stream for StreamWrapper - where - F: TunnelFilter + Send + Sync + 'static, - { - type Item = StreamItem; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let self_mut = self.get_mut(); - loop { - match self_mut.stream.poll_next_unpin(cx) { - Poll::Ready(Some(ret)) => { - let Some(ret) = self_mut.filter.after_received(ret) else { - continue; - }; - return Poll::Ready(Some(ret)); - } - Poll::Ready(None) => { - return Poll::Ready(None); - } - Poll::Pending => { - return Poll::Pending; - } - } - } - } - } - - Box::new(StreamWrapper { - stream: self.inner.pin_stream(), - filter: self.filter.clone(), - }) - } - - fn info(&self) -> Option { - self.inner.info() - } -} - -impl TunnelWithFilter -where - T: Tunnel + Send + Sync + 'static, - F: TunnelFilter + Send + Sync + 'static, -{ - pub fn new(inner: T, filter: Arc) -> Self { - Self { inner, filter } - } -} - -pub struct PacketRecorderTunnelFilter { - pub received: Arc>>, - pub sent: Arc>>, -} - -impl TunnelFilter for PacketRecorderTunnelFilter { - fn before_send(&self, data: SinkItem) -> Option> { - self.received.lock().unwrap().push(data.clone()); - Some(Ok(data)) - } - - fn after_received(&self, data: StreamItem) -> Option> { - match data { - Ok(v) => { - self.sent.lock().unwrap().push(v.clone().into()); - Some(Ok(v)) - } - Err(e) => Some(Err(e)), - } - } -} - -impl PacketRecorderTunnelFilter { - pub fn new() -> Self { - Self { - received: Arc::new(std::sync::Mutex::new(Vec::new())), - sent: Arc::new(std::sync::Mutex::new(Vec::new())), - } - } -} - -pub struct StatsRecorderTunnelFilter { - throughput: Arc, -} - -impl TunnelFilter for StatsRecorderTunnelFilter { - fn before_send(&self, data: SinkItem) -> Option> { - self.throughput.record_tx_bytes(data.len() as u64); - Some(Ok(data)) - } - - fn after_received(&self, data: StreamItem) -> Option> { - match data { - Ok(v) => { - self.throughput.record_rx_bytes(v.len() as u64); - Some(Ok(v)) - } - Err(e) => Some(Err(e)), - } - } -} - -impl StatsRecorderTunnelFilter { - pub fn new() -> Self { - Self { - throughput: Arc::new(Throughput::new()), - } - } - - pub fn get_throughput(&self) -> Arc { - self.throughput.clone() - } -} - -#[macro_export] -macro_rules! define_tunnel_filter_chain { - ($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => ( - pub struct $type_name { - $($field_name: std::sync::Arc<$filter_type>,)+ - } - - impl $type_name { - pub fn new() -> Self { - Self { - $($field_name: std::sync::Arc::new(<$filter_type>::new()),)+ - } - } - - pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel { - $( - let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone()); - )+ - tunnel - } - } - ) -} - -#[cfg(test)] -pub mod tests { - use std::sync::atomic::{AtomicU32, Ordering}; - - use super::*; - use crate::tunnels::ring_tunnel::RingTunnel; - - pub struct DropSendTunnelFilter { - start: AtomicU32, - end: AtomicU32, - cur: AtomicU32, - } - - impl TunnelFilter for DropSendTunnelFilter { - fn before_send(&self, data: SinkItem) -> Option> { - self.cur.fetch_add(1, Ordering::SeqCst); - if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst) - && self.cur.load(std::sync::atomic::Ordering::SeqCst) - < self.end.load(Ordering::SeqCst) - { - tracing::trace!("drop packet: {:?}", data); - return None; - } - Some(Ok(data)) - } - } - - impl DropSendTunnelFilter { - pub fn new(start: u32, end: u32) -> Self { - Self { - start: AtomicU32::new(start), - end: AtomicU32::new(end), - cur: AtomicU32::new(0), - } - } - } - - #[tokio::test] - async fn test_nested_filter() { - define_tunnel_filter_chain!( - Filter, - a = PacketRecorderTunnelFilter, - b = PacketRecorderTunnelFilter, - c = PacketRecorderTunnelFilter - ); - - let filter = Filter::new(); - let tunnel = filter.wrap_tunnel(RingTunnel::new(1)); - - let mut s = tunnel.pin_sink(); - s.send(Bytes::from("hello")).await.unwrap(); - - assert_eq!(1, filter.a.received.lock().unwrap().len()); - assert_eq!(1, filter.b.received.lock().unwrap().len()); - assert_eq!(1, filter.c.received.lock().unwrap().len()); - } -} diff --git a/easytier/src/tunnels/udp_tunnel.rs b/easytier/src/tunnels/udp_tunnel.rs deleted file mode 100644 index 48ff187..0000000 --- a/easytier/src/tunnels/udp_tunnel.rs +++ /dev/null @@ -1,768 +0,0 @@ -use std::{fmt::Debug, pin::Pin, sync::Arc}; - -use async_trait::async_trait; -use dashmap::DashMap; -use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; -use rkyv::{Archive, Deserialize, Serialize}; -use std::net::SocketAddr; -use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; -use tokio_util::{ - bytes::{Buf, Bytes, BytesMut}, - udp::UdpFramed, -}; -use tracing::Instrument; - -use crate::{ - common::{ - join_joinset_background, - rkyv_util::{self, encode_to_bytes, vec_to_string}, - }, - rpc::TunnelInfo, - tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector}, -}; - -use super::{ - codec::BytesCodec, - common::{ - setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures, FramedTunnel, - TunnelWithCustomInfo, - }, - ring_tunnel::create_ring_tunnel_pair, - DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl, -}; - -pub const UDP_DATA_MTU: usize = 65000; - -#[derive(Archive, Deserialize, Serialize)] -#[archive(compare(PartialEq), check_bytes)] -// Derives can be passed through to the generated type: -pub enum UdpPacketPayload { - Syn, - Sack, - HolePunch(String), - Data(String), -} - -impl std::fmt::Debug for UdpPacketPayload { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); - match self { - UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(), - UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(), - UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(), - UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(), - } - } -} - -#[derive(Archive, Deserialize, Serialize, Debug)] -#[archive(compare(PartialEq), check_bytes)] -#[archive_attr(derive(Debug))] -pub struct UdpPacket { - pub conn_id: u32, - pub magic: u32, - pub payload: UdpPacketPayload, -} - -const UDP_PACKET_MAGIC: u32 = 0x19941126; - -impl std::fmt::Debug for ArchivedUdpPacketPayload { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); - match self { - ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(), - ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(), - ArchivedUdpPacketPayload::HolePunch(s) => { - tmp.field("HolePunch", &s.as_bytes()).finish() - } - ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(), - } - } -} - -impl UdpPacket { - pub fn new_data_packet(conn_id: u32, data: Vec) -> Self { - Self { - conn_id, - magic: UDP_PACKET_MAGIC, - payload: UdpPacketPayload::Data(vec_to_string(data)), - } - } - - pub fn new_hole_punch_packet(data: Vec) -> Self { - Self { - conn_id: 0, - magic: UDP_PACKET_MAGIC, - payload: UdpPacketPayload::HolePunch(vec_to_string(data)), - } - } - - pub fn new_syn_packet(conn_id: u32) -> Self { - Self { - conn_id, - magic: UDP_PACKET_MAGIC, - payload: UdpPacketPayload::Syn, - } - } - - pub fn new_sack_packet(conn_id: u32) -> Self { - Self { - conn_id, - magic: UDP_PACKET_MAGIC, - payload: UdpPacketPayload::Sack, - } - } -} - -fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option { - let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { - tracing::warn!(?buf, "udp decode error"); - return None; - }; - - if udp_packet.conn_id != conn_id.clone() { - tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match"); - return None; - } - - if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::trace!(?udp_packet, "udp magic not match"); - return None; - } - - let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else { - tracing::warn!(?udp_packet, "udp payload not data"); - return None; - }; - - let offset = payload.as_ptr() as usize - buf.as_ptr() as usize; - let len = payload.len(); - if offset + len > buf.len() { - tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range"); - return None; - } - - buf.advance(offset); - buf.truncate(len); - tracing::trace!(?offset, ?len, ?buf, "udp payload data"); - - Some(buf) -} - -fn get_tunnel_from_socket( - socket: Arc, - addr: SocketAddr, - conn_id: u32, -) -> Box { - let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU)); - let (sink, stream) = udp.split(); - - let recv_addr = addr; - let stream = stream.filter_map(move |v| async move { - tracing::trace!(?v, "udp stream recv something"); - if v.is_err() { - tracing::warn!(?v, "udp stream error"); - return Some(Err(super::TunnelError::CommonError( - "udp stream error".to_owned(), - ))); - } - - let (buf, addr) = v.unwrap(); - if recv_addr != addr { - tracing::warn!(?addr, ?recv_addr, "udp recv addr not match"); - return None; - } - - Some(Ok(try_get_data_payload(buf, conn_id.clone())?)) - }); - let stream = Box::pin(stream); - - let sender_addr = addr; - let sink = Box::pin(sink.with(move |v: Bytes| async move { - if false { - return Err(super::TunnelError::CommonError("udp sink error".to_owned())); - } - - // TODO: two copy here, how to avoid? - let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec()); - let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); - tracing::trace!(?udp_packet, ?v, "udp send packet"); - - Ok((v, sender_addr)) - })); - - FramedTunnel::new_tunnel_with_info( - stream, - sink, - // TODO: this remote addr is not a url - super::TunnelInfo { - tunnel_type: "udp".to_owned(), - local_addr: super::build_url_from_socket_addr( - &socket.local_addr().unwrap().to_string(), - "udp", - ) - .into(), - remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(), - }, - ) -} - -pub(crate) struct StreamSinkPair( - pub Pin>, - pub Pin>, - pub u32, -); -pub(crate) type ArcStreamSinkPair = Arc>; - -pub struct UdpTunnelListener { - addr: url::Url, - socket: Option>, - - sock_map: Arc>, - forward_tasks: Arc>>, - - conn_recv: tokio::sync::mpsc::Receiver>, - conn_send: Option>>, -} - -impl UdpTunnelListener { - pub fn new(addr: url::Url) -> Self { - let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100); - Self { - addr, - socket: None, - sock_map: Arc::new(DashMap::new()), - forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), - conn_recv, - conn_send: Some(conn_send), - } - } - - async fn try_forward_packet( - sock_map: &DashMap, - buf: BytesMut, - addr: SocketAddr, - ) -> Result<(), super::TunnelError> { - let entry = sock_map.get_mut(&addr); - if entry.is_none() { - log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf); - return Ok(()); - } - - log::trace!("udp forward packet: {:?}, {:?}", addr, buf); - let entry = entry.unwrap(); - let pair = entry.value().clone(); - drop(entry); - - let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else { - return Ok(()); - }; - pair.lock().await.1.send(buf.freeze()).await?; - Ok(()) - } - - async fn handle_connect( - socket: Arc, - addr: SocketAddr, - forward_tasks: Arc>>, - sock_map: Arc>, - local_url: url::Url, - conn_id: u32, - ) -> Result, super::TunnelError> { - tracing::info!(?conn_id, ?addr, "udp connection accept handling",); - - let udp_packet = UdpPacket::new_sack_packet(conn_id); - let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); - socket.send_to(&sack_buf, addr).await?; - - let (ctunnel, stunnel) = create_ring_tunnel_pair(); - let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id); - let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id); - let addr_copy = addr.clone(); - sock_map.insert(addr, Arc::new(Mutex::new(ss_pair))); - let ctunnel_stream = ctunnel.pin_stream(); - forward_tasks.lock().unwrap().spawn(async move { - let ret = ctunnel_stream - .map(|v| { - tracing::trace!(?v, "udp stream recv something in forward task"); - if v.is_err() { - return Err(super::TunnelError::CommonError( - "udp stream error".to_owned(), - )); - } - Ok(v.unwrap().freeze()) - }) - .forward(udp_tunnel.pin_sink()) - .await; - if let None = sock_map.remove(&addr_copy) { - log::warn!("udp forward packet: {:?}, no entry", addr_copy); - } - close_tunnel(&ctunnel).await.unwrap(); - log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret); - }); - - Ok(Box::new(TunnelWithCustomInfo::new( - stunnel, - TunnelInfo { - tunnel_type: "udp".to_owned(), - local_addr: local_url.into(), - remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(), - }, - ))) - } - - pub fn get_socket(&self) -> Option> { - self.socket.clone() - } -} - -#[async_trait] -impl TunnelListener for UdpTunnelListener { - async fn listen(&mut self) -> Result<(), super::TunnelError> { - let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "udp")?; - - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - - let tunnel_url: TunnelUrl = self.addr.clone().into(); - if let Some(bind_dev) = tunnel_url.bind_dev() { - setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; - } else { - setup_sokcet2(&socket2_socket, &addr)?; - } - - self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); - - let socket = self.socket.as_ref().unwrap().clone(); - let forward_tasks = self.forward_tasks.clone(); - let sock_map = self.sock_map.clone(); - let conn_send = self.conn_send.take().unwrap(); - let local_url = self.local_url().clone(); - self.forward_tasks.lock().unwrap().spawn( - async move { - loop { - let mut buf = BytesMut::new(); - buf.resize(UDP_DATA_MTU, 0); - let (_size, addr) = socket.recv_from(&mut buf).await.unwrap(); - let _ = buf.split_off(_size); - log::trace!( - "udp recv packet: {:?}, buf: {:?}, size: {}", - addr, - buf, - _size - ); - - let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { - tracing::warn!(?buf, "udp decode error in forward task"); - continue; - }; - - if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::trace!(?udp_packet, "udp magic not match"); - continue; - } - - if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) { - let Ok(conn) = Self::handle_connect( - socket.clone(), - addr, - forward_tasks.clone(), - sock_map.clone(), - local_url.clone(), - udp_packet.conn_id.into(), - ) - .await - else { - tracing::error!(?addr, "udp handle connect error"); - continue; - }; - if let Err(e) = conn_send.send(conn).await { - tracing::warn!(?e, "udp send conn to accept channel error"); - } - } else { - Self::try_forward_packet(sock_map.as_ref(), buf, addr) - .await - .unwrap(); - } - } - } - .instrument(tracing::info_span!("udp forward task", ?self.socket)), - ); - - join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned()); - - Ok(()) - } - - async fn accept(&mut self) -> Result, super::TunnelError> { - log::info!("start udp accept: {:?}", self.addr); - while let Some(conn) = self.conn_recv.recv().await { - return Ok(conn); - } - return Err(super::TunnelError::CommonError( - "udp accept error".to_owned(), - )); - } - - fn local_url(&self) -> url::Url { - self.addr.clone() - } - - fn get_conn_counter(&self) -> Arc> { - struct UdpTunnelConnCounter { - sock_map: Arc>, - } - - impl Debug for UdpTunnelConnCounter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("UdpTunnelConnCounter") - .field("sock_map_len", &self.sock_map.len()) - .finish() - } - } - - impl TunnelConnCounter for UdpTunnelConnCounter { - fn get(&self) -> u32 { - self.sock_map.len() as u32 - } - } - - Arc::new(Box::new(UdpTunnelConnCounter { - sock_map: self.sock_map.clone(), - })) - } -} - -pub struct UdpTunnelConnector { - addr: url::Url, - bind_addrs: Vec, -} - -impl UdpTunnelConnector { - pub fn new(addr: url::Url) -> Self { - Self { - addr, - bind_addrs: vec![], - } - } - - async fn wait_sack( - socket: &UdpSocket, - addr: SocketAddr, - conn_id: u32, - ) -> Result<(), super::TunnelError> { - let mut buf = BytesMut::new(); - buf.resize(128, 0); - - let (usize, recv_addr) = tokio::time::timeout( - tokio::time::Duration::from_secs(3), - socket.recv_from(&mut buf), - ) - .await??; - - if recv_addr != addr { - return Err(super::TunnelError::ConnectError(format!( - "udp connect error, unexpected sack addr: {:?}, {:?}", - recv_addr, addr - ))); - } - - let _ = buf.split_off(usize); - - let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { - tracing::warn!(?buf, "udp decode error in wait sack"); - return Err(super::TunnelError::ConnectError(format!( - "udp connect error, decode error. buf: {:?}", - buf - ))); - }; - - if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::trace!(?udp_packet, "udp magic not match"); - return Err(super::TunnelError::ConnectError(format!( - "udp connect error, magic not match. magic: {:?}", - udp_packet.magic - ))); - } - - if conn_id != udp_packet.conn_id { - return Err(super::TunnelError::ConnectError(format!( - "udp connect error, conn id not match. conn_id: {:?}, {:?}", - conn_id, udp_packet.conn_id - ))); - } - - if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) { - return Err(super::TunnelError::ConnectError(format!( - "udp connect error, unexpected payload. payload: {:?}", - udp_packet.payload - ))); - } - - Ok(()) - } - - async fn wait_sack_loop( - socket: &UdpSocket, - addr: SocketAddr, - conn_id: u32, - ) -> Result<(), super::TunnelError> { - while let Err(err) = Self::wait_sack(socket, addr, conn_id).await { - tracing::warn!(?err, "udp wait sack error"); - } - Ok(()) - } - - pub async fn try_connect_with_socket( - &self, - socket: UdpSocket, - ) -> Result, super::TunnelError> { - let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "udp")?; - log::warn!("udp connect: {:?}", self.addr); - - #[cfg(target_os = "windows")] - crate::arch::windows::disable_connection_reset(&socket)?; - - // send syn - let conn_id = rand::random(); - let udp_packet = UdpPacket::new_syn_packet(conn_id); - let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); - let ret = socket.send_to(&b, &addr).await?; - tracing::warn!(?udp_packet, ?ret, "udp send syn"); - - // wait sack - tokio::time::timeout( - tokio::time::Duration::from_secs(3), - Self::wait_sack_loop(&socket, addr, conn_id), - ) - .await??; - - // sack done - let local_addr = socket.local_addr().unwrap().to_string(); - Ok(Box::new(TunnelWithCustomInfo::new( - get_tunnel_from_socket(Arc::new(socket), addr, conn_id), - TunnelInfo { - tunnel_type: "udp".to_owned(), - local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(), - remote_addr: self.remote_url().into(), - }, - ))) - } - - async fn connect_with_default_bind(&mut self) -> Result, super::TunnelError> { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - return self.try_connect_with_socket(socket).await; - } - - async fn connect_with_custom_bind(&mut self) -> Result, super::TunnelError> { - let futures = FuturesUnordered::new(); - - for bind_addr in self.bind_addrs.iter() { - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(*bind_addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - setup_sokcet2(&socket2_socket, &bind_addr)?; - let socket = UdpSocket::from_std(socket2_socket.into())?; - futures.push(self.try_connect_with_socket(socket)); - } - wait_for_connect_futures(futures).await - } -} - -#[async_trait] -impl super::TunnelConnector for UdpTunnelConnector { - async fn connect(&mut self) -> Result, super::TunnelError> { - if self.bind_addrs.is_empty() { - self.connect_with_default_bind().await - } else { - self.connect_with_custom_bind().await - } - } - - fn remote_url(&self) -> url::Url { - self.addr.clone() - } - - fn set_bind_addrs(&mut self, addrs: Vec) { - self.bind_addrs = addrs; - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use rand::Rng; - use tokio::time::timeout; - - use crate::{ - common::global_ctx::tests::get_mock_global_ctx, - tunnels::{ - check_scheme_and_get_socket_addr, - common::{ - get_interface_name_by_ip, setup_sokcet2_ext, - tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}, - }, - }, - }; - - use super::*; - - #[tokio::test] - async fn udp_pingpong() { - let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); - let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - async fn udp_bench() { - let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap()); - let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap()); - _tunnel_bench(listener, connector).await - } - - #[tokio::test] - async fn udp_bench_with_bind() { - let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap()); - let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap()); - connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - #[should_panic] - async fn udp_bench_with_bind_fail() { - let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap()); - let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap()); - connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - async fn send_random_data_to_socket(remote_url: url::Url) { - let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); - socket - .connect(format!( - "{}:{}", - remote_url.host().unwrap(), - remote_url.port().unwrap() - )) - .await - .unwrap(); - - // get a random 100-len buf - loop { - let mut buf = vec![0u8; 100]; - rand::thread_rng().fill(&mut buf[..]); - socket.send(&buf).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - } - } - - #[tokio::test] - async fn udp_multiple_conns() { - let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap()); - listener.listen().await.unwrap(); - - let _lis = tokio::spawn(async move { - loop { - let ret = listener.accept().await.unwrap(); - assert_eq!( - ret.info().unwrap().local_addr, - listener.local_url().to_string() - ); - tokio::spawn(async move { _tunnel_echo_server(ret, false).await }); - } - }); - - let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); - let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); - - let t1 = connector1.connect().await.unwrap(); - let t2 = connector2.connect().await.unwrap(); - - tokio::spawn(timeout( - Duration::from_secs(2), - send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()), - )); - tokio::spawn(timeout( - Duration::from_secs(2), - send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()), - )); - tokio::spawn(timeout( - Duration::from_secs(2), - send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()), - )); - - let sender1 = tokio::spawn(async move { - let mut sink = t1.pin_sink(); - let mut stream = t1.pin_stream(); - - for i in 0..10 { - sink.send(Bytes::from("hello1")).await.unwrap(); - let recv = stream.next().await.unwrap().unwrap(); - println!("t1 recv: {:?}, {:?}", recv, i); - assert_eq!(recv, Bytes::from("hello1")); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - }); - - let sender2 = tokio::spawn(async move { - let mut sink = t2.pin_sink(); - let mut stream = t2.pin_stream(); - - for i in 0..10 { - sink.send(Bytes::from("hello2")).await.unwrap(); - let recv = stream.next().await.unwrap().unwrap(); - println!("t2 recv: {:?}, {:?}", recv, i); - assert_eq!(recv, Bytes::from("hello2")); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - }); - - let _ = tokio::join!(sender1, sender2); - } - - #[tokio::test] - async fn udp_packet_print() { - let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]); - let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); - let a_udp_packet = rkyv_util::decode_from_bytes::(&b).unwrap(); - println!("{:?}, {:?}", udp_packet, a_udp_packet); - } - - #[tokio::test] - async fn bind_multi_ip_to_same_dev() { - let global_ctx = get_mock_global_ctx(); - let ips = global_ctx - .get_ip_collector() - .collect_ip_addrs() - .await - .interface_ipv4s; - if ips.is_empty() { - return; - } - let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap()); - - for ip in ips { - println!("bind to ip: {:?}, {:?}", ip, bind_dev); - let addr = check_scheme_and_get_socket_addr::( - &format!("udp://{}:11111", ip).parse().unwrap(), - "udp", - ) - .unwrap(); - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - ) - .unwrap(); - setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap(); - } - } -} diff --git a/easytier/src/tunnels/wireguard.rs b/easytier/src/tunnels/wireguard.rs deleted file mode 100644 index c6e79c6..0000000 --- a/easytier/src/tunnels/wireguard.rs +++ /dev/null @@ -1,841 +0,0 @@ -use std::{ - collections::hash_map::DefaultHasher, - fmt::{Debug, Formatter}, - hash::Hasher, - net::SocketAddr, - pin::Pin, - sync::{atomic::AtomicBool, Arc}, - time::Duration, -}; - -use anyhow::Context; -use async_recursion::async_recursion; -use async_trait::async_trait; -use boringtun::{ - noise::{errors::WireGuardError, Tunn, TunnResult}, - x25519::{PublicKey, StaticSecret}, -}; -use dashmap::DashMap; -use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; -use rand::RngCore; -use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; - -use crate::{ - rpc::TunnelInfo, - tunnels::{build_url_from_socket_addr, common::TunnelWithCustomInfo}, -}; - -use super::{ - check_scheme_and_get_socket_addr, - common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, - ring_tunnel::create_ring_tunnel_pair, - DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl, -}; - -const MAX_PACKET: usize = 65500; - -#[derive(Debug, Clone)] -enum WgType { - // used by easytier peer, need remove/add ip header for in/out wg msg - InternalUse, - // used by wireguard peer, keep original ip header - ExternalUse, -} - -#[derive(Clone)] -pub struct WgConfig { - my_secret_key: StaticSecret, - my_public_key: PublicKey, - - peer_secret_key: StaticSecret, - peer_public_key: PublicKey, - - wg_type: WgType, -} - -impl WgConfig { - pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self { - let mut my_sec = [0u8; 32]; - let mut hasher = DefaultHasher::new(); - hasher.write(network_name.as_bytes()); - hasher.write(network_secret.as_bytes()); - my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..8]); - my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..16]); - my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..24]); - my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes()); - - let my_secret_key = StaticSecret::from(my_sec); - let my_public_key = PublicKey::from(&my_secret_key); - let peer_secret_key = StaticSecret::from(my_sec); - let peer_public_key = my_public_key.clone(); - - WgConfig { - my_secret_key, - my_public_key, - peer_secret_key, - peer_public_key, - - wg_type: WgType::InternalUse, - } - } - - pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self { - let server_cfg = Self::new_from_network_identity("server", server_key_seed); - let client_cfg = Self::new_from_network_identity("client", client_key_seed); - Self { - my_secret_key: server_cfg.my_secret_key, - my_public_key: server_cfg.my_public_key, - peer_secret_key: client_cfg.my_secret_key, - peer_public_key: client_cfg.my_public_key, - - wg_type: WgType::ExternalUse, - } - } - - pub fn my_secret_key(&self) -> &[u8] { - self.my_secret_key.as_bytes() - } - - pub fn peer_secret_key(&self) -> &[u8] { - self.peer_secret_key.as_bytes() - } - - pub fn my_public_key(&self) -> &[u8] { - self.my_public_key.as_bytes() - } - - pub fn peer_public_key(&self) -> &[u8] { - self.peer_public_key.as_bytes() - } -} - -#[derive(Clone)] -struct WgPeerData { - udp: Arc, // only for send - endpoint: SocketAddr, - tunn: Arc>, - sink: Arc>>>, - stream: Arc>>>, - wg_type: WgType, - stopped: Arc, -} - -impl Debug for WgPeerData { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WgPeerData") - .field("endpoint", &self.endpoint) - .field("local", &self.udp.local_addr()) - .finish() - } -} - -impl WgPeerData { - #[tracing::instrument] - async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> { - let mut send_buf = vec![0u8; MAX_PACKET]; - - let encapsulate_result = { - let mut peer = self.tunn.lock().await; - if matches!(self.wg_type, WgType::InternalUse) { - peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf) - } else { - peer.encapsulate(&packet, &mut send_buf) - } - }; - - tracing::trace!( - ?encapsulate_result, - "Received {} bytes from me", - packet.len() - ); - - match encapsulate_result { - TunnResult::WriteToNetwork(packet) => { - self.udp - .send_to(packet, self.endpoint) - .await - .context("Failed to send encrypted IP packet to WireGuard endpoint.")?; - tracing::debug!( - "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", - packet.len() - ); - } - TunnResult::Err(e) => { - tracing::error!("Failed to encapsulate IP packet: {:?}", e); - } - TunnResult::Done => { - // Ignored - } - other => { - tracing::error!( - "Unexpected WireGuard state during encapsulation: {:?}", - other - ); - } - }; - Ok(()) - } - - /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, - /// decapsulates them, and dispatches newly received IP packets. - #[tracing::instrument] - pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) { - let mut send_buf = vec![0u8; MAX_PACKET]; - let data = &recv_buf[..]; - let decapsulate_result = { - let mut peer = self.tunn.lock().await; - peer.decapsulate(None, data, &mut send_buf) - }; - - tracing::debug!("Decapsulation result: {:?}", decapsulate_result); - - match decapsulate_result { - TunnResult::WriteToNetwork(packet) => { - match self.udp.send_to(packet, self.endpoint).await { - Ok(_) => {} - Err(e) => { - tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); - return; - } - }; - let mut peer = self.tunn.lock().await; - loop { - let mut send_buf = vec![0u8; MAX_PACKET]; - match peer.decapsulate(None, &[], &mut send_buf) { - TunnResult::WriteToNetwork(packet) => { - match self.udp.send_to(packet, self.endpoint).await { - Ok(_) => {} - Err(e) => { - tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); - break; - } - }; - } - _ => { - break; - } - } - } - } - TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { - tracing::debug!( - "WireGuard endpoint sent IP packet of {} bytes", - packet.len() - ); - let ret = self - .sink - .lock() - .await - .send( - if matches!(self.wg_type, WgType::InternalUse) { - self.remove_ip_header(packet, packet[0] >> 4 == 4) - } else { - packet - } - .to_vec() - .into(), - ) - .await; - if ret.is_err() { - tracing::error!("Failed to send packet to tunnel: {:?}", ret); - } - } - _ => { - tracing::warn!( - "Unexpected WireGuard state during decapsulation: {:?}", - decapsulate_result - ); - } - } - } - - #[tracing::instrument] - #[async_recursion] - async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () { - match result { - TunnResult::WriteToNetwork(packet) => { - tracing::debug!( - "Sending routine packet of {} bytes to WireGuard endpoint", - packet.len() - ); - match self.udp.send_to(packet, self.endpoint).await { - Ok(_) => {} - Err(e) => { - tracing::error!( - "Failed to send routine packet to WireGuard endpoint: {:?}", - e - ); - } - }; - } - TunnResult::Err(WireGuardError::ConnectionExpired) => { - tracing::warn!("Wireguard handshake has expired!"); - - let mut buf = vec![0u8; MAX_PACKET]; - let result = self - .tunn - .lock() - .await - .format_handshake_initiation(&mut buf[..], false); - - self.handle_routine_tun_result(result).await - } - TunnResult::Err(e) => { - tracing::error!( - "Failed to prepare routine packet for WireGuard endpoint: {:?}", - e - ); - } - TunnResult::Done => { - // Sleep for a bit - tokio::time::sleep(Duration::from_millis(250)).await; - } - other => { - tracing::warn!("Unexpected WireGuard routine task state: {:?}", other); - tokio::time::sleep(Duration::from_millis(250)).await; - } - }; - } - - /// WireGuard Routine task. Handles Handshake, keep-alive, etc. - pub async fn routine_task(self) { - loop { - let mut send_buf = vec![0u8; MAX_PACKET]; - let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) }; - self.handle_routine_tun_result(tun_result).await; - } - } - - fn add_ip_header(&self, packet: &[u8]) -> Vec { - let mut ret = vec![0u8; packet.len() + 20]; - let ip_header = ret.as_mut_slice(); - ip_header[0] = 0x45; - ip_header[1] = 0; - ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes()); - ip_header[4..6].copy_from_slice(&0u16.to_be_bytes()); - ip_header[6..8].copy_from_slice(&0u16.to_be_bytes()); - ip_header[8] = 64; - ip_header[9] = 0; - ip_header[10..12].copy_from_slice(&0u16.to_be_bytes()); - ip_header[12..16].copy_from_slice(&0u32.to_be_bytes()); - ip_header[16..20].copy_from_slice(&0u32.to_be_bytes()); - ip_header[20..].copy_from_slice(packet); - ret - } - - fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] { - if is_v4 { - return &packet[20..]; - } else { - return &packet[40..]; - } - } -} - -struct WgPeer { - udp: Arc, // only for send - config: WgConfig, - endpoint: SocketAddr, - - data: Option, - tasks: JoinSet<()>, - - access_time: std::time::Instant, -} - -impl WgPeer { - fn new(udp: Arc, config: WgConfig, endpoint: SocketAddr) -> Self { - WgPeer { - udp, - config, - endpoint, - - data: None, - tasks: JoinSet::new(), - - access_time: std::time::Instant::now(), - } - } - - async fn handle_packet_from_me(data: WgPeerData) { - while let Some(Ok(packet)) = data.stream.lock().await.next().await { - let ret = data.handle_one_packet_from_me(&packet).await; - if let Err(e) = ret { - tracing::error!("Failed to handle packet from me: {}", e); - } - } - data.stopped - .store(true, std::sync::atomic::Ordering::Relaxed); - } - - async fn handle_packet_from_peer(&mut self, packet: &[u8]) { - self.access_time = std::time::Instant::now(); - tracing::trace!("Received {} bytes from peer", packet.len()); - let data = self.data.as_ref().unwrap(); - data.handle_one_packet_from_peer(packet).await; - } - - fn start_and_get_tunnel(&mut self) -> Box { - let (stunnel, ctunnel) = create_ring_tunnel_pair(); - - let data = WgPeerData { - udp: self.udp.clone(), - endpoint: self.endpoint, - tunn: Arc::new(Mutex::new( - Tunn::new( - self.config.my_secret_key.clone(), - self.config.peer_public_key.clone(), - None, - None, - rand::thread_rng().next_u32(), - None, - ) - .unwrap(), - )), - sink: Arc::new(Mutex::new(stunnel.pin_sink())), - stream: Arc::new(Mutex::new(stunnel.pin_stream())), - wg_type: self.config.wg_type.clone(), - stopped: Arc::new(AtomicBool::new(false)), - }; - - self.data = Some(data.clone()); - self.tasks.spawn(Self::handle_packet_from_me(data.clone())); - self.tasks.spawn(data.routine_task()); - - ctunnel - } - - fn stopped(&self) -> bool { - self.data - .as_ref() - .unwrap() - .stopped - .load(std::sync::atomic::Ordering::Relaxed) - } -} - -impl Drop for WgPeer { - fn drop(&mut self) { - self.tasks.abort_all(); - if let Some(data) = self.data.clone() { - tokio::spawn(async move { - let _ = data.sink.lock().await.close().await; - }); - } - } -} - -type ConnSender = tokio::sync::mpsc::UnboundedSender>; -type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver>; - -pub struct WgTunnelListener { - addr: url::Url, - config: WgConfig, - - udp: Option>, - conn_recv: ConnReceiver, - conn_send: Option, - - wg_peer_map: Arc>, - - tasks: JoinSet<()>, -} - -impl WgTunnelListener { - pub fn new(addr: url::Url, config: WgConfig) -> Self { - let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel(); - WgTunnelListener { - addr, - config, - - udp: None, - conn_recv, - conn_send: Some(conn_send), - - wg_peer_map: Arc::new(DashMap::new()), - - tasks: JoinSet::new(), - } - } - - fn get_udp_socket(&self) -> Arc { - self.udp.as_ref().unwrap().clone() - } - - async fn handle_udp_incoming( - socket: Arc, - config: WgConfig, - conn_sender: ConnSender, - peer_map: Arc>, - ) { - let mut tasks = JoinSet::new(); - - let peer_map_clone = peer_map.clone(); - tasks.spawn(async move { - loop { - peer_map_clone - .retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped()); - tokio::time::sleep(Duration::from_secs(1)).await; - } - }); - - let mut buf = vec![0u8; MAX_PACKET]; - loop { - let Ok((n, addr)) = socket.recv_from(&mut buf).await else { - tracing::error!("Failed to receive from UDP socket"); - break; - }; - - let data = &buf[..n]; - tracing::trace!("Received {} bytes from {}", n, addr); - - if !peer_map.contains_key(&addr) { - tracing::info!("New peer: {}", addr); - let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone()); - let tunnel = Box::new(TunnelWithCustomInfo::new( - wg.start_and_get_tunnel(), - TunnelInfo { - tunnel_type: "wg".to_owned(), - local_addr: build_url_from_socket_addr( - &socket.local_addr().unwrap().to_string(), - "wg", - ) - .into(), - remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(), - }, - )); - if let Err(e) = conn_sender.send(tunnel) { - tracing::error!("Failed to send tunnel to conn_sender: {}", e); - } - peer_map.insert(addr, wg); - } - - let mut peer = peer_map.get_mut(&addr).unwrap(); - peer.handle_packet_from_peer(data).await; - } - } -} - -#[async_trait] -impl TunnelListener for WgTunnelListener { - async fn listen(&mut self) -> Result<(), super::TunnelError> { - let addr = check_scheme_and_get_socket_addr::(&self.addr, "wg")?; - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - - let tunnel_url: TunnelUrl = self.addr.clone().into(); - if let Some(bind_dev) = tunnel_url.bind_dev() { - setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; - } else { - setup_sokcet2(&socket2_socket, &addr)?; - } - - self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); - self.tasks.spawn(Self::handle_udp_incoming( - self.get_udp_socket(), - self.config.clone(), - self.conn_send.take().unwrap(), - self.wg_peer_map.clone(), - )); - - Ok(()) - } - - async fn accept(&mut self) -> Result, super::TunnelError> { - while let Some(tunnel) = self.conn_recv.recv().await { - tracing::info!(?tunnel, "Accepted tunnel"); - return Ok(tunnel); - } - Err(TunnelError::CommonError( - "Failed to accept tunnel".to_string(), - )) - } - - fn local_url(&self) -> url::Url { - self.addr.clone() - } -} - -pub struct WgClientTunnel { - wg_peer: WgPeer, - tunnel: Box, - info: TunnelInfo, -} - -impl Tunnel for WgClientTunnel { - fn stream(&self) -> Box { - self.tunnel.stream() - } - - fn sink(&self) -> Box { - self.tunnel.sink() - } - - fn info(&self) -> Option { - Some(self.info.clone()) - } -} - -#[derive(Clone)] -pub struct WgTunnelConnector { - addr: url::Url, - config: WgConfig, - udp: Option>, - - bind_addrs: Vec, -} - -impl Debug for WgTunnelConnector { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WgTunnelConnector") - .field("addr", &self.addr) - .field("udp", &self.udp) - .finish() - } -} - -impl WgTunnelConnector { - pub fn new(addr: url::Url, config: WgConfig) -> Self { - WgTunnelConnector { - addr, - config, - udp: None, - bind_addrs: vec![], - } - } - - fn create_handshake_init(tun: &mut Tunn) -> Vec { - let mut dst = vec![0u8; 2048]; - let handshake_init = tun.format_handshake_initiation(&mut dst, false); - assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); - let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { - sent - } else { - unreachable!(); - }; - - handshake_init.into() - } - - fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { - let mut dst = vec![0u8; 2048]; - let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); - assert!( - matches!(keepalive, TunnResult::WriteToNetwork(_)), - "Failed to parse handshake response, {:?}", - keepalive - ); - - let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { - sent - } else { - unreachable!(); - }; - - keepalive.into() - } - - #[tracing::instrument(skip(config))] - async fn connect_with_socket( - addr_url: url::Url, - config: WgConfig, - udp: UdpSocket, - ) -> Result, super::TunnelError> { - let addr = super::check_scheme_and_get_socket_addr::(&addr_url, "wg")?; - tracing::warn!("wg connect: {:?}", addr); - let local_addr = udp.local_addr().unwrap().to_string(); - - let mut my_tun = Tunn::new( - config.my_secret_key.clone(), - config.peer_public_key.clone(), - None, - None, - rand::thread_rng().next_u32(), - None, - ) - .unwrap(); - - let init = Self::create_handshake_init(&mut my_tun); - udp.send_to(&init, addr).await?; - - let mut buf = vec![0u8; MAX_PACKET]; - let (n, _) = udp.recv_from(&mut buf).await.unwrap(); - let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]); - udp.send_to(&keepalive, addr).await?; - - let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr); - let tunnel = wg_peer.start_and_get_tunnel(); - - let data = wg_peer.data.as_ref().unwrap().clone(); - wg_peer.tasks.spawn(async move { - loop { - let mut buf = vec![0u8; MAX_PACKET]; - let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); - if recv_addr != addr { - continue; - } - data.handle_one_packet_from_peer(&buf[..n]).await; - } - }); - - let ret = Box::new(WgClientTunnel { - wg_peer, - tunnel, - info: TunnelInfo { - tunnel_type: "wg".to_owned(), - local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(), - remote_addr: addr_url.to_string(), - }, - }); - - Ok(ret) - } -} - -#[async_trait] -impl super::TunnelConnector for WgTunnelConnector { - #[tracing::instrument] - async fn connect(&mut self) -> Result, super::TunnelError> { - let bind_addrs = if self.bind_addrs.is_empty() { - vec!["0.0.0.0:0".parse().unwrap()] - } else { - self.bind_addrs.clone() - }; - let futures = FuturesUnordered::new(); - - for bind_addr in bind_addrs.into_iter() { - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(bind_addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - setup_sokcet2(&socket2_socket, &bind_addr)?; - let socket = UdpSocket::from_std(socket2_socket.into())?; - tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task"); - futures.push(Self::connect_with_socket( - self.addr.clone(), - self.config.clone(), - socket, - )); - } - - wait_for_connect_futures(futures).await - } - - fn remote_url(&self) -> url::Url { - self.addr.clone() - } - - fn set_bind_addrs(&mut self, addrs: Vec) { - self.bind_addrs = addrs; - } -} - -#[cfg(test)] -pub mod tests { - use boringtun::*; - - use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; - use crate::tunnels::{wireguard::*, TunnelConnector}; - - pub fn create_wg_config() -> (WgConfig, WgConfig) { - let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); - let my_public_key = x25519::PublicKey::from(&my_secret_key); - - let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); - let their_public_key = x25519::PublicKey::from(&their_secret_key); - - let server_cfg = WgConfig { - my_secret_key: my_secret_key.clone(), - my_public_key, - peer_secret_key: their_secret_key.clone(), - peer_public_key: their_public_key.clone(), - wg_type: WgType::InternalUse, - }; - - let client_cfg = WgConfig { - my_secret_key: their_secret_key, - my_public_key: their_public_key, - peer_secret_key: my_secret_key, - peer_public_key: my_public_key, - wg_type: WgType::InternalUse, - }; - - (server_cfg, client_cfg) - } - - #[tokio::test] - async fn wg_pingpong() { - let (server_cfg, client_cfg) = create_wg_config(); - let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg); - let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - async fn wg_bench() { - let (server_cfg, client_cfg) = create_wg_config(); - let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg); - let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg); - _tunnel_bench(listener, connector).await - } - - #[tokio::test] - async fn wg_bench_with_bind() { - let (server_cfg, client_cfg) = create_wg_config(); - let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg); - let mut connector = - WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg); - connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - #[should_panic] - async fn wg_bench_with_bind_fail() { - let (server_cfg, client_cfg) = create_wg_config(); - let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg); - let mut connector = - WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg); - connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); - _tunnel_pingpong(listener, connector).await - } - - #[tokio::test] - async fn wg_server_erase_from_map_after_close() { - let (server_cfg, client_cfg) = create_wg_config(); - let mut listener = - WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg); - listener.listen().await.unwrap(); - - const CONN_COUNT: usize = 10; - - tokio::spawn(async move { - for _ in 0..CONN_COUNT { - let mut connector = WgTunnelConnector::new( - "wg://127.0.0.1:5595".parse().unwrap(), - client_cfg.clone(), - ); - let ret = connector.connect().await; - assert!(ret.is_ok()); - drop(ret); - } - }); - - for _ in 0..CONN_COUNT { - let conn = listener.accept().await; - assert!(conn.is_ok()); - drop(conn); - } - - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; - - assert_eq!(0, listener.wg_peer_map.len()); - } -}