Files
EasyTier/easytier/src/gateway/ip_reassembler.rs
Sijie.Sun fede35cca4 correctly handle ip fragment for udp/icmp proxy (#137)
icmp/udp proxy do not rely on kernel net stack, but currently not handle ip fragmentation correctly.

this patch add ip resembler to merge fragmented ip packet for udp/icmp proxy
2024-06-09 22:59:50 +08:00

300 lines
9.1 KiB
Rust

use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocol;
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
use pnet::packet::Packet;
use std::net::Ipv4Addr;
use std::time::{Duration, Instant};
use crate::common::error::Error;
#[derive(Debug, Clone)]
pub(crate) struct IpFragment {
id: u16,
offset: u16,
data: Vec<u8>,
}
impl<'a> From<&Ipv4Packet<'a>> for IpFragment {
fn from(packet: &Ipv4Packet<'a>) -> Self {
let id = packet.get_identification();
let offset = packet.get_fragment_offset() * 8;
let data = packet.payload().to_vec();
IpFragment { id, offset, data }
}
}
#[derive(Debug, Clone)]
struct IpPacket {
source: Ipv4Addr,
destination: Ipv4Addr,
total_length: Option<u16>,
fragments: Vec<IpFragment>,
}
impl IpPacket {
fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
IpPacket {
source,
destination,
total_length: None,
fragments: Vec::new(),
}
}
fn add_fragment(&mut self, fragment: IpFragment) {
// make sure the fragment doesn't overlap with existing fragments
for f in &self.fragments {
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
return;
}
if fragment.offset <= f.offset
&& f.offset < fragment.offset + fragment.data.len() as u16
{
return;
}
}
self.fragments.push(fragment);
}
fn is_complete(&self) -> bool {
if self.total_length.is_none() {
return false;
}
let mut total_length = 0;
for fragment in &self.fragments {
total_length += fragment.data.len() as u16;
}
tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete");
Some(total_length) == self.total_length
}
fn set_total_length(&mut self, total_length: u16) {
self.total_length = Some(total_length);
}
fn assemble(&mut self) -> Option<Vec<u8>> {
if !self.is_complete() {
return None;
}
// sort fragments by offset
self.fragments.sort_by_key(|f| f.offset);
let mut packet = vec![0u8; self.total_length.unwrap() as usize];
for fragment in &self.fragments {
let start = fragment.offset as usize;
let end = start + fragment.data.len();
packet[start..end].copy_from_slice(&fragment.data);
}
Some(packet)
}
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct IpResemblerKey {
source: Ipv4Addr,
destination: Ipv4Addr,
id: u16,
}
#[derive(Debug)]
struct IpResemblerValue {
packet: IpPacket,
timestamp: Instant,
}
#[derive(Debug)]
pub(crate) struct IpReassembler {
packets: DashMap<IpResemblerKey, IpResemblerValue>,
timeout: Duration,
}
impl IpReassembler {
pub fn new(timeout: Duration) -> Self {
IpReassembler {
packets: DashMap::new(),
timeout,
}
}
pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool {
packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0
}
pub fn is_last_fragment(packet: &Ipv4Packet) -> bool {
packet.get_flags() & Ipv4Flags::MoreFragments == 0
}
pub fn add_fragment(
&self,
source: Ipv4Addr,
destination: Ipv4Addr,
packet: &Ipv4Packet,
) -> Option<Vec<u8>> {
let id = packet.get_identification();
let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4;
if total_length != packet.payload().len() as u16 {
tracing::trace!(
?packet,
?total_length,
payload_len = ?packet.payload().len(),
"unexpected total length",
);
return None;
}
let fragment: IpFragment = packet.into();
let key = IpResemblerKey {
source,
destination,
id,
};
let mut entry = self.packets.entry(key.clone()).or_insert_with(|| {
let packet = IpPacket::new(source, destination);
let timestamp = Instant::now();
IpResemblerValue { packet, timestamp }
});
let value_mut = entry.value_mut();
if Self::is_last_fragment(packet) {
value_mut
.packet
.set_total_length(total_length + fragment.offset);
}
value_mut.packet.add_fragment(fragment);
if let Some(data) = value_mut.packet.assemble() {
drop(entry);
self.packets.remove(&key);
Some(data)
} else {
value_mut.timestamp = Instant::now();
None
}
}
pub fn remove_expired_packets(&self) {
let timeout = self.timeout;
self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout);
}
}
// ip payload should be in buf[20..]
pub fn compose_ipv4_packet<F>(
buf: &mut [u8],
src_v4: &Ipv4Addr,
dst_v4: &Ipv4Addr,
next_protocol: IpNextHeaderProtocol,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
cb: F,
) -> Result<(), Error>
where
F: Fn(&[u8]) -> Result<(), Error>,
{
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.clone());
ipv4_packet.set_destination(dst_v4.clone());
ipv4_packet.set_next_level_protocol(next_protocol);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
cb(ipv4_packet.packet())?;
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resembler() {
let raw_packets = vec![
// last packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
],
// 1st packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07,
],
// 2nd packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
// expired packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
];
let source = "192.168.0.1".parse().unwrap();
let destination = "192.168.0.2".parse().unwrap();
let resembler = IpReassembler::new(Duration::from_secs(1));
for (idx, raw_packet) in raw_packets.iter().enumerate() {
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
let ret = resembler.add_fragment(source, destination, &packet);
if idx != 2 {
assert!(ret.is_none());
} else {
assert!(ret.is_some());
}
println!(
"packet: {:?}, ret: {:?}, palyload_len: {}",
packet,
ret,
packet.payload().len()
);
}
}
resembler.remove_expired_packets();
assert_eq!(1, resembler.packets.len());
std::thread::sleep(Duration::from_secs(2));
resembler.remove_expired_packets();
assert_eq!(0, resembler.packets.len());
}
}