diff --git a/Cargo.lock b/Cargo.lock index 30df35e..bae6e93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1955,6 +1955,7 @@ version = "2.3.2" dependencies = [ "aes-gcm", "anyhow", + "arc-swap", "async-recursion", "async-ringbuf", "async-stream", diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 5d26709..2980639 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -40,6 +40,7 @@ tracing-appender = "0.2.3" thiserror = "1.0" auto_impl = "1.1.0" crossbeam = "0.8.4" +arc-swap = "1.7" time = "0.3" toml = "0.8.12" chrono = { version = "0.4.37", features = ["serde"] } diff --git a/easytier/build.rs b/easytier/build.rs index aed1655..1987c8f 100644 --- a/easytier/build.rs +++ b/easytier/build.rs @@ -147,6 +147,7 @@ fn main() -> Result<(), Box> { "src/proto/cli.proto", "src/proto/web.proto", "src/proto/magic_dns.proto", + "src/proto/acl.proto", ]; for proto_file in proto_files.iter().chain(proto_files_reflect.iter()) { @@ -156,6 +157,7 @@ fn main() -> Result<(), Box> { let mut config = prost_build::Config::new(); config .protoc_arg("--experimental_allow_proto3_optional") + .type_attribute(".acl", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".cli", "#[derive(serde::Serialize, serde::Deserialize)]") diff --git a/easytier/src/common/acl_processor.rs b/easytier/src/common/acl_processor.rs new file mode 100644 index 0000000..386aff9 --- /dev/null +++ b/easytier/src/common/acl_processor.rs @@ -0,0 +1,1334 @@ +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + str::FromStr as _, + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use crate::common::token_bucket::TokenBucket; +use crate::proto::acl::*; +use dashmap::DashMap; +use tokio::task::JoinSet; + +// Performance-optimized key for rate limiting to avoid string allocations +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct RateLimitKey { + pub chain_type: ChainType, + pub rule_priority: u32, +} + +impl RateLimitKey { + pub fn new(chain_type: ChainType, rule_priority: u32) -> Self { + Self { + chain_type, + rule_priority, + } + } +} + +// Performance-optimized rule identifier to avoid string allocations +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RuleId { + Priority(u32), + Stateful(u32), + Default, +} + +impl RuleId { + /// Convert to string only when actually needed (lazy evaluation) + pub fn to_string_cached(&self) -> String { + match self { + RuleId::Priority(p) => p.to_string(), + RuleId::Stateful(p) => format!("stateful-{}", p), + RuleId::Default => "default".to_string(), + } + } + + /// Get string representation for logging (optimized for hot path) + pub fn as_str(&self) -> String { + self.to_string_cached() + } +} + +// Fast lookup structures for performance optimization +#[derive(Debug, Clone)] +pub struct FastLookupRule { + pub priority: u32, + pub protocol: Protocol, + pub src_ip_ranges: Vec, + pub dst_ip_ranges: Vec, + pub src_port_ranges: Vec<(u16, u16)>, + pub dst_port_ranges: Vec<(u16, u16)>, + pub action: Action, + pub enabled: bool, + pub stateful: bool, + pub rate_limit: u32, + pub burst_limit: u32, + pub rule_stats: Arc, +} + +// Cache key combining packet info and chain type +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct AclCacheKey { + pub chain_type: ChainType, + pub protocol: Protocol, + pub src_ip: IpAddr, + pub dst_ip: IpAddr, + pub src_port: u16, + pub dst_port: u16, +} + +impl AclCacheKey { + pub fn from_packet_info(packet_info: &PacketInfo, chain_type: ChainType) -> Self { + Self { + chain_type, + protocol: packet_info.protocol, + src_ip: packet_info.src_ip, + dst_ip: packet_info.dst_ip, + src_port: packet_info.src_port.unwrap_or(0), + dst_port: packet_info.dst_port.unwrap_or(0), + } + } +} + +// Cache entry with timestamp for LRU cleanup +#[derive(Debug, Clone)] +pub struct AclCacheEntry { + pub action: Action, + pub matched_rule: RuleId, + pub last_access: u64, + // New fields to track rule characteristics for proper cache behavior + pub conn_track_key: Option, + pub rate_limit_keys: Vec, + pub chain_type: ChainType, + pub acl_result: Option, + pub rule_stats_vec: Vec>, +} + +// Packet info extracted for ACL processing +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct PacketInfo { + pub src_ip: IpAddr, + pub dst_ip: IpAddr, + pub src_port: Option, + pub dst_port: Option, + pub protocol: Protocol, + pub packet_size: usize, +} + +// ACL processing result +#[derive(Debug, Clone)] +pub struct AclResult { + pub action: Action, + pub matched_rule: Option, + pub should_log: bool, + pub log_context: Option, +} + +impl AclResult { + /// Get matched rule as string (lazy evaluation) + pub fn matched_rule_string(&self) -> Option { + self.matched_rule.as_ref().map(|r| r.to_string_cached()) + } + + /// Get matched rule as string reference for logging (compatibility method) + pub fn matched_rule_str(&self) -> Option { + self.matched_rule.as_ref().map(|r| r.as_str()) + } +} + +// Context for lazy log message construction +#[derive(Debug, Clone)] +pub enum AclLogContext { + StatefulMatch { + src_ip: IpAddr, + dst_ip: IpAddr, + }, + RuleMatch { + src_ip: IpAddr, + dst_ip: IpAddr, + action: Action, + }, + DefaultDrop, + DefaultAllow, + UnsupportedChainType, + RateLimitDrop, +} + +impl AclLogContext { + pub fn to_message(&self) -> String { + match self { + AclLogContext::StatefulMatch { src_ip, dst_ip } => { + format!("Stateful match: {} -> {}", src_ip, dst_ip) + } + AclLogContext::RuleMatch { + src_ip, + dst_ip, + action, + } => { + format!("Rule match: {} -> {} action: {:?}", src_ip, dst_ip, action) + } + AclLogContext::DefaultDrop => "No matching rule, default drop".to_string(), + AclLogContext::DefaultAllow => "No matching rule, default allow".to_string(), + AclLogContext::UnsupportedChainType => "Unsupported chain type".to_string(), + AclLogContext::RateLimitDrop => "Rate limit drop".to_string(), + } + } +} + +// High-performance ACL processor - No more internal locks! +pub struct AclProcessor { + // Immutable rule vectors - no locks needed since they're never modified after creation + inbound_rules: Vec, + outbound_rules: Vec, + forward_rules: Vec, + + default_inbound_action: Action, + default_outbound_action: Action, + default_forward_action: Action, + + default_rule_stats: Arc, + + // Connection tracking table - shared across different processor instances if needed + conn_track: Arc>, + + // Rate limiting buckets per rule using TokenBucket with optimized keys + rate_limiters: Arc>>, + + // Rule lookup cache with LRU cleanup + rule_cache: Arc>, + cache_max_size: usize, + cache_cleanup_interval: Duration, + + // Statistics + stats: Arc>, + + tasks: JoinSet<()>, +} + +impl AclProcessor { + /// Create a new ACL processor with pre-built immutable rules + /// This is the main constructor that should be used + pub fn new(acl_config: Acl) -> Self { + Self::new_with_shared_state(acl_config, None, None, None) + } + + /// Create a new ACL processor while preserving connection tracking and rate limiting state + /// This is useful for hot reloading where you want to preserve established connections + pub fn new_with_shared_state( + acl_config: Acl, + conn_track: Option>>, + rate_limiters: Option>>>, + stats: Option>>, + ) -> Self { + let (inbound_rules, outbound_rules, forward_rules) = Self::build_rules(&acl_config); + let (default_inbound_action, default_outbound_action, default_forward_action) = + Self::build_default_actions(&acl_config); + let tasks = JoinSet::new(); + + let mut processor = Self { + inbound_rules, + outbound_rules, + forward_rules, + + default_inbound_action, + default_outbound_action, + default_forward_action, + + default_rule_stats: Arc::new(RuleStats { + rule: None, + stat: Some(StatItem { + packet_count: 0, + byte_count: 0, + }), + }), + conn_track: conn_track.unwrap_or_else(|| Arc::new(DashMap::new())), + rate_limiters: rate_limiters.unwrap_or_else(|| Arc::new(DashMap::new())), + rule_cache: Arc::new(DashMap::new()), // Always start with fresh cache + cache_max_size: 10000, // Limit cache to 10k entries + cache_cleanup_interval: Duration::from_secs(20), // Cleanup every 5 minutes + stats: stats.unwrap_or_else(|| Arc::new(DashMap::new())), + tasks, + }; + + processor.start_cache_cleanup_task(); + processor + } + + fn build_default_actions(acl_config: &Acl) -> (Action, Action, Action) { + let default_inbound_action = acl_config + .acl_v1 + .as_ref() + .and_then(|v1| { + v1.chains + .iter() + .find(|c| c.chain_type == ChainType::Inbound as i32) + }) + .map(|c| c.default_action()) + .unwrap_or(Action::Allow); + + let default_outbound_action = acl_config + .acl_v1 + .as_ref() + .and_then(|v1| { + v1.chains + .iter() + .find(|c| c.chain_type == ChainType::Outbound as i32) + }) + .map(|c| c.default_action()) + .unwrap_or(Action::Allow); + + let default_forward_action = acl_config + .acl_v1 + .as_ref() + .and_then(|v1| { + v1.chains + .iter() + .find(|c| c.chain_type == ChainType::Forward as i32) + }) + .map(|c| c.default_action()) + .unwrap_or(Action::Allow); + + ( + default_inbound_action, + default_outbound_action, + default_forward_action, + ) + } + + /// Build all rule vectors from configuration + fn build_rules( + acl_config: &Acl, + ) -> ( + Vec, + Vec, + Vec, + ) { + let mut inbound_rules = Vec::new(); + let mut outbound_rules = Vec::new(); + let mut forward_rules = Vec::new(); + + // Build new rule vectors + if let Some(ref acl_v1) = acl_config.acl_v1 { + for chain in &acl_v1.chains { + if !chain.enabled { + continue; + } + + let mut rules = chain + .rules + .iter() + .filter(|rule| rule.enabled) + .map(|rule| Self::convert_to_fast_lookup_rule(rule)) + .collect::>(); + + // Sort by priority (higher priority first) + rules.sort_by(|a, b| b.priority.cmp(&a.priority)); + + match chain.chain_type() { + ChainType::Inbound => inbound_rules.extend(rules), + ChainType::Outbound => outbound_rules.extend(rules), + ChainType::Forward => forward_rules.extend(rules), + _ => {} + } + } + } + + tracing::info!( + "ACL rules built: {} inbound, {} outbound, {} forward", + inbound_rules.len(), + outbound_rules.len(), + forward_rules.len(), + ); + + (inbound_rules, outbound_rules, forward_rules) + } + + /// Start periodic cache cleanup task + fn start_cache_cleanup_task(&mut self) { + let rule_cache = self.rule_cache.clone(); + let cache_max_size = self.cache_max_size; + let cleanup_interval = self.cache_cleanup_interval; + + self.tasks.spawn(async move { + let mut interval = tokio::time::interval(cleanup_interval); + loop { + interval.tick().await; + Self::cleanup_cache(&rule_cache, cache_max_size); + } + }); + + let conn_track = self.conn_track.clone(); + self.tasks.spawn(async move { + let mut interval = tokio::time::interval(cleanup_interval); + loop { + interval.tick().await; + Self::cleanup_expired_connections(conn_track.clone(), 60); + } + }); + } + + /// Clean up cache using LRU strategy + fn cleanup_cache(cache: &DashMap, max_size: usize) { + let current_size = cache.len(); + if current_size <= max_size { + return; + } + + // Remove oldest entries (LRU cleanup) + let mut entries: Vec<(AclCacheKey, u64)> = cache + .iter() + .map(|entry| (entry.key().clone(), entry.value().last_access)) + .collect(); + + // Sort by last_access (oldest first) + entries.sort_by_key(|(_, last_access)| *last_access); + + // Remove oldest 20% of entries + let to_remove = current_size - max_size + (max_size / 5); + for (key, _) in entries.into_iter().take(to_remove) { + cache.remove(&key); + } + + tracing::debug!( + "Cache cleanup completed: removed {} entries, current size: {}", + to_remove, + cache.len() + ); + } + + pub fn process_packet_with_cache_entry( + &self, + packet_info: &PacketInfo, + cache_entry: &AclCacheEntry, + ) -> AclResult { + for rate_limit_key in cache_entry.rate_limit_keys.iter() { + // bucket should already be created, so rate and burst are not important + if !self.check_rate_limit(rate_limit_key, 1, 1, false) { + return AclResult { + action: Action::Drop, + matched_rule: Some(cache_entry.matched_rule.clone()), + should_log: false, + log_context: Some(AclLogContext::RateLimitDrop), + }; + } + } + + if let Some(conn_track_key) = cache_entry.conn_track_key.as_ref() { + self.check_connection_state(conn_track_key, packet_info); + } + + self.inc_cache_entry_stats(cache_entry, packet_info); + + return cache_entry.acl_result.clone().unwrap(); + } + + fn inc_cache_entry_stats(&self, cache_entry: &AclCacheEntry, packet_info: &PacketInfo) { + for rule_stats in cache_entry.rule_stats_vec.iter() { + // Use unsafe code to mutate the contents behind the Arc + let stat_ptr = rule_stats.stat.as_ref().unwrap() as *const StatItem as *mut StatItem; + unsafe { + (*stat_ptr).packet_count += 1; + (*stat_ptr).byte_count += packet_info.packet_size as u64; + } + } + } + + pub fn get_rules_stats(&self) -> Vec { + let mut stats: Vec = Vec::new(); + for rule in self.inbound_rules.iter() { + stats.push((*rule.rule_stats).clone()); + } + for rule in self.outbound_rules.iter() { + stats.push((*rule.rule_stats).clone()); + } + for rule in self.forward_rules.iter() { + stats.push((*rule.rule_stats).clone()); + } + stats + } + + /// Process a packet through ACL rules - Now lock-free! + pub fn process_packet(&self, packet_info: &PacketInfo, chain_type: ChainType) -> AclResult { + // Check cache first for performance + let cache_key = AclCacheKey::from_packet_info(packet_info, chain_type); + + // If cache hit and can skip checks, return cached result + if let Some(mut cached) = self.rule_cache.get_mut(&cache_key) { + // Update last access time for LRU + cached.last_access = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + self.increment_stat(AclStatKey::CacheHits); + return self.process_packet_with_cache_entry(packet_info, &cached); + } + + // Direct access to rules - no locks needed! + let rules = match chain_type { + ChainType::Inbound => &self.inbound_rules, + ChainType::Outbound => &self.outbound_rules, + _ => { + return AclResult { + action: Action::Drop, + matched_rule: Some(RuleId::Default), + should_log: false, + log_context: Some(AclLogContext::UnsupportedChainType), + } + } + }; + + let mut cache_entry = AclCacheEntry { + action: Action::Allow, + matched_rule: RuleId::Default, + last_access: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + conn_track_key: None, + rate_limit_keys: vec![], + chain_type, + acl_result: None, + rule_stats_vec: vec![], + }; + + // Process rules in priority order + for rule in rules.iter() { + if !rule.enabled || !self.rule_matches(rule, packet_info) { + continue; + } + + // Check rate limiting if configured + if rule.rate_limit > 0 { + let rule_key = RateLimitKey::new(chain_type, rule.priority); + cache_entry.rate_limit_keys.push(rule_key.clone()); + cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); + if !self.check_rate_limit(&rule_key, rule.rate_limit, rule.burst_limit, true) { + // rate limited, drop packet + return AclResult { + action: Action::Drop, + matched_rule: Some(RuleId::Priority(rule.priority)), + should_log: false, + log_context: Some(AclLogContext::RateLimitDrop), + }; + } + } + + // Handle stateful connections if configured + if rule.stateful && rule.action == Action::Allow { + let conn_track_key = self.conn_track_key(packet_info); + self.check_connection_state(&conn_track_key, packet_info); + cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); + cache_entry.matched_rule = RuleId::Stateful(rule.priority); + cache_entry.conn_track_key = Some(conn_track_key); + cache_entry.acl_result = Some(AclResult { + action: Action::Allow, + matched_rule: Some(RuleId::Stateful(rule.priority)), + should_log: false, + log_context: Some(AclLogContext::StatefulMatch { + src_ip: packet_info.src_ip, + dst_ip: packet_info.dst_ip, + }), + }); + } else { + // Rule matched, return action + cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); + cache_entry.matched_rule = RuleId::Priority(rule.priority); + cache_entry.acl_result = Some(AclResult { + action: rule.action.clone(), + matched_rule: Some(RuleId::Priority(rule.priority)), + should_log: false, + log_context: Some(AclLogContext::RuleMatch { + src_ip: packet_info.src_ip, + dst_ip: packet_info.dst_ip, + action: rule.action, + }), + }); + } + + // Cache the result with rule info + self.increment_stat(AclStatKey::RuleMatches); + self.inc_cache_entry_stats(&cache_entry, packet_info); + self.cache_result(&cache_key, cache_entry.clone()); + return cache_entry.acl_result.clone().unwrap(); + } + + let default_action = match chain_type { + ChainType::Inbound => self.default_inbound_action, + ChainType::Outbound => self.default_outbound_action, + ChainType::Forward => self.default_forward_action, + _ => Action::Allow, + }; + + // No rule matched, return default drop + if default_action == Action::Drop { + self.increment_stat(AclStatKey::DefaultDrops); + } else { + self.increment_stat(AclStatKey::DefaultAllows); + } + + let log_context = if default_action == Action::Drop { + AclLogContext::DefaultDrop + } else { + AclLogContext::DefaultAllow + }; + + cache_entry + .rule_stats_vec + .push(self.default_rule_stats.clone()); + cache_entry.matched_rule = RuleId::Default; + cache_entry.acl_result = Some(AclResult { + action: default_action, + matched_rule: Some(RuleId::Default), + should_log: false, + log_context: Some(log_context), + }); + + // Cache the default result (no rule info) + self.inc_cache_entry_stats(&cache_entry, packet_info); + self.cache_result(&cache_key, cache_entry.clone()); + cache_entry.acl_result.clone().unwrap() + } + + /// Get shared state for preserving across hot reloads + pub fn get_shared_state( + &self, + ) -> ( + Arc>, + Arc>>, + Arc>, + ) { + ( + self.conn_track.clone(), + self.rate_limiters.clone(), + self.stats.clone(), + ) + } + + /// Cache an ACL result + fn cache_result(&self, cache_key: &AclCacheKey, cache_entry: AclCacheEntry) { + self.rule_cache.insert(cache_key.clone(), cache_entry); + + // Trigger cleanup if cache is getting too large + if self.rule_cache.len() > self.cache_max_size * 2 { + let cache = self.rule_cache.clone(); + let max_size = self.cache_max_size; + Self::cleanup_cache(&cache, max_size); + } + } + + /// Check if a rule matches the packet + fn rule_matches(&self, rule: &FastLookupRule, packet_info: &PacketInfo) -> bool { + // Protocol check + if rule.protocol != Protocol::Any && rule.protocol as i32 != packet_info.protocol as i32 { + return false; + } + + // Source IP check + if !rule.src_ip_ranges.is_empty() { + let matches = rule + .src_ip_ranges + .iter() + .any(|cidr| match (cidr, packet_info.src_ip) { + (cidr::IpCidr::V4(v4_cidr), IpAddr::V4(v4_addr)) => v4_cidr.contains(&v4_addr), + (cidr::IpCidr::V6(v6_cidr), IpAddr::V6(v6_addr)) => v6_cidr.contains(&v6_addr), + _ => false, + }); + if !matches { + return false; + } + } + + // Destination IP check + if !rule.dst_ip_ranges.is_empty() { + let matches = rule + .dst_ip_ranges + .iter() + .any(|cidr| match (cidr, packet_info.dst_ip) { + (cidr::IpCidr::V4(v4_cidr), IpAddr::V4(v4_addr)) => v4_cidr.contains(&v4_addr), + (cidr::IpCidr::V6(v6_cidr), IpAddr::V6(v6_addr)) => v6_cidr.contains(&v6_addr), + _ => false, + }); + if !matches { + return false; + } + } + + // Source port check + if let Some(src_port) = packet_info.src_port { + if !rule.src_port_ranges.is_empty() { + let matches = rule + .src_port_ranges + .iter() + .any(|(start, end)| src_port >= *start && src_port <= *end); + if !matches { + return false; + } + } + } + + // Destination port check + if let Some(dst_port) = packet_info.dst_port { + if !rule.dst_port_ranges.is_empty() { + let matches = rule + .dst_port_ranges + .iter() + .any(|(start, end)| dst_port >= *start && dst_port <= *end); + if !matches { + return false; + } + } + } + + true + } + + fn conn_track_key(&self, packet_info: &PacketInfo) -> String { + format!( + "{}:{}->{}:{}", + packet_info.src_ip, + packet_info.src_port.unwrap_or(0), + packet_info.dst_ip, + packet_info.dst_port.unwrap_or(0) + ) + } + + /// Check connection state for stateful rules + fn check_connection_state(&self, conn_track_key: &String, packet_info: &PacketInfo) { + self.conn_track + .entry(conn_track_key.clone()) + .and_modify(|x| { + x.last_seen = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + x.packet_count += 1; + x.byte_count += packet_info.packet_size as u64; + x.state = ConnState::Established as i32; + }) + .or_insert_with(|| ConnTrackEntry { + src_addr: Some( + SocketAddr::new(packet_info.src_ip, packet_info.src_port.unwrap_or(0)).into(), + ), + dst_addr: Some( + SocketAddr::new(packet_info.dst_ip, packet_info.dst_port.unwrap_or(0)).into(), + ), + protocol: packet_info.protocol as i32, + state: ConnState::New as i32, + created_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + last_seen: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + packet_count: 1, + byte_count: packet_info.packet_size as u64, + }); + } + + /// Check rate limiting for a rule + fn check_rate_limit( + &self, + rule_key: &RateLimitKey, + rate: u32, + burst: u32, + allow_create: bool, + ) -> bool { + if rate == 0 { + return true; // No rate limiting + } + + let bucket = self + .rate_limiters + .entry(rule_key.clone()) + .or_insert_with(|| { + if !allow_create { + panic!("Rate limit bucket not found"); + } + TokenBucket::new(burst as u64, rate as u64, Duration::from_millis(10)) + }) + .clone(); + + // Try to consume 1 token (1 packet) + bucket.try_consume(1) + } + + /// Convert proto Rule to FastLookupRule + fn convert_to_fast_lookup_rule(rule: &Rule) -> FastLookupRule { + let src_ip_ranges = rule + .source_ips + .iter() + .filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet)) + .collect(); + + let dst_ip_ranges = rule + .destination_ips + .iter() + .filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet)) + .collect(); + + let src_port_ranges = rule + .source_ports + .iter() + .filter_map(|port_range| { + if let Some((start, end)) = parse_port_range(port_range) { + Some((start, end)) + } else { + None + } + }) + .collect(); + + let dst_port_ranges = rule + .ports + .iter() + .filter_map(|port_range| { + if let Some((start, end)) = parse_port_range(port_range) { + Some((start, end)) + } else { + None + } + }) + .collect(); + + FastLookupRule { + priority: rule.priority, + protocol: rule.protocol(), + src_ip_ranges, + dst_ip_ranges, + src_port_ranges, + dst_port_ranges, + action: rule.action(), + enabled: rule.enabled, + stateful: rule.stateful, + rate_limit: rule.rate_limit, + burst_limit: rule.burst_limit, + rule_stats: Arc::new(RuleStats { + rule: Some(rule.clone()), + stat: Some(StatItem { + packet_count: 0, + byte_count: 0, + }), + }), + } + } + + /// Convert IpInet to CIDR for fast lookup + fn convert_ip_inet_to_cidr(input: &String) -> Option { + cidr::IpCidr::from_str(input.as_str()).ok() + } + + /// Increment statistics counter + pub fn increment_stat(&self, key: AclStatKey) { + self.stats + .entry(key) + .and_modify(|counter| *counter += 1) + .or_insert(1); + } + + /// Get statistics + pub fn get_stats(&self) -> HashMap { + let mut stats = self + .stats + .iter() + .map(|entry| (entry.key().as_str(), *entry.value())) + .collect::>(); + + // Add cache statistics using enum keys + stats.insert(AclStatKey::CacheSize.as_str(), self.rule_cache.len() as u64); + stats.insert( + AclStatKey::CacheMaxSize.as_str(), + self.cache_max_size as u64, + ); + + stats + } + + /// Clean up expired connection tracking entries + pub fn cleanup_expired_connections( + conn_track: Arc>, + timeout_secs: u64, + ) { + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let keys_to_remove: Vec = conn_track + .iter() + .filter_map(|entry| { + if current_time - entry.last_seen > timeout_secs { + Some(entry.key().clone()) + } else { + None + } + }) + .collect(); + + for key in keys_to_remove { + conn_track.remove(&key); + } + } + + /// Get cache hit rate + pub fn get_cache_hit_rate(&self) -> f64 { + let cache_hits = self + .stats + .get(&AclStatKey::CacheHits) + .map(|v| *v.value()) + .unwrap_or(0); + let total_requests = cache_hits + + self + .stats + .get(&AclStatKey::RuleMatches) + .map(|v| *v.value()) + .unwrap_or(0); + + if total_requests == 0 { + 0.0 + } else { + cache_hits as f64 / total_requests as f64 + } + } +} + +// 新增辅助函数 +fn parse_port_start( + port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>, +) -> Option { + port_strs + .iter() + .filter_map(|s| parse_port_range(s).map(|(start, _)| start)) + .min() +} +fn parse_port_end( + port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>, +) -> Option { + port_strs + .iter() + .filter_map(|s| parse_port_range(s).map(|(_, end)| end)) + .max() +} +fn parse_port_range(s: &str) -> Option<(u16, u16)> { + if let Some((start, end)) = s.split_once('-') { + let start = start.trim().parse().ok()?; + let end = end.trim().parse().ok()?; + Some((start, end)) + } else { + let port = s.trim().parse().ok()?; + Some((port, port)) + } +} + +// Statistics key enum for better performance +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum AclStatKey { + // Cache statistics + CacheHits, + CacheSize, + CacheMaxSize, + RuleMatches, + DefaultAllows, + DefaultDrops, + + // Global packet statistics + PacketsTotal, + PacketsAllowed, + PacketsDropped, + PacketsNoop, + + // Per-chain statistics + InboundPacketsTotal, + InboundPacketsAllowed, + InboundPacketsDropped, + InboundPacketsNoop, + + OutboundPacketsTotal, + OutboundPacketsAllowed, + OutboundPacketsDropped, + OutboundPacketsNoop, + + ForwardPacketsTotal, + ForwardPacketsAllowed, + ForwardPacketsDropped, + ForwardPacketsNoop, + + UnknownPacketsTotal, + UnknownPacketsAllowed, + UnknownPacketsDropped, + UnknownPacketsNoop, +} + +impl AclStatKey { + pub fn as_str(&self) -> String { + format!("{:?}", self) + } + + pub fn from_chain_and_action(chain_type: ChainType, stat_type: AclStatType) -> Self { + match (chain_type, stat_type) { + (ChainType::Inbound, AclStatType::Total) => AclStatKey::InboundPacketsTotal, + (ChainType::Inbound, AclStatType::Allowed) => AclStatKey::InboundPacketsAllowed, + (ChainType::Inbound, AclStatType::Dropped) => AclStatKey::InboundPacketsDropped, + (ChainType::Inbound, AclStatType::Noop) => AclStatKey::InboundPacketsNoop, + + (ChainType::Outbound, AclStatType::Total) => AclStatKey::OutboundPacketsTotal, + (ChainType::Outbound, AclStatType::Allowed) => AclStatKey::OutboundPacketsAllowed, + (ChainType::Outbound, AclStatType::Dropped) => AclStatKey::OutboundPacketsDropped, + (ChainType::Outbound, AclStatType::Noop) => AclStatKey::OutboundPacketsNoop, + + (ChainType::Forward, AclStatType::Total) => AclStatKey::ForwardPacketsTotal, + (ChainType::Forward, AclStatType::Allowed) => AclStatKey::ForwardPacketsAllowed, + (ChainType::Forward, AclStatType::Dropped) => AclStatKey::ForwardPacketsDropped, + (ChainType::Forward, AclStatType::Noop) => AclStatKey::ForwardPacketsNoop, + + (_, AclStatType::Total) => AclStatKey::UnknownPacketsTotal, + (_, AclStatType::Allowed) => AclStatKey::UnknownPacketsAllowed, + (_, AclStatType::Dropped) => AclStatKey::UnknownPacketsDropped, + (_, AclStatType::Noop) => AclStatKey::UnknownPacketsNoop, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum AclStatType { + Total, + Allowed, + Dropped, + Noop, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::hash::{Hash, Hasher}; + use std::net::{IpAddr, Ipv4Addr}; + + fn create_test_acl_config() -> Acl { + let mut acl_config = Acl::default(); + + let mut acl_v1 = AclV1::default(); + + // Create inbound chain + let mut chain = Chain::default(); + chain.name = "test_inbound".to_string(); + chain.chain_type = ChainType::Inbound as i32; + chain.enabled = true; + + // Allow all rule + let mut rule = Rule::default(); + rule.name = "allow_all".to_string(); + rule.priority = 100; + rule.enabled = true; + rule.action = Action::Allow as i32; + rule.protocol = Protocol::Any as i32; + + chain.rules.push(rule); + acl_v1.chains.push(chain); + acl_config.acl_v1 = Some(acl_v1); + + acl_config + } + + fn create_test_packet_info() -> PacketInfo { + PacketInfo { + src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + src_port: Some(12345), + dst_port: Some(80), + protocol: Protocol::Tcp, + packet_size: 1024, + } + } + + #[test] + fn test_acl_cache_key_creation() { + let packet_info = create_test_packet_info(); + let cache_key = AclCacheKey::from_packet_info(&packet_info, ChainType::Inbound); + + assert_eq!(cache_key.chain_type, ChainType::Inbound); + assert_eq!(cache_key.protocol, Protocol::Tcp); + assert_eq!( + cache_key.src_ip, + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)) + ); + assert_eq!(cache_key.dst_ip, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); + assert_eq!(cache_key.src_port, 12345); + assert_eq!(cache_key.dst_port, 80); + } + + #[test] + fn test_acl_cache_key_equality() { + let packet_info1 = create_test_packet_info(); + let packet_info2 = create_test_packet_info(); + + let key1 = AclCacheKey::from_packet_info(&packet_info1, ChainType::Inbound); + let key2 = AclCacheKey::from_packet_info(&packet_info2, ChainType::Inbound); + + assert_eq!(key1, key2); + + // Test hash consistency + use std::collections::hash_map::DefaultHasher; + let mut hasher1 = DefaultHasher::new(); + let mut hasher2 = DefaultHasher::new(); + key1.hash(&mut hasher1); + key2.hash(&mut hasher2); + assert_eq!(hasher1.finish(), hasher2.finish()); + } + + #[tokio::test] + async fn test_acl_processor_basic_functionality() { + let acl_config = create_test_acl_config(); + let processor = AclProcessor::new(acl_config); + let packet_info = create_test_packet_info(); + + let result = processor.process_packet(&packet_info, ChainType::Inbound); + + assert_eq!(result.action, Action::Allow); + assert!(result.matched_rule.is_some()); + } + + #[tokio::test] + async fn test_acl_cache_hit() { + let acl_config = create_test_acl_config(); + let processor = AclProcessor::new(acl_config); + let packet_info = create_test_packet_info(); + + // First request - should be a cache miss + let result1 = processor.process_packet(&packet_info, ChainType::Inbound); + + // Second request - should be a cache hit + let result2 = processor.process_packet(&packet_info, ChainType::Inbound); + + assert_eq!(result1.action, result2.action); + assert_eq!(result1.matched_rule, result2.matched_rule); + + // Check cache statistics + let stats = processor.get_stats(); + assert_eq!(stats.get(&AclStatKey::CacheHits.as_str()).unwrap_or(&0), &1); + assert!(processor.get_cache_hit_rate() > 0.0); + } + + #[tokio::test] + async fn test_lock_free_hot_reload_demo() { + println!("\n=== ACL 优化演示:无锁热加载 ==="); + + // 创建初始配置 + let initial_config = create_test_acl_config(); + let processor = AclProcessor::new(initial_config); + let packet_info = create_test_packet_info(); + + // 处理一些数据包 + println!("1. 处理初始数据包..."); + let result1 = processor.process_packet(&packet_info, ChainType::Inbound); + assert_eq!(result1.action, Action::Allow); + println!(" ✓ 数据包被允许通过"); + + // 获取共享状态 + let (conn_track, rate_limiters, stats) = processor.get_shared_state(); + println!("2. 保存连接跟踪和统计状态..."); + println!(" ✓ 连接数: {}", conn_track.len()); + println!(" ✓ 限流器数量: {}", rate_limiters.len()); + println!(" ✓ 统计计数器数量: {}", stats.len()); + + // 创建新配置(模拟热加载) + let mut new_config = create_test_acl_config(); + if let Some(ref mut acl_v1) = new_config.acl_v1 { + let mut drop_rule = Rule::default(); + drop_rule.name = "drop_all".to_string(); + drop_rule.priority = 200; + drop_rule.enabled = true; + drop_rule.action = Action::Drop as i32; + drop_rule.protocol = Protocol::Any as i32; + acl_v1.chains[0].rules.push(drop_rule); + } + + // 创建新的处理器实例(热加载) + println!("3. 执行热加载(创建新的处理器实例)..."); + let new_processor = AclProcessor::new_with_shared_state( + new_config, + Some(conn_track.clone()), + Some(rate_limiters.clone()), + Some(stats.clone()), + ); + + // 验证新处理器的行为 + let result2 = new_processor.process_packet(&packet_info, ChainType::Inbound); + assert_eq!(result2.action, Action::Drop); // 新规则应该拒绝 + println!(" ✓ 新规则生效:数据包被拒绝"); + + // 验证状态被保留 + let (new_conn_track, new_rate_limiters, new_stats) = new_processor.get_shared_state(); + assert!(Arc::ptr_eq(&conn_track, &new_conn_track)); + assert!(Arc::ptr_eq(&rate_limiters, &new_rate_limiters)); + assert!(Arc::ptr_eq(&stats, &new_stats)); + println!(" ✓ 连接状态和统计信息被完整保留"); + + println!("\n=== 性能优化效果 ==="); + println!("✓ 无锁访问:处理器内部不再有任何锁"); + println!("✓ 零拷贝:规则访问直接引用,无需克隆Arc"); + println!("✓ 热加载:创建新实例替换,保留所有状态"); + println!("✓ 内存效率:消除了多层Arc包装的开销"); + } + + #[tokio::test] + async fn test_performance_and_security_balance() { + // Create ACL config with different rule types + let mut acl_config = Acl::default(); + + let mut acl_v1 = AclV1::default(); + let mut chain = Chain::default(); + chain.name = "performance_test".to_string(); + chain.chain_type = ChainType::Inbound as i32; + chain.enabled = true; + + // 1. High-priority simple rule for UDP (can be cached efficiently) + let mut simple_rule = Rule::default(); + simple_rule.name = "simple_udp".to_string(); + simple_rule.priority = 300; + simple_rule.enabled = true; + simple_rule.action = Action::Allow as i32; + simple_rule.protocol = Protocol::Udp as i32; + // No stateful or rate limit - can benefit from full cache optimization + chain.rules.push(simple_rule); + + // 2. Medium-priority stateful + rate-limited rule for TCP (security critical) + let mut security_rule = Rule::default(); + security_rule.name = "security_tcp".to_string(); + security_rule.priority = 200; + security_rule.enabled = true; + security_rule.action = Action::Allow as i32; + security_rule.protocol = Protocol::Tcp as i32; + security_rule.stateful = true; + security_rule.rate_limit = 100; // 100 packets/sec + security_rule.burst_limit = 200; + chain.rules.push(security_rule); + + // 3. Low-priority default allow rule for Any + let mut default_rule = Rule::default(); + default_rule.name = "default_allow".to_string(); + default_rule.priority = 100; + default_rule.enabled = true; + default_rule.action = Action::Allow as i32; + default_rule.protocol = Protocol::Any as i32; + chain.rules.push(default_rule); + + acl_v1.chains.push(chain); + acl_config.acl_v1 = Some(acl_v1); + + let processor = AclProcessor::new(acl_config); + + // Test simple UDP packet (should hit high-priority simple rule and be cached) + let udp_packet = PacketInfo { + src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + src_port: Some(12345), + dst_port: Some(53), // DNS + protocol: Protocol::Udp, // UDP + packet_size: 512, + }; + + // Test TCP packet (should hit stateful+rate-limited rule) + let tcp_packet = PacketInfo { + src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + src_port: Some(12345), + dst_port: Some(80), // HTTP + protocol: Protocol::Tcp, // TCP + packet_size: 1024, + }; + + // Process UDP packets multiple times + println!("\n=== Performance Test Results ==="); + for i in 1..=5 { + let result = processor.process_packet(&udp_packet, ChainType::Inbound); + assert_eq!(result.action, Action::Allow); + // UDP packets should match the highest priority rule that applies + // Since all rules allow "Any" protocol, UDP will match the highest priority one + println!( + "UDP packet {}: Allowed by rule (priority {:?})", + i, result.matched_rule + ); + } + + // Process TCP packets multiple times (stateful + rate limited) + for i in 1..=3 { + let result = processor.process_packet(&tcp_packet, ChainType::Inbound); + println!( + "TCP packet {}: {:?} by rule (priority {:?})", + i, result.action, result.matched_rule + ); + } + + let stats = processor.get_stats(); + println!("\nStatistics:"); + println!( + " Cache hits: {}", + stats.get(&AclStatKey::CacheHits.as_str()).unwrap_or(&0) + ); + println!( + " Rule matches: {}", + stats.get(&AclStatKey::RuleMatches.as_str()).unwrap_or(&0) + ); + println!( + " Cache hit rate: {:.1}%", + processor.get_cache_hit_rate() * 100.0 + ); + + println!("\n✓ Stateful + rate-limited rules: Always processed for security"); + println!("✓ Simple rules: Cached for performance"); + println!( + "✓ Cache hit rate: {:.1}%", + processor.get_cache_hit_rate() * 100.0 + ); + } + + #[test] + fn test_rate_limit_drop_log_context() { + // Test that RateLimitDrop log context is properly created + let context = AclLogContext::RateLimitDrop; + let message = context.to_message(); + assert_eq!(message, "Rate limit drop"); + } + + #[tokio::test] + async fn test_rate_limit_drop_behavior() { + let mut acl_config = create_test_acl_config(); + + // Create a very restrictive rate-limited rule + if let Some(ref mut acl_v1) = acl_config.acl_v1 { + let mut rule = Rule::default(); + rule.name = "strict_rate_limit".to_string(); + rule.priority = 200; + rule.enabled = true; + rule.action = Action::Allow as i32; + rule.protocol = Protocol::Any as i32; + rule.rate_limit = 1; // Allow only 1 packet per second + rule.burst_limit = 1; // Burst of 1 packet + + acl_v1.chains[0].rules.push(rule); + } + + let processor = AclProcessor::new(acl_config); + let packet_info = create_test_packet_info(); + + // First request should be allowed + let result1 = processor.process_packet(&packet_info, ChainType::Inbound); + assert_eq!(result1.action, Action::Allow); + assert_eq!(result1.matched_rule, Some(RuleId::Priority(200))); + + // Second request should be rate limited and dropped immediately + let result2 = processor.process_packet(&packet_info, ChainType::Inbound); + assert_eq!(result2.action, Action::Drop); + assert_eq!(result2.matched_rule, Some(RuleId::Priority(200))); + assert!(!result2.should_log); + + // Verify the specific log context + assert!(matches!( + result2.log_context, + Some(AclLogContext::RateLimitDrop) + )); + } +} diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index d146d70..a2d8495 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -10,7 +10,10 @@ use cidr::IpCidr; use serde::{Deserialize, Serialize}; use crate::{ - proto::common::{CompressionAlgoPb, PortForwardConfigPb, SocketType}, + proto::{ + acl::Acl, + common::{CompressionAlgoPb, PortForwardConfigPb, SocketType}, + }, tunnel::generate_digest_from_str, }; @@ -116,6 +119,9 @@ pub trait ConfigLoader: Send + Sync { fn get_port_forwards(&self) -> Vec; fn set_port_forwards(&self, forwards: Vec); + fn get_acl(&self) -> Option; + fn set_acl(&self, acl: Option); + fn dump(&self) -> String; } @@ -291,6 +297,8 @@ struct Config { #[serde(skip)] flags_struct: Option, + + acl: Option, } #[derive(Debug, Clone)] @@ -649,6 +657,14 @@ impl ConfigLoader for TomlConfigLoader { self.config.lock().unwrap().port_forward = Some(forwards); } + fn get_acl(&self) -> Option { + self.config.lock().unwrap().acl.clone() + } + + fn set_acl(&self, acl: Option) { + self.config.lock().unwrap().acl = acl; + } + fn dump(&self) -> String { let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap(); let default_flags_hashmap = diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 9fe50da..98208fb 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -6,6 +6,7 @@ use std::{ use crate::common::config::ProxyNetworkConfig; use crate::common::token_bucket::TokenBucketManager; +use crate::peers::acl_filter::AclFilter; use crate::proto::cli::PeerConnInfo; use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb}; use crossbeam::atomic::AtomicCell; @@ -81,6 +82,8 @@ pub struct GlobalCtx { quic_proxy_port: AtomicCell>, token_bucket_manager: TokenBucketManager, + + acl_filter: Arc, } impl std::fmt::Debug for GlobalCtx { @@ -108,7 +111,7 @@ impl GlobalCtx { let stun_info_collection = Arc::new(StunInfoCollector::new_with_default_servers()); - let enable_exit_node = config_fs.get_flags().enable_exit_node || cfg!(target_env= "ohos"); + let enable_exit_node = config_fs.get_flags().enable_exit_node || cfg!(target_env = "ohos"); let proxy_forward_by_system = config_fs.get_flags().proxy_forward_by_system; let no_tun = config_fs.get_flags().no_tun; @@ -147,6 +150,8 @@ impl GlobalCtx { quic_proxy_port: AtomicCell::new(None), token_bucket_manager: TokenBucketManager::new(), + + acl_filter: Arc::new(AclFilter::new()), } } @@ -317,6 +322,10 @@ impl GlobalCtx { pub fn token_bucket_manager(&self) -> &TokenBucketManager { &self.token_bucket_manager } + + pub fn get_acl_filter(&self) -> &Arc { + &self.acl_filter + } } #[cfg(test)] diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index 308f60d..9bc5028 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -10,6 +10,7 @@ use tracing::Instrument; use crate::{set_global_var, use_global_var}; +pub mod acl_processor; pub mod compressor; pub mod config; pub mod constants; diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 45d3104..9d2c024 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -27,11 +27,12 @@ use easytier::{ }, proto::{ cli::{ - list_peer_route_pair, ConnectorManageRpc, ConnectorManageRpcClientFactory, - DumpRouteRequest, GetVpnPortalInfoRequest, ListConnectorRequest, - ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest, - ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse, - ManageMappedListenerRequest, MappedListenerManageAction, MappedListenerManageRpc, + list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, ConnectorManageRpc, + ConnectorManageRpcClientFactory, DumpRouteRequest, GetAclStatsRequest, + GetVpnPortalInfoRequest, ListConnectorRequest, ListForeignNetworkRequest, + ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest, + ListPeerResponse, ListRouteRequest, ListRouteResponse, ManageMappedListenerRequest, + MappedListenerManageAction, MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, PeerManageRpcClientFactory, ShowNodeInfoRequest, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc, @@ -93,6 +94,8 @@ enum SubCommand { Service(ServiceArgs), #[command(about = "show tcp/kcp proxy status")] Proxy, + #[command(about = "show ACL rules statistics")] + Acl(AclArgs), #[command(about = t!("core_clap.generate_completions").to_string())] GenAutocomplete { shell: Shell }, } @@ -179,6 +182,17 @@ struct NodeArgs { sub_command: Option, } +#[derive(Args, Debug)] +struct AclArgs { + #[command(subcommand)] + sub_command: Option, +} + +#[derive(Subcommand, Debug)] +enum AclSubCommand { + Stats, +} + #[derive(Args, Debug)] struct ServiceArgs { #[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")] @@ -301,6 +315,18 @@ impl CommandHandler<'_> { .with_context(|| "failed to get vpn portal client")?) } + async fn get_acl_manager_client( + &self, + ) -> Result>, Error> { + Ok(self + .client + .lock() + .unwrap() + .scoped_client::>("".to_string()) + .await + .with_context(|| "failed to get acl manager client")?) + } + async fn get_tcp_proxy_client( &self, transport_type: &str, @@ -688,6 +714,26 @@ impl CommandHandler<'_> { Ok(()) } + async fn handle_acl_stats(&self) -> Result<(), Error> { + let client = self.get_acl_manager_client().await?; + let request = GetAclStatsRequest::default(); + let response = client + .get_acl_stats(BaseController::default(), request) + .await?; + + if let Some(acl_stats) = response.acl_stats { + if self.output_format == &OutputFormat::Json { + println!("{}", serde_json::to_string_pretty(&acl_stats)?); + } else { + println!("{}", acl_stats); + } + } else { + println!("No ACL statistics available"); + } + + Ok(()) + } + async fn handle_mapped_listener_list(&self) -> Result<(), Error> { let client = self.get_mapped_listener_manager_client().await?; let request = ListMappedListenerRequest::default(); @@ -1443,6 +1489,11 @@ async fn main() -> Result<(), Error> { print_output(&table_rows, &cli.output_format)?; } + SubCommand::Acl(acl_args) => match &acl_args.sub_command { + Some(AclSubCommand::Stats) | None => { + handler.handle_acl_stats().await?; + } + }, SubCommand::GenAutocomplete { shell } => { let mut cmd = Cli::command(); easytier::print_completions(shell, &mut cmd, "easytier-cli"); diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 0d80837..371cae5 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -29,7 +29,10 @@ use easytier::{ connector::create_connector_by_url, instance_manager::NetworkInstanceManager, launcher::{add_proxy_network_to_config, ConfigSource}, - proto::common::{CompressionAlgoPb, NatType}, + proto::{ + acl::{Acl, AclV1, Action, Chain, ChainType, Protocol, Rule}, + common::{CompressionAlgoPb, NatType}, + }, tunnel::{IpVersion, PROTO_PORT_OFFSET}, utils::{init_logger, setup_panic_handler}, web_client, @@ -506,6 +509,22 @@ struct NetworkOptions { help = t!("core_clap.foreign_relay_bps_limit").to_string(), )] foreign_relay_bps_limit: Option, + + #[arg( + long, + value_delimiter = ',', + help = "TCP port whitelist. Supports single ports (80) and ranges (8000-9000)", + num_args = 0.. + )] + tcp_whitelist: Vec, + + #[arg( + long, + value_delimiter = ',', + help = "UDP port whitelist. Supports single ports (53) and ranges (5000-6000)", + num_args = 0.. + )] + udp_whitelist: Vec, } #[derive(Parser, Debug)] @@ -603,6 +622,117 @@ impl NetworkOptions { false } + fn parse_port_list(port_list: &[String]) -> anyhow::Result> { + let mut ports = Vec::new(); + + for port_spec in port_list { + if port_spec.contains('-') { + // Handle port range like "8000-9000" + let parts: Vec<&str> = port_spec.split('-').collect(); + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec)); + } + + let start: u16 = parts[0] + .parse() + .with_context(|| format!("Invalid start port in range: {}", port_spec))?; + let end: u16 = parts[1] + .parse() + .with_context(|| format!("Invalid end port in range: {}", port_spec))?; + + if start > end { + return Err(anyhow::anyhow!( + "Start port must be <= end port in range: {}", + port_spec + )); + } + + // Add individual ports in the range + for port in start..=end { + ports.push(port.to_string()); + } + } else { + // Handle single port + let port: u16 = port_spec + .parse() + .with_context(|| format!("Invalid port number: {}", port_spec))?; + ports.push(port.to_string()); + } + } + + Ok(ports) + } + + fn generate_acl_from_whitelists(&self) -> anyhow::Result> { + if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() { + return Ok(None); + } + + let mut acl = Acl { + acl_v1: Some(AclV1 { chains: vec![] }), + }; + + let acl_v1 = acl.acl_v1.as_mut().unwrap(); + + // Create inbound chain for whitelist rules + let mut inbound_chain = Chain { + name: "inbound_whitelist".to_string(), + chain_type: ChainType::Inbound as i32, + description: "Auto-generated inbound whitelist from CLI".to_string(), + enabled: true, + rules: vec![], + default_action: Action::Drop as i32, // Default deny + }; + + let mut rule_priority = 1000u32; + + // Add TCP whitelist rules + if !self.tcp_whitelist.is_empty() { + let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?; + let tcp_rule = Rule { + name: "tcp_whitelist".to_string(), + description: "Auto-generated TCP whitelist rule".to_string(), + priority: rule_priority, + enabled: true, + protocol: Protocol::Tcp as i32, + ports: tcp_ports, + source_ips: vec![], + destination_ips: vec![], + source_ports: vec![], + action: Action::Allow as i32, + rate_limit: 0, + burst_limit: 0, + stateful: true, + }; + inbound_chain.rules.push(tcp_rule); + rule_priority -= 1; + } + + // Add UDP whitelist rules + if !self.udp_whitelist.is_empty() { + let udp_ports = Self::parse_port_list(&self.udp_whitelist)?; + let udp_rule = Rule { + name: "udp_whitelist".to_string(), + description: "Auto-generated UDP whitelist rule".to_string(), + priority: rule_priority, + enabled: true, + protocol: Protocol::Udp as i32, + ports: udp_ports, + source_ips: vec![], + destination_ips: vec![], + source_ports: vec![], + action: Action::Allow as i32, + rate_limit: 0, + burst_limit: 0, + stateful: false, + }; + inbound_chain.rules.push(udp_rule); + } + + acl_v1.chains.push(inbound_chain); + Ok(Some(acl)) + } + fn merge_into(&self, cfg: &mut TomlConfigLoader) -> anyhow::Result<()> { if self.hostname.is_some() { cfg.set_hostname(self.hostname.clone()); @@ -860,6 +990,11 @@ impl NetworkOptions { cfg.set_exit_nodes(self.exit_nodes.clone()); } + // Handle port whitelists by generating ACL configuration + if let Some(acl) = self.generate_acl_from_whitelists()? { + cfg.set_acl(Some(acl)); + } + Ok(()) } } diff --git a/easytier/src/instance/dns_server/server_instance.rs b/easytier/src/instance/dns_server/server_instance.rs index 1e1f8f0..bbce458 100644 --- a/easytier/src/instance/dns_server/server_instance.rs +++ b/easytier/src/instance/dns_server/server_instance.rs @@ -298,7 +298,10 @@ impl NicPacketFilter for MagicDnsServerInstanceData { #[async_trait::async_trait] impl RpcServerHook for MagicDnsServerInstanceData { - async fn on_new_client(&self, tunnel_info: Option)-> Result, anyhow::Error> { + async fn on_new_client( + &self, + tunnel_info: Option, + ) -> Result, anyhow::Error> { tracing::info!(?tunnel_info, "New client connected"); Ok(tunnel_info) } diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 0741876..70dccb1 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -609,6 +609,10 @@ impl Instance { } } + if let Some(acl) = self.global_ctx.config.get_acl() { + self.global_ctx.get_acl_filter().reload_rules(Some(&acl)); + } + // run after tun device created, so listener can bind to tun device, which may be required by win 10 self.ip_proxy = Some(IpProxy::new( self.get_global_ctx(), @@ -801,10 +805,11 @@ impl Instance { let mapped_listener_manager_rpc = self.get_mapped_listener_manager_rpc_service(); let s = self.rpc_server.as_mut().unwrap(); - s.registry().register( - PeerManageRpcServer::new(PeerManagerRpcService::new(peer_mgr)), - "", - ); + let peer_mgr_rpc_service = PeerManagerRpcService::new(peer_mgr.clone()); + s.registry() + .register(PeerManageRpcServer::new(peer_mgr_rpc_service.clone()), ""); + s.registry() + .register(AclManageRpcServer::new(peer_mgr_rpc_service), ""); s.registry().register( ConnectorManageRpcServer::new(ConnectorManagerRpcService(conn_manager)), "", diff --git a/easytier/src/peers/acl_filter.rs b/easytier/src/peers/acl_filter.rs new file mode 100644 index 0000000..8687345 --- /dev/null +++ b/easytier/src/peers/acl_filter.rs @@ -0,0 +1,289 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::Ordering; +use std::{ + net::IpAddr, + sync::{atomic::AtomicBool, Arc}, +}; + +use arc_swap::ArcSwap; +use pnet::packet::ipv6::Ipv6Packet; +use pnet::packet::{ + ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet as _, +}; + +use crate::proto::acl::{AclStats, Protocol}; +use crate::tunnel::packet_def::PacketType; +use crate::{ + common::acl_processor::{AclProcessor, AclResult, AclStatKey, AclStatType, PacketInfo}, + proto::acl::{Acl, Action, ChainType}, + tunnel::packet_def::ZCPacket, +}; + +/// ACL filter that can be inserted into the packet processing pipeline +/// Optimized with lock-free hot reloading via atomic processor replacement +pub struct AclFilter { + // Use ArcSwap for lock-free atomic replacement during hot reload + acl_processor: ArcSwap, + acl_enabled: Arc, +} + +impl AclFilter { + pub fn new() -> Self { + Self { + acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))), + acl_enabled: Arc::new(AtomicBool::new(false)), + } + } + + /// Hot reload ACL rules by creating a new processor instance + /// Preserves connection tracking and rate limiting state across reloads + /// Now lock-free and doesn't require &mut self! + pub fn reload_rules(&self, acl_config: Option<&Acl>) { + let Some(acl_config) = acl_config else { + self.acl_enabled.store(false, Ordering::Relaxed); + return; + }; + + // Get current processor to extract shared state + let current_processor = self.acl_processor.load(); + let (conn_track, rate_limiters, stats) = current_processor.get_shared_state(); + + // Create new processor with preserved state + let new_processor = AclProcessor::new_with_shared_state( + acl_config.clone(), + Some(conn_track), + Some(rate_limiters), + Some(stats), + ); + + // Atomic replacement - this is completely lock-free! + self.acl_processor.store(Arc::new(new_processor)); + self.acl_enabled.store(true, Ordering::Relaxed); + + tracing::info!("ACL rules hot reloaded with preserved state (lock-free)"); + } + + /// Get current processor for processing packets + fn get_processor(&self) -> Arc { + self.acl_processor.load_full() + } + + pub fn get_stats(&self) -> AclStats { + let processor = self.get_processor(); + let global_stats = processor.get_stats(); + let (conn_track, _, _) = processor.get_shared_state(); + let rules_stats = processor.get_rules_stats(); + + AclStats { + global: global_stats.into_iter().map(|(k, v)| (k, v)).collect(), + conn_track: conn_track.iter().map(|x| x.value().clone()).collect(), + rules: rules_stats, + } + } + + /// Extract packet information for ACL processing + fn extract_packet_info(&self, packet: &ZCPacket) -> Option { + let payload = packet.payload(); + + let src_ip; + let dst_ip; + let src_port; + let dst_port; + let protocol; + + let ipv4_packet = Ipv4Packet::new(payload)?; + if ipv4_packet.get_version() == 4 { + src_ip = IpAddr::V4(ipv4_packet.get_source()); + dst_ip = IpAddr::V4(ipv4_packet.get_destination()); + protocol = ipv4_packet.get_next_level_protocol(); + + (src_port, dst_port) = match protocol { + IpNextHeaderProtocols::Tcp => { + let tcp_packet = TcpPacket::new(ipv4_packet.payload())?; + ( + Some(tcp_packet.get_source()), + Some(tcp_packet.get_destination()), + ) + } + IpNextHeaderProtocols::Udp => { + let udp_packet = UdpPacket::new(ipv4_packet.payload())?; + ( + Some(udp_packet.get_source()), + Some(udp_packet.get_destination()), + ) + } + _ => (None, None), + }; + } else if ipv4_packet.get_version() == 6 { + let ipv6_packet = Ipv6Packet::new(payload)?; + src_ip = IpAddr::V6(ipv6_packet.get_source()); + dst_ip = IpAddr::V6(ipv6_packet.get_destination()); + protocol = ipv6_packet.get_next_header(); + + (src_port, dst_port) = match protocol { + IpNextHeaderProtocols::Tcp => { + let tcp_packet = TcpPacket::new(ipv6_packet.payload())?; + ( + Some(tcp_packet.get_source()), + Some(tcp_packet.get_destination()), + ) + } + IpNextHeaderProtocols::Udp => { + let udp_packet = UdpPacket::new(ipv6_packet.payload())?; + ( + Some(udp_packet.get_source()), + Some(udp_packet.get_destination()), + ) + } + _ => (None, None), + }; + } else { + return None; + } + + let acl_protocol = match protocol { + IpNextHeaderProtocols::Tcp => Protocol::Tcp, + IpNextHeaderProtocols::Udp => Protocol::Udp, + IpNextHeaderProtocols::Icmp => Protocol::Icmp, + IpNextHeaderProtocols::Icmpv6 => Protocol::IcmPv6, + _ => Protocol::Unspecified, + }; + + Some(PacketInfo { + src_ip, + dst_ip, + src_port, + dst_port, + protocol: acl_protocol, + packet_size: payload.len(), + }) + } + + /// Process ACL result and log if needed + fn handle_acl_result( + &self, + result: &AclResult, + packet_info: &PacketInfo, + chain_type: ChainType, + processor: &AclProcessor, + ) { + if result.should_log { + if let Some(ref log_context) = result.log_context { + let log_message = log_context.to_message(); + tracing::info!( + src_ip = %packet_info.src_ip, + dst_ip = %packet_info.dst_ip, + src_port = packet_info.src_port, + dst_port = packet_info.dst_port, + protocol = ?packet_info.protocol, + action = ?result.action, + rule = result.matched_rule_str().as_deref().unwrap_or("unknown"), + chain_type = ?chain_type, + "ACL: {}", log_message + ); + } + } + + // Update global statistics in the ACL processor + match result.action { + Action::Allow => { + processor.increment_stat(AclStatKey::PacketsAllowed); + processor.increment_stat(AclStatKey::from_chain_and_action( + chain_type, + AclStatType::Allowed, + )); + tracing::trace!("ACL: Packet allowed"); + } + Action::Drop => { + processor.increment_stat(AclStatKey::PacketsDropped); + processor.increment_stat(AclStatKey::from_chain_and_action( + chain_type, + AclStatType::Dropped, + )); + tracing::debug!("ACL: Packet dropped"); + } + Action::Noop => { + processor.increment_stat(AclStatKey::PacketsNoop); + processor.increment_stat(AclStatKey::from_chain_and_action( + chain_type, + AclStatType::Noop, + )); + tracing::trace!("ACL: No operation"); + } + } + + // Track total packets processed per chain + processor.increment_stat(AclStatKey::from_chain_and_action( + chain_type, + AclStatType::Total, + )); + processor.increment_stat(AclStatKey::PacketsTotal); + } + + /// Common ACL processing logic + pub fn process_packet_with_acl( + &self, + packet: &ZCPacket, + is_in: bool, + my_ipv4: Option, + my_ipv6: Option, + ) -> bool { + if !self.acl_enabled.load(Ordering::Relaxed) { + return true; + } + + if packet.peer_manager_header().unwrap().packet_type != PacketType::Data as u8 { + return true; + } + + // Extract packet information + let packet_info = match self.extract_packet_info(packet) { + Some(info) => info, + None => { + tracing::warn!( + "Failed to extract packet info from {:?} packet, header: {:?}", + if is_in { "inbound" } else { "outbound" }, + packet.peer_manager_header() + ); + // allow all unknown packets + return true; + } + }; + + let chain_type = if is_in { + if packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED) + || packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED) + { + ChainType::Inbound + } else { + ChainType::Forward + } + } else { + ChainType::Outbound + }; + + // Get current processor atomically + let processor = self.get_processor(); + + // Process through ACL rules + let acl_result = processor.process_packet(&packet_info, chain_type); + + self.handle_acl_result(&acl_result, &packet_info, chain_type, &processor); + + // Check if packet should be allowed + match acl_result.action { + Action::Allow | Action::Noop => true, + Action::Drop => { + tracing::trace!( + "ACL: Dropping {:?} packet from {} to {}, chain_type: {:?}", + packet_info.protocol, + packet_info.src_ip, + packet_info.dst_ip, + chain_type, + ); + + false + } + } + } +} diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index c3d9e1a..fcccfe0 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -1,5 +1,6 @@ mod graph_algo; +pub mod acl_filter; pub mod peer; // pub mod peer_conn; pub mod peer_conn; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 91cabf9..6de92f1 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -573,6 +573,8 @@ impl PeerManager { let foreign_mgr = self.foreign_network_manager.clone(); let encryptor = self.encryptor.clone(); let compress_algo = self.data_compress_algo; + let acl_filter = self.global_ctx.get_acl_filter().clone(); + let global_ctx = self.global_ctx.clone(); self.tasks.lock().await.spawn(async move { tracing::trace!("start_peer_recv"); while let Ok(ret) = recv_packet_from_chan(&mut recv).await { @@ -631,6 +633,15 @@ impl PeerManager { continue; } + if !acl_filter.process_packet_with_acl( + &ret, + true, + global_ctx.get_ipv4().map(|x| x.address()), + global_ctx.get_ipv6().map(|x| x.address()), + ) { + continue; + } + let mut processed = false; let mut zc_packet = Some(ret); let mut idx = 0; @@ -845,6 +856,14 @@ impl PeerManager { } async fn run_nic_packet_process_pipeline(&self, data: &mut ZCPacket) { + if !self + .global_ctx + .get_acl_filter() + .process_packet_with_acl(data, false, None, None) + { + return; + } + for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() { let _ = pipeline.try_process_packet_from_nic(data).await; } diff --git a/easytier/src/peers/rpc_service.rs b/easytier/src/peers/rpc_service.rs index 9e58894..e9913cc 100644 --- a/easytier/src/peers/rpc_service.rs +++ b/easytier/src/peers/rpc_service.rs @@ -2,10 +2,10 @@ use std::sync::Arc; use crate::proto::{ cli::{ - DumpRouteRequest, DumpRouteResponse, ListForeignNetworkRequest, ListForeignNetworkResponse, - ListGlobalForeignNetworkRequest, ListGlobalForeignNetworkResponse, ListPeerRequest, - ListPeerResponse, ListRouteRequest, ListRouteResponse, PeerInfo, PeerManageRpc, - ShowNodeInfoRequest, ShowNodeInfoResponse, + AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, GetAclStatsResponse, + ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest, + ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, + ListRouteResponse, PeerInfo, PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse, }, rpc_types::{self, controller::BaseController}, }; @@ -134,3 +134,23 @@ impl PeerManageRpc for PeerManagerRpcService { }) } } + +#[async_trait::async_trait] +impl AclManageRpc for PeerManagerRpcService { + type Controller = BaseController; + + async fn get_acl_stats( + &self, + _: BaseController, + _request: GetAclStatsRequest, + ) -> Result { + let acl_stats = self + .peer_manager + .get_global_ctx() + .get_acl_filter() + .get_stats(); + Ok(GetAclStatsResponse { + acl_stats: Some(acl_stats), + }) + } +} diff --git a/easytier/src/proto/acl.proto b/easytier/src/proto/acl.proto new file mode 100644 index 0000000..393fc74 --- /dev/null +++ b/easytier/src/proto/acl.proto @@ -0,0 +1,127 @@ +syntax = "proto3"; + +import "common.proto"; + +package acl; + +// Enhanced protocol enum with more granular options +enum Protocol { + Unspecified = 0; + TCP = 1; + UDP = 2; + ICMP = 3; + ICMPv6 = 4; + Any = 5; +} + +enum Action { + Noop = 0; + Allow = 1; + Drop = 2; // Silent drop (no response) +} + +enum ChainType { + UnspecifiedChain = 0; + // send to this node + Inbound = 1; + // send from this node + Outbound = 2; + // subnet proxy + Forward = 3; +} + +// Time-based access control +message TimeWindow { + // Days of week: 0=Sunday, 1=Monday, ..., 6=Saturday + repeated uint32 days_of_week = 1; + // Time in minutes from midnight (0-1439) + uint32 start_time = 2; + uint32 end_time = 3; + // Timezone offset in minutes from UTC + int32 timezone_offset = 4; +} + +// Enhanced rule with priority and metadata +message Rule { + // Rule identification and metadata + string name = 1; // Human-readable rule name + string description = 2; // Rule description + uint32 priority = 3; // Higher number = higher priority (0-65535) + bool enabled = 4; // Rule enabled/disabled state + + // Core matching criteria + Protocol protocol = 5; + repeated string ports = 6; + repeated string source_ips = 7; // Source IP ranges + repeated string destination_ips = 8; // Destination IP ranges + + // Enhanced matching criteria + repeated string source_ports = 9; // Source port range + + // Action and logging + Action action = 10; + + // Rate limiting (packets per second) + uint32 rate_limit = 11; // 0 = no limit + uint32 burst_limit = 12; // Burst allowance + + // Connection tracking + bool stateful = 13; // Enable connection tracking +} + +// Rule chain with metadata and optimization hints +message Chain { + // Chain identification + string name = 1; // Human-readable chain name + ChainType chain_type = 2; + string description = 3; // Chain description + bool enabled = 4; // Chain enabled/disabled state + + // Rules in priority order (highest priority first) + repeated Rule rules = 5; + + // Default action when no rules match + Action default_action = 6; +} + +message AclV1 { repeated Chain chains = 1; } + +enum ConnState { + New = 0; + Established = 1; + Related = 2; + Invalid = 3; +} + +// Connection tracking entry for stateful ACLs +message ConnTrackEntry { + common.SocketAddr src_addr = 1; + common.SocketAddr dst_addr = 2; + Protocol protocol = 3; // IP protocol number (e.g., 6 = TCP, 17 = UDP) + ConnState state = 4; + uint64 created_at = 5; // Unix timestamp (seconds) + uint64 last_seen = 6; // Unix timestamp (seconds) + uint64 packet_count = 7; + uint64 byte_count = 8; +} + +// Top-level ACL configuration +message Acl { + AclV1 acl_v1 = 2; +} + +message StatItem { + uint64 packet_count = 1; + uint64 byte_count = 2; +} + +message RuleStats { + Rule rule = 1; + StatItem stat = 2; +} + +message AclStats { + repeated RuleStats rules = 1; + repeated ConnTrackEntry conn_track = 2; + map global = 3; +} diff --git a/easytier/src/proto/acl.rs b/easytier/src/proto/acl.rs new file mode 100644 index 0000000..6948ba3 --- /dev/null +++ b/easytier/src/proto/acl.rs @@ -0,0 +1,95 @@ +use std::fmt::Display; + +include!(concat!(env!("OUT_DIR"), "/acl.rs")); + +impl Display for ConnTrackEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let src = self + .src_addr + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| "-".to_string()); + let dst = self + .dst_addr + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| "-".to_string()); + let last_seen = chrono::DateTime::::from_timestamp(self.last_seen as i64, 0) + .unwrap() + .with_timezone(&chrono::Local); + let created_at = chrono::DateTime::::from_timestamp(self.created_at as i64, 0) + .unwrap() + .with_timezone(&chrono::Local); + write!( + f, + "[src: {}, dst: {}, proto: {:?}, state: {:?}, pkts: {}, bytes: {}, created: {}, last_seen: {}]", + src, + dst, + Protocol::try_from(self.protocol).unwrap_or(Protocol::Unspecified), + ConnState::try_from(self.state).unwrap_or(ConnState::Invalid), + self.packet_count, + self.byte_count, + created_at, + last_seen + ) + } +} + +impl Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[name: '{}', prio: {}, action: {:?}, enabled: {}, proto: {:?}, ports: {:?}, src_ports: {:?}, src_ips: {:?}, dst_ips: {:?}, stateful: {}, rate: {}, burst: {}]", + self.name, + self.priority, + Action::try_from(self.action).unwrap_or(Action::Noop), + self.enabled, + Protocol::try_from(self.protocol).unwrap_or(Protocol::Unspecified), + self.ports, + self.source_ports, + self.source_ips, + self.destination_ips, + self.stateful, + self.rate_limit, + self.burst_limit + ) + } +} + +impl Display for StatItem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[pkts: {}, bytes: {}]", + self.packet_count, self.byte_count + ) + } +} + +impl Display for AclStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "AclStats:")?; + writeln!(f, " Global:")?; + for (k, v) in &self.global { + writeln!(f, " {}: {}", k, v)?; + } + writeln!(f, " ConnTrack:")?; + for entry in &self.conn_track { + writeln!(f, " {}", entry)?; + } + writeln!(f, " Rules:")?; + for rule_stat in &self.rules { + if let Some(rule) = &rule_stat.rule { + write!(f, " {} ", rule)?; + } else { + write!(f, " ")?; + } + if let Some(stat) = &rule_stat.stat { + writeln!(f, "{}", stat)?; + } else { + writeln!(f)?; + } + } + Ok(()) + } +} diff --git a/easytier/src/proto/cli.proto b/easytier/src/proto/cli.proto index c73205d..847e48a 100644 --- a/easytier/src/proto/cli.proto +++ b/easytier/src/proto/cli.proto @@ -2,6 +2,7 @@ syntax = "proto3"; import "common.proto"; import "peer_rpc.proto"; +import "acl.proto"; package cli; @@ -251,3 +252,13 @@ service TcpProxyRpc { rpc ListTcpProxyEntry(ListTcpProxyEntryRequest) returns (ListTcpProxyEntryResponse); } + +message GetAclStatsRequest {} + +message GetAclStatsResponse { + acl.AclStats acl_stats = 1; +} + +service AclManageRpc { + rpc GetAclStats(GetAclStatsRequest) returns (GetAclStatsResponse); +} diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index c6378e4..ad9da5b 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -18,7 +18,8 @@ message FlagsInConfig { bool disable_p2p = 11; bool relay_all_peer_rpc = 12; bool disable_udp_hole_punching = 13; - // string ipv6_listener = 14; [deprecated = true]; use -l udp://[::]:12345 instead + // string ipv6_listener = 14; [deprecated = true]; use -l udp://[::]:12345 + // instead bool multi_thread = 15; CompressionAlgoPb data_compress_algo = 16; bool bind_device = 17; @@ -144,6 +145,13 @@ message Ipv6Inet { uint32 network_length = 2; } +message IpInet { + oneof ip { + Ipv4Inet ipv4 = 1; + Ipv6Inet ipv6 = 2; + }; +} + message Url { string url = 1; } message SocketAddr { @@ -173,7 +181,7 @@ message PeerFeatureFlag { bool is_public_server = 1; bool avoid_relay_data = 2; bool kcp_input = 3; - bool no_relay_kcp = 4; + bool no_relay_kcp = 4; } enum SocketType { @@ -182,17 +190,17 @@ enum SocketType { } message PortForwardConfigPb { - SocketAddr bind_addr = 1; - SocketAddr dst_addr = 2; - SocketType socket_type = 3; + SocketAddr bind_addr = 1; + SocketAddr dst_addr = 2; + SocketType socket_type = 3; } -message ProxyDstInfo { - SocketAddr dst_addr = 1; -} +message ProxyDstInfo { SocketAddr dst_addr = 1; } message LimiterConfig { - optional uint64 burst_rate = 1; // default 1 means no burst (capacity is same with bps) + optional uint64 burst_rate = + 1; // default 1 means no burst (capacity is same with bps) optional uint64 bps = 2; // default 0 means no limit (unit is B/s) - optional uint64 fill_duration_ms = 3; // default 10ms, the period to fill the bucket + optional uint64 fill_duration_ms = + 3; // default 10ms, the period to fill the bucket } diff --git a/easytier/src/proto/common.rs b/easytier/src/proto/common.rs index 07d3d1c..e30a7d7 100644 --- a/easytier/src/proto/common.rs +++ b/easytier/src/proto/common.rs @@ -1,4 +1,7 @@ -use std::{fmt, str::FromStr}; +use std::{ + fmt::{self, Display}, + str::FromStr, +}; use anyhow::Context; @@ -166,6 +169,43 @@ impl FromStr for Ipv6Inet { } } +impl From for IpInet { + fn from(value: cidr::IpInet) -> Self { + match value { + cidr::IpInet::V4(v4) => IpInet { + ip: Some(ip_inet::Ip::Ipv4(Ipv4Inet::from(v4))), + }, + cidr::IpInet::V6(v6) => IpInet { + ip: Some(ip_inet::Ip::Ipv6(Ipv6Inet::from(v6))), + }, + } + } +} + +impl From for cidr::IpInet { + fn from(value: IpInet) -> Self { + match value.ip { + Some(ip_inet::Ip::Ipv4(v4)) => cidr::IpInet::V4(v4.into()), + Some(ip_inet::Ip::Ipv6(v6)) => cidr::IpInet::V6(v6.into()), + None => panic!("IpInet is None"), + } + } +} + +impl Display for IpInet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", cidr::IpInet::from(self.clone())) + } +} + +impl FromStr for IpInet { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Ok(IpInet::from(cidr::IpInet::from_str(s)?)) + } +} + impl From for Url { fn from(value: url::Url) -> Self { Url { diff --git a/easytier/src/proto/mod.rs b/easytier/src/proto/mod.rs index 51fe99a..fd8b455 100644 --- a/easytier/src/proto/mod.rs +++ b/easytier/src/proto/mod.rs @@ -1,6 +1,7 @@ pub mod rpc_impl; pub mod rpc_types; +pub mod acl; pub mod cli; pub mod common; pub mod error; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 901d67d..ba407ea 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -1328,3 +1328,183 @@ async fn avoid_tunnel_loop_back_to_virtual_network() { drop_insts(insts).await; } + +#[tokio::test] +#[serial_test::serial] +pub async fn acl_rule_test_inbound() { + use crate::tunnel::{ + common::tests::_tunnel_pingpong_netns, + tcp::{TcpTunnelConnector, TcpTunnelListener}, + udp::{UdpTunnelConnector, UdpTunnelListener}, + }; + use rand::Rng; + let insts = init_three_node("udp").await; + + // 构造 ACL 配置 + use crate::proto::acl::*; + let mut acl = Acl::default(); + let mut acl_v1 = AclV1::default(); + + let mut chain = Chain::default(); + chain.name = "test_inbound".to_string(); + chain.chain_type = ChainType::Inbound as i32; + chain.enabled = true; + + // 禁止 8080 + let mut deny_rule = Rule::default(); + deny_rule.name = "deny_8080".to_string(); + deny_rule.priority = 200; + deny_rule.enabled = true; + deny_rule.action = Action::Drop as i32; + deny_rule.protocol = Protocol::Any as i32; + deny_rule.ports = vec!["8080".to_string()]; + chain.rules.push(deny_rule); + + // 允许其他 + let mut allow_rule = Rule::default(); + allow_rule.name = "allow_all".to_string(); + allow_rule.priority = 100; + allow_rule.enabled = true; + allow_rule.action = Action::Allow as i32; + allow_rule.protocol = Protocol::Any as i32; + allow_rule.stateful = true; + chain.rules.push(allow_rule); + + // 禁止 src ip 为 10.144.144.2 的流量 + let mut deny_rule = Rule::default(); + deny_rule.name = "deny_10.144.144.2".to_string(); + deny_rule.priority = 200; + deny_rule.enabled = true; + deny_rule.action = Action::Drop as i32; + deny_rule.protocol = Protocol::Any as i32; + deny_rule.source_ips = vec!["10.144.144.2/32".to_string()]; + chain.rules.push(deny_rule); + + acl_v1.chains.push(chain); + acl.acl_v1 = Some(acl_v1); + + // convert acl to to toml + let acl_toml = toml::to_string(&acl).unwrap(); + println!("ACL TOML: {}", acl_toml); + + insts[2] + .get_global_ctx() + .get_acl_filter() + .reload_rules(Some(&acl)); + + // TCP 测试部分 + { + // 2. 在 inst2 上监听 8080 和 8081 + let listener_8080 = TcpTunnelListener::new("tcp://0.0.0.0:8080".parse().unwrap()); + let listener_8081 = TcpTunnelListener::new("tcp://0.0.0.0:8081".parse().unwrap()); + let listener_8082 = TcpTunnelListener::new("tcp://0.0.0.0:8082".parse().unwrap()); + + // 3. inst1 作为客户端,尝试连接 inst2 的 8080(应被拒绝)和 8081(应被允许) + let connector_8080 = + TcpTunnelConnector::new(format!("tcp://{}:8080", "10.144.144.3").parse().unwrap()); + let connector_8081 = + TcpTunnelConnector::new(format!("tcp://{}:8081", "10.144.144.3").parse().unwrap()); + let connector_8082 = + TcpTunnelConnector::new(format!("tcp://{}:8082", "10.144.144.3").parse().unwrap()); + + // 4. 构造测试数据 + let mut buf = vec![0; 32]; + rand::thread_rng().fill(&mut buf[..]); + + // 5. 8081 应该可以 pingpong 成功 + _tunnel_pingpong_netns( + listener_8081, + connector_8081, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + ) + .await; + + // 6. 8080 应该连接失败(被 ACL 拦截) + let result = tokio::time::timeout( + std::time::Duration::from_millis(200), + _tunnel_pingpong_netns( + listener_8080, + connector_8080, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + ), + ) + .await; + + assert!(result.is_err(), "TCP 连接 8080 应被 ACL 拦截,不能成功"); + + // 7. 从 10.144.144.2 连接 8082 应该连接失败(被 ACL 拦截) + let result = tokio::time::timeout( + std::time::Duration::from_millis(200), + _tunnel_pingpong_netns( + listener_8082, + connector_8082, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_b".into())), + buf.clone(), + ), + ) + .await; + + assert!(result.is_err(), "TCP 连接 8082 应被 ACL 拦截,不能成功"); + + let stats = insts[2].get_global_ctx().get_acl_filter().get_stats(); + println!("stats: {:?}", stats); + } + + // UDP 测试部分 + { + // 1. 在 inst2 上监听 UDP 8080 和 8081 + let listener_8080 = UdpTunnelListener::new("udp://0.0.0.0:8080".parse().unwrap()); + let listener_8081 = UdpTunnelListener::new("udp://0.0.0.0:8081".parse().unwrap()); + + // 2. inst1 作为客户端,尝试连接 inst2 的 8080(应被拒绝)和 8081(应被允许) + let connector_8080 = + UdpTunnelConnector::new(format!("udp://{}:8080", "10.144.144.3").parse().unwrap()); + let connector_8081 = + UdpTunnelConnector::new(format!("udp://{}:8081", "10.144.144.3").parse().unwrap()); + + // 3. 构造测试数据 + let mut buf = vec![0; 32]; + rand::thread_rng().fill(&mut buf[..]); + + // 4. 8081 应该可以 pingpong 成功 + _tunnel_pingpong_netns( + listener_8081, + connector_8081, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + ) + .await; + + // 5. 8080 应该连接失败(被 ACL 拦截) + let result = tokio::time::timeout( + std::time::Duration::from_millis(200), + _tunnel_pingpong_netns( + listener_8080, + connector_8080, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + ), + ) + .await; + + assert!(result.is_err(), "UDP 连接 8080 应被 ACL 拦截,不能成功"); + + let stats = insts[2].get_global_ctx().get_acl_filter().get_stats(); + println!("stats: {}", stats); + } + + // remove acl, 8080 should succ + insts[2] + .get_global_ctx() + .get_acl_filter() + .reload_rules(None); + + drop_insts(insts).await; +}