Some Improvements (#1172)

1. do not exit when dns query failed on et startup.
2. do not send secret digest to client when secret mismatch.
This commit is contained in:
Sijie.Sun
2025-07-29 23:05:38 +08:00
committed by GitHub
parent 2ec88da823
commit 3d610c0f0f
2 changed files with 38 additions and 48 deletions

View File

@@ -4,11 +4,11 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use dashmap::{DashMap, DashSet}; use dashmap::DashSet;
use tokio::{ use tokio::{
sync::{ sync::{
broadcast::{error::RecvError, Receiver}, broadcast::{error::RecvError, Receiver},
mpsc, Mutex, mpsc,
}, },
task::JoinSet, task::JoinSet,
time::timeout, time::timeout,
@@ -32,7 +32,6 @@ use crate::{
global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
netns::NetNS, netns::NetNS,
}, },
connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
proto::cli::{ proto::cli::{
Connector, ConnectorManageRpc, ConnectorStatus, ListConnectorRequest, Connector, ConnectorManageRpc, ConnectorStatus, ListConnectorRequest,
@@ -43,8 +42,7 @@ use crate::{
use super::create_connector_by_url; use super::create_connector_by_url;
type MutexConnector = Arc<Mutex<Box<dyn TunnelConnector>>>; type ConnectorMap = Arc<DashSet<String>>;
type ConnectorMap = Arc<DashMap<String, MutexConnector>>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ReconnResult { struct ReconnResult {
@@ -72,7 +70,7 @@ pub struct ManualConnectorManager {
impl ManualConnectorManager { impl ManualConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self { pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
let connectors = Arc::new(DashMap::new()); let connectors = Arc::new(DashSet::new());
let tasks = JoinSet::new(); let tasks = JoinSet::new();
let event_subscriber = global_ctx.subscribe(); let event_subscriber = global_ctx.subscribe();
@@ -105,14 +103,11 @@ impl ManualConnectorManager {
T: TunnelConnector + 'static, T: TunnelConnector + 'static,
{ {
tracing::info!("add_connector: {}", connector.remote_url()); tracing::info!("add_connector: {}", connector.remote_url());
self.data.connectors.insert( self.data.connectors.insert(connector.remote_url().into());
connector.remote_url().into(),
Arc::new(Mutex::new(Box::new(connector))),
);
} }
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> { pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
self.add_connector(create_connector_by_url(url, &self.global_ctx, IpVersion::Both).await?); self.data.connectors.insert(url.to_owned());
Ok(()) Ok(())
} }
@@ -236,16 +231,16 @@ impl ManualConnectorManager {
for dead_url in dead_urls { for dead_url in dead_urls {
let data_clone = data.clone(); let data_clone = data.clone();
let sender = reconn_result_send.clone(); let sender = reconn_result_send.clone();
let (_, connector) = data.connectors.remove(&dead_url).unwrap(); data.connectors.remove(&dead_url).unwrap();
let insert_succ = data.reconnecting.insert(dead_url.clone()); let insert_succ = data.reconnecting.insert(dead_url.clone());
assert!(insert_succ); assert!(insert_succ);
tasks.lock().unwrap().spawn(async move { tasks.lock().unwrap().spawn(async move {
let reconn_ret = Self::conn_reconnect(data_clone.clone(), dead_url.clone(), connector.clone()).await; let reconn_ret = Self::conn_reconnect(data_clone.clone(), dead_url.clone() ).await;
let _ = sender.send(reconn_ret).await; let _ = sender.send(reconn_ret).await;
data_clone.reconnecting.remove(&dead_url).unwrap(); data_clone.reconnecting.remove(&dead_url).unwrap();
data_clone.connectors.insert(dead_url.clone(), connector); data_clone.connectors.insert(dead_url.clone());
}); });
} }
tracing::info!("reconn_interval tick, done"); tracing::info!("reconn_interval tick, done");
@@ -323,25 +318,13 @@ impl ManualConnectorManager {
async fn conn_reconnect_with_ip_version( async fn conn_reconnect_with_ip_version(
data: Arc<ConnectorManagerData>, data: Arc<ConnectorManagerData>,
dead_url: String, dead_url: String,
connector: MutexConnector,
ip_version: IpVersion, ip_version: IpVersion,
) -> Result<ReconnResult, Error> { ) -> Result<ReconnResult, Error> {
let ip_collector = data.global_ctx.get_ip_collector(); let connector =
create_connector_by_url(&dead_url, &data.global_ctx.clone(), ip_version).await?;
connector.lock().await.set_ip_version(ip_version); data.global_ctx
.issue_event(GlobalCtxEvent::Connecting(connector.remote_url().clone()));
if data.global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
connector.lock().await.as_mut(),
ip_version == IpVersion::V4,
&ip_collector,
)
.await;
}
data.global_ctx.issue_event(GlobalCtxEvent::Connecting(
connector.lock().await.remote_url().clone(),
));
tracing::info!("reconnect try connect... conn: {:?}", connector); tracing::info!("reconnect try connect... conn: {:?}", connector);
let Some(pm) = data.peer_manager.upgrade() else { let Some(pm) = data.peer_manager.upgrade() else {
return Err(Error::AnyhowError(anyhow::anyhow!( return Err(Error::AnyhowError(anyhow::anyhow!(
@@ -349,9 +332,7 @@ impl ManualConnectorManager {
))); )));
}; };
let (peer_id, conn_id) = pm let (peer_id, conn_id) = pm.try_direct_connect(connector).await?;
.try_direct_connect(connector.lock().await.as_mut())
.await?;
tracing::info!("reconnect succ: {} {} {}", peer_id, conn_id, dead_url); tracing::info!("reconnect succ: {} {} {}", peer_id, conn_id, dead_url);
Ok(ReconnResult { Ok(ReconnResult {
dead_url, dead_url,
@@ -363,7 +344,6 @@ impl ManualConnectorManager {
async fn conn_reconnect( async fn conn_reconnect(
data: Arc<ConnectorManagerData>, data: Arc<ConnectorManagerData>,
dead_url: String, dead_url: String,
connector: MutexConnector,
) -> Result<ReconnResult, Error> { ) -> Result<ReconnResult, Error> {
tracing::info!("reconnect: {}", dead_url); tracing::info!("reconnect: {}", dead_url);
@@ -415,12 +395,7 @@ impl ManualConnectorManager {
let ret = timeout( let ret = timeout(
// allow http connector to wait longer // allow http connector to wait longer
std::time::Duration::from_secs(if use_long_timeout { 20 } else { 2 }), std::time::Duration::from_secs(if use_long_timeout { 20 } else { 2 }),
Self::conn_reconnect_with_ip_version( Self::conn_reconnect_with_ip_version(data.clone(), dead_url.clone(), ip_version),
data.clone(),
dead_url.clone(),
connector.clone(),
ip_version,
),
) )
.await; .await;
tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret); tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret);

View File

@@ -247,7 +247,7 @@ impl PeerConn {
.await? .await?
} }
async fn send_handshake(&mut self) -> Result<(), Error> { async fn send_handshake(&mut self, send_secret_digest: bool) -> Result<(), Error> {
let network = self.global_ctx.get_network_identity(); let network = self.global_ctx.get_network_identity();
let mut req = HandshakeRequest { let mut req = HandshakeRequest {
magic: MAGIC, magic: MAGIC,
@@ -257,8 +257,16 @@ impl PeerConn {
network_name: network.network_name.clone(), network_name: network.network_name.clone(),
..Default::default() ..Default::default()
}; };
req.network_secret_digrest
.extend_from_slice(&network.network_secret_digest.unwrap_or_default()); // only send network secret digest if the network is the same
if send_secret_digest {
req.network_secret_digrest
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
} else {
// fill zero
req.network_secret_digrest
.extend_from_slice(&[0u8; std::mem::size_of::<NetworkSecretDigest>()]);
}
let hs_req = req.encode_to_vec(); let hs_req = req.encode_to_vec();
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
@@ -295,7 +303,8 @@ impl PeerConn {
self.info = Some(rsp); self.info = Some(rsp);
self.is_client = Some(false); self.is_client = Some(false);
self.send_handshake().await?; let send_digest = self.get_network_identity() == self.global_ctx.get_network_identity();
self.send_handshake(send_digest).await?;
if self.get_peer_id() == self.my_peer_id { if self.get_peer_id() == self.my_peer_id {
Err(Error::WaitRespError("peer id conflict".to_owned())) Err(Error::WaitRespError("peer id conflict".to_owned()))
@@ -310,10 +319,14 @@ impl PeerConn {
tracing::info!("handshake request: {:?}", rsp); tracing::info!("handshake request: {:?}", rsp);
self.info = Some(rsp); self.info = Some(rsp);
self.is_client = Some(false); self.is_client = Some(false);
self.send_handshake().await?;
let send_digest = self.get_network_identity() == self.global_ctx.get_network_identity();
self.send_handshake(send_digest).await?;
if self.get_peer_id() == self.my_peer_id { if self.get_peer_id() == self.my_peer_id {
Err(Error::WaitRespError("peer id conflict".to_owned())) Err(Error::WaitRespError(
"peer id conflict, are you connecting to yourself?".to_owned(),
))
} else { } else {
Ok(()) Ok(())
} }
@@ -321,7 +334,7 @@ impl PeerConn {
#[tracing::instrument] #[tracing::instrument]
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> { pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
self.send_handshake().await?; self.send_handshake(true).await?;
tracing::info!("waiting for handshake request from server"); tracing::info!("waiting for handshake request from server");
let rsp = self.wait_handshake_loop().await?; let rsp = self.wait_handshake_loop().await?;
tracing::info!("handshake response: {:?}", rsp); tracing::info!("handshake response: {:?}", rsp);
@@ -329,7 +342,9 @@ impl PeerConn {
self.is_client = Some(true); self.is_client = Some(true);
if self.get_peer_id() == self.my_peer_id { if self.get_peer_id() == self.my_peer_id {
Err(Error::WaitRespError("peer id conflict".to_owned())) Err(Error::WaitRespError(
"peer id conflict, are you connecting to yourself?".to_owned(),
))
} else { } else {
Ok(()) Ok(())
} }