diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index eef23fd..6ec7750 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -4,8 +4,8 @@ use std::{ }; use anyhow::Context; +use arc_swap::ArcSwap; use dashmap::{DashMap, DashSet}; -use tokio::sync::RwLock; use crate::{ common::{ @@ -32,7 +32,7 @@ pub struct PeerMap { my_peer_id: PeerId, peer_map: DashMap>, packet_send: PacketRecvChan, - routes: RwLock>, + route: ArcSwap>, alive_conns: Arc>, } @@ -43,7 +43,7 @@ impl PeerMap { my_peer_id, peer_map: DashMap::new(), packet_send, - routes: RwLock::new(Vec::new()), + route: ArcSwap::from(Arc::new(None)), alive_conns: Arc::new(DashMap::new()), } } @@ -154,7 +154,7 @@ impl PeerMap { } // get route info - for route in self.routes.read().await.iter() { + if let Some(route) = self.route.load().as_ref() { if let Some(gateway_peer_id) = route .get_next_hop_with_policy(dst_peer_id, policy.clone()) .await @@ -171,14 +171,13 @@ impl PeerMap { &self, network_identity: &NetworkIdentity, ) -> Vec { - let mut ret = Vec::new(); - for route in self.routes.read().await.iter() { - let peers = route + if let Some(route) = self.route.load().as_ref() { + route .list_peers_own_foreign_network(&network_identity) - .await; - ret.extend(peers); + .await + } else { + Vec::new() } - ret } pub async fn send_msg( @@ -199,32 +198,27 @@ impl PeerMap { } pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option { - for route in self.routes.read().await.iter() { - let peer_id = route.get_peer_id_by_ipv4(ipv4).await; - if peer_id.is_some() { - return peer_id; - } + if let Some(route) = self.route.load().as_ref() { + route.get_peer_id_by_ipv4(ipv4).await + } else { + None } - None } pub async fn get_peer_id_by_ipv6(&self, ipv6: &Ipv6Addr) -> Option { - for route in self.routes.read().await.iter() { - let peer_id = route.get_peer_id_by_ipv6(ipv6).await; - if peer_id.is_some() { - return peer_id; - } + if let Some(route) = self.route.load().as_ref() { + route.get_peer_id_by_ipv6(ipv6).await + } else { + None } - None } pub async fn get_route_peer_info(&self, peer_id: PeerId) -> Option { - for route in self.routes.read().await.iter() { - if let Some(info) = route.get_peer_info(peer_id).await { - return Some(info); - } + if let Some(route) = self.route.load().as_ref() { + route.get_peer_info(peer_id).await + } else { + None } - None } pub async fn get_origin_my_peer_id( @@ -232,15 +226,13 @@ impl PeerMap { network_name: &str, foreign_my_peer_id: PeerId, ) -> Option { - for route in self.routes.read().await.iter() { - let origin_peer_id = route + if let Some(route) = self.route.load().as_ref() { + route .get_origin_my_peer_id(network_name, foreign_my_peer_id) - .await; - if origin_peer_id.is_some() { - return origin_peer_id; - } + .await + } else { + None } - None } pub fn is_empty(&self) -> bool { @@ -309,8 +301,7 @@ impl PeerMap { } pub async fn add_route(&self, route: ArcRoute) { - let mut routes = self.routes.write().await; - routes.insert(0, route); + self.route.store(Arc::new(Some(route))); } pub async fn clean_peer_without_conn(&self) { @@ -330,7 +321,7 @@ impl PeerMap { pub async fn list_routes(&self) -> DashMap { let route_map = DashMap::new(); - for route in self.routes.read().await.iter() { + if let Some(route) = self.route.load().as_ref() { for item in route.list_routes().await.iter() { route_map.insert(item.peer_id, item.next_hop_peer_id); } @@ -339,10 +330,11 @@ impl PeerMap { } pub async fn list_route_infos(&self) -> Vec { - for route in self.routes.read().await.iter() { - return route.list_routes().await; + if let Some(route) = self.route.load().as_ref() { + route.list_routes().await + } else { + vec![] } - vec![] } pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result { @@ -383,3 +375,42 @@ impl Drop for PeerMap { ); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::global_ctx::tests::get_mock_global_ctx; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_peer_map_route_arcswap() { + let (packet_send, _packet_recv) = mpsc::channel(128); + let global_ctx = get_mock_global_ctx(); + let my_peer_id = 1; + + let peer_map = PeerMap::new(packet_send, global_ctx, my_peer_id); + + // Initially, no route should be set + assert!(peer_map.route.load().is_none()); + + // Test that methods return None/empty when no route is set + assert_eq!( + peer_map + .get_gateway_peer_id(2, NextHopPolicy::LeastHop) + .await, + None + ); + assert_eq!( + peer_map + .get_peer_id_by_ipv4(&"192.168.1.1".parse().unwrap()) + .await, + None + ); + assert_eq!(peer_map.get_route_peer_info(2).await, None); + assert_eq!(peer_map.list_route_infos().await.len(), 0); + + // The route field should be accessible and work with ArcSwap + let route_loaded = peer_map.route.load(); + assert!(route_loaded.is_none()); + } +}