diff --git a/easytier/src/connector/manual.rs b/easytier/src/connector/manual.rs index 2cce883..fc04d54 100644 --- a/easytier/src/connector/manual.rs +++ b/easytier/src/connector/manual.rs @@ -4,11 +4,11 @@ use std::{ }; use anyhow::Context; -use dashmap::{DashMap, DashSet}; +use dashmap::DashSet; use tokio::{ sync::{ broadcast::{error::RecvError, Receiver}, - mpsc, Mutex, + mpsc, }, task::JoinSet, time::timeout, @@ -32,7 +32,6 @@ use crate::{ global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, netns::NetNS, }, - connector::set_bind_addr_for_peer_connector, peers::peer_manager::PeerManager, proto::cli::{ Connector, ConnectorManageRpc, ConnectorStatus, ListConnectorRequest, @@ -43,8 +42,7 @@ use crate::{ use super::create_connector_by_url; -type MutexConnector = Arc>>; -type ConnectorMap = Arc>; +type ConnectorMap = Arc>; #[derive(Debug, Clone)] struct ReconnResult { @@ -72,7 +70,7 @@ pub struct ManualConnectorManager { impl ManualConnectorManager { pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc) -> Self { - let connectors = Arc::new(DashMap::new()); + let connectors = Arc::new(DashSet::new()); let tasks = JoinSet::new(); let event_subscriber = global_ctx.subscribe(); @@ -105,14 +103,11 @@ impl ManualConnectorManager { T: TunnelConnector + 'static, { tracing::info!("add_connector: {}", connector.remote_url()); - self.data.connectors.insert( - connector.remote_url().into(), - Arc::new(Mutex::new(Box::new(connector))), - ); + self.data.connectors.insert(connector.remote_url().into()); } pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> { - self.add_connector(create_connector_by_url(url, &self.global_ctx, IpVersion::Both).await?); + self.data.connectors.insert(url.to_owned()); Ok(()) } @@ -236,16 +231,16 @@ impl ManualConnectorManager { for dead_url in dead_urls { let data_clone = data.clone(); let sender = reconn_result_send.clone(); - let (_, connector) = data.connectors.remove(&dead_url).unwrap(); + data.connectors.remove(&dead_url).unwrap(); let insert_succ = data.reconnecting.insert(dead_url.clone()); assert!(insert_succ); tasks.lock().unwrap().spawn(async move { - let reconn_ret = Self::conn_reconnect(data_clone.clone(), dead_url.clone(), connector.clone()).await; + let reconn_ret = Self::conn_reconnect(data_clone.clone(), dead_url.clone() ).await; let _ = sender.send(reconn_ret).await; data_clone.reconnecting.remove(&dead_url).unwrap(); - data_clone.connectors.insert(dead_url.clone(), connector); + data_clone.connectors.insert(dead_url.clone()); }); } tracing::info!("reconn_interval tick, done"); @@ -323,25 +318,13 @@ impl ManualConnectorManager { async fn conn_reconnect_with_ip_version( data: Arc, dead_url: String, - connector: MutexConnector, ip_version: IpVersion, ) -> Result { - let ip_collector = data.global_ctx.get_ip_collector(); + let connector = + create_connector_by_url(&dead_url, &data.global_ctx.clone(), ip_version).await?; - connector.lock().await.set_ip_version(ip_version); - - if data.global_ctx.config.get_flags().bind_device { - set_bind_addr_for_peer_connector( - connector.lock().await.as_mut(), - ip_version == IpVersion::V4, - &ip_collector, - ) - .await; - } - - data.global_ctx.issue_event(GlobalCtxEvent::Connecting( - connector.lock().await.remote_url().clone(), - )); + data.global_ctx + .issue_event(GlobalCtxEvent::Connecting(connector.remote_url().clone())); tracing::info!("reconnect try connect... conn: {:?}", connector); let Some(pm) = data.peer_manager.upgrade() else { return Err(Error::AnyhowError(anyhow::anyhow!( @@ -349,9 +332,7 @@ impl ManualConnectorManager { ))); }; - let (peer_id, conn_id) = pm - .try_direct_connect(connector.lock().await.as_mut()) - .await?; + let (peer_id, conn_id) = pm.try_direct_connect(connector).await?; tracing::info!("reconnect succ: {} {} {}", peer_id, conn_id, dead_url); Ok(ReconnResult { dead_url, @@ -363,7 +344,6 @@ impl ManualConnectorManager { async fn conn_reconnect( data: Arc, dead_url: String, - connector: MutexConnector, ) -> Result { tracing::info!("reconnect: {}", dead_url); @@ -415,12 +395,7 @@ impl ManualConnectorManager { let ret = timeout( // allow http connector to wait longer std::time::Duration::from_secs(if use_long_timeout { 20 } else { 2 }), - Self::conn_reconnect_with_ip_version( - data.clone(), - dead_url.clone(), - connector.clone(), - ip_version, - ), + Self::conn_reconnect_with_ip_version(data.clone(), dead_url.clone(), ip_version), ) .await; tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret); diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 8989d8d..59820eb 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -247,7 +247,7 @@ impl PeerConn { .await? } - async fn send_handshake(&mut self) -> Result<(), Error> { + async fn send_handshake(&mut self, send_secret_digest: bool) -> Result<(), Error> { let network = self.global_ctx.get_network_identity(); let mut req = HandshakeRequest { magic: MAGIC, @@ -257,8 +257,16 @@ impl PeerConn { network_name: network.network_name.clone(), ..Default::default() }; - req.network_secret_digrest - .extend_from_slice(&network.network_secret_digest.unwrap_or_default()); + + // only send network secret digest if the network is the same + if send_secret_digest { + req.network_secret_digrest + .extend_from_slice(&network.network_secret_digest.unwrap_or_default()); + } else { + // fill zero + req.network_secret_digrest + .extend_from_slice(&[0u8; std::mem::size_of::()]); + } let hs_req = req.encode_to_vec(); let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); @@ -295,7 +303,8 @@ impl PeerConn { self.info = Some(rsp); self.is_client = Some(false); - self.send_handshake().await?; + let send_digest = self.get_network_identity() == self.global_ctx.get_network_identity(); + self.send_handshake(send_digest).await?; if self.get_peer_id() == self.my_peer_id { Err(Error::WaitRespError("peer id conflict".to_owned())) @@ -310,10 +319,14 @@ impl PeerConn { tracing::info!("handshake request: {:?}", rsp); self.info = Some(rsp); self.is_client = Some(false); - self.send_handshake().await?; + + let send_digest = self.get_network_identity() == self.global_ctx.get_network_identity(); + self.send_handshake(send_digest).await?; if self.get_peer_id() == self.my_peer_id { - Err(Error::WaitRespError("peer id conflict".to_owned())) + Err(Error::WaitRespError( + "peer id conflict, are you connecting to yourself?".to_owned(), + )) } else { Ok(()) } @@ -321,7 +334,7 @@ impl PeerConn { #[tracing::instrument] pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> { - self.send_handshake().await?; + self.send_handshake(true).await?; tracing::info!("waiting for handshake request from server"); let rsp = self.wait_handshake_loop().await?; tracing::info!("handshake response: {:?}", rsp); @@ -329,7 +342,9 @@ impl PeerConn { self.is_client = Some(true); if self.get_peer_id() == self.my_peer_id { - Err(Error::WaitRespError("peer id conflict".to_owned())) + Err(Error::WaitRespError( + "peer id conflict, are you connecting to yourself?".to_owned(), + )) } else { Ok(()) }