From f9e6264f31f0ba56c7b1b4a8e7bb6ce2340605e7 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Tue, 4 Jun 2024 18:50:30 +0800 Subject: [PATCH] fix upx and udp conn counter (#131) * fix upx in workflow * fix udp conn counter --- .github/workflows/core.yml | 2 +- easytier/src/connector/udp_hole_punch.rs | 2 +- easytier/src/peer_center/instance.rs | 5 +- easytier/src/peers/peer_manager.rs | 3 +- easytier/src/peers/peer_ospf_route.rs | 3 +- easytier/src/peers/peer_rpc.rs | 9 +- easytier/src/peers/tests.rs | 17 -- easytier/src/tests/three_node.rs | 2 +- easytier/src/tunnel/common.rs | 17 +- easytier/src/tunnel/udp.rs | 234 +++++++++++++---------- 10 files changed, 158 insertions(+), 136 deletions(-) diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 433b00f..3c5fb2b 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -121,7 +121,7 @@ jobs: TAG=$GITHUB_SHA fi - if [[ $OS =~ ^windows.*$ ]]; then + if [[ ! $OS =~ ^macos.*$ ]]; then upx --lzma --best ./target/$TARGET/release/easytier-core"$SUFFIX" upx --lzma --best ./target/$TARGET/release/easytier-cli"$SUFFIX" fi diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index f410f83..33980c4 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -1008,8 +1008,8 @@ pub mod tests { use tokio::net::UdpSocket; use crate::connector::udp_hole_punch::UdpHolePunchListener; - use crate::peers::tests::wait_for_condition; use crate::rpc::{NatType, StunInfo}; + use crate::tunnel::common::tests::wait_for_condition; use crate::{ common::{error::Error, stun::StunInfoCollectorTrait}, diff --git a/easytier/src/peer_center/instance.rs b/easytier/src/peer_center/instance.rs index 4bfebe2..06a4826 100644 --- a/easytier/src/peer_center/instance.rs +++ b/easytier/src/peer_center/instance.rs @@ -352,9 +352,8 @@ impl PeerCenterInstance { mod tests { use crate::{ peer_center::server::get_global_data, - peers::tests::{ - connect_peer_manager, create_mock_peer_manager, wait_for_condition, wait_route_appear, - }, + peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, + tunnel::common::tests::wait_for_condition, }; use super::*; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 4ee0e18..17a4439 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -758,9 +758,10 @@ mod tests { peers::{ peer_manager::RouteAlgoType, peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient}, - tests::{connect_peer_manager, wait_for_condition, wait_route_appear}, + tests::{connect_peer_manager, wait_route_appear}, }, rpc::NatType, + tunnel::common::tests::wait_for_condition, tunnel::{TunnelConnector, TunnelListener}, }; diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 95e8234..3a9b3c5 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -1467,9 +1467,10 @@ mod tests { peers::{ peer_manager::{PeerManager, RouteAlgoType}, route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, - tests::{connect_peer_manager, wait_for_condition}, + tests::connect_peer_manager, }, rpc::NatType, + tunnel::common::tests::wait_for_condition, }; use super::PeerRoute; diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index fcdde0a..7d87459 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -557,14 +557,11 @@ pub mod tests { common::{error::Error, new_peer_id, PeerId}, peers::{ peer_rpc::PeerRpcManager, - tests::{ - connect_peer_manager, create_mock_peer_manager, wait_for_condition, - wait_route_appear, - }, + tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, }, tunnel::{ - packet_def::ZCPacket, ring::create_ring_tunnel_pair, Tunnel, ZCPacketSink, - ZCPacketStream, + common::tests::wait_for_condition, packet_def::ZCPacket, ring::create_ring_tunnel_pair, + Tunnel, ZCPacketSink, ZCPacketStream, }, }; diff --git a/easytier/src/peers/tests.rs b/easytier/src/peers/tests.rs index 6cef861..a2764e2 100644 --- a/easytier/src/peers/tests.rs +++ b/easytier/src/peers/tests.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use futures::Future; - use crate::{ common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId}, tunnel::ring::create_ring_tunnel_pair, @@ -58,18 +56,3 @@ pub async fn wait_route_appear( wait_route_appear_with_cost(peer_mgr.clone(), target_peer.my_peer_id(), None).await?; wait_route_appear_with_cost(target_peer, peer_mgr.my_peer_id(), None).await } - -pub async fn wait_for_condition(mut condition: F, timeout: std::time::Duration) -> () -where - F: FnMut() -> FRet + Send, - FRet: Future, -{ - let now = std::time::Instant::now(); - while now.elapsed() < timeout { - if condition().await { - return; - } - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - } - assert!(condition().await, "Timeout") -} diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 04c88b2..7f6f9d7 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -13,7 +13,7 @@ use crate::{ netns::{NetNS, ROOT_NETNS_NAME}, }, instance::instance::Instance, - peers::tests::wait_for_condition, + tunnel::common::tests::wait_for_condition, tunnel::{ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector}, }; diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 5a2a4f4..ec35446 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -419,7 +419,7 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) { pub mod tests { use std::time::Instant; - use futures::{SinkExt, StreamExt, TryStreamExt}; + use futures::{Future, SinkExt, StreamExt, TryStreamExt}; use tokio_util::bytes::{BufMut, Bytes, BytesMut}; use crate::{ @@ -595,4 +595,19 @@ pub mod tests { .with_env_filter(filter) .init(); } + + pub async fn wait_for_condition(mut condition: F, timeout: std::time::Duration) -> () + where + F: FnMut() -> FRet + Send, + FRet: Future, + { + let now = std::time::Instant::now(); + while now.elapsed() < timeout { + if condition().await { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + assert!(condition().await, "Timeout") + } } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index c92af75..59a61c2 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -35,8 +35,8 @@ use super::{ pub const UDP_DATA_MTU: usize = 2000; -type UdpCloseEventSender = UnboundedSender>; -type UdpCloseEventReceiver = UnboundedReceiver>; +type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option)>; +type UdpCloseEventReceiver = UnboundedReceiver<(SocketAddr, Option)>; fn new_udp_packet(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket where @@ -151,6 +151,33 @@ async fn forward_from_ring_to_udp( } } +async fn udp_recv_from_socket_forward_task(socket: Arc, f: F) +where + F: Fn(ZCPacket, SocketAddr) -> (), +{ + let mut buf = BytesMut::new(); + loop { + reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); + let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap(); + tracing::trace!( + "udp recv packet: {:?}, buf: {:?}, size: {}", + addr, + buf, + dg_size + ); + + let zc_packet = match get_zcpacket_from_buf(buf.split()) { + Ok(v) => v, + Err(e) => { + tracing::warn!(?e, "udp get zc packet from buf error"); + continue; + } + }; + + f(zc_packet, addr); + } +} + struct UdpConnection { socket: Arc, conn_id: u32, @@ -173,7 +200,7 @@ impl UdpConnection { let forward_task = tokio::spawn(async move { let close_event_sender = close_event_sender; let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await; - if let Err(e) = close_event_sender.send(err) { + if let Err(e) = close_event_sender.send((dst_addr, err)) { tracing::error!(?e, "udp send close event error"); } }); @@ -186,6 +213,27 @@ impl UdpConnection { forward_task, } } + + pub fn handle_packet_from_remote(&self, zc_packet: ZCPacket) -> Result<(), TunnelError> { + let header = zc_packet.udp_tunnel_header().unwrap(); + let conn_id = header.conn_id.get(); + + if header.msg_type != UdpPacketType::Data as u8 { + return Err(TunnelError::InvalidPacket("not data packet".to_owned())); + } + + if self.conn_id != conn_id { + return Err(TunnelError::ConnIdNotMatch(self.conn_id, conn_id)); + } + + if !self.ring_sender.has_empty_slot() { + return Err(TunnelError::BufferFull); + } + + self.ring_sender.push_no_check(zc_packet)?; + + Ok(()) + } } impl Drop for UdpConnection { @@ -275,40 +323,16 @@ impl UdpTunnelListenerData { } } - async fn try_forward_packet( - self: &Self, - remote_addr: &SocketAddr, - conn_id: u32, - p: ZCPacket, - ) -> Result<(), TunnelError> { - let Some(conn) = self.sock_map.get(remote_addr) else { - return Err(TunnelError::InternalError( - "udp connection not found".to_owned(), - )); - }; - - if conn.conn_id != conn_id { - return Err(TunnelError::ConnIdNotMatch(conn.conn_id, conn_id)); - } - - if !conn.ring_sender.has_empty_slot() { - return Err(TunnelError::BufferFull); - } - - conn.ring_sender.push_no_check(p)?; - - Ok(()) - } - - async fn process_forward_packet(&self, zc_packet: ZCPacket, addr: &SocketAddr) { + fn do_forward_one_packet_to_conn(&self, zc_packet: ZCPacket, addr: SocketAddr) { let header = zc_packet.udp_tunnel_header().unwrap(); if header.msg_type == UdpPacketType::Syn as u8 { - tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet)); + tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet)); } else if header.msg_type != UdpPacketType::HolePunch as u8 { - if let Err(e) = self - .try_forward_packet(addr, header.conn_id.get(), zc_packet) - .await - { + let Some(conn) = self.sock_map.get(&addr) else { + tracing::trace!(?header, "udp forward packet error, connection not found"); + return; + }; + if let Err(e) = conn.handle_packet_from_remote(zc_packet) { tracing::trace!(?e, "udp forward packet error"); } } @@ -316,26 +340,10 @@ impl UdpTunnelListenerData { async fn do_forward_task(self: Self) { let socket = self.socket.as_ref().unwrap().clone(); - let mut buf = BytesMut::new(); - loop { - reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); - let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap(); - tracing::trace!( - "udp recv packet: {:?}, buf: {:?}, size: {}", - addr, - buf, - dg_size - ); - - let zc_packet = match get_zcpacket_from_buf(buf.split()) { - Ok(v) => v, - Err(e) => { - tracing::warn!(?e, "udp get zc packet from buf error"); - continue; - } - }; - self.process_forward_packet(zc_packet, &addr).await; - } + udp_recv_from_socket_forward_task(socket, |zc_packet, addr| { + self.do_forward_one_packet_to_conn(zc_packet, addr); + }) + .await; } } @@ -346,7 +354,7 @@ pub struct UdpTunnelListener { conn_recv: Receiver>, data: UdpTunnelListenerData, forward_tasks: Arc>>, - close_event_recv: UdpCloseEventReceiver, + close_event_recv: Option, } impl UdpTunnelListener { @@ -359,7 +367,7 @@ impl UdpTunnelListener { conn_recv, data: UdpTunnelListenerData::new(addr, conn_send, close_event_send), forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), - close_event_recv, + close_event_recv: Some(close_event_recv), } } @@ -398,6 +406,17 @@ impl TunnelListener for UdpTunnelListener { .unwrap() .spawn(self.data.clone().do_forward_task()); + let sock_map = Arc::downgrade(&self.data.sock_map.clone()); + let mut close_recv = self.close_event_recv.take().unwrap(); + self.forward_tasks.lock().unwrap().spawn(async move { + while let Some((dst_addr, err)) = close_recv.recv().await { + if let Some(err) = err { + tracing::error!(?err, "udp close event error"); + } + sock_map.upgrade().map(|v| v.remove(&dst_addr)); + } + }); + join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned()); Ok(()) @@ -538,62 +557,44 @@ impl UdpTunnelConnector { "udp build tunnel for connector" ); - let (close_event_send, mut close_event_recv) = tokio::sync::mpsc::unbounded_channel(); + let (close_event_sender, mut close_event_recv) = tokio::sync::mpsc::unbounded_channel(); - // forward from ring to udp - let socket_sender = socket.clone(); let ring_recv = RingStream::new(ring_for_send_udp.clone()); - tokio::spawn(async move { - let err = forward_from_ring_to_udp(ring_recv, &socket_sender, &dst_addr, conn_id).await; - tracing::debug!(?err, "udp forward from ring to udp done"); - close_event_send.send(err).unwrap(); - }); - - let socket_recv = socket.clone(); let ring_sender = RingSink::new(ring_for_recv_udp.clone()); - tokio::spawn(async move { - let mut buf = BytesMut::new(); - loop { - reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); - let ret; + let udp_conn = UdpConnection::new( + socket.clone(), + conn_id, + dst_addr, + ring_sender, + ring_recv, + close_event_sender, + ); + + let socket_clone = socket.clone(); + tokio::spawn( + async move { tokio::select! { _ = close_event_recv.recv() => { tracing::debug!("connector udp close event"); - break; + return; } - recv_res = socket_recv.recv_buf_from(&mut buf) => ret = Some(recv_res.unwrap()), - } - let (dg_size, addr) = ret.unwrap(); - tracing::trace!( - "connector udp recv packet: {:?}, buf: {:?}, size: {}", - addr, - buf, - dg_size - ); - - let zc_packet = match get_zcpacket_from_buf(buf.split()) { - Ok(v) => v, - Err(e) => { - tracing::warn!(?e, "connector udp get zc packet from buf error"); - continue; - } - }; - let header = zc_packet.udp_tunnel_header().unwrap(); - if header.conn_id.get() != conn_id { - tracing::trace!( - "connector udp conn id not match: {:?}, {:?}", - header.conn_id.get(), - conn_id - ); - } - - if header.msg_type == UdpPacketType::Data as u8 { - if let Err(e) = ring_sender.push_no_check(zc_packet) { - tracing::trace!(?e, "udp forward packet error"); + _ = udp_recv_from_socket_forward_task(socket_clone, |zc_packet, addr| { + tracing::debug!(?addr, "connector udp forward task done"); + if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) { + tracing::trace!(?e, ?addr, "udp forward packet error"); + } + }) => { + tracing::debug!("connector udp forward task done"); + return; } } } - }.instrument(tracing::info_span!("udp connector forward from udp to ring", ?ring_for_recv_udp))); + .instrument(tracing::info_span!( + "udp forward from udp to ring", + ?conn_id, + ?dst_addr, + )), + ); Ok(Box::new(TunnelWrapper::new( Box::new(RingStream::new(ring_for_recv_udp)), @@ -713,7 +714,7 @@ mod tests { check_scheme_and_get_socket_addr, common::{ get_interface_name_by_ip, - tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}, + tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong, wait_for_condition}, }, TunnelConnector, }, @@ -723,7 +724,7 @@ mod tests { async fn udp_pingpong() { let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); - _tunnel_pingpong(listener, connector).await + _tunnel_pingpong(listener, connector).await; } #[tokio::test] @@ -911,4 +912,29 @@ mod tests { let port = listener.local_url().port().unwrap(); assert!(port > 0); } + + #[tokio::test] + async fn test_conn_counter() { + let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); + let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + let _c1 = connector.connect().await.unwrap(); + let _c2 = connector.connect().await.unwrap(); + }); + + let conn_counter = listener.get_conn_counter(); + + listener.listen().await.unwrap(); + let c1 = listener.accept().await.unwrap(); + assert_eq!(conn_counter.get(), 1); + let c2 = listener.accept().await.unwrap(); + assert_eq!(conn_counter.get(), 2); + + drop(c2); + wait_for_condition(|| async { conn_counter.get() == 1 }, Duration::from_secs(1)).await; + + drop(c1); + wait_for_condition(|| async { conn_counter.get() == 0 }, Duration::from_secs(1)).await; + } }