diff --git a/easytier-web/src/client_manager/mod.rs b/easytier-web/src/client_manager/mod.rs index d8354a3..9c4830e 100644 --- a/easytier-web/src/client_manager/mod.rs +++ b/easytier-web/src/client_manager/mod.rs @@ -1,21 +1,24 @@ pub mod session; pub mod storage; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, +}; use dashmap::DashMap; -use easytier::{ - common::scoped_task::ScopedTask, proto::web::HeartbeatRequest, tunnel::TunnelListener, -}; +use easytier::{proto::web::HeartbeatRequest, tunnel::TunnelListener}; use session::Session; use storage::{Storage, StorageToken}; +use tokio::task::JoinSet; use crate::db::{Db, UserIdInDb}; #[derive(Debug)] pub struct ClientManager { - accept_task: Option>, - clear_task: Option>, + tasks: JoinSet<()>, + + listeners_cnt: Arc, client_sessions: Arc>>, storage: Storage, @@ -23,24 +26,35 @@ pub struct ClientManager { impl ClientManager { pub fn new(db: Db) -> Self { + let client_sessions = Arc::new(DashMap::new()); + let sessions: Arc>> = client_sessions.clone(); + let mut tasks = JoinSet::new(); + tasks.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + sessions.retain(|_, session| session.is_running()); + } + }); ClientManager { - accept_task: None, - clear_task: None, + tasks, - client_sessions: Arc::new(DashMap::new()), + listeners_cnt: Arc::new(AtomicU32::new(0)), + + client_sessions, storage: Storage::new(db), } } - pub async fn serve( + pub async fn add_listener( &mut self, mut listener: L, ) -> Result<(), anyhow::Error> { listener.listen().await?; - + self.listeners_cnt.fetch_add(1, Ordering::Relaxed); let sessions = self.client_sessions.clone(); let storage = self.storage.weak_ref(); - let task = tokio::spawn(async move { + let listeners_cnt = self.listeners_cnt.clone(); + self.tasks.spawn(async move { while let Ok(tunnel) = listener.accept().await { let info = tunnel.info().unwrap(); let client_url: url::Url = info.remote_addr.unwrap().into(); @@ -49,24 +63,14 @@ impl ClientManager { session.serve(tunnel).await; sessions.insert(client_url, Arc::new(session)); } + listeners_cnt.fetch_sub(1, Ordering::Relaxed); }); - self.accept_task = Some(ScopedTask::from(task)); - - let sessions = self.client_sessions.clone(); - let task = tokio::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(15)).await; - sessions.retain(|_, session| session.is_running()); - } - }); - self.clear_task = Some(ScopedTask::from(task)); - Ok(()) } pub fn is_running(&self) -> bool { - self.accept_task.is_some() && self.clear_task.is_some() + self.listeners_cnt.load(Ordering::Relaxed) > 0 } pub async fn list_sessions(&self) -> Vec { @@ -132,7 +136,7 @@ mod tests { async fn test_client() { let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap()); let mut mgr = ClientManager::new(Db::memory_db().await); - mgr.serve(Box::new(listener)).await.unwrap(); + mgr.add_listener(Box::new(listener)).await.unwrap(); mgr.db() .inner() diff --git a/easytier-web/src/main.rs b/easytier-web/src/main.rs index 6a2e6d1..3ad85bb 100644 --- a/easytier-web/src/main.rs +++ b/easytier-web/src/main.rs @@ -11,6 +11,7 @@ use easytier::{ config::{ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, TomlConfigLoader}, constants::EASYTIER_VERSION, error::Error, + network::{local_ipv4, local_ipv6}, }, tunnel::{ tcp::TcpTunnelListener, udp::UdpTunnelListener, websocket::WSTunnelListener, TunnelListener, @@ -111,6 +112,31 @@ pub fn get_listener_by_url(l: &url::Url) -> Result, Erro }) } +async fn get_dual_stack_listener( + protocol: &str, + port: u16, +) -> Result< + ( + Option>, + Option>, + ), + Error, +> { + let is_protocol_support_dual_stack = + protocol.trim().to_lowercase() == "tcp" || protocol.trim().to_lowercase() == "udp"; + let v6_listener = if is_protocol_support_dual_stack && local_ipv6().await.is_ok() { + get_listener_by_url(&format!("{}://[::0]:{}", protocol, port).parse().unwrap()).ok() + } else { + None + }; + let v4_listener = if let Ok(_) = local_ipv4().await { + get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok() + } else { + None + }; + Ok((v6_listener, v4_listener)) +} + #[tokio::main] async fn main() { let locale = sys_locale::get_locale().unwrap_or_else(|| String::from("en-US")); @@ -131,18 +157,21 @@ async fn main() { // let db = db::Db::new(":memory:").await.unwrap(); let db = db::Db::new(cli.db).await.unwrap(); - - let listener = get_listener_by_url( - &format!( - "{}://0.0.0.0:{}", - cli.config_server_protocol, cli.config_server_port - ) - .parse() - .unwrap(), - ) - .unwrap(); let mut mgr = client_manager::ClientManager::new(db.clone()); - mgr.serve(listener).await.unwrap(); + let (v6_listener, v4_listener) = + get_dual_stack_listener(&cli.config_server_protocol, cli.config_server_port) + .await + .unwrap(); + if v4_listener.is_none() && v6_listener.is_none() { + panic!("Listen to both IPv4 and IPv6 failed"); + } + if let Some(listener) = v6_listener { + mgr.add_listener(listener).await.unwrap(); + } + if let Some(listener) = v4_listener { + mgr.add_listener(listener).await.unwrap(); + } + let mgr = Arc::new(mgr); #[cfg(feature = "embed")]