add bps limiter (#1015)

* add token bucket
* remove quinn-proto
This commit is contained in:
Sijie.Sun
2025-06-19 21:15:04 +08:00
committed by GitHub
parent 72d5ed908e
commit 40601bd05b
13 changed files with 463 additions and 38 deletions

View File

@@ -2,6 +2,7 @@ use std::{
net::{Ipv4Addr, SocketAddr},
path::PathBuf,
sync::{Arc, Mutex},
u64,
};
use anyhow::Context;
@@ -41,6 +42,7 @@ pub fn gen_default_flags() -> Flags {
private_mode: false,
enable_quic_proxy: false,
disable_quic_input: false,
foreign_relay_bps_limit: u64::MAX,
}
}

View File

@@ -5,6 +5,7 @@ use std::{
};
use crate::common::config::ProxyNetworkConfig;
use crate::common::token_bucket::TokenBucketManager;
use crate::proto::cli::PeerConnInfo;
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
use crossbeam::atomic::AtomicCell;
@@ -77,6 +78,8 @@ pub struct GlobalCtx {
feature_flags: AtomicCell<PeerFeatureFlag>,
quic_proxy_port: AtomicCell<Option<u16>>,
token_bucket_manager: TokenBucketManager,
}
impl std::fmt::Debug for GlobalCtx {
@@ -140,6 +143,8 @@ impl GlobalCtx {
feature_flags: AtomicCell::new(feature_flags),
quic_proxy_port: AtomicCell::new(None),
token_bucket_manager: TokenBucketManager::new(),
}
}
@@ -292,6 +297,10 @@ impl GlobalCtx {
pub fn set_quic_proxy_port(&self, port: Option<u16>) {
self.quic_proxy_port.store(port);
}
pub fn token_bucket_manager(&self) -> &TokenBucketManager {
&self.token_bucket_manager
}
}
#[cfg(test)]

View File

@@ -23,6 +23,7 @@ pub mod network;
pub mod scoped_task;
pub mod stun;
pub mod stun_codec_ext;
pub mod token_bucket;
pub fn get_logger_timer<F: time::formatting::Formattable>(
format: F,

View File

@@ -0,0 +1,312 @@
use atomic_shim::AtomicU64;
use dashmap::DashMap;
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::time;
use crate::common::scoped_task::ScopedTask;
use crate::proto::common::LimiterConfig;
/// Token Bucket rate limiter using atomic operations
pub struct TokenBucket {
available_tokens: AtomicU64, // Current token count (atomic)
last_refill_time: AtomicU64, // Last refill time as micros since epoch
config: BucketConfig, // Immutable configuration
refill_task: Mutex<Option<ScopedTask<()>>>, // Background refill task
start_time: Instant, // Bucket creation time
}
#[derive(Clone, Copy)]
pub struct BucketConfig {
capacity: u64, // Maximum token capacity
fill_rate: u64, // Tokens added per second
refill_interval: Duration, // Time between refill operations
}
impl From<LimiterConfig> for BucketConfig {
fn from(cfg: LimiterConfig) -> Self {
let burst_rate = 1.max(cfg.burst_rate.unwrap_or(1));
let fill_rate = 8196.max(cfg.bps.unwrap_or(u64::MAX / burst_rate));
let refill_interval = cfg
.fill_duration_ms
.map(|x| Duration::from_millis(1.max(x)))
.unwrap_or(Duration::from_millis(10));
BucketConfig {
capacity: burst_rate * fill_rate,
fill_rate: fill_rate,
refill_interval,
}
}
}
impl TokenBucket {
pub fn new(capacity: u64, bps: u64, refill_interval: Duration) -> Arc<Self> {
let config = BucketConfig {
capacity,
fill_rate: bps,
refill_interval,
};
Self::new_from_cfg(config)
}
/// Creates a new Token Bucket rate limiter
///
/// # Arguments
/// * `capacity` - Bucket capacity in bytes
/// * `bps` - Bandwidth limit in bytes per second
/// * `refill_interval` - Refill interval (recommended 10-50ms)
pub fn new_from_cfg(config: BucketConfig) -> Arc<Self> {
// Create Arc instance with placeholder task
let arc_self = Arc::new(Self {
available_tokens: AtomicU64::new(config.capacity),
last_refill_time: AtomicU64::new(0),
config,
refill_task: Mutex::new(None),
start_time: std::time::Instant::now(),
});
// Start background refill task
let arc_clone = arc_self.clone();
let refill_task = tokio::spawn(async move {
let mut interval = time::interval(arc_clone.config.refill_interval);
loop {
interval.tick().await;
arc_clone.refill();
}
});
// Replace placeholder task with actual one
arc_self
.refill_task
.lock()
.unwrap()
.replace(refill_task.into());
arc_self
}
/// Internal refill method (called only by background task)
fn refill(&self) {
let now_micros = self.elapsed_micros();
let prev_time = self.last_refill_time.swap(now_micros, Ordering::Acquire);
// Calculate elapsed time in seconds
let elapsed_secs = (now_micros.saturating_sub(prev_time)) as f64 / 1_000_000.0;
// Calculate tokens to add
let tokens_to_add = (self.config.fill_rate as f64 * elapsed_secs) as u64;
if tokens_to_add == 0 {
return;
}
// Add tokens without exceeding capacity
let mut current = self.available_tokens.load(Ordering::Relaxed);
loop {
let new = current
.saturating_add(tokens_to_add)
.min(self.config.capacity);
match self.available_tokens.compare_exchange_weak(
current,
new,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
/// Calculate microseconds since bucket creation
fn elapsed_micros(&self) -> u64 {
self.start_time.elapsed().as_micros() as u64
}
/// Attempt to consume tokens without blocking
///
/// # Returns
/// `true` if tokens were consumed, `false` if insufficient tokens
pub fn try_consume(&self, tokens: u64) -> bool {
// Fast path for oversized packets
if tokens > self.config.capacity {
return false;
}
let mut current = self.available_tokens.load(Ordering::Relaxed);
loop {
if current < tokens {
return false;
}
let new = current - tokens;
match self.available_tokens.compare_exchange_weak(
current,
new,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => current = actual,
}
}
}
}
pub struct TokenBucketManager {
buckets: Arc<DashMap<String, Arc<TokenBucket>>>,
retain_task: ScopedTask<()>,
}
impl TokenBucketManager {
/// Creates a new TokenBucketManager
pub fn new() -> Self {
let buckets = Arc::new(DashMap::new());
let buckets_clone = buckets.clone();
let retain_task = tokio::spawn(async move {
loop {
// Retain only buckets that are still in use
buckets_clone.retain(|_, bucket| Arc::<TokenBucket>::strong_count(bucket) <= 1);
// Sleep for a while before next retention check
tokio::time::sleep(Duration::from_secs(60)).await;
}
});
Self {
buckets,
retain_task: retain_task.into(),
}
}
/// Get or create a token bucket for the given key
pub fn get_or_create(&self, key: &str, cfg: BucketConfig) -> Arc<TokenBucket> {
self.buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new_from_cfg(cfg))
.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
/// Test initial state after creation
#[tokio::test]
async fn test_initial_state() {
let bucket = TokenBucket::new(1000, 1000, Duration::from_millis(10));
// Should have full capacity initially
assert!(bucket.try_consume(1000));
assert!(!bucket.try_consume(1)); // Should be empty now
}
/// Test token consumption behavior
#[tokio::test]
async fn test_consumption() {
let bucket = TokenBucket::new(1500, 1000, Duration::from_millis(10));
// First packet should succeed
assert!(bucket.try_consume(1000));
// Second packet should fail (only 500 left)
assert!(!bucket.try_consume(600));
// Should be able to take remaining tokens
assert!(bucket.try_consume(500));
}
/// Test background refill functionality
#[tokio::test]
async fn test_refill() {
let bucket = TokenBucket::new(1000, 1000, Duration::from_millis(10));
// Drain the bucket
assert!(bucket.try_consume(1000));
assert!(!bucket.try_consume(1));
// Wait for refill (1 refill interval + buffer)
sleep(Duration::from_millis(25)).await;
// Should have approximately 20 tokens (1000 tokens/s * 0.02s)
assert!(bucket.try_consume(15));
assert!(!bucket.try_consume(10)); // But not full capacity
}
/// Test capacity enforcement
#[tokio::test]
async fn test_capacity_limit() {
let bucket = TokenBucket::new(500, 1000, Duration::from_millis(10));
// Wait longer than refill interval
sleep(Duration::from_millis(50)).await;
// Should not exceed capacity despite time passed
assert!(bucket.try_consume(500));
assert!(!bucket.try_consume(1));
}
/// Test high load with concurrent access
#[tokio::test]
async fn test_concurrent_access() {
let bucket = TokenBucket::new(10_000, 1_000_000, Duration::from_millis(10));
let mut handles = vec![];
// Spawn 100 tasks to consume tokens concurrently
for _ in 0..100 {
let bucket = bucket.clone();
handles.push(tokio::spawn(async move {
for _ in 0..100 {
let _ = bucket.try_consume(10);
}
}));
}
// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}
// Verify we didn't exceed capacity
let tokens_left = bucket.available_tokens.load(Ordering::Relaxed);
assert!(
tokens_left <= 10_000,
"Tokens exceeded capacity: {}",
tokens_left
);
}
/// Test behavior when packet size exceeds capacity
#[tokio::test]
async fn test_oversized_packet() {
let bucket = TokenBucket::new(1500, 1000, Duration::from_millis(10));
// Packet larger than capacity should be rejected
assert!(!bucket.try_consume(1600));
// Regular packets should still work
assert!(bucket.try_consume(1000));
}
/// Test refill precision with small intervals
#[tokio::test]
async fn test_refill_precision() {
let bucket = TokenBucket::new(10_000, 10_000, Duration::from_micros(100)); // 100μs interval
// Drain most tokens
assert!(bucket.try_consume(9900));
// Wait for multiple refills
sleep(Duration::from_millis(1)).await;
// Should have accumulated about 100 tokens (10,000 tokens/s * 0.001s)
let tokens = bucket.available_tokens.load(Ordering::Relaxed);
assert!(
tokens >= 100 && tokens <= 200,
"Unexpected token count: {}",
tokens
);
}
}

View File

@@ -481,6 +481,13 @@ struct NetworkOptions {
help = t!("core_clap.private_mode").to_string(),
)]
private_mode: Option<bool>,
#[arg(
long,
env = "ET_FOREIGN_RELAY_BPS_LIMIT",
help = t!("core_clap.foreign_relay_bps_limit").to_string(),
)]
foreign_relay_bps_limit: Option<u64>,
}
#[derive(Parser, Debug)]
@@ -803,6 +810,9 @@ impl NetworkOptions {
f.disable_quic_input = self.disable_quic_input.unwrap_or(f.disable_quic_input);
f.accept_dns = self.accept_dns.unwrap_or(f.accept_dns);
f.private_mode = self.private_mode.unwrap_or(f.private_mode);
f.foreign_relay_bps_limit = self
.foreign_relay_bps_limit
.unwrap_or(f.foreign_relay_bps_limit);
cfg.set_flags(f);
if !self.exit_nodes.is_empty() {

View File

@@ -26,12 +26,13 @@ use crate::{
global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent, NetworkIdentity},
join_joinset_background,
stun::MockStunInfoCollector,
token_bucket::TokenBucket,
PeerId,
},
peers::route_trait::{Route, RouteInterface},
proto::{
cli::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo},
common::NatType,
common::{LimiterConfig, NatType},
peer_rpc::DirectConnectorRpcServer,
},
tunnel::packet_def::{PacketType, ZCPacket},
@@ -69,6 +70,8 @@ struct ForeignNetworkEntry {
packet_recv: Mutex<Option<PacketRecvChanReceiver>>,
bps_limiter: Arc<TokenBucket>,
tasks: Mutex<JoinSet<()>>,
pub lock: Mutex<()>,
@@ -102,6 +105,16 @@ impl ForeignNetworkEntry {
&network.network_name,
);
let relay_bps_limit = global_ctx.config.get_flags().foreign_relay_bps_limit;
let limiter_config = LimiterConfig {
burst_rate: None,
bps: Some(relay_bps_limit),
fill_duration_ms: None,
};
let bps_limiter = global_ctx
.token_bucket_manager()
.get_or_create(&network.network_name, limiter_config.into());
Self {
my_peer_id,
@@ -116,6 +129,8 @@ impl ForeignNetworkEntry {
packet_recv: Mutex::new(Some(packet_recv)),
bps_limiter,
tasks: Mutex::new(JoinSet::new()),
lock: Mutex::new(()),
@@ -265,6 +280,7 @@ impl ForeignNetworkEntry {
let relay_data = self.relay_data;
let pm_sender = self.pm_packet_sender.lock().await.take().unwrap();
let network_name = self.network.network_name.clone();
let bps_limiter = self.bps_limiter.clone();
self.tasks.lock().await.spawn(async move {
while let Ok(zc_packet) = recv_packet_from_chan(&mut recv).await {
@@ -284,8 +300,13 @@ impl ForeignNetworkEntry {
}
tracing::trace!(?hdr, "ignore packet in foreign network");
} else {
if !relay_data && hdr.packet_type == PacketType::Data as u8 {
continue;
if hdr.packet_type == PacketType::Data as u8 {
if !relay_data {
continue;
}
if !bps_limiter.try_consume(hdr.len.into()) {
continue;
}
}
let gateway_peer_id = peer_map

View File

@@ -1066,7 +1066,7 @@ impl PeerManager {
pub fn get_directly_connections(&self, peer_id: PeerId) -> DashSet<uuid::Uuid> {
if let Some(peer) = self.peers.get_peer_by_id(peer_id) {
return peer.get_directly_connections()
return peer.get_directly_connections();
}
DashSet::new()

View File

@@ -40,6 +40,9 @@ message FlagsInConfig {
bool enable_quic_proxy = 24;
// does this peer allow quic input
bool disable_quic_input = 25;
// a global relay limit, only work for foreign network
uint64 foreign_relay_bps_limit = 26;
}
message RpcDescriptor {
@@ -180,3 +183,9 @@ message PortForwardConfigPb {
message ProxyDstInfo {
SocketAddr dst_addr = 1;
}
message LimiterConfig {
optional uint64 burst_rate = 1; // default 1 means no burst (capacity is same with bps)
optional uint64 bps = 2; // default 0 means no limit (unit is B/s)
optional uint64 fill_duration_ms = 3; // default 10ms, the period to fill the bucket
}

View File

@@ -17,7 +17,9 @@ use crate::{
instance::instance::Instance,
proto::common::CompressionAlgoPb,
tunnel::{
common::tests::wait_for_condition, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
common::tests::{_tunnel_bench_netns, wait_for_condition},
ring::RingTunnelConnector,
tcp::{TcpTunnelConnector, TcpTunnelListener},
udp::UdpTunnelConnector,
},
};
@@ -1195,3 +1197,46 @@ pub async fn port_forward_test(
drop_insts(_insts).await;
}
#[rstest::rstest]
#[serial_test::serial]
#[tokio::test]
pub async fn relay_bps_limit_test(#[values(100, 200, 400, 800)] bps_limit: u64) {
let insts = init_three_node_ex(
"udp",
|cfg| {
if cfg.get_inst_name() == "inst2" {
cfg.set_network_identity(NetworkIdentity::new(
"public".to_string(),
"public".to_string(),
));
let mut f = cfg.get_flags();
f.foreign_relay_bps_limit = bps_limit * 1024;
cfg.set_flags(f);
}
cfg
},
true,
)
.await;
// connect to virtual ip (no tun mode)
let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:22223".parse().unwrap());
let tcp_connector = TcpTunnelConnector::new("tcp://10.144.144.3:22223".parse().unwrap());
let bps = _tunnel_bench_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_c".into())),
NetNS::new(Some("net_a".into())),
)
.await;
println!("bps: {}", bps);
let bps = bps as u64 / 1024;
// allow 50kb jitter
assert!(bps >= bps_limit - 50 && bps <= bps_limit + 50);
drop_insts(insts).await;
}

View File

@@ -436,9 +436,10 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
}
pub mod tests {
use std::time::Instant;
use atomic_shim::AtomicU64;
use std::{sync::Arc, time::Instant};
use futures::{Future, SinkExt, StreamExt, TryStreamExt};
use futures::{Future, SinkExt, StreamExt};
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use crate::{
@@ -554,21 +555,56 @@ pub mod tests {
}
}
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
pub(crate) async fn _tunnel_bench<L, C>(listener: L, connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
listener.listen().await.unwrap();
_tunnel_bench_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await;
}
pub(crate) async fn _tunnel_bench_netns<L, C>(
mut listener: L,
mut connector: C,
netns_l: NetNS,
netns_c: NetNS,
) -> usize
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
{
let _g = netns_l.guard();
listener.listen().await.unwrap();
}
let bps = Arc::new(AtomicU64::new(0));
let bps_clone = bps.clone();
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
_tunnel_echo_server(ret, false).await
// _tunnel_echo_server(ret, false).await
let (mut r, _s) = ret.split();
let now = Instant::now();
let mut count = 0;
while let Some(Ok(p)) = r.next().await {
count += p.payload_len();
let elapsed_sec = now.elapsed().as_secs();
if elapsed_sec > 0 {
bps_clone.store(
count as u64 / now.elapsed().as_secs() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
}
});
let tunnel = connector.connect().await.unwrap();
let tunnel = {
let _g = netns_c.guard();
connector.connect().await.unwrap()
};
let (recv, mut send) = tunnel.split();
let (_recv, mut send) = tunnel.split();
// prepare a 4k buffer with random data
let mut send_buf = BytesMut::new();
@@ -576,22 +612,6 @@ pub mod tests {
send_buf.put_i128(rand::random::<i128>());
}
let r = tokio::spawn(async move {
let now = Instant::now();
let count = recv
.try_fold(0usize, |mut ret, _| async move {
ret += 1;
Ok(ret)
})
.await
.unwrap();
println!(
"bps: {}",
(count / 1024) * 4 / now.elapsed().as_secs() as usize
);
});
let now = Instant::now();
while now.elapsed().as_secs() < 10 {
// send.feed(item)
@@ -605,11 +625,11 @@ pub mod tests {
drop(tunnel);
tracing::warn!("wait for recv to finish...");
let _ = tokio::join!(r);
let bps = bps.load(std::sync::atomic::Ordering::Acquire);
println!("bps: {}", bps);
lis.abort();
let _ = tokio::join!(lis);
bps as usize
}
pub fn enable_log() {

View File

@@ -11,8 +11,8 @@ use crate::tunnel::{
use anyhow::Context;
use quinn::{
crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig,
TransportConfig,
congestion::BbrConfig, crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint,
ServerConfig, TransportConfig,
};
use super::{
@@ -20,8 +20,6 @@ use super::{
insecure_tls::{get_insecure_tls_cert, get_insecure_tls_client_config},
IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
use quinn_proto::congestion::BbrConfig;
pub fn configure_client() -> ClientConfig {
let client_crypto = QuicClientConfig::try_from(get_insecure_tls_client_config()).unwrap();