mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-16 14:47:25 +08:00
correctly handle ip fragment for udp/icmp proxy (#137)
icmp/udp proxy do not rely on kernel net stack, but currently not handle ip fragmentation correctly. this patch add ip resembler to merge fragmented ip packet for udp/icmp proxy
This commit is contained in:
299
easytier/src/gateway/ip_reassembler.rs
Normal file
299
easytier/src/gateway/ip_reassembler.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
use dashmap::DashMap;
|
||||
use pnet::packet::ip::IpNextHeaderProtocol;
|
||||
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::Packet;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::common::error::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct IpFragment {
|
||||
id: u16,
|
||||
offset: u16,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<'a> From<&Ipv4Packet<'a>> for IpFragment {
|
||||
fn from(packet: &Ipv4Packet<'a>) -> Self {
|
||||
let id = packet.get_identification();
|
||||
let offset = packet.get_fragment_offset() * 8;
|
||||
let data = packet.payload().to_vec();
|
||||
IpFragment { id, offset, data }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct IpPacket {
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
total_length: Option<u16>,
|
||||
fragments: Vec<IpFragment>,
|
||||
}
|
||||
|
||||
impl IpPacket {
|
||||
fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
|
||||
IpPacket {
|
||||
source,
|
||||
destination,
|
||||
total_length: None,
|
||||
fragments: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_fragment(&mut self, fragment: IpFragment) {
|
||||
// make sure the fragment doesn't overlap with existing fragments
|
||||
for f in &self.fragments {
|
||||
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
|
||||
return;
|
||||
}
|
||||
if fragment.offset <= f.offset
|
||||
&& f.offset < fragment.offset + fragment.data.len() as u16
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
self.fragments.push(fragment);
|
||||
}
|
||||
|
||||
fn is_complete(&self) -> bool {
|
||||
if self.total_length.is_none() {
|
||||
return false;
|
||||
}
|
||||
let mut total_length = 0;
|
||||
for fragment in &self.fragments {
|
||||
total_length += fragment.data.len() as u16;
|
||||
}
|
||||
tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete");
|
||||
Some(total_length) == self.total_length
|
||||
}
|
||||
|
||||
fn set_total_length(&mut self, total_length: u16) {
|
||||
self.total_length = Some(total_length);
|
||||
}
|
||||
|
||||
fn assemble(&mut self) -> Option<Vec<u8>> {
|
||||
if !self.is_complete() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// sort fragments by offset
|
||||
self.fragments.sort_by_key(|f| f.offset);
|
||||
|
||||
let mut packet = vec![0u8; self.total_length.unwrap() as usize];
|
||||
for fragment in &self.fragments {
|
||||
let start = fragment.offset as usize;
|
||||
let end = start + fragment.data.len();
|
||||
packet[start..end].copy_from_slice(&fragment.data);
|
||||
}
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
|
||||
struct IpResemblerKey {
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
id: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IpResemblerValue {
|
||||
packet: IpPacket,
|
||||
timestamp: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct IpReassembler {
|
||||
packets: DashMap<IpResemblerKey, IpResemblerValue>,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl IpReassembler {
|
||||
pub fn new(timeout: Duration) -> Self {
|
||||
IpReassembler {
|
||||
packets: DashMap::new(),
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool {
|
||||
packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0
|
||||
}
|
||||
|
||||
pub fn is_last_fragment(packet: &Ipv4Packet) -> bool {
|
||||
packet.get_flags() & Ipv4Flags::MoreFragments == 0
|
||||
}
|
||||
|
||||
pub fn add_fragment(
|
||||
&self,
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
packet: &Ipv4Packet,
|
||||
) -> Option<Vec<u8>> {
|
||||
let id = packet.get_identification();
|
||||
let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4;
|
||||
if total_length != packet.payload().len() as u16 {
|
||||
tracing::trace!(
|
||||
?packet,
|
||||
?total_length,
|
||||
payload_len = ?packet.payload().len(),
|
||||
"unexpected total length",
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let fragment: IpFragment = packet.into();
|
||||
let key = IpResemblerKey {
|
||||
source,
|
||||
destination,
|
||||
id,
|
||||
};
|
||||
|
||||
let mut entry = self.packets.entry(key.clone()).or_insert_with(|| {
|
||||
let packet = IpPacket::new(source, destination);
|
||||
let timestamp = Instant::now();
|
||||
IpResemblerValue { packet, timestamp }
|
||||
});
|
||||
let value_mut = entry.value_mut();
|
||||
|
||||
if Self::is_last_fragment(packet) {
|
||||
value_mut
|
||||
.packet
|
||||
.set_total_length(total_length + fragment.offset);
|
||||
}
|
||||
|
||||
value_mut.packet.add_fragment(fragment);
|
||||
if let Some(data) = value_mut.packet.assemble() {
|
||||
drop(entry);
|
||||
self.packets.remove(&key);
|
||||
Some(data)
|
||||
} else {
|
||||
value_mut.timestamp = Instant::now();
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_expired_packets(&self) {
|
||||
let timeout = self.timeout;
|
||||
self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout);
|
||||
}
|
||||
}
|
||||
|
||||
// ip payload should be in buf[20..]
|
||||
pub fn compose_ipv4_packet<F>(
|
||||
buf: &mut [u8],
|
||||
src_v4: &Ipv4Addr,
|
||||
dst_v4: &Ipv4Addr,
|
||||
next_protocol: IpNextHeaderProtocol,
|
||||
payload_len: usize,
|
||||
payload_mtu: usize,
|
||||
ip_id: u16,
|
||||
cb: F,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
F: Fn(&[u8]) -> Result<(), Error>,
|
||||
{
|
||||
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
|
||||
let mut buf_offset = 0;
|
||||
let mut fragment_offset = 0;
|
||||
let mut cur_piece = 0;
|
||||
while fragment_offset < payload_len {
|
||||
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
|
||||
let fragment_len = next_fragment_offset - fragment_offset;
|
||||
let mut ipv4_packet =
|
||||
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
|
||||
ipv4_packet.set_version(4);
|
||||
ipv4_packet.set_header_length(5);
|
||||
ipv4_packet.set_total_length((fragment_len + 20) as u16);
|
||||
ipv4_packet.set_identification(ip_id);
|
||||
if total_pieces > 1 {
|
||||
if cur_piece != total_pieces - 1 {
|
||||
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
|
||||
} else {
|
||||
ipv4_packet.set_flags(0);
|
||||
}
|
||||
assert_eq!(0, fragment_offset % 8);
|
||||
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
|
||||
} else {
|
||||
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
|
||||
ipv4_packet.set_fragment_offset(0);
|
||||
}
|
||||
ipv4_packet.set_ecn(0);
|
||||
ipv4_packet.set_dscp(0);
|
||||
ipv4_packet.set_ttl(32);
|
||||
ipv4_packet.set_source(src_v4.clone());
|
||||
ipv4_packet.set_destination(dst_v4.clone());
|
||||
ipv4_packet.set_next_level_protocol(next_protocol);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
||||
|
||||
cb(ipv4_packet.packet())?;
|
||||
|
||||
buf_offset += next_fragment_offset - fragment_offset;
|
||||
fragment_offset = next_fragment_offset;
|
||||
cur_piece += 1;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resembler() {
|
||||
let raw_packets = vec![
|
||||
// last packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// 1st packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// 2nd packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// expired packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
];
|
||||
|
||||
let source = "192.168.0.1".parse().unwrap();
|
||||
let destination = "192.168.0.2".parse().unwrap();
|
||||
let resembler = IpReassembler::new(Duration::from_secs(1));
|
||||
|
||||
for (idx, raw_packet) in raw_packets.iter().enumerate() {
|
||||
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
|
||||
let ret = resembler.add_fragment(source, destination, &packet);
|
||||
if idx != 2 {
|
||||
assert!(ret.is_none());
|
||||
} else {
|
||||
assert!(ret.is_some());
|
||||
}
|
||||
println!(
|
||||
"packet: {:?}, ret: {:?}, palyload_len: {}",
|
||||
packet,
|
||||
ret,
|
||||
packet.payload().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
resembler.remove_expired_packets();
|
||||
assert_eq!(1, resembler.packets.len());
|
||||
|
||||
std::thread::sleep(Duration::from_secs(2));
|
||||
resembler.remove_expired_packets();
|
||||
assert_eq!(0, resembler.packets.len());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user