Compare commits

...

3 Commits

Author SHA1 Message Date
sijie.sun
ba3f36d22b remove lock on pipelines 2025-07-25 10:46:06 +08:00
sijie.sun
78d8848ede fix cond of rpc encrypt 2025-07-25 09:13:42 +08:00
sijie.sun
601a0bf719 remove lock of routes 2025-07-25 09:11:05 +08:00
12 changed files with 119 additions and 88 deletions

View File

@@ -306,7 +306,7 @@ impl IcmpProxy {
return Err(anyhow::anyhow!("peer manager is gone").into()); return Err(anyhow::anyhow!("peer manager is gone").into());
}; };
pm.add_packet_process_pipeline(Box::new(self.clone())).await; pm.add_packet_process_pipeline(self.clone()).await;
Ok(()) Ok(())
} }

View File

@@ -341,13 +341,13 @@ impl KcpProxySrc {
pub async fn start(&self) { pub async fn start(&self) {
self.peer_manager self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.tcp_proxy.clone())) .add_nic_packet_process_pipeline(Arc::new(self.tcp_proxy.clone()))
.await; .await;
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(self.tcp_proxy.0.clone())) .add_packet_process_pipeline(Arc::new(self.tcp_proxy.0.clone()))
.await; .await;
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(KcpEndpointFilter { .add_packet_process_pipeline(Arc::new(KcpEndpointFilter {
kcp_endpoint: self.kcp_endpoint.clone(), kcp_endpoint: self.kcp_endpoint.clone(),
is_src: true, is_src: true,
})) }))
@@ -484,7 +484,7 @@ impl KcpProxyDst {
pub async fn start(&mut self) { pub async fn start(&mut self) {
self.run_accept_task().await; self.run_accept_task().await;
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(KcpEndpointFilter { .add_packet_process_pipeline(Arc::new(KcpEndpointFilter {
kcp_endpoint: self.kcp_endpoint.clone(), kcp_endpoint: self.kcp_endpoint.clone(),
is_src: false, is_src: false,
})) }))

View File

@@ -227,10 +227,10 @@ impl QUICProxySrc {
pub async fn start(&self) { pub async fn start(&self) {
self.peer_manager self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.tcp_proxy.clone())) .add_nic_packet_process_pipeline(Arc::new(self.tcp_proxy.clone()))
.await; .await;
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(self.tcp_proxy.0.clone())) .add_packet_process_pipeline(Arc::new(self.tcp_proxy.0.clone()))
.await; .await;
self.tcp_proxy.0.start(false).await.unwrap(); self.tcp_proxy.0.start(false).await.unwrap();
} }

View File

@@ -621,7 +621,7 @@ impl Socks5Server {
if need_start { if need_start {
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone())) .add_packet_process_pipeline(self.clone())
.await; .await;
self.run_net_update_task().await; self.run_net_update_task().await;

View File

@@ -476,10 +476,10 @@ impl<C: NatDstConnector> TcpProxy<C> {
self.run_listener().await?; self.run_listener().await?;
if add_pipeline { if add_pipeline {
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone())) .add_packet_process_pipeline(self.clone())
.await; .await;
self.peer_manager self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone())) .add_nic_packet_process_pipeline(self.clone())
.await; .await;
} }
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned()); join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());

View File

@@ -404,7 +404,7 @@ impl UdpProxy {
pub async fn start(self: &Arc<Self>) -> Result<(), Error> { pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.peer_manager self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone())) .add_packet_process_pipeline(self.clone())
.await; .await;
// clean up nat table // clean up nat table

View File

@@ -404,9 +404,7 @@ impl MagicDnsServerInstance {
.register(MagicDnsServerRpcServer::new(data.clone()), ""); .register(MagicDnsServerRpcServer::new(data.clone()), "");
rpc_server.set_hook(data.clone()); rpc_server.set_hook(data.clone());
peer_mgr peer_mgr.add_nic_packet_process_pipeline(data.clone()).await;
.add_nic_packet_process_pipeline(Box::new(data.clone()))
.await;
let data_clone = data.clone(); let data_clone = data.clone();
tokio::task::spawn_blocking(move || data_clone.do_system_config(DEFAULT_ET_DNS_ZONE)) tokio::task::spawn_blocking(move || data_clone.do_system_config(DEFAULT_ET_DNS_ZONE))

View File

@@ -23,6 +23,8 @@ pub mod peer_task;
#[cfg(test)] #[cfg(test)]
pub mod tests; pub mod tests;
use std::sync::Arc;
use crate::tunnel::packet_def::ZCPacket; use crate::tunnel::packet_def::ZCPacket;
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -43,8 +45,8 @@ pub trait NicPacketFilter {
} }
} }
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>; type BoxPeerPacketFilter = Arc<dyn PeerPacketFilter + Send + Sync>;
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>; type BoxNicPacketFilter = Arc<dyn NicPacketFilter + Send + Sync>;
// pub type PacketRecvChan = tachyonix::Sender<ZCPacket>; // pub type PacketRecvChan = tachyonix::Sender<ZCPacket>;
// pub type PacketRecvChanReceiver = tachyonix::Receiver<ZCPacket>; // pub type PacketRecvChanReceiver = tachyonix::Receiver<ZCPacket>;

View File

@@ -6,6 +6,7 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use arc_swap::ArcSwap;
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
@@ -13,7 +14,7 @@ use dashmap::DashMap;
use tokio::{ use tokio::{
sync::{ sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender}, mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, RwLock, Mutex,
}, },
task::JoinSet, task::JoinSet,
}; };
@@ -86,7 +87,8 @@ impl PeerRpcManagerTransport for RpcTransport {
.get_route_peer_info(dst_peer_id) .get_route_peer_info(dst_peer_id)
.await .await
.and_then(|x| x.feature_flag.map(|x| x.is_public_server)) .and_then(|x| x.feature_flag.map(|x| x.is_public_server))
.unwrap_or(true); // if dst is directly connected, it's must not public server
.unwrap_or(!peers.has_peer(dst_peer_id));
if !is_dst_peer_public_server { if !is_dst_peer_public_server {
self.encryptor self.encryptor
.encrypt(&mut msg) .encrypt(&mut msg)
@@ -130,8 +132,8 @@ pub struct PeerManager {
peer_rpc_mgr: Arc<PeerRpcManager>, peer_rpc_mgr: Arc<PeerRpcManager>,
peer_rpc_tspt: Arc<RpcTransport>, peer_rpc_tspt: Arc<RpcTransport>,
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>, peer_packet_process_pipeline: Arc<ArcSwap<Vec<BoxPeerPacketFilter>>>,
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>, nic_packet_process_pipeline: ArcSwap<Vec<BoxNicPacketFilter>>,
route_algo_inst: RouteAlgoInst, route_algo_inst: RouteAlgoInst,
@@ -260,8 +262,8 @@ impl PeerManager {
peer_rpc_mgr, peer_rpc_mgr,
peer_rpc_tspt: rpc_tspt, peer_rpc_tspt: rpc_tspt,
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())), peer_packet_process_pipeline: Arc::new(ArcSwap::from(Arc::new(Vec::new()))),
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())), nic_packet_process_pipeline: ArcSwap::from(Arc::new(Vec::new())),
route_algo_inst, route_algo_inst,
@@ -646,7 +648,7 @@ impl PeerManager {
let mut processed = false; let mut processed = false;
let mut zc_packet = Some(ret); let mut zc_packet = Some(ret);
let mut idx = 0; let mut idx = 0;
for pipeline in pipe_line.read().await.iter().rev() { for pipeline in pipe_line.load().iter().rev() {
tracing::trace!(?zc_packet, ?idx, "try_process_packet_from_peer"); tracing::trace!(?zc_packet, ?idx, "try_process_packet_from_peer");
idx += 1; idx += 1;
zc_packet = pipeline zc_packet = pipeline
@@ -668,18 +670,20 @@ impl PeerManager {
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) { pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
// newest pipeline will be executed first // newest pipeline will be executed first
let current = self.peer_packet_process_pipeline.load();
let mut new_pipelines = (*(*current)).iter().map(|x| x.clone()).collect::<Vec<_>>();
new_pipelines.push(pipeline);
self.peer_packet_process_pipeline self.peer_packet_process_pipeline
.write() .swap(Arc::new(new_pipelines));
.await
.push(pipeline);
} }
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) { pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
// newest pipeline will be executed first // newest pipeline will be executed first
let current = self.nic_packet_process_pipeline.load();
let mut new_pipelines = (*current).iter().map(|x| x.clone()).collect::<Vec<_>>();
new_pipelines.push(pipeline);
self.nic_packet_process_pipeline self.nic_packet_process_pipeline
.write() .swap(Arc::new(new_pipelines));
.await
.push(pipeline);
} }
async fn init_packet_process_pipeline(&self) { async fn init_packet_process_pipeline(&self) {
@@ -701,7 +705,7 @@ impl PeerManager {
} }
} }
} }
self.add_packet_process_pipeline(Box::new(NicPacketProcessor { self.add_packet_process_pipeline(Arc::new(NicPacketProcessor {
nic_channel: self.nic_channel.clone(), nic_channel: self.nic_channel.clone(),
})) }))
.await; .await;
@@ -726,7 +730,7 @@ impl PeerManager {
} }
} }
} }
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor { self.add_packet_process_pipeline(Arc::new(PeerRpcPacketProcessor {
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(), peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
})) }))
.await; .await;
@@ -734,12 +738,8 @@ impl PeerManager {
pub async fn add_route<T>(&self, route: T) pub async fn add_route<T>(&self, route: T)
where where
T: Route + PeerPacketFilter + Send + Sync + Clone + 'static, T: Route + Send + Sync + Clone + 'static,
{ {
// for route
self.add_packet_process_pipeline(Box::new(route.clone()))
.await;
struct Interface { struct Interface {
my_peer_id: PeerId, my_peer_id: PeerId,
peers: Weak<PeerMap>, peers: Weak<PeerMap>,
@@ -865,15 +865,19 @@ impl PeerManager {
return; return;
} }
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() { let pipelines = self.nic_packet_process_pipeline.load();
for pipeline in pipelines.iter().rev() {
let _ = pipeline.try_process_packet_from_nic(data).await; let _ = pipeline.try_process_packet_from_nic(data).await;
} }
} }
pub async fn remove_nic_packet_process_pipeline(&self, id: String) -> Result<(), Error> { pub async fn remove_nic_packet_process_pipeline(&self, id: String) -> Result<(), Error> {
let mut pipelines = self.nic_packet_process_pipeline.write().await; let current = self.nic_packet_process_pipeline.load();
if let Some(pos) = pipelines.iter().position(|x| x.id() == id) { let mut new_pipelines = (*current).iter().map(|x| x.clone()).collect::<Vec<_>>();
pipelines.remove(pos); if let Some(pos) = new_pipelines.iter().position(|x| x.id() == id) {
new_pipelines.remove(pos);
self.nic_packet_process_pipeline
.swap(Arc::new(new_pipelines));
Ok(()) Ok(())
} else { } else {
Err(Error::NotFound) Err(Error::NotFound)
@@ -1205,10 +1209,9 @@ impl PeerManager {
} }
pub async fn clear_resources(&self) { pub async fn clear_resources(&self) {
let mut peer_pipeline = self.peer_packet_process_pipeline.write().await; self.peer_packet_process_pipeline
peer_pipeline.clear(); .store(Arc::new(Vec::new()));
let mut nic_pipeline = self.nic_packet_process_pipeline.write().await; self.nic_packet_process_pipeline.store(Arc::new(Vec::new()));
nic_pipeline.clear();
self.peer_rpc_mgr.rpc_server().registry().unregister_all(); self.peer_rpc_mgr.rpc_server().registry().unregister_all();
} }

View File

@@ -4,8 +4,8 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use arc_swap::ArcSwap;
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use tokio::sync::RwLock;
use crate::{ use crate::{
common::{ common::{
@@ -32,7 +32,7 @@ pub struct PeerMap {
my_peer_id: PeerId, my_peer_id: PeerId,
peer_map: DashMap<PeerId, Arc<Peer>>, peer_map: DashMap<PeerId, Arc<Peer>>,
packet_send: PacketRecvChan, packet_send: PacketRecvChan,
routes: RwLock<Vec<ArcRoute>>, route: ArcSwap<Option<ArcRoute>>,
alive_conns: Arc<DashMap<(PeerId, PeerConnId), PeerConnInfo>>, alive_conns: Arc<DashMap<(PeerId, PeerConnId), PeerConnInfo>>,
} }
@@ -43,7 +43,7 @@ impl PeerMap {
my_peer_id, my_peer_id,
peer_map: DashMap::new(), peer_map: DashMap::new(),
packet_send, packet_send,
routes: RwLock::new(Vec::new()), route: ArcSwap::from(Arc::new(None)),
alive_conns: Arc::new(DashMap::new()), alive_conns: Arc::new(DashMap::new()),
} }
} }
@@ -154,7 +154,7 @@ impl PeerMap {
} }
// get route info // get route info
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
if let Some(gateway_peer_id) = route if let Some(gateway_peer_id) = route
.get_next_hop_with_policy(dst_peer_id, policy.clone()) .get_next_hop_with_policy(dst_peer_id, policy.clone())
.await .await
@@ -171,14 +171,13 @@ impl PeerMap {
&self, &self,
network_identity: &NetworkIdentity, network_identity: &NetworkIdentity,
) -> Vec<PeerId> { ) -> Vec<PeerId> {
let mut ret = Vec::new(); if let Some(route) = self.route.load().as_ref() {
for route in self.routes.read().await.iter() { route
let peers = route
.list_peers_own_foreign_network(&network_identity) .list_peers_own_foreign_network(&network_identity)
.await; .await
ret.extend(peers); } else {
Vec::new()
} }
ret
} }
pub async fn send_msg( pub async fn send_msg(
@@ -199,32 +198,27 @@ impl PeerMap {
} }
pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> { pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
let peer_id = route.get_peer_id_by_ipv4(ipv4).await; route.get_peer_id_by_ipv4(ipv4).await
if peer_id.is_some() { } else {
return peer_id; None
}
} }
None
} }
pub async fn get_peer_id_by_ipv6(&self, ipv6: &Ipv6Addr) -> Option<PeerId> { pub async fn get_peer_id_by_ipv6(&self, ipv6: &Ipv6Addr) -> Option<PeerId> {
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
let peer_id = route.get_peer_id_by_ipv6(ipv6).await; route.get_peer_id_by_ipv6(ipv6).await
if peer_id.is_some() { } else {
return peer_id; None
}
} }
None
} }
pub async fn get_route_peer_info(&self, peer_id: PeerId) -> Option<RoutePeerInfo> { pub async fn get_route_peer_info(&self, peer_id: PeerId) -> Option<RoutePeerInfo> {
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
if let Some(info) = route.get_peer_info(peer_id).await { route.get_peer_info(peer_id).await
return Some(info); } else {
} None
} }
None
} }
pub async fn get_origin_my_peer_id( pub async fn get_origin_my_peer_id(
@@ -232,15 +226,13 @@ impl PeerMap {
network_name: &str, network_name: &str,
foreign_my_peer_id: PeerId, foreign_my_peer_id: PeerId,
) -> Option<PeerId> { ) -> Option<PeerId> {
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
let origin_peer_id = route route
.get_origin_my_peer_id(network_name, foreign_my_peer_id) .get_origin_my_peer_id(network_name, foreign_my_peer_id)
.await; .await
if origin_peer_id.is_some() { } else {
return origin_peer_id; None
}
} }
None
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
@@ -309,8 +301,7 @@ impl PeerMap {
} }
pub async fn add_route(&self, route: ArcRoute) { pub async fn add_route(&self, route: ArcRoute) {
let mut routes = self.routes.write().await; self.route.store(Arc::new(Some(route)));
routes.insert(0, route);
} }
pub async fn clean_peer_without_conn(&self) { pub async fn clean_peer_without_conn(&self) {
@@ -330,7 +321,7 @@ impl PeerMap {
pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> { pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> {
let route_map = DashMap::new(); let route_map = DashMap::new();
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
for item in route.list_routes().await.iter() { for item in route.list_routes().await.iter() {
route_map.insert(item.peer_id, item.next_hop_peer_id); route_map.insert(item.peer_id, item.next_hop_peer_id);
} }
@@ -339,10 +330,11 @@ impl PeerMap {
} }
pub async fn list_route_infos(&self) -> Vec<cli::Route> { pub async fn list_route_infos(&self) -> Vec<cli::Route> {
for route in self.routes.read().await.iter() { if let Some(route) = self.route.load().as_ref() {
return route.list_routes().await; route.list_routes().await
} else {
vec![]
} }
vec![]
} }
pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result<bool, Error> { pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result<bool, Error> {
@@ -383,3 +375,42 @@ impl Drop for PeerMap {
); );
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::common::global_ctx::tests::get_mock_global_ctx;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_peer_map_route_arcswap() {
let (packet_send, _packet_recv) = mpsc::channel(128);
let global_ctx = get_mock_global_ctx();
let my_peer_id = 1;
let peer_map = PeerMap::new(packet_send, global_ctx, my_peer_id);
// Initially, no route should be set
assert!(peer_map.route.load().is_none());
// Test that methods return None/empty when no route is set
assert_eq!(
peer_map
.get_gateway_peer_id(2, NextHopPolicy::LeastHop)
.await,
None
);
assert_eq!(
peer_map
.get_peer_id_by_ipv4(&"192.168.1.1".parse().unwrap())
.await,
None
);
assert_eq!(peer_map.get_route_peer_info(2).await, None);
assert_eq!(peer_map.list_route_infos().await.len(), 0);
// The route field should be accessible and work with ArcSwap
let route_loaded = peer_map.route.load();
assert!(route_loaded.is_none());
}
}

View File

@@ -55,7 +55,6 @@ use super::{
DefaultRouteCostCalculator, ForeignNetworkRouteInfoMap, NextHopPolicy, RouteCostCalculator, DefaultRouteCostCalculator, ForeignNetworkRouteInfoMap, NextHopPolicy, RouteCostCalculator,
RouteCostCalculatorInterface, RouteCostCalculatorInterface,
}, },
PeerPacketFilter,
}; };
static SERVICE_ID: u32 = 7; static SERVICE_ID: u32 = 7;
@@ -2369,8 +2368,6 @@ impl Route for PeerRoute {
} }
} }
impl PeerPacketFilter for Arc<PeerRoute> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{ use std::{

View File

@@ -203,7 +203,7 @@ impl WireGuardImpl {
} }
self.peer_mgr self.peer_mgr
.add_packet_process_pipeline(Box::new(PeerPacketFilterForVpnPortal { .add_packet_process_pipeline(Arc::new(PeerPacketFilterForVpnPortal {
wg_peer_ip_table: self.wg_peer_ip_table.clone(), wg_peer_ip_table: self.wg_peer_ip_table.clone(),
})) }))
.await; .await;