mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-12 04:37:23 +08:00
225 lines
6.9 KiB
Rust
225 lines
6.9 KiB
Rust
use anyhow::Context;
|
|
use dashmap::DashMap;
|
|
use std::cell::RefCell;
|
|
use zstd::bulk;
|
|
|
|
use zerocopy::{AsBytes as _, FromBytes as _};
|
|
|
|
use crate::tunnel::packet_def::{CompressorAlgo, CompressorTail, ZCPacket, COMPRESSOR_TAIL_SIZE};
|
|
|
|
type Error = anyhow::Error;
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait Compressor {
|
|
async fn compress(
|
|
&self,
|
|
packet: &mut ZCPacket,
|
|
compress_algo: CompressorAlgo,
|
|
) -> Result<(), Error>;
|
|
async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>;
|
|
}
|
|
|
|
pub struct DefaultCompressor {}
|
|
|
|
impl DefaultCompressor {
|
|
pub fn new() -> Self {
|
|
DefaultCompressor {}
|
|
}
|
|
|
|
pub async fn compress_raw(
|
|
&self,
|
|
data: &[u8],
|
|
compress_algo: CompressorAlgo,
|
|
) -> Result<Vec<u8>, Error> {
|
|
match compress_algo {
|
|
CompressorAlgo::ZstdDefault => CTX_MAP.with(|map_cell| {
|
|
let map = map_cell.borrow();
|
|
let mut ctx_entry = map.entry(compress_algo).or_default();
|
|
ctx_entry.compress(data).with_context(|| {
|
|
format!(
|
|
"Failed to compress data with algorithm: {:?}",
|
|
compress_algo
|
|
)
|
|
})
|
|
}),
|
|
CompressorAlgo::None => Ok(data.to_vec()),
|
|
}
|
|
}
|
|
|
|
pub async fn decompress_raw(
|
|
&self,
|
|
data: &[u8],
|
|
compress_algo: CompressorAlgo,
|
|
) -> Result<Vec<u8>, Error> {
|
|
match compress_algo {
|
|
CompressorAlgo::ZstdDefault => DCTX_MAP.with(|map_cell| {
|
|
let map = map_cell.borrow();
|
|
let mut ctx_entry = map.entry(compress_algo).or_default();
|
|
for i in 1..=5 {
|
|
let mut len = data.len() * 2usize.pow(i);
|
|
if i == 5 && len < 64 * 1024 {
|
|
len = 64 * 1024; // Ensure a minimum buffer size
|
|
}
|
|
match ctx_entry.decompress(data, len) {
|
|
Ok(buf) => return Ok(buf),
|
|
Err(e) if e.to_string().contains("buffer is too small") => {
|
|
continue; // Try with a larger buffer
|
|
}
|
|
Err(e) => return Err(e.into()),
|
|
}
|
|
}
|
|
Err(anyhow::anyhow!(
|
|
"Failed to decompress data after multiple attempts with algorithm: {:?}",
|
|
compress_algo
|
|
))
|
|
}),
|
|
CompressorAlgo::None => Ok(data.to_vec()),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl Compressor for DefaultCompressor {
|
|
async fn compress(
|
|
&self,
|
|
zc_packet: &mut ZCPacket,
|
|
compress_algo: CompressorAlgo,
|
|
) -> Result<(), Error> {
|
|
if matches!(compress_algo, CompressorAlgo::None) {
|
|
return Ok(());
|
|
}
|
|
|
|
let pm_header = zc_packet.peer_manager_header().unwrap();
|
|
if pm_header.is_compressed() {
|
|
return Ok(());
|
|
}
|
|
|
|
let tail = CompressorTail::new(compress_algo);
|
|
let buf = self
|
|
.compress_raw(zc_packet.payload(), compress_algo)
|
|
.await?;
|
|
|
|
if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize {
|
|
// Compressed data is larger than original data, don't compress
|
|
return Ok(());
|
|
}
|
|
|
|
zc_packet
|
|
.mut_peer_manager_header()
|
|
.unwrap()
|
|
.set_compressed(true);
|
|
|
|
let payload_offset = zc_packet.payload_offset();
|
|
zc_packet.mut_inner().truncate(payload_offset);
|
|
zc_packet.mut_inner().extend_from_slice(&buf);
|
|
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
|
|
let pm_header = zc_packet.peer_manager_header().unwrap();
|
|
if !pm_header.is_compressed() {
|
|
return Ok(());
|
|
}
|
|
|
|
let payload_len = zc_packet.payload().len();
|
|
if payload_len < COMPRESSOR_TAIL_SIZE {
|
|
return Err(anyhow::anyhow!("Packet too short: {}", payload_len));
|
|
}
|
|
|
|
let text_len = payload_len - COMPRESSOR_TAIL_SIZE;
|
|
|
|
let tail = CompressorTail::ref_from_suffix(zc_packet.payload())
|
|
.unwrap()
|
|
.clone();
|
|
|
|
let algo = tail
|
|
.get_algo()
|
|
.ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?;
|
|
|
|
let buf = self
|
|
.decompress_raw(&zc_packet.payload()[..text_len], algo)
|
|
.await?;
|
|
|
|
if buf.len() != pm_header.len.get() as usize {
|
|
anyhow::bail!(
|
|
"Decompressed length mismatch: decompressed len {} != pm header len {}",
|
|
buf.len(),
|
|
pm_header.len.get()
|
|
);
|
|
}
|
|
|
|
zc_packet
|
|
.mut_peer_manager_header()
|
|
.unwrap()
|
|
.set_compressed(false);
|
|
|
|
let payload_offset = zc_packet.payload_offset();
|
|
zc_packet.mut_inner().truncate(payload_offset);
|
|
zc_packet.mut_inner().extend_from_slice(&buf);
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
thread_local! {
|
|
static CTX_MAP: RefCell<DashMap<CompressorAlgo, bulk::Compressor<'static>>> = RefCell::new(DashMap::new());
|
|
static DCTX_MAP: RefCell<DashMap<CompressorAlgo, bulk::Decompressor<'static>>> = RefCell::new(DashMap::new());
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_compress() {
|
|
let text = b"12345670000000000000000000";
|
|
let mut packet = ZCPacket::new_with_payload(text);
|
|
packet.fill_peer_manager_hdr(0, 0, 0);
|
|
|
|
let compressor = DefaultCompressor {};
|
|
|
|
println!(
|
|
"Uncompressed packet: {:?}, len: {}",
|
|
packet,
|
|
packet.payload_len()
|
|
);
|
|
|
|
compressor
|
|
.compress(&mut packet, CompressorAlgo::ZstdDefault)
|
|
.await
|
|
.unwrap();
|
|
println!(
|
|
"Compressed packet: {:?}, len: {}",
|
|
packet,
|
|
packet.payload_len()
|
|
);
|
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
|
|
|
|
compressor.decompress(&mut packet).await.unwrap();
|
|
assert_eq!(packet.payload(), text);
|
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_short_text_compress() {
|
|
let text = b"1234";
|
|
let mut packet = ZCPacket::new_with_payload(text);
|
|
packet.fill_peer_manager_hdr(0, 0, 0);
|
|
|
|
let compressor = DefaultCompressor {};
|
|
|
|
// short text can't be compressed
|
|
compressor
|
|
.compress(&mut packet, CompressorAlgo::ZstdDefault)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
|
|
|
|
compressor.decompress(&mut packet).await.unwrap();
|
|
assert_eq!(packet.payload(), text);
|
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
|
|
}
|
|
}
|