use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::{atomic::AtomicBool, Arc}, time::Duration, }; use dashmap::DashMap; use pnet::packet::{ ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, udp::{self, MutableUdpPacket}, Packet, }; use tokio::{ net::UdpSocket, sync::{ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, Mutex, }, task::{JoinHandle, JoinSet}, time::timeout, }; use tracing::Level; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, gateway::ip_reassembler::compose_ipv4_packet, peers::{peer_manager::PeerManager, PeerPacketFilter}, tunnel::{ common::setup_sokcet2, packet_def::{PacketType, ZCPacket}, }, }; use super::{ip_reassembler::IpReassembler, CidrSet}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct UdpNatKey { src_socket: SocketAddr, } #[derive(Debug)] struct UdpNatEntry { src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr, socket: UdpSocket, forward_task: Mutex>>, stopped: AtomicBool, start_time: std::time::Instant, } impl UdpNatEntry { #[tracing::instrument(err(level = Level::WARN))] fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr) -> Result { // TODO: try use src port, so we will be ip restricted nat type let socket2_socket = socket2::Socket::new( socket2::Domain::IPV4, socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; let dst_socket_addr = "0.0.0.0:0".parse().unwrap(); setup_sokcet2(&socket2_socket, &dst_socket_addr)?; let socket = UdpSocket::from_std(socket2_socket.into())?; Ok(Self { src_peer_id, my_peer_id, src_socket, socket, forward_task: Mutex::new(None), stopped: AtomicBool::new(false), start_time: std::time::Instant::now(), }) } pub fn stop(&self) { self.stopped .store(true, std::sync::atomic::Ordering::Relaxed); } async fn compose_ipv4_packet( self: &Arc, packet_sender: &mut UnboundedSender, buf: &mut [u8], src_v4: &SocketAddrV4, payload_len: usize, payload_mtu: usize, ip_id: u16, ) -> Result<(), Error> { let SocketAddr::V4(nat_src_v4) = self.src_socket else { return Err(Error::Unknown); }; assert_eq!(0, payload_mtu % 8); // udp payload is in buf[20 + 8..] let mut udp_packet = MutableUdpPacket::new(&mut buf[20..28 + payload_len]).unwrap(); udp_packet.set_source(src_v4.port()); udp_packet.set_destination(self.src_socket.port()); udp_packet.set_length(payload_len as u16 + 8); udp_packet.set_checksum(udp::ipv4_checksum( &udp_packet.to_immutable(), src_v4.ip(), nat_src_v4.ip(), )); compose_ipv4_packet( &mut buf[..], src_v4.ip(), nat_src_v4.ip(), IpNextHeaderProtocols::Udp, payload_len + 8, // include udp header payload_mtu, ip_id, |buf| { let mut p = ZCPacket::new_with_payload(buf); p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8); if let Err(e) = packet_sender.send(p) { tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); return Err(Error::AnyhowError(e.into())); } Ok(()) }, )?; Ok(()) } async fn forward_task( self: Arc, mut packet_sender: UnboundedSender, virtual_ipv4: Ipv4Addr, ) { let mut buf = [0u8; 65536]; let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) }; let mut ip_id = 1; loop { let (len, src_socket) = match timeout( Duration::from_secs(30), self.socket.recv_from(&mut udp_body), ) .await { Ok(Ok(x)) => x, Ok(Err(err)) => { tracing::error!(?err, "udp nat recv failed"); break; } Err(err) => { tracing::error!(?err, "udp nat recv timeout"); break; } }; tracing::trace!(?len, ?src_socket, "udp nat packet response received"); if self.stopped.load(std::sync::atomic::Ordering::Relaxed) { break; } let SocketAddr::V4(mut src_v4) = src_socket else { continue; }; if src_v4.ip().is_loopback() { src_v4.set_ip(virtual_ipv4); } let Ok(_) = Self::compose_ipv4_packet( &self, &mut packet_sender, &mut buf, &src_v4, len, 1200, ip_id, ) .await else { break; }; ip_id = ip_id.wrapping_add(1); } self.stop(); } } #[derive(Debug)] pub struct UdpProxy { global_ctx: ArcGlobalCtx, peer_manager: Arc, cidr_set: CidrSet, nat_table: Arc>>, sender: UnboundedSender, receiver: Mutex>>, tasks: Mutex>, ip_resemmbler: Arc, } impl UdpProxy { async fn try_handle_packet(&self, packet: &ZCPacket) -> Option<()> { if self.cidr_set.is_empty() && !self.global_ctx.enable_exit_node() && !self.global_ctx.no_tun() { return None; } let _ = self.global_ctx.get_ipv4()?; let hdr = packet.peer_manager_header().unwrap(); let is_exit_node = hdr.is_exit_node(); if hdr.packet_type != PacketType::Data as u8 { return None; }; let ipv4 = Ipv4Packet::new(packet.payload())?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp { return None; } if !self.cidr_set.contains_v4(ipv4.get_destination()) && !is_exit_node && !(self.global_ctx.no_tun() && Some(ipv4.get_destination()) == self.global_ctx.get_ipv4()) { return None; } let resembled_buf: Option>; let udp_packet = if IpReassembler::is_packet_fragmented(&ipv4) { resembled_buf = self.ip_resemmbler .add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4); if resembled_buf.is_none() { return None; }; udp::UdpPacket::new(resembled_buf.as_ref().unwrap())? } else { udp::UdpPacket::new(ipv4.payload())? }; tracing::trace!( ?packet, ?ipv4, ?udp_packet, "udp nat packet request received" ); let nat_key = UdpNatKey { src_socket: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()), }; let nat_entry = self .nat_table .entry(nat_key) .or_try_insert_with::(|| { tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created"); let _g = self.global_ctx.net_ns.guard(); Ok(Arc::new(UdpNatEntry::new( hdr.from_peer_id.get(), hdr.to_peer_id.get(), nat_key.src_socket, )?)) }) .ok()? .clone(); if nat_entry.forward_task.lock().await.is_none() { nat_entry .forward_task .lock() .await .replace(tokio::spawn(UdpNatEntry::forward_task( nat_entry.clone(), self.sender.clone(), self.global_ctx.get_ipv4()?, ))); } // TODO: should it be async. let dst_socket = if Some(ipv4.get_destination()) == self.global_ctx.get_ipv4() { format!("127.0.0.1:{}", udp_packet.get_destination()) .parse() .unwrap() } else { SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination()) }; let send_ret = { let _g = self.global_ctx.net_ns.guard(); nat_entry .socket .send_to(udp_packet.payload(), dst_socket) .await }; if let Err(send_err) = send_ret { tracing::error!( ?send_err, ?nat_key, ?nat_entry, ?send_err, "udp nat send failed" ); } Some(()) } } #[async_trait::async_trait] impl PeerPacketFilter for UdpProxy { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { if let Some(_) = self.try_handle_packet(&packet).await { return None; } else { return Some(packet); } } } impl UdpProxy { pub fn new( global_ctx: ArcGlobalCtx, peer_manager: Arc, ) -> Result, Error> { let cidr_set = CidrSet::new(global_ctx.clone()); let (sender, receiver) = unbounded_channel(); let ret = Self { global_ctx, peer_manager, cidr_set, nat_table: Arc::new(DashMap::new()), sender, receiver: Mutex::new(Some(receiver)), tasks: Mutex::new(JoinSet::new()), ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))), }; Ok(Arc::new(ret)) } pub async fn start(self: &Arc) -> Result<(), Error> { self.peer_manager .add_packet_process_pipeline(Box::new(self.clone())) .await; // clean up nat table let nat_table = self.nat_table.clone(); self.tasks.lock().await.spawn(async move { loop { tokio::time::sleep(Duration::from_secs(15)).await; nat_table.retain(|_, v| { if v.start_time.elapsed().as_secs() > 120 { tracing::info!(?v, "udp nat table entry removed"); v.stop(); false } else { true } }); } }); let ip_resembler = self.ip_resemmbler.clone(); self.tasks.lock().await.spawn(async move { loop { tokio::time::sleep(Duration::from_secs(1)).await; ip_resembler.remove_expired_packets(); } }); // forward packets to peer manager let mut receiver = self.receiver.lock().await.take().unwrap(); let peer_manager = self.peer_manager.clone(); self.tasks.lock().await.spawn(async move { while let Some(msg) = receiver.recv().await { let to_peer_id: PeerId = msg.peer_manager_header().unwrap().to_peer_id.get(); tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send"); let ret = peer_manager.send_msg(msg, to_peer_id).await; if ret.is_err() { tracing::error!("send icmp packet to peer failed: {:?}", ret); } } }); Ok(()) } } impl Drop for UdpProxy { fn drop(&mut self) { for v in self.nat_table.iter() { v.stop(); } } }