mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-12 20:57:26 +08:00
allow listener retry listen (#554)
This commit is contained in:
@@ -27,7 +27,7 @@ pub fn gen_default_flags() -> Flags {
|
||||
relay_all_peer_rpc: false,
|
||||
disable_udp_hole_punching: false,
|
||||
ipv6_listener: "udp://[::]:0".to_string(),
|
||||
multi_thread: false,
|
||||
multi_thread: true,
|
||||
data_compress_algo: CompressionAlgoPb::None.into(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,7 +230,10 @@ impl GlobalCtx {
|
||||
}
|
||||
|
||||
pub fn add_running_listener(&self, url: url::Url) {
|
||||
self.running_listeners.lock().unwrap().push(url);
|
||||
let mut l = self.running_listeners.lock().unwrap();
|
||||
if !l.contains(&url) {
|
||||
l.push(url);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
use crate::tunnel::quic::QUICTunnelListener;
|
||||
@@ -63,16 +62,20 @@ impl TunnelHandlerForListener for PeerManager {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Listener {
|
||||
inner: Arc<Mutex<dyn TunnelListener>>,
|
||||
pub trait ListenerCreatorTrait: Fn() -> Box<dyn TunnelListener> + Send + Sync {}
|
||||
impl<T: Send + Sync> ListenerCreatorTrait for T where T: Fn() -> Box<dyn TunnelListener> + Send {}
|
||||
pub type ListenerCreator = Box<dyn ListenerCreatorTrait>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ListenerFactory {
|
||||
creator_fn: Arc<ListenerCreator>,
|
||||
must_succ: bool,
|
||||
}
|
||||
|
||||
pub struct ListenerManager<H> {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
net_ns: NetNS,
|
||||
listeners: Vec<Listener>,
|
||||
listeners: Vec<ListenerFactory>,
|
||||
peer_manager: Arc<H>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
@@ -90,31 +93,39 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
||||
}
|
||||
|
||||
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
|
||||
let self_id = self.global_ctx.get_id();
|
||||
self.add_listener(
|
||||
RingTunnelListener::new(
|
||||
format!("ring://{}", self.global_ctx.get_id())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
),
|
||||
move || {
|
||||
Box::new(RingTunnelListener::new(
|
||||
format!("ring://{}", self_id).parse().unwrap(),
|
||||
))
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for l in self.global_ctx.config.get_listener_uris().iter() {
|
||||
let Ok(lis) = get_listener_by_url(l, self.global_ctx.clone()) else {
|
||||
let l = l.clone();
|
||||
let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else {
|
||||
let msg = format!("failed to get listener by url: {}, maybe not supported", l);
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg));
|
||||
continue;
|
||||
};
|
||||
self.add_listener(lis, true).await?;
|
||||
let ctx = self.global_ctx.clone();
|
||||
self.add_listener(move || get_listener_by_url(&l, ctx.clone()).unwrap(), true)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if self.global_ctx.config.get_flags().enable_ipv6 {
|
||||
let ipv6_listener = self.global_ctx.config.get_flags().ipv6_listener.clone();
|
||||
let _ = self
|
||||
.add_listener(
|
||||
UdpTunnelListener::new(ipv6_listener.parse().unwrap()),
|
||||
move || {
|
||||
Box::new(UdpTunnelListener::new(
|
||||
ipv6_listener.clone().parse().unwrap(),
|
||||
))
|
||||
},
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
@@ -123,85 +134,91 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_listener<L>(&mut self, listener: L, must_succ: bool) -> Result<(), Error>
|
||||
where
|
||||
L: TunnelListener + 'static,
|
||||
{
|
||||
let listener = Arc::new(Mutex::new(listener));
|
||||
self.listeners.push(Listener {
|
||||
inner: listener,
|
||||
pub async fn add_listener<C: ListenerCreatorTrait + 'static>(
|
||||
&mut self,
|
||||
creator: C,
|
||||
must_succ: bool,
|
||||
) -> Result<(), Error> {
|
||||
self.listeners.push(ListenerFactory {
|
||||
creator_fn: Arc::new(Box::new(creator)),
|
||||
must_succ,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
#[tracing::instrument(skip(creator))]
|
||||
async fn run_listener(
|
||||
listener: Arc<Mutex<dyn TunnelListener>>,
|
||||
creator: Arc<ListenerCreator>,
|
||||
peer_manager: Arc<H>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) {
|
||||
let mut l = listener.lock().await;
|
||||
global_ctx.add_running_listener(l.local_url());
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
|
||||
loop {
|
||||
let ret = match l.accept().await {
|
||||
Ok(ret) => ret,
|
||||
let mut l = (creator)();
|
||||
let _g = global_ctx.net_ns.guard();
|
||||
match l.listen().await {
|
||||
Ok(_) => {
|
||||
global_ctx.add_running_listener(l.local_url());
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
|
||||
}
|
||||
Err(e) => {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed(
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAddFailed(
|
||||
l.local_url(),
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(?e, ?l, "listener accept error");
|
||||
tracing::error!(?e, ?l, "listener listen error");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
}
|
||||
loop {
|
||||
let ret = match l.accept().await {
|
||||
Ok(ret) => ret,
|
||||
Err(e) => {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed(
|
||||
l.local_url(),
|
||||
format!("error: {}, retry listen later...", e.to_string()),
|
||||
));
|
||||
tracing::error!(?e, ?l, "listener accept error");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let tunnel_info = ret.info().unwrap();
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
|
||||
tunnel_info
|
||||
.local_addr
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
tunnel_info
|
||||
.remote_addr
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
));
|
||||
tracing::info!(ret = ?ret, "conn accepted");
|
||||
let peer_manager = peer_manager.clone();
|
||||
let global_ctx = global_ctx.clone();
|
||||
tokio::spawn(async move {
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||
tunnel_info.local_addr.unwrap_or_default().to_string(),
|
||||
tunnel_info.remote_addr.unwrap_or_default().to_string(),
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
});
|
||||
let tunnel_info = ret.info().unwrap();
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
|
||||
tunnel_info
|
||||
.local_addr
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
tunnel_info
|
||||
.remote_addr
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
));
|
||||
tracing::info!(ret = ?ret, "conn accepted");
|
||||
let peer_manager = peer_manager.clone();
|
||||
let global_ctx = global_ctx.clone();
|
||||
tokio::spawn(async move {
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||
tunnel_info.local_addr.unwrap_or_default().to_string(),
|
||||
tunnel_info.remote_addr.unwrap_or_default().to_string(),
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
for listener in &self.listeners {
|
||||
let _guard = self.net_ns.guard();
|
||||
let addr = listener.inner.lock().await.local_url();
|
||||
tracing::warn!("run listener: {:?}", listener);
|
||||
listener
|
||||
.inner
|
||||
.lock()
|
||||
.await
|
||||
.listen()
|
||||
.await
|
||||
.with_context(|| format!("failed to add listener {}", addr))?;
|
||||
self.tasks.spawn(Self::run_listener(
|
||||
listener.inner.clone(),
|
||||
listener.creator_fn.clone(),
|
||||
self.peer_manager.clone(),
|
||||
self.global_ctx.clone(),
|
||||
));
|
||||
@@ -213,12 +230,14 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx,
|
||||
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector},
|
||||
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector, TunnelError},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -245,12 +264,18 @@ mod tests {
|
||||
|
||||
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
|
||||
|
||||
let ring_id_clone = ring_id.clone();
|
||||
listener_mgr
|
||||
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()), true)
|
||||
.add_listener(
|
||||
move || Box::new(RingTunnelListener::new(ring_id_clone.parse().unwrap())),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
listener_mgr.run().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
let connect_once = |ring_id| async move {
|
||||
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
|
||||
let (mut recv, _send) = tunnel.split();
|
||||
@@ -269,4 +294,60 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_listen() {
|
||||
let counter = Arc::new(AtomicI32::new(0));
|
||||
let drop_counter = Arc::new(AtomicI32::new(0));
|
||||
struct MockListener {
|
||||
counter: Arc<AtomicI32>,
|
||||
drop_counter: Arc<AtomicI32>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for MockListener {
|
||||
fn local_url(&self) -> url::Url {
|
||||
"mock://".parse().unwrap()
|
||||
}
|
||||
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
Err(TunnelError::BufferFull)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MockListener {
|
||||
fn drop(&mut self) {
|
||||
self.drop_counter.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
let handler = Arc::new(MockListenerHandler {});
|
||||
let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone());
|
||||
let counter_clone = counter.clone();
|
||||
let drop_counter_clone = drop_counter.clone();
|
||||
listener_mgr
|
||||
.add_listener(
|
||||
move || {
|
||||
Box::new(MockListener {
|
||||
counter: counter_clone.clone(),
|
||||
drop_counter: drop_counter_clone.clone(),
|
||||
})
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
listener_mgr.run().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
|
||||
assert!(counter.load(Ordering::Relaxed) >= 2);
|
||||
assert!(drop_counter.load(Ordering::Relaxed) >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user