From 69651ae3fdf723b3644b1ba620352a6d42ab45c9 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Fri, 26 Apr 2024 23:02:07 +0800 Subject: [PATCH] Perf improve (#59) * improve perf * fix forward --- Cargo.lock | 27 +++++ easytier/Cargo.toml | 2 + easytier/src/connector/mod.rs | 13 +++ easytier/src/connector/udp_hole_punch.rs | 3 +- easytier/src/easytier-core.rs | 6 ++ easytier/src/instance/listeners.rs | 41 +++---- easytier/src/instance/virtual_nic.rs | 30 ++++-- easytier/src/peers/peer.rs | 24 ++--- easytier/src/peers/peer_manager.rs | 130 +++++++++++++++++++---- easytier/src/peers/peer_rpc.rs | 6 +- easytier/src/peers/zc_peer_conn.rs | 27 +++-- easytier/src/tunnel/common.rs | 31 +++--- easytier/src/tunnel/packet_def.rs | 129 +++++++++++++++------- easytier/src/tunnel/stats.rs | 41 +++---- easytier/src/tunnel/udp.rs | 10 +- easytier/src/tunnel/wireguard.rs | 12 +-- 16 files changed, 370 insertions(+), 162 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 65a4552..205f60c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,6 +983,12 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "cty" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" + [[package]] name = "curve25519-dalek" version = "4.0.0-rc.3" @@ -1298,6 +1304,7 @@ dependencies = [ "gethostname", "humansize", "log", + "mimalloc-rust", "network-interface", "nix 0.27.1", "once_cell", @@ -2672,6 +2679,26 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mimalloc-rust" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb726c8298efb4010b2c46d8050e4be36cf807b9d9e98cb112f830914fc9bbe" +dependencies = [ + "cty", + "mimalloc-rust-sys", +] + +[[package]] +name = "mimalloc-rust-sys" +version = "1.7.9-source" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6413e13241a9809f291568133eca6694572cf528c1a6175502d090adce5dd5db" +dependencies = [ + "cc", + "cty", +] + [[package]] name = "mime" version = "0.3.17" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 452b85c..acad937 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -140,6 +140,8 @@ base64 = "0.21.7" derivative = "2.2.0" +mimalloc-rust = "0.2.1" + [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.52", features = [ "Win32_Networking_WinSock", diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 7df015d..3b61f26 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, tunnel::{ + quic::QUICTunnelConnector, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, @@ -77,6 +78,18 @@ pub async fn create_connector_by_url( let connector = RingTunnelConnector::new(url); return Ok(Box::new(connector)); } + "quic" => { + let dst_addr = + crate::tunnels::check_scheme_and_get_socket_addr::(&url, "quic")?; + let mut connector = QUICTunnelConnector::new(url); + set_bind_addr_for_peer_connector( + &mut connector, + dst_addr.is_ipv4(), + &global_ctx.get_ip_collector(), + ) + .await; + return Ok(Box::new(connector)); + } "wg" => { let dst_addr = crate::tunnels::check_scheme_and_get_socket_addr::(&url, "wg")?; diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index 53f138e..42e1f6f 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -15,7 +15,6 @@ use crate::{ rpc::NatType, tunnel::{ common::setup_sokcet2, - packet_def::ZCPacketType, udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener, }, @@ -153,7 +152,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer { let udp_packet = new_hole_punch_packet(); let _ = socket - .send_to(&udp_packet.into_bytes(ZCPacketType::UDP), local_mapped_addr) + .send_to(&udp_packet.into_bytes(), local_mapped_addr) .await; tokio::time::sleep(std::time::Duration::from_millis(300)).await; } diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index c80ed52..4c56410 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -33,6 +33,11 @@ use crate::common::{ global_ctx::GlobalCtxEvent, }; +use mimalloc_rust::*; + +#[global_allocator] +static GLOBAL_MIMALLOC: GlobalMiMalloc = GlobalMiMalloc; + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { @@ -437,6 +442,7 @@ fn main() { if cli.multi_thread { tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) .enable_all() .build() .unwrap() diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index ee8e185..8148ae9 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -12,6 +12,7 @@ use crate::{ }, peers::peer_manager::PeerManager, tunnel::{ + quic::QUICTunnelListener, ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, @@ -20,6 +21,26 @@ use crate::{ }, }; +pub fn get_listener_by_url( + l: &url::Url, + ctx: ArcGlobalCtx, +) -> Result, Error> { + Ok(match l.scheme() { + "tcp" => Box::new(TcpTunnelListener::new(l.clone())), + "udp" => Box::new(UdpTunnelListener::new(l.clone())), + "wg" => { + let nid = ctx.get_network_identity(); + let wg_config = + WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); + Box::new(WgTunnelListener::new(l.clone(), wg_config)) + } + "quic" => Box::new(QUICTunnelListener::new(l.clone())), + _ => { + unreachable!("unsupported listener uri"); + } + }) +} + #[async_trait] pub trait TunnelHandlerForListener { async fn handle_tunnel(&self, tunnel: Box) -> Result<(), Error>; @@ -62,24 +83,8 @@ impl ListenerManage .await?; for l in self.global_ctx.config.get_listener_uris().iter() { - match l.scheme() { - "tcp" => { - self.add_listener(TcpTunnelListener::new(l.clone())).await?; - } - "udp" => { - self.add_listener(UdpTunnelListener::new(l.clone())).await?; - } - "wg" => { - let nid = self.global_ctx.get_network_identity(); - let wg_config = - WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); - self.add_listener(WgTunnelListener::new(l.clone(), wg_config)) - .await?; - } - _ => { - log::warn!("unsupported listener uri: {}", l); - } - } + let lis = get_listener_by_url(l, self.global_ctx.clone())?; + self.add_listener(lis).await?; } Ok(()) diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index 3c50058..ad6b27d 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -19,11 +19,11 @@ use crate::{ }; use byteorder::WriteBytesExt as _; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut}; use futures::{lock::BiLock, ready, Stream}; use pin_project_lite::pin_project; -use tokio::io::AsyncWrite; -use tokio_util::{bytes::Bytes, io::poll_read_buf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_util::bytes::Bytes; use tun::{create_as_async, AsyncDevice, Configuration, Device as _, Layer}; use zerocopy::{NativeEndian, NetworkEndian}; @@ -64,12 +64,24 @@ impl Stream for TunStream { self_mut.cur_buf.set_len(*self_mut.payload_offset); } } - match ready!(poll_read_buf(g.as_pin_mut(), cx, &mut self_mut.cur_buf)) { - Ok(0) => Poll::Ready(None), - Ok(_n) => Poll::Ready(Some(Ok(ZCPacket::new_from_buf( - self_mut.cur_buf.split(), - ZCPacketType::NIC, - )))), + let buf = self_mut.cur_buf.chunk_mut().as_mut_ptr(); + let buf = unsafe { std::slice::from_raw_parts_mut(buf, 2500) }; + let mut buf = ReadBuf::new(buf); + + let ret = ready!(g.as_pin_mut().poll_read(cx, &mut buf)); + let len = buf.filled().len(); + unsafe { self_mut.cur_buf.advance_mut(len) }; + + match ret { + Ok(_) => { + if len == 0 { + return Poll::Ready(None); + } + Poll::Ready(Some(Ok(ZCPacket::new_from_buf( + self_mut.cur_buf.split(), + ZCPacketType::NIC, + )))) + } Err(err) => { println!("tun stream error: {:?}", err); Poll::Ready(None) diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index 6e10229..ea19405 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -3,11 +3,7 @@ use std::sync::Arc; use crossbeam::atomic::AtomicCell; use dashmap::DashMap; -use tokio::{ - select, - sync::{mpsc, Mutex}, - task::JoinHandle, -}; +use tokio::{select, sync::mpsc, task::JoinHandle}; use tracing::Instrument; @@ -25,7 +21,7 @@ use crate::{ tunnel::packet_def::ZCPacket, }; -type ArcPeerConn = Arc>; +type ArcPeerConn = Arc; type ConnMap = Arc>; pub struct Peer { @@ -73,7 +69,7 @@ impl Peer { if let Some((_, conn)) = conns_copy.remove(&ret) { global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved( - conn.lock().await.get_conn_info(), + conn.get_conn_info(), )); } } @@ -108,12 +104,11 @@ impl Peer { pub async fn add_peer_conn(&self, mut conn: PeerConn) { conn.set_close_event_sender(self.close_event_sender.clone()); - conn.start_recv_loop(self.packet_recv_chan.clone()); + conn.start_recv_loop(self.packet_recv_chan.clone()).await; conn.start_pingpong(); self.global_ctx .issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info())); - self.conns - .insert(conn.get_conn_id(), Arc::new(Mutex::new(conn))); + self.conns.insert(conn.get_conn_id(), Arc::new(conn)); } async fn select_conn(&self) -> Option { @@ -128,7 +123,7 @@ impl Peer { } let conn = conn.unwrap().clone(); - self.default_conn_id.store(conn.lock().await.get_conn_id()); + self.default_conn_id.store(conn.get_conn_id()); Some(conn) } @@ -136,10 +131,7 @@ impl Peer { let Some(conn) = self.select_conn().await else { return Err(Error::PeerNoConnectionError(self.peer_node_id)); }; - - let conn_clone = conn.clone(); - drop(conn); - conn_clone.lock().await.send_msg(msg).await?; + conn.send_msg(msg).await?; Ok(()) } @@ -162,7 +154,7 @@ impl Peer { let mut ret = Vec::new(); for conn in conns { - ret.push(conn.lock().await.get_conn_info()); + ret.push(conn.get_conn_info()); } ret } diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 2a65336..d659181 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -279,19 +279,10 @@ impl PeerManager { let from_peer_id = hdr.from_peer_id.get(); let to_peer_id = hdr.to_peer_id.get(); if to_peer_id != my_peer_id { - log::trace!( - "need forward: to_peer_id: {:?}, my_peer_id: {:?}", - to_peer_id, - my_peer_id - ); + tracing::trace!(?to_peer_id, ?my_peer_id, "need forward"); let ret = peers.send_msg(ret, to_peer_id).await; if ret.is_err() { - log::error!( - "forward packet error: {:?}, dst: {:?}, from: {:?}", - ret, - to_peer_id, - from_peer_id - ); + tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } } else { let mut processed = false; @@ -516,15 +507,11 @@ impl PeerManager { msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8); if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await { - if let Err(e) = self.peers.send_msg_directly(msg.clone(), gateway).await { + if let Err(e) = self.peers.send_msg_directly(msg, gateway).await { errs.push(e); } } else if self.foreign_network_client.has_next_hop(*peer_id) { - if let Err(e) = self - .foreign_network_client - .send_msg(msg.clone(), *peer_id) - .await - { + if let Err(e) = self.foreign_network_client.send_msg(msg, *peer_id).await { errs.push(e); } } @@ -622,12 +609,23 @@ impl PeerManager { #[cfg(test)] mod tests { + use std::{fmt::Debug, sync::Arc}; + use crate::{ - connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, - peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear}, + connector::{ + create_connector_by_url, udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, + }, + instance::listeners::get_listener_by_url, + peers::{ + peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient}, + tests::{connect_peer_manager, wait_for_condition, wait_route_appear}, + }, rpc::NatType, + tunnel::{TunnelConnector, TunnelListener}, }; + use super::PeerManager; + #[tokio::test] async fn drop_peer_manager() { let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; @@ -659,4 +657,98 @@ mod tests { ) .await; } + + async fn connect_peer_manager_with( + client_mgr: Arc, + server_mgr: &Arc, + mut client: C, + server: &mut L, + ) { + server.listen().await.unwrap(); + + tokio::spawn(async move { + client.set_bind_addrs(vec![]); + client_mgr.try_connect(client).await.unwrap(); + }); + + server_mgr + .add_client_tunnel(server.accept().await.unwrap()) + .await + .unwrap(); + } + + #[rstest::rstest] + #[tokio::test] + #[serial_test::serial(forward_packet_test)] + async fn forward_packet( + #[values("tcp", "udp", "wg", "quic")] proto1: &str, + #[values("tcp", "udp", "wg", "quic")] proto2: &str, + ) { + let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + peer_mgr_a.get_peer_rpc_mgr().run_service( + 100, + MockService { + prefix: "hello a".to_owned(), + } + .serve(), + ); + + let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + + let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + peer_mgr_c.get_peer_rpc_mgr().run_service( + 100, + MockService { + prefix: "hello c".to_owned(), + } + .serve(), + ); + + let mut listener1 = get_listener_by_url( + &format!("{}://0.0.0.0:31013", proto1).parse().unwrap(), + peer_mgr_b.get_global_ctx(), + ) + .unwrap(); + let connector1 = create_connector_by_url( + format!("{}://127.0.0.1:31013", proto1).as_str(), + &peer_mgr_a.get_global_ctx(), + ) + .await + .unwrap(); + connect_peer_manager_with(peer_mgr_a.clone(), &peer_mgr_b, connector1, &mut listener1) + .await; + + wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone()) + .await + .unwrap(); + + let mut listener2 = get_listener_by_url( + &format!("{}://0.0.0.0:31014", proto2).parse().unwrap(), + peer_mgr_c.get_global_ctx(), + ) + .unwrap(); + let connector2 = create_connector_by_url( + format!("{}://127.0.0.1:31014", proto2).as_str(), + &peer_mgr_b.get_global_ctx(), + ) + .await + .unwrap(); + connect_peer_manager_with(peer_mgr_b.clone(), &peer_mgr_c, connector2, &mut listener2) + .await; + + wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone()) + .await + .unwrap(); + + let ret = peer_mgr_a + .get_peer_rpc_mgr() + .do_client_rpc_scoped(100, peer_mgr_c.my_peer_id(), |c| async { + let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); + let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; + ret + }) + .await + .unwrap(); + assert_eq!(ret, "hello c abc"); + } } diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index 1c8368a..901d0d8 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -389,7 +389,7 @@ impl PeerRpcManager { } #[cfg(test)] -mod tests { +pub mod tests { use std::{pin::Pin, sync::Arc}; use futures::{SinkExt, StreamExt}; @@ -415,8 +415,8 @@ mod tests { } #[derive(Clone)] - struct MockService { - prefix: String, + pub struct MockService { + pub prefix: String, } #[tarpc::server] diff --git a/easytier/src/peers/zc_peer_conn.rs b/easytier/src/peers/zc_peer_conn.rs index 637ebd9..5d29b26 100644 --- a/easytier/src/peers/zc_peer_conn.rs +++ b/easytier/src/peers/zc_peer_conn.rs @@ -13,7 +13,7 @@ use futures::{SinkExt, StreamExt, TryFutureExt}; use prost::Message; use tokio::{ - sync::{broadcast, mpsc}, + sync::{broadcast, mpsc, Mutex}, task::JoinSet, time::{timeout, Duration}, }; @@ -52,9 +52,9 @@ pub struct PeerConn { my_peer_id: PeerId, global_ctx: ArcGlobalCtx, - tunnel: Box, + tunnel: Arc>>, sink: MpscTunnelSender, - recv: Option>>, + recv: Arc>>>>, tunnel_info: Option, tasks: JoinSet>, @@ -98,9 +98,9 @@ impl PeerConn { my_peer_id, global_ctx, - tunnel: Box::new(mpsc_tunnel), + tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))), sink, - recv: Some(recv), + recv: Arc::new(Mutex::new(Some(recv))), tunnel_info, tasks: JoinSet::new(), @@ -121,7 +121,8 @@ impl PeerConn { } async fn wait_handshake(&mut self) -> Result { - let recv = self.recv.as_mut().unwrap(); + let mut locked = self.recv.lock().await; + let recv = locked.as_mut().unwrap(); let Some(rsp) = recv.next().await else { return Err(Error::WaitRespError( "conn closed during wait handshake response".to_owned(), @@ -199,8 +200,8 @@ impl PeerConn { self.info.is_some() } - pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { - let mut stream = self.recv.take().unwrap(); + pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { + let mut stream = self.recv.lock().await.take().unwrap(); let sink = self.sink.clone(); let mut sender = PollSender::new(packet_recv_chan.clone()); let close_event_sender = self.close_event_sender.clone().unwrap(); @@ -286,7 +287,7 @@ impl PeerConn { }); } - pub async fn send_msg(&mut self, msg: ZCPacket) -> Result<(), Error> { + pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { Ok(self.sink.send(msg).await?) } @@ -398,7 +399,9 @@ mod tests { ); s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0); - s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + s_peer + .start_recv_loop(tokio::sync::mpsc::channel(200).0) + .await; assert!(c_ret.is_ok()); assert!(s_ret.is_ok()); @@ -406,7 +409,9 @@ mod tests { let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1); c_peer.set_close_event_sender(close_send); c_peer.start_pingpong(); - c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + c_peer + .start_recv_loop(tokio::sync::mpsc::channel(200).0) + .await; // wait 5s, conn should not be disconnected tokio::time::sleep(Duration::from_secs(15)).await; diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 4133d76..cfefc25 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -9,11 +9,11 @@ use std::{ use futures::{stream::FuturesUnordered, Future, Sink, Stream}; use network_interface::NetworkInterfaceConfig as _; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio_stream::StreamExt; -use tokio_util::io::{poll_read_buf, poll_write_buf}; +use tokio_util::io::poll_write_buf; use zerocopy::FromBytes as _; use crate::{ @@ -149,13 +149,18 @@ where *self_mut.max_packet_size * 64, ); - match ready!(poll_read_buf( - self_mut.reader.as_mut(), - cx, - &mut self_mut.buf - )) { - Ok(size) => { - if size == 0 { + let cap = self_mut.buf.capacity() - self_mut.buf.len(); + let buf = self_mut.buf.chunk_mut().as_mut_ptr(); + let buf = unsafe { std::slice::from_raw_parts_mut(buf, cap) }; + let mut buf = ReadBuf::new(buf); + + let ret = ready!(self_mut.reader.as_mut().poll_read(cx, &mut buf)); + let len = buf.filled().len(); + unsafe { self_mut.buf.advance_mut(len) }; + + match ret { + Ok(_) => { + if len == 0 { return Poll::Ready(None); } } @@ -173,14 +178,16 @@ pub trait ZCPacketToBytes { pub struct TcpZCPacketToBytes; impl ZCPacketToBytes for TcpZCPacketToBytes { - fn into_bytes(&self, mut item: ZCPacket) -> Result { + fn into_bytes(&self, item: ZCPacket) -> Result { + let mut item = item.convert_type(ZCPacketType::TCP); + let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len(); let Some(header) = item.mut_tcp_tunnel_header() else { return Err(TunnelError::InvalidPacket("packet too short".to_string())); }; header.len.set(tcp_len.try_into().unwrap()); - Ok(item.into_bytes(ZCPacketType::TCP)) + Ok(item.into_bytes()) } } diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 7a95bc8..1979572 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -96,23 +96,57 @@ const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max( WG_TUNNEL_HEADER_SIZE, ) + PEER_MANAGER_HEADER_SIZE; +const INVALID_OFFSET: usize = usize::MAX; + +const fn get_converted_offset(old_hdr_size: usize, new_hdr_size: usize) -> usize { + if old_hdr_size < new_hdr_size { + INVALID_OFFSET + } else { + old_hdr_size - new_hdr_size + } +} + impl ZCPacketType { pub fn get_packet_offsets(&self) -> ZCPacketOffsets { match self { ZCPacketType::TCP => ZCPacketOffsets { payload_offset: TCP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, peer_manager_header_offset: TCP_TUNNEL_HEADER_SIZE, - ..Default::default() + tcp_tunnel_header_offset: 0, + udp_tunnel_header_offset: get_converted_offset( + TCP_TUNNEL_HEADER_SIZE, + UDP_TUNNEL_HEADER_SIZE, + ), + wg_tunnel_header_offset: get_converted_offset( + TCP_TUNNEL_HEADER_SIZE, + WG_TUNNEL_HEADER_SIZE, + ), }, ZCPacketType::UDP => ZCPacketOffsets { payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, peer_manager_header_offset: UDP_TUNNEL_HEADER_SIZE, - ..Default::default() + tcp_tunnel_header_offset: get_converted_offset( + UDP_TUNNEL_HEADER_SIZE, + TCP_TUNNEL_HEADER_SIZE, + ), + udp_tunnel_header_offset: 0, + wg_tunnel_header_offset: get_converted_offset( + UDP_TUNNEL_HEADER_SIZE, + WG_TUNNEL_HEADER_SIZE, + ), }, ZCPacketType::WG => ZCPacketOffsets { payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, peer_manager_header_offset: WG_TUNNEL_HEADER_SIZE, - ..Default::default() + tcp_tunnel_header_offset: get_converted_offset( + WG_TUNNEL_HEADER_SIZE, + TCP_TUNNEL_HEADER_SIZE, + ), + udp_tunnel_header_offset: get_converted_offset( + WG_TUNNEL_HEADER_SIZE, + UDP_TUNNEL_HEADER_SIZE, + ), + wg_tunnel_header_offset: 0, }, ZCPacketType::NIC => ZCPacketOffsets { payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET, @@ -155,8 +189,10 @@ impl ZCPacket { pub fn new_with_payload(payload: &[u8]) -> Self { let mut ret = Self::new_nic_packet(); - let total_len = ret.packet_type.get_packet_offsets().payload_offset + payload.len(); - ret.inner.resize(total_len, 0); + let payload_off = ret.packet_type.get_packet_offsets().payload_offset; + let total_len = payload_off + payload.len(); + ret.inner.reserve(total_len); + unsafe { ret.inner.set_len(total_len) }; ret.mut_payload()[..payload.len()].copy_from_slice(&payload); ret } @@ -165,7 +201,7 @@ impl ZCPacket { let mut ret = Self::new_nic_packet(); ret.inner.reserve(cap); let total_len = ret.packet_type.get_packet_offsets().payload_offset - packet_info_len; - ret.inner.resize(total_len, 0); + unsafe { ret.inner.set_len(total_len) }; ret } @@ -275,45 +311,56 @@ impl ZCPacket { hdr.len.set(payload_len as u32); } - pub fn into_bytes(mut self, target_packet_type: ZCPacketType) -> Bytes { + fn tunnel_payload(&self) -> &[u8] { + &self.inner[self + .packet_type + .get_packet_offsets() + .peer_manager_header_offset..] + } + + pub fn convert_type(mut self, target_packet_type: ZCPacketType) -> Self { if target_packet_type == self.packet_type { - return self.inner.freeze(); - } else { - assert_eq!( - self.packet_type, - ZCPacketType::NIC, - "only support NIC, got {:?}", - self - ); + return self; } - match target_packet_type { - ZCPacketType::TCP => self - .inner - .split_off( - self.packet_type - .get_packet_offsets() - .tcp_tunnel_header_offset, - ) - .freeze(), - ZCPacketType::UDP => self - .inner - .split_off( - self.packet_type - .get_packet_offsets() - .udp_tunnel_header_offset, - ) - .freeze(), - ZCPacketType::WG => self - .inner - .split_off( - self.packet_type - .get_packet_offsets() - .wg_tunnel_header_offset, - ) - .freeze(), + let new_offset = match target_packet_type { + ZCPacketType::TCP => { + self.packet_type + .get_packet_offsets() + .tcp_tunnel_header_offset + } + ZCPacketType::UDP => { + self.packet_type + .get_packet_offsets() + .udp_tunnel_header_offset + } + ZCPacketType::WG => { + self.packet_type + .get_packet_offsets() + .wg_tunnel_header_offset + } ZCPacketType::NIC => unreachable!(), + }; + + tracing::debug!(?self.packet_type, ?target_packet_type, ?new_offset, "convert zc packet type"); + + if new_offset == INVALID_OFFSET { + // copy peer manager header and payload to new buffer + let tunnel_payload = self.tunnel_payload(); + let new_pm_offset = target_packet_type + .get_packet_offsets() + .peer_manager_header_offset; + let mut buf = BytesMut::with_capacity(new_pm_offset + tunnel_payload.len()); + unsafe { buf.set_len(new_pm_offset) }; + buf.extend_from_slice(tunnel_payload); + return Self::new_from_buf(buf, target_packet_type); } + + return Self::new_from_buf(self.inner.split_off(new_offset), target_packet_type); + } + + pub fn into_bytes(self) -> Bytes { + self.inner.freeze() } pub fn inner(self) -> BytesMut { @@ -349,7 +396,7 @@ mod tests { assert_eq!(packet.payload_len(), 11); println!("{:?}", packet.inner); - let tcp_packet = packet.into_bytes(ZCPacketType::TCP); + let tcp_packet = packet.convert_type(ZCPacketType::TCP).into_bytes(); assert_eq!(&tcp_packet[..1], b"\x0b"); println!("{:?}", tcp_packet); } diff --git a/easytier/src/tunnel/stats.rs b/easytier/src/tunnel/stats.rs index 8e8d7a4..be89327 100644 --- a/easytier/src/tunnel/stats.rs +++ b/easytier/src/tunnel/stats.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed}; +use std::sync::atomic::{AtomicU32, Ordering::Relaxed}; pub struct WindowLatency { latency_us_window: Vec, @@ -48,48 +48,49 @@ impl WindowLatency { } } +#[derive(Default)] pub struct Throughput { - tx_bytes: AtomicU64, - rx_bytes: AtomicU64, + tx_bytes: u64, + rx_bytes: u64, - tx_packets: AtomicU64, - rx_packets: AtomicU64, + tx_packets: u64, + rx_packets: u64, } impl Throughput { pub fn new() -> Self { - Self { - tx_bytes: AtomicU64::new(0), - rx_bytes: AtomicU64::new(0), - - tx_packets: AtomicU64::new(0), - rx_packets: AtomicU64::new(0), - } + Self::default() } pub fn tx_bytes(&self) -> u64 { - self.tx_bytes.load(Relaxed) + self.tx_bytes } pub fn rx_bytes(&self) -> u64 { - self.rx_bytes.load(Relaxed) + self.rx_bytes } pub fn tx_packets(&self) -> u64 { - self.tx_packets.load(Relaxed) + self.tx_packets } pub fn rx_packets(&self) -> u64 { - self.rx_packets.load(Relaxed) + self.rx_packets } pub fn record_tx_bytes(&self, bytes: u64) { - self.tx_bytes.fetch_add(bytes, Relaxed); - self.tx_packets.fetch_add(1, Relaxed); + #[allow(invalid_reference_casting)] + unsafe { + *(&self.tx_bytes as *const u64 as *mut u64) += bytes; + *(&self.tx_packets as *const u64 as *mut u64) += 1; + } } pub fn record_rx_bytes(&self, bytes: u64) { - self.rx_bytes.fetch_add(bytes, Relaxed); - self.rx_packets.fetch_add(1, Relaxed); + #[allow(invalid_reference_casting)] + unsafe { + *(&self.rx_bytes as *const u64 as *mut u64) += bytes; + *(&self.rx_packets as *const u64 as *mut u64) += 1; + } } } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 81ceab4..bede853 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -125,20 +125,21 @@ async fn forward_from_ring_to_udp( let Some(buf) = ring_recv.next().await else { return None; }; - let mut packet = match buf { + let packet = match buf { Ok(v) => v, Err(e) => { return Some(e); } }; + let mut packet = packet.convert_type(ZCPacketType::UDP); let udp_payload_len = packet.udp_payload().len(); let header = packet.mut_udp_tunnel_header().unwrap(); header.conn_id.set(conn_id); header.len.set(udp_payload_len as u16); header.msg_type = UdpPacketType::Data as u8; - let buf = packet.into_bytes(ZCPacketType::UDP); + let buf = packet.into_bytes(); tracing::trace!(?udp_payload_len, ?buf, "udp forward from ring to udp"); let ret = socket.send_to(&buf, &addr).await; if ret.is_err() { @@ -232,7 +233,7 @@ impl UdpTunnelListenerData { tracing::info!(?conn_id, ?remote_addr, "udp connection accept handling",); let socket = self.socket.as_ref().unwrap().clone(); - let sack_buf = new_sack_packet(conn_id, magic).into_bytes(ZCPacketType::UDP); + let sack_buf = new_sack_packet(conn_id, magic).into_bytes(); if let Err(e) = socket.send_to(&sack_buf, remote_addr).await { tracing::error!(?e, "udp send sack packet error"); return; @@ -436,6 +437,7 @@ impl TunnelListener for UdpTunnelListener { } } +#[derive(Debug)] pub struct UdpTunnelConnector { addr: url::Url, bind_addrs: Vec, @@ -613,7 +615,7 @@ impl UdpTunnelConnector { // send syn let conn_id = rand::random(); let magic = rand::random(); - let udp_packet = new_syn_packet(conn_id, magic).into_bytes(ZCPacketType::UDP); + let udp_packet = new_syn_packet(conn_id, magic).into_bytes(); let ret = socket.send_to(&udp_packet, &addr).await?; tracing::warn!(?udp_packet, ?ret, "udp send syn"); diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index 02091ae..da5d85b 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -138,17 +138,15 @@ impl Debug for WgPeerData { impl WgPeerData { #[tracing::instrument] - async fn handle_one_packet_from_me( - &self, - mut zc_packet: ZCPacket, - ) -> Result<(), anyhow::Error> { + async fn handle_one_packet_from_me(&self, zc_packet: ZCPacket) -> Result<(), anyhow::Error> { let mut send_buf = vec![0u8; MAX_PACKET]; let packet = if matches!(self.wg_type, WgType::InternalUse) { + let mut zc_packet = zc_packet.convert_type(ZCPacketType::WG); Self::fill_ip_header(&mut zc_packet); - zc_packet.into_bytes(ZCPacketType::WG) + zc_packet.into_bytes() } else { - zc_packet.into_bytes(ZCPacketType::WG) + zc_packet.convert_type(ZCPacketType::WG).into_bytes() }; tracing::trace!(?packet, "Sending packet to peer"); @@ -650,7 +648,7 @@ impl WgTunnelConnector { let mut buf = vec![0u8; MAX_PACKET]; let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); if recv_addr != addr { - continue; + tracing::warn!(?recv_addr, "Received packet from changed address"); } data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await; }