zstd should reuse ctx to avoid huge mmap cost (#941)

This commit is contained in:
Sijie.Sun
2025-06-06 08:59:06 +08:00
committed by GitHub
parent 0314c66635
commit a6773aa549
5 changed files with 196 additions and 109 deletions

View File

@@ -1,5 +1,10 @@
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
use tokio::io::AsyncWriteExt;
use std::io::{Read, Write};
use dashmap::DashMap;
use std::cell::RefCell;
use zstd::stream::read::Decoder;
use zstd::stream::write::Encoder;
use zstd::zstd_safe::{CCtx, DCtx};
use zerocopy::{AsBytes as _, FromBytes as _};
@@ -29,17 +34,20 @@ impl DefaultCompressor {
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdEncoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
let ret = CTX_MAP.with(|map_cell| {
let map = map_cell.borrow();
let mut ctx_entry = map.entry(compress_algo).or_default();
let writer = Vec::new();
let mut o = Encoder::with_context(writer, ctx_entry.value_mut());
o.write_all(data)?;
o.finish()
});
Ok(ret?)
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
CompressorAlgo::None => Ok(data.to_vec()),
}
}
pub async fn decompress_raw(
@@ -47,17 +55,17 @@ impl DefaultCompressor {
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdDecoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
match compress_algo {
CompressorAlgo::ZstdDefault => DCTX_MAP.with(|map_cell| {
let map = map_cell.borrow();
let mut ctx_entry = map.entry(compress_algo).or_default();
let mut decoder = Decoder::with_context(data, ctx_entry.value_mut());
let mut output = Vec::new();
decoder.read_to_end(&mut output)?;
Ok(output)
}),
CompressorAlgo::None => Ok(data.to_vec()),
}
}
}
@@ -146,6 +154,11 @@ impl Compressor for DefaultCompressor {
}
}
thread_local! {
static CTX_MAP: RefCell<DashMap<CompressorAlgo, CCtx<'static>>> = RefCell::new(DashMap::new());
static DCTX_MAP: RefCell<DashMap<CompressorAlgo, DCtx<'static>>> = RefCell::new(DashMap::new());
}
#[cfg(test)]
pub mod tests {
use super::*;
@@ -158,10 +171,21 @@ pub mod tests {
let compressor = DefaultCompressor {};
println!(
"Uncompressed packet: {:?}, len: {}",
packet,
packet.payload_len()
);
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.await
.unwrap();
println!(
"Compressed packet: {:?}, len: {}",
packet,
packet.payload_len()
);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
compressor.decompress(&mut packet).await.unwrap();

View File

@@ -4,6 +4,7 @@ use std::{
io::Write as _,
sync::{Arc, Mutex},
};
use time::util::refresh_tz;
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
@@ -24,9 +25,7 @@ pub mod stun_codec_ext;
pub fn get_logger_timer<F: time::formatting::Formattable>(
format: F,
) -> tracing_subscriber::fmt::time::OffsetTime<F> {
unsafe {
time::util::local_offset::set_soundness(time::util::local_offset::Soundness::Unsound)
};
refresh_tz();
let local_offset = time::UtcOffset::current_local_offset()
.unwrap_or(time::UtcOffset::from_whole_seconds(0).unwrap());
tracing_subscriber::fmt::time::OffsetTime::new(local_offset, format)

View File

@@ -234,7 +234,7 @@ pub struct AesGcmTail {
}
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
#[derive(AsBytes, FromZeroes, Clone, Debug, Copy)]
#[derive(AsBytes, FromZeroes, Clone, Debug, Copy, PartialEq, Hash, Eq)]
#[repr(u8)]
pub enum CompressorAlgo {
None = 0,