From d5bc15cf7ac7d879d14cae1f8822123f579d87d3 Mon Sep 17 00:00:00 2001 From: "sijie.sun" Date: Sat, 3 Aug 2024 11:01:17 +0800 Subject: [PATCH] fix session_task and session mismatch --- easytier/src/connector/udp_hole_punch.rs | 18 --- easytier/src/peers/peer_ospf_route.rs | 164 ++++++++++++++++------- 2 files changed, 115 insertions(+), 67 deletions(-) diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index 33980c4..4a35345 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -1192,22 +1192,4 @@ pub mod tests { ) .await; } - - #[tokio::test] - async fn udp_listener() { - let p_a = create_mock_peer_manager().await; - wait_for_condition( - || async { - p_a.get_global_ctx() - .get_stun_info_collector() - .get_stun_info() - .udp_nat_type - != NatType::Unknown as i32 - }, - Duration::from_secs(20), - ) - .await; - let l = UdpHolePunchListener::new(p_a.clone()).await.unwrap(); - println!("{:#?}", l.mapped_addr); - } } diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 1dc0684..5a40428 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -16,7 +16,11 @@ use petgraph::{ Directed, Graph, }; use serde::{Deserialize, Serialize}; -use tokio::{select, sync::Mutex, task::JoinSet}; +use tokio::{ + select, + sync::Mutex, + task::{JoinHandle, JoinSet}, +}; use crate::{ common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, @@ -602,6 +606,48 @@ type SessionId = u64; type AtomicSessionId = atomic_shim::AtomicU64; +struct SessionTask { + task: Arc>>>, +} + +impl SessionTask { + fn new() -> Self { + SessionTask { + task: Arc::new(std::sync::Mutex::new(None)), + } + } + + fn set_task(&self, task: JoinHandle<()>) { + if let Some(old) = self.task.lock().unwrap().replace(task) { + old.abort(); + } + } + + fn is_running(&self) -> bool { + if let Some(task) = self.task.lock().unwrap().as_ref() { + !task.is_finished() + } else { + false + } + } +} + +impl Drop for SessionTask { + fn drop(&mut self) { + if let Some(task) = self.task.lock().unwrap().take() { + task.abort(); + } + } +} + +impl Debug for SessionTask { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionTask") + .field("is_running", &self.is_running()) + .finish() + } +} + // if we need to sync route info with one peer, we create a SyncRouteSession with that peer. #[derive(Debug)] struct SyncRouteSession { @@ -620,6 +666,8 @@ struct SyncRouteSession { rpc_tx_count: AtomicU32, rpc_rx_count: AtomicU32, + + task: SessionTask, } impl SyncRouteSession { @@ -639,6 +687,8 @@ impl SyncRouteSession { rpc_tx_count: AtomicU32::new(0), rpc_rx_count: AtomicU32::new(0), + + task: SessionTask::new(), } } @@ -684,6 +734,20 @@ impl SyncRouteSession { self.dst_saved_peer_info_versions.clear(); } } + + fn short_debug_string(&self) -> String { + format!( + "session_dst_peer: {:?}, my_session_id: {:?}, dst_session_id: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, rpc_tx_count: {:?}, rpc_rx_count: {:?}, task: {:?}", + self.dst_peer_id, + self.my_session_id, + self.dst_session_id, + self.we_are_initiator, + self.dst_is_initiator, + self.rpc_tx_count, + self.rpc_rx_count, + self.task + ) + } } struct PeerRouteServiceImpl { @@ -756,6 +820,10 @@ impl PeerRouteServiceImpl { self.sessions.remove(&dst_peer_id); } + fn list_session_peers(&self) -> Vec { + self.sessions.iter().map(|x| *x.key()).collect() + } + async fn list_peers_from_interface>(&self) -> T { self.interface .lock() @@ -944,7 +1012,11 @@ impl PeerRouteServiceImpl { dst_peer_id: PeerId, peer_rpc: Arc, ) -> bool { - let session = self.get_or_create_session(dst_peer_id); + let Some(session) = self.get_session(dst_peer_id) else { + // if session not exist, exit the sync loop. + return true; + }; + let my_peer_id = self.my_peer_id; let (peer_infos, conn_bitmap) = self.build_sync_request(&session); @@ -1018,7 +1090,6 @@ impl PeerRouteServiceImpl { struct RouteSessionManager { service_impl: Weak, peer_rpc: Weak, - session_tasks: Arc>>, sync_now_broadcast: tokio::sync::broadcast::Sender<()>, } @@ -1026,14 +1097,6 @@ struct RouteSessionManager { impl Debug for RouteSessionManager { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RouteSessionManager") - .field( - "session_tasks", - &self - .session_tasks - .iter() - .map(|x| *x.key()) - .collect::>(), - ) .field("dump_sessions", &self.dump_sessions()) .finish() } @@ -1101,7 +1164,6 @@ impl RouteSessionManager { RouteSessionManager { service_impl: Arc::downgrade(&service_impl), peer_rpc: Arc::downgrade(&peer_rpc), - session_tasks: Arc::new(DashMap::new()), sync_now_broadcast: tokio::sync::broadcast::channel(100).0, } @@ -1143,7 +1205,6 @@ impl RouteSessionManager { fn stop_session(&self, peer_id: PeerId) -> Result<(), Error> { tracing::warn!(?peer_id, "stop ospf sync session"); - self.session_tasks.remove(&peer_id); let Some(service_impl) = self.service_impl.upgrade() else { return Err(Error::Stopped); }; @@ -1151,24 +1212,15 @@ impl RouteSessionManager { Ok(()) } - fn start_session(&self, peer_id: PeerId) -> Result, Error> { - let Some(service_impl) = self.service_impl.upgrade() else { - return Err(Error::Stopped); - }; - - tracing::warn!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session"); - - let mut tasks = JoinSet::new(); - tasks.spawn(Self::session_task( - self.peer_rpc.clone(), - self.service_impl.clone(), - peer_id, - self.sync_now_broadcast.subscribe(), - )); - - let session = service_impl.get_or_create_session(peer_id); - self.session_tasks.insert(peer_id, tasks); - Ok(session) + fn start_session_task(&self, session: &Arc) { + if !session.task.is_running() { + session.task.set_task(tokio::spawn(Self::session_task( + self.peer_rpc.clone(), + self.service_impl.clone(), + session.dst_peer_id, + self.sync_now_broadcast.subscribe(), + ))); + } } fn get_or_start_session(&self, peer_id: PeerId) -> Result, Error> { @@ -1176,11 +1228,11 @@ impl RouteSessionManager { return Err(Error::Stopped); }; - if let Some(session) = service_impl.get_session(peer_id) { - return Ok(session); - } + tracing::info!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session"); - self.start_session(peer_id) + let session = service_impl.get_or_create_session(peer_id); + self.start_session_task(&session); + Ok(session) } #[tracing::instrument(skip(self))] @@ -1267,9 +1319,10 @@ impl RouteSessionManager { // clear sessions that are neither dst_initiator or we_are_initiator. for peer_id in session_peers.iter() { if let Some(session) = service_impl.get_session(*peer_id) { - if session.dst_is_initiator.load(Ordering::Relaxed) + if (session.dst_is_initiator.load(Ordering::Relaxed) || session.we_are_initiator.load(Ordering::Relaxed) - || session.need_sync_initiator_info.load(Ordering::Relaxed) + || session.need_sync_initiator_info.load(Ordering::Relaxed)) + && session.task.is_running() { continue; } @@ -1283,10 +1336,11 @@ impl RouteSessionManager { } fn list_session_peers(&self) -> Vec { - self.session_tasks - .iter() - .map(|x| *x.key()) - .collect::>() + let Some(service_impl) = self.service_impl.upgrade() else { + return vec![]; + }; + + service_impl.list_session_peers() } fn dump_sessions(&self) -> Result { @@ -1296,10 +1350,12 @@ impl RouteSessionManager { let mut ret = format!("my_peer_id: {:?}\n", service_impl.my_peer_id); for item in service_impl.sessions.iter() { - ret += format!(" session: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, need_sync_initiator_info: {:?}\n", - item.key(), item.value().we_are_initiator.load(Ordering::Relaxed), - item.value().dst_is_initiator.load(Ordering::Relaxed), - item.value().need_sync_initiator_info.load(Ordering::Relaxed)).as_str(); + ret += format!( + " session: {}, {}\n", + item.key(), + item.value().short_debug_string() + ) + .as_str(); } Ok(ret.to_string()) @@ -1582,8 +1638,9 @@ mod tests { assert_eq!(2, r_a.service_impl.synced_route_info.peer_infos.len()); assert_eq!(2, r_b.service_impl.synced_route_info.peer_infos.len()); - assert_eq!(1, r_a.session_mgr.session_tasks.len()); - assert_eq!(1, r_b.session_mgr.session_tasks.len()); + for s in r_a.service_impl.sessions.iter() { + assert!(s.value().task.is_running()); + } assert_eq!( r_a.service_impl @@ -1619,7 +1676,12 @@ mod tests { Duration::from_secs(5), ) .await; - assert_eq!(0, r_a.session_mgr.session_tasks.len()); + + wait_for_condition( + || async { r_a.service_impl.sessions.is_empty() }, + Duration::from_secs(5), + ) + .await; } #[tokio::test] @@ -1687,11 +1749,15 @@ mod tests { connect_peer_manager(p_e.clone(), last_p.clone()).await; wait_for_condition( - || async { r_e.session_mgr.session_tasks.len() == 1 }, + || async { r_e.session_mgr.list_session_peers().len() == 1 }, Duration::from_secs(3), ) .await; + for s in r_e.service_impl.sessions.iter() { + assert!(s.value().task.is_running()); + } + tokio::time::sleep(Duration::from_secs(2)).await; check_rpc_counter(&r_e, last_p.my_peer_id(), 2, 2);