diff --git a/easytier/src/gateway/icmp_proxy.rs b/easytier/src/gateway/icmp_proxy.rs index 8f3c385..3e0f4d4 100644 --- a/easytier/src/gateway/icmp_proxy.rs +++ b/easytier/src/gateway/icmp_proxy.rs @@ -3,12 +3,13 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddrV4}, sync::Arc, thread, + time::Duration, }; use pnet::packet::{ icmp::{self, IcmpTypes}, ip::IpNextHeaderProtocols, - ipv4::{self, Ipv4Packet, MutableIpv4Packet}, + ipv4::Ipv4Packet, Packet, }; use socket2::Socket; @@ -25,7 +26,10 @@ use crate::{ tunnel::packet_def::{PacketType, ZCPacket}, }; -use super::CidrSet; +use super::{ + ip_reassembler::{compose_ipv4_packet, IpReassembler}, + CidrSet, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct IcmpNatKey { @@ -68,6 +72,8 @@ pub struct IcmpProxy { nat_table: IcmpNatTable, tasks: Mutex>, + + ip_resemmbler: Arc, } fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit]) -> Result<(usize, IpAddr), Error> { @@ -80,7 +86,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit]) -> Result<(usize, I } fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender) { - let mut buf = [0u8; 2048]; + let mut buf = [0u8; 8192]; let data: &mut [MaybeUninit] = unsafe { std::mem::transmute(&mut buf[..]) }; loop { @@ -92,7 +98,7 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe continue; } - let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[..len]) else { + let Some(ipv4_packet) = Ipv4Packet::new(&buf[..len]) else { continue; }; @@ -120,24 +126,31 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe continue; }; - ipv4_packet.set_destination(dest_ip); + let src_v4 = ipv4_packet.get_source(); + let payload_len = len - ipv4_packet.get_header_length() as usize * 4; + let id = ipv4_packet.get_identification(); + let _ = compose_ipv4_packet( + &mut buf[..], + &src_v4, + &dest_ip, + IpNextHeaderProtocols::Icmp, + payload_len, + 1200, + id, + |buf| { + let mut p = ZCPacket::new_with_payload(buf); + p.fill_peer_manager_hdr( + v.my_peer_id.into(), + v.src_peer_id.into(), + PacketType::Data as u8, + ); - // MacOS do not correctly set ip length when receiving from raw socket - ipv4_packet.set_total_length(len as u16); - - ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); - - let mut p = ZCPacket::new_with_payload(ipv4_packet.packet()); - p.fill_peer_manager_hdr( - v.my_peer_id.into(), - v.src_peer_id.into(), - PacketType::Data as u8, + if let Err(e) = sender.send(p) { + tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); + } + Ok(()) + }, ); - - if let Err(e) = sender.send(p) { - tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); - break; - } } } @@ -166,6 +179,8 @@ impl IcmpProxy { nat_table: Arc::new(dashmap::DashMap::new()), tasks: Mutex::new(JoinSet::new()), + + ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))), }; Ok(Arc::new(ret)) @@ -226,6 +241,14 @@ impl IcmpProxy { .instrument(tracing::info_span!("icmp proxy send loop")), ); + let ip_resembler = self.ip_resemmbler.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + ip_resembler.remove_expired_packets(); + } + }); + self.peer_manager .add_packet_process_pipeline(Box::new(self.clone())) .await; @@ -269,7 +292,18 @@ impl IcmpProxy { return None; } - let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?; + let resembled_buf: Option>; + let icmp_packet = if IpReassembler::is_packet_fragmented(&ipv4) { + resembled_buf = + self.ip_resemmbler + .add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4); + if resembled_buf.is_none() { + return None; + }; + icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())? + } else { + icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())? + }; if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { // drop it because we do not support other icmp types diff --git a/easytier/src/gateway/ip_reassembler.rs b/easytier/src/gateway/ip_reassembler.rs new file mode 100644 index 0000000..7f20c9e --- /dev/null +++ b/easytier/src/gateway/ip_reassembler.rs @@ -0,0 +1,299 @@ +use dashmap::DashMap; +use pnet::packet::ip::IpNextHeaderProtocol; +use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet}; +use pnet::packet::Packet; +use std::net::Ipv4Addr; +use std::time::{Duration, Instant}; + +use crate::common::error::Error; + +#[derive(Debug, Clone)] +pub(crate) struct IpFragment { + id: u16, + offset: u16, + data: Vec, +} + +impl<'a> From<&Ipv4Packet<'a>> for IpFragment { + fn from(packet: &Ipv4Packet<'a>) -> Self { + let id = packet.get_identification(); + let offset = packet.get_fragment_offset() * 8; + let data = packet.payload().to_vec(); + IpFragment { id, offset, data } + } +} + +#[derive(Debug, Clone)] +struct IpPacket { + source: Ipv4Addr, + destination: Ipv4Addr, + total_length: Option, + fragments: Vec, +} + +impl IpPacket { + fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self { + IpPacket { + source, + destination, + total_length: None, + fragments: Vec::new(), + } + } + + fn add_fragment(&mut self, fragment: IpFragment) { + // make sure the fragment doesn't overlap with existing fragments + for f in &self.fragments { + if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 { + return; + } + if fragment.offset <= f.offset + && f.offset < fragment.offset + fragment.data.len() as u16 + { + return; + } + } + self.fragments.push(fragment); + } + + fn is_complete(&self) -> bool { + if self.total_length.is_none() { + return false; + } + let mut total_length = 0; + for fragment in &self.fragments { + total_length += fragment.data.len() as u16; + } + tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete"); + Some(total_length) == self.total_length + } + + fn set_total_length(&mut self, total_length: u16) { + self.total_length = Some(total_length); + } + + fn assemble(&mut self) -> Option> { + if !self.is_complete() { + return None; + } + + // sort fragments by offset + self.fragments.sort_by_key(|f| f.offset); + + let mut packet = vec![0u8; self.total_length.unwrap() as usize]; + for fragment in &self.fragments { + let start = fragment.offset as usize; + let end = start + fragment.data.len(); + packet[start..end].copy_from_slice(&fragment.data); + } + + Some(packet) + } +} + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +struct IpResemblerKey { + source: Ipv4Addr, + destination: Ipv4Addr, + id: u16, +} + +#[derive(Debug)] +struct IpResemblerValue { + packet: IpPacket, + timestamp: Instant, +} + +#[derive(Debug)] +pub(crate) struct IpReassembler { + packets: DashMap, + timeout: Duration, +} + +impl IpReassembler { + pub fn new(timeout: Duration) -> Self { + IpReassembler { + packets: DashMap::new(), + timeout, + } + } + + pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool { + packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0 + } + + pub fn is_last_fragment(packet: &Ipv4Packet) -> bool { + packet.get_flags() & Ipv4Flags::MoreFragments == 0 + } + + pub fn add_fragment( + &self, + source: Ipv4Addr, + destination: Ipv4Addr, + packet: &Ipv4Packet, + ) -> Option> { + let id = packet.get_identification(); + let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4; + if total_length != packet.payload().len() as u16 { + tracing::trace!( + ?packet, + ?total_length, + payload_len = ?packet.payload().len(), + "unexpected total length", + ); + return None; + } + + let fragment: IpFragment = packet.into(); + let key = IpResemblerKey { + source, + destination, + id, + }; + + let mut entry = self.packets.entry(key.clone()).or_insert_with(|| { + let packet = IpPacket::new(source, destination); + let timestamp = Instant::now(); + IpResemblerValue { packet, timestamp } + }); + let value_mut = entry.value_mut(); + + if Self::is_last_fragment(packet) { + value_mut + .packet + .set_total_length(total_length + fragment.offset); + } + + value_mut.packet.add_fragment(fragment); + if let Some(data) = value_mut.packet.assemble() { + drop(entry); + self.packets.remove(&key); + Some(data) + } else { + value_mut.timestamp = Instant::now(); + None + } + } + + pub fn remove_expired_packets(&self) { + let timeout = self.timeout; + self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout); + } +} + +// ip payload should be in buf[20..] +pub fn compose_ipv4_packet( + buf: &mut [u8], + src_v4: &Ipv4Addr, + dst_v4: &Ipv4Addr, + next_protocol: IpNextHeaderProtocol, + payload_len: usize, + payload_mtu: usize, + ip_id: u16, + cb: F, +) -> Result<(), Error> +where + F: Fn(&[u8]) -> Result<(), Error>, +{ + let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu; + let mut buf_offset = 0; + let mut fragment_offset = 0; + let mut cur_piece = 0; + while fragment_offset < payload_len { + let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len); + let fragment_len = next_fragment_offset - fragment_offset; + let mut ipv4_packet = + MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap(); + ipv4_packet.set_version(4); + ipv4_packet.set_header_length(5); + ipv4_packet.set_total_length((fragment_len + 20) as u16); + ipv4_packet.set_identification(ip_id); + if total_pieces > 1 { + if cur_piece != total_pieces - 1 { + ipv4_packet.set_flags(Ipv4Flags::MoreFragments); + } else { + ipv4_packet.set_flags(0); + } + assert_eq!(0, fragment_offset % 8); + ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8); + } else { + ipv4_packet.set_flags(Ipv4Flags::DontFragment); + ipv4_packet.set_fragment_offset(0); + } + ipv4_packet.set_ecn(0); + ipv4_packet.set_dscp(0); + ipv4_packet.set_ttl(32); + ipv4_packet.set_source(src_v4.clone()); + ipv4_packet.set_destination(dst_v4.clone()); + ipv4_packet.set_next_level_protocol(next_protocol); + ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); + + tracing::trace!(?ipv4_packet, "udp nat packet response send"); + + cb(ipv4_packet.packet())?; + + buf_offset += next_fragment_offset - fragment_offset; + fragment_offset = next_fragment_offset; + cur_piece += 1; + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resembler() { + let raw_packets = vec![ + // last packet + vec![ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07, + ], + // 1st packet + vec![ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07, + ], + // 2nd packet + vec![ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ], + // expired packet + vec![ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ], + ]; + + let source = "192.168.0.1".parse().unwrap(); + let destination = "192.168.0.2".parse().unwrap(); + let resembler = IpReassembler::new(Duration::from_secs(1)); + + for (idx, raw_packet) in raw_packets.iter().enumerate() { + if let Some(packet) = Ipv4Packet::new(&raw_packet) { + let ret = resembler.add_fragment(source, destination, &packet); + if idx != 2 { + assert!(ret.is_none()); + } else { + assert!(ret.is_some()); + } + println!( + "packet: {:?}, ret: {:?}, palyload_len: {}", + packet, + ret, + packet.payload().len() + ); + } + } + + resembler.remove_expired_packets(); + assert_eq!(1, resembler.packets.len()); + + std::thread::sleep(Duration::from_secs(2)); + resembler.remove_expired_packets(); + assert_eq!(0, resembler.packets.len()); + } +} diff --git a/easytier/src/gateway/mod.rs b/easytier/src/gateway/mod.rs index df8c8ad..c1007d8 100644 --- a/easytier/src/gateway/mod.rs +++ b/easytier/src/gateway/mod.rs @@ -4,6 +4,7 @@ use tokio::task::JoinSet; use crate::common::global_ctx::ArcGlobalCtx; pub mod icmp_proxy; +pub mod ip_reassembler; pub mod tcp_proxy; pub mod udp_proxy; diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index 0714e6a..7f8d55c 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -7,7 +7,7 @@ use std::{ use dashmap::DashMap; use pnet::packet::{ ip::IpNextHeaderProtocols, - ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet}, + ipv4::Ipv4Packet, udp::{self, MutableUdpPacket}, Packet, }; @@ -25,6 +25,7 @@ use tracing::Level; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, + gateway::ip_reassembler::compose_ipv4_packet, peers::{peer_manager::PeerManager, PeerPacketFilter}, tunnel::{ common::setup_sokcet2, @@ -32,7 +33,7 @@ use crate::{ }, }; -use super::CidrSet; +use super::{ip_reassembler::IpReassembler, CidrSet}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct UdpNatKey { @@ -105,60 +106,31 @@ impl UdpNatEntry { nat_src_v4.ip(), )); - let payload_len = payload_len + 8; // include udp header - let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu; - let mut buf_offset = 0; - let mut fragment_offset = 0; - let mut cur_piece = 0; - while fragment_offset < payload_len { - let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len); - let fragment_len = next_fragment_offset - fragment_offset; - let mut ipv4_packet = - MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]) - .unwrap(); - ipv4_packet.set_version(4); - ipv4_packet.set_header_length(5); - ipv4_packet.set_total_length((fragment_len + 20) as u16); - ipv4_packet.set_identification(ip_id); - if total_pieces > 1 { - if cur_piece != total_pieces - 1 { - ipv4_packet.set_flags(Ipv4Flags::MoreFragments); - } else { - ipv4_packet.set_flags(0); + compose_ipv4_packet( + &mut buf[..], + src_v4.ip(), + nat_src_v4.ip(), + IpNextHeaderProtocols::Udp, + payload_len + 8, // include udp header + payload_mtu, + ip_id, + |buf| { + let mut p = ZCPacket::new_with_payload(buf); + p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8); + + if let Err(e) = packet_sender.send(p) { + tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); + return Err(Error::AnyhowError(e.into())); } - assert_eq!(0, fragment_offset % 8); - ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8); - } else { - ipv4_packet.set_flags(Ipv4Flags::DontFragment); - ipv4_packet.set_fragment_offset(0); - } - ipv4_packet.set_ecn(0); - ipv4_packet.set_dscp(0); - ipv4_packet.set_ttl(32); - ipv4_packet.set_source(src_v4.ip().clone()); - ipv4_packet.set_destination(nat_src_v4.ip().clone()); - ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp); - ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); + Ok(()) + }, + )?; - tracing::trace!(?ipv4_packet, "udp nat packet response send"); - - let mut p = ZCPacket::new_with_payload(ipv4_packet.packet()); - p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8); - - if let Err(e) = packet_sender.send(p) { - tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); - return Err(Error::AnyhowError(e.into())); - } - - buf_offset += next_fragment_offset - fragment_offset; - fragment_offset = next_fragment_offset; - cur_piece += 1; - } Ok(()) } async fn forward_task(self: Arc, mut packet_sender: UnboundedSender) { - let mut buf = [0u8; 8192]; + let mut buf = [0u8; 65536]; let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) }; let mut ip_id = 1; @@ -223,6 +195,8 @@ pub struct UdpProxy { receiver: Mutex>>, tasks: Mutex>, + + ip_resemmbler: Arc, } impl UdpProxy { @@ -247,7 +221,18 @@ impl UdpProxy { return None; } - let udp_packet = udp::UdpPacket::new(ipv4.payload())?; + let resembled_buf: Option>; + let udp_packet = if IpReassembler::is_packet_fragmented(&ipv4) { + resembled_buf = + self.ip_resemmbler + .add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4); + if resembled_buf.is_none() { + return None; + }; + udp::UdpPacket::new(resembled_buf.as_ref().unwrap())? + } else { + udp::UdpPacket::new(ipv4.payload())? + }; tracing::trace!( ?packet, @@ -336,6 +321,7 @@ impl UdpProxy { sender, receiver: Mutex::new(Some(receiver)), tasks: Mutex::new(JoinSet::new()), + ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))), }; Ok(Arc::new(ret)) } @@ -362,6 +348,14 @@ impl UdpProxy { } }); + let ip_resembler = self.ip_resemmbler.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + ip_resembler.remove_expired_packets(); + } + }); + // forward packets to peer manager let mut receiver = self.receiver.lock().await.take().unwrap(); let peer_manager = self.peer_manager.clone(); diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index 9774bbb..86ecec1 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -292,7 +292,10 @@ impl VirtualNic { config.platform(|config| { config.skip_config(true); config.guid(None); - config.ring_cap(Some(config.min_ring_cap() * 2)); + config.ring_cap(Some(std::cmp::min( + config.min_ring_cap() * 32, + config.max_ring_cap(), + ))); }); } diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 7f6f9d7..2a1de6f 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -136,7 +136,7 @@ pub async fn init_three_node(proto: &str) -> Vec { vec![inst1, inst2, inst3] } -async fn ping_test(from_netns: &str, target_ip: &str) -> bool { +async fn ping_test(from_netns: &str, target_ip: &str, payload_size: Option) -> bool { let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); let code = tokio::process::Command::new("ip") .args(&[ @@ -146,6 +146,8 @@ async fn ping_test(from_netns: &str, target_ip: &str) -> bool { "ping", "-c", "1", + "-s", + payload_size.unwrap_or(56).to_string().as_str(), "-W", "1", target_ip.to_string().as_str(), @@ -175,7 +177,7 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr ); wait_for_condition( - || async { ping_test("net_c", "10.144.144.1").await }, + || async { ping_test("net_c", "10.144.144.1", None).await }, Duration::from_secs(5000), ) .await; @@ -185,6 +187,8 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr #[tokio::test] #[serial_test::serial] pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + use rand::Rng; + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; let mut insts = init_three_node(proto).await; @@ -210,11 +214,15 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap()); let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap()); + let mut buf = vec![0; 32]; + rand::thread_rng().fill(&mut buf[..]); + _tunnel_pingpong_netns( tcp_listener, tcp_connector, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), + buf, ) .await; } @@ -241,7 +249,13 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st .await; wait_for_condition( - || async { ping_test("net_a", "10.1.2.4").await }, + || async { ping_test("net_a", "10.1.2.4", None).await }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || async { ping_test("net_a", "10.1.2.4", Some(5 * 1024)).await }, Duration::from_secs(5), ) .await; @@ -318,6 +332,8 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str #[tokio::test] #[serial_test::serial] pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + use rand::Rng; + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; let mut insts = init_three_node(proto).await; @@ -343,11 +359,32 @@ pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); + // NOTE: this should not excced udp tunnel max buffer size + let mut buf = vec![0; 20 * 1024]; + rand::thread_rng().fill(&mut buf[..]); + _tunnel_pingpong_netns( tcp_listener, tcp_connector, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + // no fragment + let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); + let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); + + let mut buf = vec![0; 1 * 1024]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + tcp_listener, + tcp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf, ) .await; } @@ -443,7 +480,7 @@ pub async fn foreign_network_forward_nic_data() { .await; wait_for_condition( - || async { ping_test("net_b", "10.144.145.2").await }, + || async { ping_test("net_b", "10.144.145.2", None).await }, Duration::from_secs(5), ) .await; @@ -531,19 +568,19 @@ pub async fn wireguard_vpn_portal() { // ping other node in network wait_for_condition( - || async { ping_test("net_d", "10.144.144.1").await }, + || async { ping_test("net_d", "10.144.144.1", None).await }, Duration::from_secs(5), ) .await; wait_for_condition( - || async { ping_test("net_d", "10.144.144.2").await }, + || async { ping_test("net_d", "10.144.144.2", None).await }, Duration::from_secs(5), ) .await; // ping portal node wait_for_condition( - || async { ping_test("net_d", "10.144.144.3").await }, + || async { ping_test("net_d", "10.144.144.3", None).await }, Duration::from_secs(5), ) .await; diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index ec35446..a7c4c11 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -107,7 +107,10 @@ impl FramedReader { } } - fn extract_one_packet(buf: &mut BytesMut) -> Option { + fn extract_one_packet( + buf: &mut BytesMut, + max_packet_size: usize, + ) -> Option> { if buf.len() < TCP_TUNNEL_HEADER_SIZE { // header is not complete return None; @@ -115,6 +118,11 @@ impl FramedReader { let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap(); let body_len = header.len.get() as usize; + if body_len > max_packet_size { + // body is too long + return Some(Err(TunnelError::InvalidPacket("body too long".to_string()))); + } + if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len { // body is not complete return None; @@ -122,7 +130,7 @@ impl FramedReader { // extract one packet let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len); - Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)) + Some(Ok(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))) } } @@ -139,8 +147,10 @@ where let mut self_mut = self.project(); loop { - while let Some(packet) = Self::extract_one_packet(self_mut.buf) { - return Poll::Ready(Some(Ok(packet))); + while let Some(packet) = + Self::extract_one_packet(self_mut.buf, *self_mut.max_packet_size) + { + return Poll::Ready(Some(packet)); } reserve_buf( @@ -465,7 +475,14 @@ pub mod tests { L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, { - _tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await + _tunnel_pingpong_netns( + listener, + connector, + NetNS::new(None), + NetNS::new(None), + "12345678abcdefg".as_bytes().to_vec(), + ) + .await; } pub(crate) async fn _tunnel_pingpong_netns( @@ -473,6 +490,7 @@ pub mod tests { mut connector: C, l_netns: NetNS, c_netns: NetNS, + buf: Vec, ) where L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, @@ -503,7 +521,7 @@ pub mod tests { let (mut recv, mut send) = tunnel.split(); - send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes())) + send.send(ZCPacket::new_with_payload(buf.as_slice())) .await .unwrap(); @@ -513,7 +531,7 @@ pub mod tests { .unwrap() .unwrap(); println!("echo back: {:?}", ret); - assert_eq!(ret.payload(), Bytes::from("12345678abcdefg")); + assert_eq!(ret.payload(), Bytes::from(buf)); send.close().await.unwrap(); diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 59a61c2..fbf4bed 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -158,7 +158,13 @@ where let mut buf = BytesMut::new(); loop { reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); - let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap(); + let (dg_size, addr) = match socket.recv_buf_from(&mut buf).await { + Ok(v) => v, + Err(e) => { + tracing::error!(?e, "udp recv from socket error"); + break; + } + }; tracing::trace!( "udp recv packet: {:?}, buf: {:?}, size: {}", addr,