use dashmap::DashMap; use pnet::packet::ip::IpNextHeaderProtocol; use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet}; use pnet::packet::Packet; use std::net::Ipv4Addr; use std::time::{Duration, Instant}; use crate::common::error::Error; #[derive(Debug, Clone)] pub(crate) struct IpFragment { id: u16, offset: u16, data: Vec, } impl<'a> From<&Ipv4Packet<'a>> for IpFragment { fn from(packet: &Ipv4Packet<'a>) -> Self { let id = packet.get_identification(); let offset = packet.get_fragment_offset() * 8; let data = packet.payload().to_vec(); IpFragment { id, offset, data } } } #[derive(Debug, Clone)] struct IpPacket { source: Ipv4Addr, destination: Ipv4Addr, total_length: Option, fragments: Vec, } impl IpPacket { fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self { IpPacket { source, destination, total_length: None, fragments: Vec::new(), } } fn add_fragment(&mut self, fragment: IpFragment) { // make sure the fragment doesn't overlap with existing fragments for f in &self.fragments { if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 { return; } if fragment.offset <= f.offset && f.offset < fragment.offset + fragment.data.len() as u16 { return; } } self.fragments.push(fragment); } fn is_complete(&self) -> bool { if self.total_length.is_none() { return false; } let mut total_length = 0; for fragment in &self.fragments { total_length += fragment.data.len() as u16; } tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete"); Some(total_length) == self.total_length } fn set_total_length(&mut self, total_length: u16) { self.total_length = Some(total_length); } fn assemble(&mut self) -> Option> { if !self.is_complete() { return None; } // sort fragments by offset self.fragments.sort_by_key(|f| f.offset); let mut packet = vec![0u8; self.total_length.unwrap() as usize]; for fragment in &self.fragments { let start = fragment.offset as usize; let end = start + fragment.data.len(); packet[start..end].copy_from_slice(&fragment.data); } Some(packet) } } #[derive(Hash, Eq, PartialEq, Clone, Debug)] struct IpResemblerKey { source: Ipv4Addr, destination: Ipv4Addr, id: u16, } #[derive(Debug)] struct IpResemblerValue { packet: IpPacket, timestamp: Instant, } #[derive(Debug)] pub(crate) struct IpReassembler { packets: DashMap, timeout: Duration, } impl IpReassembler { pub fn new(timeout: Duration) -> Self { IpReassembler { packets: DashMap::new(), timeout, } } pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool { packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0 } pub fn is_last_fragment(packet: &Ipv4Packet) -> bool { packet.get_flags() & Ipv4Flags::MoreFragments == 0 } pub fn add_fragment( &self, source: Ipv4Addr, destination: Ipv4Addr, packet: &Ipv4Packet, ) -> Option> { let id = packet.get_identification(); let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4; if total_length != packet.payload().len() as u16 { tracing::trace!( ?packet, ?total_length, payload_len = ?packet.payload().len(), "unexpected total length", ); return None; } let fragment: IpFragment = packet.into(); let key = IpResemblerKey { source, destination, id, }; let mut entry = self.packets.entry(key.clone()).or_insert_with(|| { let packet = IpPacket::new(source, destination); let timestamp = Instant::now(); IpResemblerValue { packet, timestamp } }); let value_mut = entry.value_mut(); if Self::is_last_fragment(packet) { value_mut .packet .set_total_length(total_length + fragment.offset); } value_mut.packet.add_fragment(fragment); if let Some(data) = value_mut.packet.assemble() { drop(entry); self.packets.remove(&key); Some(data) } else { value_mut.timestamp = Instant::now(); None } } pub fn remove_expired_packets(&self) { let timeout = self.timeout; self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout); } } // ip payload should be in buf[20..] pub fn compose_ipv4_packet( buf: &mut [u8], src_v4: &Ipv4Addr, dst_v4: &Ipv4Addr, next_protocol: IpNextHeaderProtocol, payload_len: usize, payload_mtu: usize, ip_id: u16, cb: F, ) -> Result<(), Error> where F: Fn(&[u8]) -> Result<(), Error>, { let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu; let mut buf_offset = 0; let mut fragment_offset = 0; let mut cur_piece = 0; while fragment_offset < payload_len { let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len); let fragment_len = next_fragment_offset - fragment_offset; let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap(); ipv4_packet.set_version(4); ipv4_packet.set_header_length(5); ipv4_packet.set_total_length((fragment_len + 20) as u16); ipv4_packet.set_identification(ip_id); if total_pieces > 1 { if cur_piece != total_pieces - 1 { ipv4_packet.set_flags(Ipv4Flags::MoreFragments); } else { ipv4_packet.set_flags(0); } assert_eq!(0, fragment_offset % 8); ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8); } else { ipv4_packet.set_flags(Ipv4Flags::DontFragment); ipv4_packet.set_fragment_offset(0); } ipv4_packet.set_ecn(0); ipv4_packet.set_dscp(0); ipv4_packet.set_ttl(32); ipv4_packet.set_source(src_v4.clone()); ipv4_packet.set_destination(dst_v4.clone()); ipv4_packet.set_next_level_protocol(next_protocol); ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); tracing::trace!(?ipv4_packet, "udp nat packet response send"); cb(ipv4_packet.packet())?; buf_offset += next_fragment_offset - fragment_offset; fragment_offset = next_fragment_offset; cur_piece += 1; } Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn resembler() { let raw_packets = vec![ // last packet vec![ 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07, ], // 1st packet vec![ 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07, ], // 2nd packet vec![ 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, ], // expired packet vec![ 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, 0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, ], ]; let source = "192.168.0.1".parse().unwrap(); let destination = "192.168.0.2".parse().unwrap(); let resembler = IpReassembler::new(Duration::from_secs(1)); for (idx, raw_packet) in raw_packets.iter().enumerate() { if let Some(packet) = Ipv4Packet::new(&raw_packet) { let ret = resembler.add_fragment(source, destination, &packet); if idx != 2 { assert!(ret.is_none()); } else { assert!(ret.is_some()); } println!( "packet: {:?}, ret: {:?}, palyload_len: {}", packet, ret, packet.payload().len() ); } } resembler.remove_expired_packets(); assert_eq!(1, resembler.packets.len()); std::thread::sleep(Duration::from_secs(2)); resembler.remove_expired_packets(); assert_eq!(0, resembler.packets.len()); } }