allow listener retry listen (#554)

This commit is contained in:
Sijie.Sun
2025-01-09 00:01:41 +08:00
committed by GitHub
parent d2ec60e108
commit 306817ae9a
3 changed files with 157 additions and 73 deletions

View File

@@ -27,7 +27,7 @@ pub fn gen_default_flags() -> Flags {
relay_all_peer_rpc: false,
disable_udp_hole_punching: false,
ipv6_listener: "udp://[::]:0".to_string(),
multi_thread: false,
multi_thread: true,
data_compress_algo: CompressionAlgoPb::None.into(),
}
}

View File

@@ -230,7 +230,10 @@ impl GlobalCtx {
}
pub fn add_running_listener(&self, url: url::Url) {
self.running_listeners.lock().unwrap().push(url);
let mut l = self.running_listeners.lock().unwrap();
if !l.contains(&url) {
l.push(url);
}
}
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {

View File

@@ -1,8 +1,7 @@
use std::{fmt::Debug, sync::Arc};
use anyhow::Context;
use async_trait::async_trait;
use tokio::{sync::Mutex, task::JoinSet};
use tokio::task::JoinSet;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelListener;
@@ -63,16 +62,20 @@ impl TunnelHandlerForListener for PeerManager {
}
}
#[derive(Debug, Clone)]
struct Listener {
inner: Arc<Mutex<dyn TunnelListener>>,
pub trait ListenerCreatorTrait: Fn() -> Box<dyn TunnelListener> + Send + Sync {}
impl<T: Send + Sync> ListenerCreatorTrait for T where T: Fn() -> Box<dyn TunnelListener> + Send {}
pub type ListenerCreator = Box<dyn ListenerCreatorTrait>;
#[derive(Clone)]
struct ListenerFactory {
creator_fn: Arc<ListenerCreator>,
must_succ: bool,
}
pub struct ListenerManager<H> {
global_ctx: ArcGlobalCtx,
net_ns: NetNS,
listeners: Vec<Listener>,
listeners: Vec<ListenerFactory>,
peer_manager: Arc<H>,
tasks: JoinSet<()>,
@@ -90,31 +93,39 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
}
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
let self_id = self.global_ctx.get_id();
self.add_listener(
RingTunnelListener::new(
format!("ring://{}", self.global_ctx.get_id())
.parse()
.unwrap(),
),
move || {
Box::new(RingTunnelListener::new(
format!("ring://{}", self_id).parse().unwrap(),
))
},
true,
)
.await?;
for l in self.global_ctx.config.get_listener_uris().iter() {
let Ok(lis) = get_listener_by_url(l, self.global_ctx.clone()) else {
let l = l.clone();
let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else {
let msg = format!("failed to get listener by url: {}, maybe not supported", l);
self.global_ctx
.issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg));
continue;
};
self.add_listener(lis, true).await?;
let ctx = self.global_ctx.clone();
self.add_listener(move || get_listener_by_url(&l, ctx.clone()).unwrap(), true)
.await?;
}
if self.global_ctx.config.get_flags().enable_ipv6 {
let ipv6_listener = self.global_ctx.config.get_flags().ipv6_listener.clone();
let _ = self
.add_listener(
UdpTunnelListener::new(ipv6_listener.parse().unwrap()),
move || {
Box::new(UdpTunnelListener::new(
ipv6_listener.clone().parse().unwrap(),
))
},
false,
)
.await?;
@@ -123,85 +134,91 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
Ok(())
}
pub async fn add_listener<L>(&mut self, listener: L, must_succ: bool) -> Result<(), Error>
where
L: TunnelListener + 'static,
{
let listener = Arc::new(Mutex::new(listener));
self.listeners.push(Listener {
inner: listener,
pub async fn add_listener<C: ListenerCreatorTrait + 'static>(
&mut self,
creator: C,
must_succ: bool,
) -> Result<(), Error> {
self.listeners.push(ListenerFactory {
creator_fn: Arc::new(Box::new(creator)),
must_succ,
});
Ok(())
}
#[tracing::instrument]
#[tracing::instrument(skip(creator))]
async fn run_listener(
listener: Arc<Mutex<dyn TunnelListener>>,
creator: Arc<ListenerCreator>,
peer_manager: Arc<H>,
global_ctx: ArcGlobalCtx,
) {
let mut l = listener.lock().await;
global_ctx.add_running_listener(l.local_url());
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
loop {
let ret = match l.accept().await {
Ok(ret) => ret,
let mut l = (creator)();
let _g = global_ctx.net_ns.guard();
match l.listen().await {
Ok(_) => {
global_ctx.add_running_listener(l.local_url());
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
}
Err(e) => {
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed(
global_ctx.issue_event(GlobalCtxEvent::ListenerAddFailed(
l.local_url(),
e.to_string(),
));
tracing::error!(?e, ?l, "listener accept error");
tracing::error!(?e, ?l, "listener listen error");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
continue;
}
};
}
loop {
let ret = match l.accept().await {
Ok(ret) => ret,
Err(e) => {
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed(
l.local_url(),
format!("error: {}, retry listen later...", e.to_string()),
));
tracing::error!(?e, ?l, "listener accept error");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
break;
}
};
let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info
.local_addr
.clone()
.unwrap_or_default()
.to_string(),
tunnel_info
.remote_addr
.clone()
.unwrap_or_default()
.to_string(),
));
tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone();
tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr.unwrap_or_default().to_string(),
tunnel_info.remote_addr.unwrap_or_default().to_string(),
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");
}
});
let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info
.local_addr
.clone()
.unwrap_or_default()
.to_string(),
tunnel_info
.remote_addr
.clone()
.unwrap_or_default()
.to_string(),
));
tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone();
tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr.unwrap_or_default().to_string(),
tunnel_info.remote_addr.unwrap_or_default().to_string(),
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");
}
});
}
}
}
pub async fn run(&mut self) -> Result<(), Error> {
for listener in &self.listeners {
let _guard = self.net_ns.guard();
let addr = listener.inner.lock().await.local_url();
tracing::warn!("run listener: {:?}", listener);
listener
.inner
.lock()
.await
.listen()
.await
.with_context(|| format!("failed to add listener {}", addr))?;
self.tasks.spawn(Self::run_listener(
listener.inner.clone(),
listener.creator_fn.clone(),
self.peer_manager.clone(),
self.global_ctx.clone(),
));
@@ -213,12 +230,14 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicI32, Ordering};
use futures::{SinkExt, StreamExt};
use tokio::time::timeout;
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector},
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector, TunnelError},
};
use super::*;
@@ -245,12 +264,18 @@ mod tests {
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
let ring_id_clone = ring_id.clone();
listener_mgr
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()), true)
.add_listener(
move || Box::new(RingTunnelListener::new(ring_id_clone.parse().unwrap())),
true,
)
.await
.unwrap();
listener_mgr.run().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let connect_once = |ring_id| async move {
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
let (mut recv, _send) = tunnel.split();
@@ -269,4 +294,60 @@ mod tests {
.await
.unwrap();
}
#[tokio::test]
async fn retry_listen() {
let counter = Arc::new(AtomicI32::new(0));
let drop_counter = Arc::new(AtomicI32::new(0));
struct MockListener {
counter: Arc<AtomicI32>,
drop_counter: Arc<AtomicI32>,
}
#[async_trait::async_trait]
impl TunnelListener for MockListener {
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
async fn listen(&mut self) -> Result<(), TunnelError> {
self.counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Err(TunnelError::BufferFull)
}
}
impl Drop for MockListener {
fn drop(&mut self) {
self.drop_counter.fetch_add(1, Ordering::Relaxed);
}
}
let handler = Arc::new(MockListenerHandler {});
let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone());
let counter_clone = counter.clone();
let drop_counter_clone = drop_counter.clone();
listener_mgr
.add_listener(
move || {
Box::new(MockListener {
counter: counter_clone.clone(),
drop_counter: drop_counter_clone.clone(),
})
},
true,
)
.await
.unwrap();
listener_mgr.run().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
assert!(counter.load(Ordering::Relaxed) >= 2);
assert!(drop_counter.load(Ordering::Relaxed) >= 1);
}
}