From 8440eb842b37ea1d94f29ac777dbfc325bfd32c7 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Tue, 7 May 2024 00:38:05 +0800 Subject: [PATCH] fix bugs and improve user experiance (#86) * correctly set mtu, and allow set mtu manually * communicate between enc and non-enc should not panic * allow loading cfg from file * allow change file log level dynamically --- easytier/src/common/config.rs | 7 +- easytier/src/easytier-core.rs | 100 +++++++----------- easytier/src/instance/virtual_nic.rs | 8 ++ easytier/src/peers/encrypt/aes_gcm.rs | 2 +- easytier/src/peers/encrypt/mod.rs | 11 +- easytier/src/peers/encrypt/ring_aes_gcm.rs | 2 +- easytier/src/peers/peer_manager.rs | 42 +++++++- easytier/src/utils.rs | 115 ++++++++++++++++++++- 8 files changed, 211 insertions(+), 76 deletions(-) diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index 3501300..5564a95 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -1,5 +1,6 @@ use std::{ net::SocketAddr, + path::PathBuf, sync::{Arc, Mutex}, }; @@ -144,6 +145,8 @@ pub struct Flags { pub enable_encryption: bool, #[derivative(Default(value = "true"))] pub enable_ipv6: bool, + #[derivative(Default(value = "1420"))] + pub mtu: u16, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] @@ -192,9 +195,9 @@ impl TomlConfigLoader { }) } - pub fn new(config_path: &str) -> Result { + pub fn new(config_path: &PathBuf) -> Result { let config_str = std::fs::read_to_string(config_path) - .with_context(|| format!("failed to read config file: {}", config_path))?; + .with_context(|| format!("failed to read config file: {:?}", config_path))?; Self::new_from_str(&config_str) } } diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 72d662d..3bbc25f 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -3,7 +3,7 @@ #[cfg(test)] mod tests; -use std::{backtrace, io::Write as _, net::SocketAddr}; +use std::{backtrace, io::Write as _, net::SocketAddr, path::PathBuf}; use anyhow::Context; use clap::Parser; @@ -17,19 +17,20 @@ mod peer_center; mod peers; mod rpc; mod tunnel; +mod utils; mod vpn_portal; -use common::{ - config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig}, - get_logger_timer_rfc3339, +use common::config::{ + ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig, }; use instance::instance::Instance; -use tracing::level_filters::LevelFilter; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; -use crate::common::{ - config::{ConfigLoader, TomlConfigLoader}, - global_ctx::GlobalCtxEvent, +use crate::{ + common::{ + config::{ConfigLoader, TomlConfigLoader}, + global_ctx::GlobalCtxEvent, + }, + utils::init_logger, }; #[cfg(feature = "mimalloc")] @@ -42,6 +43,13 @@ static GLOBAL_MIMALLOC: GlobalMiMalloc = GlobalMiMalloc; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { + #[arg( + short, + long, + help = "path to the config file, NOTE: if this is set, all other options will be ignored" + )] + config_file: Option, + #[arg( long, help = "network name to identify this vpn network", @@ -146,11 +154,28 @@ and the vpn client is in network of 10.14.14.0/24" #[arg(long, help = "do not use ipv6", default_value = "false")] disable_ipv6: bool, + + #[arg( + long, + help = "mtu of the TUN device, default is 1420 for non-encryption, 1400 for encryption" + )] + mtu: Option, } impl From for TomlConfigLoader { fn from(cli: Cli) -> Self { + if let Some(config_file) = &cli.config_file { + println!( + "NOTICE: loading config file: {:?}, will ignore all command line flags\n", + config_file + ); + return TomlConfigLoader::new(config_file) + .with_context(|| format!("failed to load config file: {:?}", cli.config_file)) + .unwrap(); + } + let cfg = TomlConfigLoader::default(); + cfg.set_inst_name(cli.instance_name.clone()); cfg.set_network_identity(NetworkIdentity::new( cli.network_name.clone(), @@ -276,64 +301,15 @@ impl From for TomlConfigLoader { } f.enable_encryption = !cli.disable_encryption; f.enable_ipv6 = !cli.disable_ipv6; + if let Some(mtu) = cli.mtu { + f.mtu = mtu; + } cfg.set_flags(f); cfg } } -fn init_logger(config: impl ConfigLoader) { - let file_config = config.get_file_logger_config(); - let file_level = file_config - .level - .map(|s| s.parse().unwrap()) - .unwrap_or(LevelFilter::OFF); - - // logger to rolling file - let mut file_layer = None; - if file_level != LevelFilter::OFF { - let mut l = tracing_subscriber::fmt::layer(); - l.set_ansi(false); - let file_filter = EnvFilter::builder() - .with_default_directive(file_level.into()) - .from_env() - .unwrap(); - let file_appender = tracing_appender::rolling::Builder::new() - .rotation(tracing_appender::rolling::Rotation::DAILY) - .max_log_files(5) - .filename_prefix(file_config.file.unwrap_or("easytier".to_string())) - .build(file_config.dir.unwrap_or("./".to_string())) - .expect("failed to initialize rolling file appender"); - file_layer = Some( - l.with_writer(file_appender) - .with_timer(get_logger_timer_rfc3339()) - .with_filter(file_filter), - ); - } - - // logger to console - let console_config = config.get_console_logger_config(); - let console_level = console_config - .level - .map(|s| s.parse().unwrap()) - .unwrap_or(LevelFilter::OFF); - - let console_filter = EnvFilter::builder() - .with_default_directive(console_level.into()) - .from_env() - .unwrap(); - let console_layer = tracing_subscriber::fmt::layer() - .pretty() - .with_timer(get_logger_timer_rfc3339()) - .with_writer(std::io::stderr) - .with_filter(console_filter); - - tracing_subscriber::Registry::default() - .with(console_layer) - .with(file_layer) - .init(); -} - fn print_event(msg: String) { println!( "{}: {}", @@ -363,7 +339,7 @@ fn setup_panic_handler() { pub async fn async_main(cli: Cli) { let cfg: TomlConfigLoader = cli.into(); - init_logger(&cfg); + init_logger(&cfg, false).unwrap(); let mut inst = Instance::new(cfg.clone()); let mut events = inst.get_global_ctx().subscribe(); diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index e6b7238..3b4b69c 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -286,6 +286,14 @@ impl VirtualNic { todo!("queue_num != 1") } config.queues(self.queue_num); + + let flags = self.global_ctx.config.get_flags(); + let mut mtu_in_config = flags.mtu; + if flags.enable_encryption { + mtu_in_config -= 20; + } + + config.mtu(mtu_in_config as i32); config.up(); let dev = { diff --git a/easytier/src/peers/encrypt/aes_gcm.rs b/easytier/src/peers/encrypt/aes_gcm.rs index 886b470..ab39698 100644 --- a/easytier/src/peers/encrypt/aes_gcm.rs +++ b/easytier/src/peers/encrypt/aes_gcm.rs @@ -38,7 +38,7 @@ impl Encryptor for AesGcmCipher { fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { let pm_header = zc_packet.peer_manager_header().unwrap(); if !pm_header.is_encrypted() { - return Err(Error::NotEcrypted); + return Ok(()); } let payload_len = zc_packet.payload().len(); diff --git a/easytier/src/peers/encrypt/mod.rs b/easytier/src/peers/encrypt/mod.rs index c63dee7..9bdb485 100644 --- a/easytier/src/peers/encrypt/mod.rs +++ b/easytier/src/peers/encrypt/mod.rs @@ -8,8 +8,6 @@ pub mod aes_gcm; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("packet is not encrypted")] - NotEcrypted, #[error("packet is too short. len: {0}")] PacketTooShort(usize), #[error("decryption failed")] @@ -32,7 +30,12 @@ impl Encryptor for NullCipher { Ok(()) } - fn decrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> { - Ok(()) + fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { + let pm_header = zc_packet.peer_manager_header().unwrap(); + if pm_header.is_encrypted() { + return Err(Error::DecryptionFailed); + } else { + Ok(()) + } } } diff --git a/easytier/src/peers/encrypt/ring_aes_gcm.rs b/easytier/src/peers/encrypt/ring_aes_gcm.rs index 8878e13..603c25b 100644 --- a/easytier/src/peers/encrypt/ring_aes_gcm.rs +++ b/easytier/src/peers/encrypt/ring_aes_gcm.rs @@ -54,7 +54,7 @@ impl Encryptor for AesGcmCipher { fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { let pm_header = zc_packet.peer_manager_header().unwrap(); if !pm_header.is_encrypted() { - return Err(Error::NotEcrypted); + return Ok(()); } let payload_len = zc_packet.payload().len(); diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 0174828..6f861c5 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -338,11 +338,9 @@ impl PeerManager { tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } } else { - if let Err(e) = encryptor - .decrypt(&mut ret) - .with_context(|| "decrypt failed") - { + if let Err(e) = encryptor.decrypt(&mut ret) { tracing::error!(?e, "decrypt failed"); + continue; } let mut processed = false; @@ -680,14 +678,16 @@ impl PeerManager { #[cfg(test)] mod tests { - use std::{fmt::Debug, sync::Arc}; + use std::{fmt::Debug, sync::Arc, time::Duration}; use crate::{ + common::{config::Flags, global_ctx::tests::get_mock_global_ctx}, connector::{ create_connector_by_url, udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, }, instance::listeners::get_listener_by_url, peers::{ + peer_manager::RouteAlgoType, peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient}, tests::{connect_peer_manager, wait_for_condition, wait_route_appear}, }, @@ -822,4 +822,36 @@ mod tests { .unwrap(); assert_eq!(ret, "hello c abc"); } + + #[tokio::test] + async fn communicate_between_enc_and_non_enc() { + let create_mgr = |enable_encryption| async move { + let (s, _r) = tokio::sync::mpsc::channel(1000); + let mock_global_ctx = get_mock_global_ctx(); + mock_global_ctx.config.set_flags(Flags { + enable_encryption, + ..Default::default() + }); + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, mock_global_ctx, s)); + peer_mgr.run().await.unwrap(); + peer_mgr + }; + + let peer_mgr_a = create_mgr(true).await; + let peer_mgr_b = create_mgr(false).await; + + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; + + // wait 5sec should not crash. + tokio::time::sleep(Duration::from_secs(5)).await; + + // both mgr should alive + let mgr_c = create_mgr(true).await; + connect_peer_manager(peer_mgr_a.clone(), mgr_c.clone()).await; + wait_route_appear(mgr_c, peer_mgr_a).await.unwrap(); + + let mgr_d = create_mgr(false).await; + connect_peer_manager(peer_mgr_b.clone(), mgr_d.clone()).await; + wait_route_appear(mgr_d, peer_mgr_b).await.unwrap(); + } } diff --git a/easytier/src/utils.rs b/easytier/src/utils.rs index be3f047..7707e58 100644 --- a/easytier/src/utils.rs +++ b/easytier/src/utils.rs @@ -1,6 +1,13 @@ +use anyhow::Context; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; -use crate::rpc::cli::{NatType, PeerInfo, Route}; +use crate::{ + common::{config::ConfigLoader, get_logger_timer_rfc3339}, + rpc::cli::{NatType, PeerInfo, Route}, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PeerRoutePair { @@ -128,3 +135,109 @@ pub fn cost_to_str(cost: i32) -> String { pub fn float_to_str(f: f64, precision: usize) -> String { format!("{:.1$}", f, precision) } + +pub type NewFilterSender = mpsc::UnboundedSender; + +pub fn init_logger( + config: impl ConfigLoader, + need_reload: bool, +) -> Result, anyhow::Error> { + let file_config = config.get_file_logger_config(); + let file_level = file_config + .level + .map(|s| s.parse().unwrap()) + .unwrap_or(LevelFilter::OFF); + + let mut ret_sender: Option = None; + + // logger to rolling file + let mut file_layer = None; + if file_level != LevelFilter::OFF || need_reload { + let mut l = tracing_subscriber::fmt::layer(); + l.set_ansi(false); + let file_filter = EnvFilter::builder() + .with_default_directive(file_level.into()) + .from_env() + .with_context(|| "failed to create file filter")?; + let (file_filter, file_filter_reloader) = + tracing_subscriber::reload::Layer::new(file_filter); + + if need_reload { + let (sender, mut recver) = mpsc::unbounded_channel(); + ret_sender = Some(sender); + tokio::spawn(async move { + println!("Start log filter reloader"); + while let Some(lf) = recver.recv().await { + let e = file_filter_reloader.modify(|f| { + if let Ok(nf) = EnvFilter::builder() + .with_default_directive(lf.into()) + .from_env() + .with_context(|| "failed to create file filter") + { + println!("Reload log filter succeed, new filter level: {:?}", lf); + *f = nf; + } + }); + if e.is_err() { + println!("Failed to reload log filter: {:?}", e); + } + } + println!("Stop log filter reloader"); + }); + } + + let file_appender = tracing_appender::rolling::Builder::new() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .max_log_files(5) + .filename_prefix(file_config.file.unwrap_or("easytier".to_string())) + .build(file_config.dir.unwrap_or("./".to_string())) + .with_context(|| "failed to initialize rolling file appender")?; + file_layer = Some( + l.with_writer(file_appender) + .with_timer(get_logger_timer_rfc3339()) + .with_filter(file_filter), + ); + } + + // logger to console + let console_config = config.get_console_logger_config(); + let console_level = console_config + .level + .map(|s| s.parse().unwrap()) + .unwrap_or(LevelFilter::OFF); + + let console_filter = EnvFilter::builder() + .with_default_directive(console_level.into()) + .from_env() + .unwrap(); + + let console_layer = tracing_subscriber::fmt::layer() + .pretty() + .with_timer(get_logger_timer_rfc3339()) + .with_writer(std::io::stderr) + .with_filter(console_filter); + + tracing_subscriber::Registry::default() + .with(console_layer) + .with(file_layer) + .init(); + + Ok(ret_sender) +} + +#[cfg(test)] +mod tests { + use crate::common::config::{self}; + + use super::*; + + async fn test_logger_reload() { + println!("current working dir: {:?}", std::env::current_dir()); + let config = config::TomlConfigLoader::default(); + let s = init_logger(&config, true).unwrap(); + tracing::debug!("test not display debug"); + s.unwrap().send(LevelFilter::DEBUG).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + tracing::debug!("test display debug"); + } +}