correctly handle ip fragment for udp/icmp proxy (#137)

icmp/udp proxy do not rely on kernel net stack, but currently not handle ip fragmentation correctly.

this patch add ip resembler to merge fragmented ip packet for udp/icmp proxy
This commit is contained in:
Sijie.Sun
2024-06-09 22:59:50 +08:00
committed by GitHub
parent b2100b78d3
commit fede35cca4
8 changed files with 480 additions and 88 deletions

View File

@@ -3,12 +3,13 @@ use std::{
net::{IpAddr, Ipv4Addr, SocketAddrV4},
sync::Arc,
thread,
time::Duration,
};
use pnet::packet::{
icmp::{self, IcmpTypes},
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
ipv4::Ipv4Packet,
Packet,
};
use socket2::Socket;
@@ -25,7 +26,10 @@ use crate::{
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::CidrSet;
use super::{
ip_reassembler::{compose_ipv4_packet, IpReassembler},
CidrSet,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct IcmpNatKey {
@@ -68,6 +72,8 @@ pub struct IcmpProxy {
nat_table: IcmpNatTable,
tasks: Mutex<JoinSet<()>>,
ip_resemmbler: Arc<IpReassembler>,
}
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
@@ -80,7 +86,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, I
}
fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender<ZCPacket>) {
let mut buf = [0u8; 2048];
let mut buf = [0u8; 8192];
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[..]) };
loop {
@@ -92,7 +98,7 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
continue;
}
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[..len]) else {
let Some(ipv4_packet) = Ipv4Packet::new(&buf[..len]) else {
continue;
};
@@ -120,24 +126,31 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
continue;
};
ipv4_packet.set_destination(dest_ip);
let src_v4 = ipv4_packet.get_source();
let payload_len = len - ipv4_packet.get_header_length() as usize * 4;
let id = ipv4_packet.get_identification();
let _ = compose_ipv4_packet(
&mut buf[..],
&src_v4,
&dest_ip,
IpNextHeaderProtocols::Icmp,
payload_len,
1200,
id,
|buf| {
let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
);
// MacOS do not correctly set ip length when receiving from raw socket
ipv4_packet.set_total_length(len as u16);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
p.fill_peer_manager_hdr(
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
if let Err(e) = sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
}
Ok(())
},
);
if let Err(e) = sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
break;
}
}
}
@@ -166,6 +179,8 @@ impl IcmpProxy {
nat_table: Arc::new(dashmap::DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
};
Ok(Arc::new(ret))
@@ -226,6 +241,14 @@ impl IcmpProxy {
.instrument(tracing::info_span!("icmp proxy send loop")),
);
let ip_resembler = self.ip_resemmbler.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
ip_resembler.remove_expired_packets();
}
});
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
@@ -269,7 +292,18 @@ impl IcmpProxy {
return None;
}
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
let resembled_buf: Option<Vec<u8>>;
let icmp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
resembled_buf =
self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() {
return None;
};
icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())?
} else {
icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
// drop it because we do not support other icmp types

View File

@@ -0,0 +1,299 @@
use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocol;
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
use pnet::packet::Packet;
use std::net::Ipv4Addr;
use std::time::{Duration, Instant};
use crate::common::error::Error;
#[derive(Debug, Clone)]
pub(crate) struct IpFragment {
id: u16,
offset: u16,
data: Vec<u8>,
}
impl<'a> From<&Ipv4Packet<'a>> for IpFragment {
fn from(packet: &Ipv4Packet<'a>) -> Self {
let id = packet.get_identification();
let offset = packet.get_fragment_offset() * 8;
let data = packet.payload().to_vec();
IpFragment { id, offset, data }
}
}
#[derive(Debug, Clone)]
struct IpPacket {
source: Ipv4Addr,
destination: Ipv4Addr,
total_length: Option<u16>,
fragments: Vec<IpFragment>,
}
impl IpPacket {
fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
IpPacket {
source,
destination,
total_length: None,
fragments: Vec::new(),
}
}
fn add_fragment(&mut self, fragment: IpFragment) {
// make sure the fragment doesn't overlap with existing fragments
for f in &self.fragments {
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
return;
}
if fragment.offset <= f.offset
&& f.offset < fragment.offset + fragment.data.len() as u16
{
return;
}
}
self.fragments.push(fragment);
}
fn is_complete(&self) -> bool {
if self.total_length.is_none() {
return false;
}
let mut total_length = 0;
for fragment in &self.fragments {
total_length += fragment.data.len() as u16;
}
tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete");
Some(total_length) == self.total_length
}
fn set_total_length(&mut self, total_length: u16) {
self.total_length = Some(total_length);
}
fn assemble(&mut self) -> Option<Vec<u8>> {
if !self.is_complete() {
return None;
}
// sort fragments by offset
self.fragments.sort_by_key(|f| f.offset);
let mut packet = vec![0u8; self.total_length.unwrap() as usize];
for fragment in &self.fragments {
let start = fragment.offset as usize;
let end = start + fragment.data.len();
packet[start..end].copy_from_slice(&fragment.data);
}
Some(packet)
}
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct IpResemblerKey {
source: Ipv4Addr,
destination: Ipv4Addr,
id: u16,
}
#[derive(Debug)]
struct IpResemblerValue {
packet: IpPacket,
timestamp: Instant,
}
#[derive(Debug)]
pub(crate) struct IpReassembler {
packets: DashMap<IpResemblerKey, IpResemblerValue>,
timeout: Duration,
}
impl IpReassembler {
pub fn new(timeout: Duration) -> Self {
IpReassembler {
packets: DashMap::new(),
timeout,
}
}
pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool {
packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0
}
pub fn is_last_fragment(packet: &Ipv4Packet) -> bool {
packet.get_flags() & Ipv4Flags::MoreFragments == 0
}
pub fn add_fragment(
&self,
source: Ipv4Addr,
destination: Ipv4Addr,
packet: &Ipv4Packet,
) -> Option<Vec<u8>> {
let id = packet.get_identification();
let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4;
if total_length != packet.payload().len() as u16 {
tracing::trace!(
?packet,
?total_length,
payload_len = ?packet.payload().len(),
"unexpected total length",
);
return None;
}
let fragment: IpFragment = packet.into();
let key = IpResemblerKey {
source,
destination,
id,
};
let mut entry = self.packets.entry(key.clone()).or_insert_with(|| {
let packet = IpPacket::new(source, destination);
let timestamp = Instant::now();
IpResemblerValue { packet, timestamp }
});
let value_mut = entry.value_mut();
if Self::is_last_fragment(packet) {
value_mut
.packet
.set_total_length(total_length + fragment.offset);
}
value_mut.packet.add_fragment(fragment);
if let Some(data) = value_mut.packet.assemble() {
drop(entry);
self.packets.remove(&key);
Some(data)
} else {
value_mut.timestamp = Instant::now();
None
}
}
pub fn remove_expired_packets(&self) {
let timeout = self.timeout;
self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout);
}
}
// ip payload should be in buf[20..]
pub fn compose_ipv4_packet<F>(
buf: &mut [u8],
src_v4: &Ipv4Addr,
dst_v4: &Ipv4Addr,
next_protocol: IpNextHeaderProtocol,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
cb: F,
) -> Result<(), Error>
where
F: Fn(&[u8]) -> Result<(), Error>,
{
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.clone());
ipv4_packet.set_destination(dst_v4.clone());
ipv4_packet.set_next_level_protocol(next_protocol);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
cb(ipv4_packet.packet())?;
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resembler() {
let raw_packets = vec![
// last packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
],
// 1st packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07,
],
// 2nd packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
// expired packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
];
let source = "192.168.0.1".parse().unwrap();
let destination = "192.168.0.2".parse().unwrap();
let resembler = IpReassembler::new(Duration::from_secs(1));
for (idx, raw_packet) in raw_packets.iter().enumerate() {
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
let ret = resembler.add_fragment(source, destination, &packet);
if idx != 2 {
assert!(ret.is_none());
} else {
assert!(ret.is_some());
}
println!(
"packet: {:?}, ret: {:?}, palyload_len: {}",
packet,
ret,
packet.payload().len()
);
}
}
resembler.remove_expired_packets();
assert_eq!(1, resembler.packets.len());
std::thread::sleep(Duration::from_secs(2));
resembler.remove_expired_packets();
assert_eq!(0, resembler.packets.len());
}
}

View File

@@ -4,6 +4,7 @@ use tokio::task::JoinSet;
use crate::common::global_ctx::ArcGlobalCtx;
pub mod icmp_proxy;
pub mod ip_reassembler;
pub mod tcp_proxy;
pub mod udp_proxy;

View File

@@ -7,7 +7,7 @@ use std::{
use dashmap::DashMap;
use pnet::packet::{
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
ipv4::Ipv4Packet,
udp::{self, MutableUdpPacket},
Packet,
};
@@ -25,6 +25,7 @@ use tracing::Level;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
gateway::ip_reassembler::compose_ipv4_packet,
peers::{peer_manager::PeerManager, PeerPacketFilter},
tunnel::{
common::setup_sokcet2,
@@ -32,7 +33,7 @@ use crate::{
},
};
use super::CidrSet;
use super::{ip_reassembler::IpReassembler, CidrSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey {
@@ -105,60 +106,31 @@ impl UdpNatEntry {
nat_src_v4.ip(),
));
let payload_len = payload_len + 8; // include udp header
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
compose_ipv4_packet(
&mut buf[..],
src_v4.ip(),
nat_src_v4.ip(),
IpNextHeaderProtocols::Udp,
payload_len + 8, // include udp header
payload_mtu,
ip_id,
|buf| {
let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
if let Err(e) = packet_sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.ip().clone());
ipv4_packet.set_destination(nat_src_v4.ip().clone());
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
Ok(())
},
)?;
tracing::trace!(?ipv4_packet, "udp nat packet response send");
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
if let Err(e) = packet_sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<ZCPacket>) {
let mut buf = [0u8; 8192];
let mut buf = [0u8; 65536];
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
let mut ip_id = 1;
@@ -223,6 +195,8 @@ pub struct UdpProxy {
receiver: Mutex<Option<UnboundedReceiver<ZCPacket>>>,
tasks: Mutex<JoinSet<()>>,
ip_resemmbler: Arc<IpReassembler>,
}
impl UdpProxy {
@@ -247,7 +221,18 @@ impl UdpProxy {
return None;
}
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
let resembled_buf: Option<Vec<u8>>;
let udp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
resembled_buf =
self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() {
return None;
};
udp::UdpPacket::new(resembled_buf.as_ref().unwrap())?
} else {
udp::UdpPacket::new(ipv4.payload())?
};
tracing::trace!(
?packet,
@@ -336,6 +321,7 @@ impl UdpProxy {
sender,
receiver: Mutex::new(Some(receiver)),
tasks: Mutex::new(JoinSet::new()),
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
};
Ok(Arc::new(ret))
}
@@ -362,6 +348,14 @@ impl UdpProxy {
}
});
let ip_resembler = self.ip_resemmbler.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
ip_resembler.remove_expired_packets();
}
});
// forward packets to peer manager
let mut receiver = self.receiver.lock().await.take().unwrap();
let peer_manager = self.peer_manager.clone();