Web dual stack (#953)

* reimplement easytier-web dual stack
* add protocol check for dual stack listener current only support tcp and udp
This commit is contained in:
BlackLuny
2025-06-07 22:05:11 +08:00
committed by GitHub
parent 3c7837692e
commit 707963c0d9
2 changed files with 69 additions and 36 deletions

View File

@@ -1,21 +1,24 @@
pub mod session; pub mod session;
pub mod storage; pub mod storage;
use std::sync::Arc; use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use dashmap::DashMap; use dashmap::DashMap;
use easytier::{ use easytier::{proto::web::HeartbeatRequest, tunnel::TunnelListener};
common::scoped_task::ScopedTask, proto::web::HeartbeatRequest, tunnel::TunnelListener,
};
use session::Session; use session::Session;
use storage::{Storage, StorageToken}; use storage::{Storage, StorageToken};
use tokio::task::JoinSet;
use crate::db::{Db, UserIdInDb}; use crate::db::{Db, UserIdInDb};
#[derive(Debug)] #[derive(Debug)]
pub struct ClientManager { pub struct ClientManager {
accept_task: Option<ScopedTask<()>>, tasks: JoinSet<()>,
clear_task: Option<ScopedTask<()>>,
listeners_cnt: Arc<AtomicU32>,
client_sessions: Arc<DashMap<url::Url, Arc<Session>>>, client_sessions: Arc<DashMap<url::Url, Arc<Session>>>,
storage: Storage, storage: Storage,
@@ -23,24 +26,35 @@ pub struct ClientManager {
impl ClientManager { impl ClientManager {
pub fn new(db: Db) -> Self { pub fn new(db: Db) -> Self {
let client_sessions = Arc::new(DashMap::new());
let sessions: Arc<DashMap<url::Url, Arc<Session>>> = client_sessions.clone();
let mut tasks = JoinSet::new();
tasks.spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(15)).await;
sessions.retain(|_, session| session.is_running());
}
});
ClientManager { ClientManager {
accept_task: None, tasks,
clear_task: None,
client_sessions: Arc::new(DashMap::new()), listeners_cnt: Arc::new(AtomicU32::new(0)),
client_sessions,
storage: Storage::new(db), storage: Storage::new(db),
} }
} }
pub async fn serve<L: TunnelListener + 'static>( pub async fn add_listener<L: TunnelListener + 'static>(
&mut self, &mut self,
mut listener: L, mut listener: L,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
listener.listen().await?; listener.listen().await?;
self.listeners_cnt.fetch_add(1, Ordering::Relaxed);
let sessions = self.client_sessions.clone(); let sessions = self.client_sessions.clone();
let storage = self.storage.weak_ref(); let storage = self.storage.weak_ref();
let task = tokio::spawn(async move { let listeners_cnt = self.listeners_cnt.clone();
self.tasks.spawn(async move {
while let Ok(tunnel) = listener.accept().await { while let Ok(tunnel) = listener.accept().await {
let info = tunnel.info().unwrap(); let info = tunnel.info().unwrap();
let client_url: url::Url = info.remote_addr.unwrap().into(); let client_url: url::Url = info.remote_addr.unwrap().into();
@@ -49,24 +63,14 @@ impl ClientManager {
session.serve(tunnel).await; session.serve(tunnel).await;
sessions.insert(client_url, Arc::new(session)); sessions.insert(client_url, Arc::new(session));
} }
listeners_cnt.fetch_sub(1, Ordering::Relaxed);
}); });
self.accept_task = Some(ScopedTask::from(task));
let sessions = self.client_sessions.clone();
let task = tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(15)).await;
sessions.retain(|_, session| session.is_running());
}
});
self.clear_task = Some(ScopedTask::from(task));
Ok(()) Ok(())
} }
pub fn is_running(&self) -> bool { pub fn is_running(&self) -> bool {
self.accept_task.is_some() && self.clear_task.is_some() self.listeners_cnt.load(Ordering::Relaxed) > 0
} }
pub async fn list_sessions(&self) -> Vec<StorageToken> { pub async fn list_sessions(&self) -> Vec<StorageToken> {
@@ -132,7 +136,7 @@ mod tests {
async fn test_client() { async fn test_client() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap()); let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap());
let mut mgr = ClientManager::new(Db::memory_db().await); let mut mgr = ClientManager::new(Db::memory_db().await);
mgr.serve(Box::new(listener)).await.unwrap(); mgr.add_listener(Box::new(listener)).await.unwrap();
mgr.db() mgr.db()
.inner() .inner()

View File

@@ -11,6 +11,7 @@ use easytier::{
config::{ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, TomlConfigLoader}, config::{ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, TomlConfigLoader},
constants::EASYTIER_VERSION, constants::EASYTIER_VERSION,
error::Error, error::Error,
network::{local_ipv4, local_ipv6},
}, },
tunnel::{ tunnel::{
tcp::TcpTunnelListener, udp::UdpTunnelListener, websocket::WSTunnelListener, TunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, websocket::WSTunnelListener, TunnelListener,
@@ -111,6 +112,31 @@ pub fn get_listener_by_url(l: &url::Url) -> Result<Box<dyn TunnelListener>, Erro
}) })
} }
async fn get_dual_stack_listener(
protocol: &str,
port: u16,
) -> Result<
(
Option<Box<dyn TunnelListener>>,
Option<Box<dyn TunnelListener>>,
),
Error,
> {
let is_protocol_support_dual_stack =
protocol.trim().to_lowercase() == "tcp" || protocol.trim().to_lowercase() == "udp";
let v6_listener = if is_protocol_support_dual_stack && local_ipv6().await.is_ok() {
get_listener_by_url(&format!("{}://[::0]:{}", protocol, port).parse().unwrap()).ok()
} else {
None
};
let v4_listener = if let Ok(_) = local_ipv4().await {
get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok()
} else {
None
};
Ok((v6_listener, v4_listener))
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let locale = sys_locale::get_locale().unwrap_or_else(|| String::from("en-US")); let locale = sys_locale::get_locale().unwrap_or_else(|| String::from("en-US"));
@@ -131,18 +157,21 @@ async fn main() {
// let db = db::Db::new(":memory:").await.unwrap(); // let db = db::Db::new(":memory:").await.unwrap();
let db = db::Db::new(cli.db).await.unwrap(); let db = db::Db::new(cli.db).await.unwrap();
let listener = get_listener_by_url(
&format!(
"{}://0.0.0.0:{}",
cli.config_server_protocol, cli.config_server_port
)
.parse()
.unwrap(),
)
.unwrap();
let mut mgr = client_manager::ClientManager::new(db.clone()); let mut mgr = client_manager::ClientManager::new(db.clone());
mgr.serve(listener).await.unwrap(); let (v6_listener, v4_listener) =
get_dual_stack_listener(&cli.config_server_protocol, cli.config_server_port)
.await
.unwrap();
if v4_listener.is_none() && v6_listener.is_none() {
panic!("Listen to both IPv4 and IPv6 failed");
}
if let Some(listener) = v6_listener {
mgr.add_listener(listener).await.unwrap();
}
if let Some(listener) = v4_listener {
mgr.add_listener(listener).await.unwrap();
}
let mgr = Arc::new(mgr); let mgr = Arc::new(mgr);
#[cfg(feature = "embed")] #[cfg(feature = "embed")]