use path with least cost if hop count is same

This commit is contained in:
sijie.sun
2024-05-13 19:02:05 +08:00
parent 29365c39ed
commit 3e6b1ac384
3 changed files with 175 additions and 94 deletions

View File

@@ -10,6 +10,11 @@ use std::{
};
use dashmap::DashMap;
use petgraph::{
algo::{all_simple_paths, astar, dijkstra},
graph::NodeIndex,
Directed, Graph,
};
use serde::{Deserialize, Serialize};
use tokio::{select, sync::Mutex, task::JoinSet};
@@ -367,11 +372,15 @@ impl SyncedRouteInfo {
}
}
type PeerGraph = Graph<PeerId, i32, Directed>;
type PeerIdToNodexIdxMap = DashMap<PeerId, NodeIndex>;
type NextHopMap = DashMap<PeerId, (PeerId, i32)>;
// computed with SyncedRouteInfo. used to get next hop.
#[derive(Debug)]
struct RouteTable {
peer_infos: DashMap<PeerId, RoutePeerInfo>,
next_hop_map: DashMap<PeerId, (PeerId, i32)>,
next_hop_map: NextHopMap,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
@@ -400,41 +409,119 @@ impl RouteTable {
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
}
fn find_path_with_least_cost<T: RouteCostCalculatorInterface>(
my_peer_id: PeerId,
peer_id: PeerId,
fn build_peer_graph_from_synced_info<T: RouteCostCalculatorInterface>(
peers: Vec<PeerId>,
synced_info: &SyncedRouteInfo,
cost_calc: &mut T,
) -> Option<Vec<PeerId>> {
let Some((path, _cost)): Option<(Vec<u32>, i32)> = pathfinding::prelude::dijkstra(
&my_peer_id,
|src_peer| {
synced_info
.get_connected_peers(*src_peer)
.unwrap_or_else(|| BTreeSet::new())
.into_iter()
.map(|dst_peer| {
let cost = cost_calc.calculate_cost(*src_peer, dst_peer);
(dst_peer, cost)
})
.collect::<BTreeSet<_>>()
},
|x| *x == peer_id,
) else {
return None;
};
if !path.is_empty() {
Some(path)
} else {
None
) -> (PeerGraph, PeerIdToNodexIdxMap) {
let mut graph: PeerGraph = Graph::new();
let peer_id_to_node_index = PeerIdToNodexIdxMap::new();
for peer_id in peers.iter() {
peer_id_to_node_index.insert(*peer_id, graph.add_node(*peer_id));
}
for peer_id in peers.iter() {
let connected_peers = synced_info
.get_connected_peers(*peer_id)
.unwrap_or(BTreeSet::new());
for dst_peer_id in connected_peers.iter() {
let Some(dst_idx) = peer_id_to_node_index.get(dst_peer_id) else {
continue;
};
graph.add_edge(
*peer_id_to_node_index.get(&peer_id).unwrap(),
*dst_idx,
cost_calc.calculate_cost(*peer_id, *dst_peer_id),
);
}
}
(graph, peer_id_to_node_index)
}
fn gen_next_hop_map_with_least_hop<T: RouteCostCalculatorInterface>(
my_peer_id: PeerId,
graph: &PeerGraph,
idx_map: &PeerIdToNodexIdxMap,
cost_calc: &mut T,
) -> NextHopMap {
let res = dijkstra(&graph, *idx_map.get(&my_peer_id).unwrap(), None, |_| 1);
let next_hop_map = NextHopMap::new();
for (node_idx, cost) in res.iter() {
if *cost == 0 {
continue;
}
let all_paths = all_simple_paths::<Vec<_>, _>(
graph,
*idx_map.get(&my_peer_id).unwrap(),
*node_idx,
*cost - 1,
Some(*cost - 1),
)
.collect::<Vec<_>>();
assert!(!all_paths.is_empty());
// find a path with least cost.
let mut min_cost = i32::MAX;
let mut min_path = Vec::new();
for path in all_paths.iter() {
let mut cost = 0;
for i in 0..path.len() - 1 {
let src_peer_id = *graph.node_weight(path[i]).unwrap();
let dst_peer_id = *graph.node_weight(path[i + 1]).unwrap();
cost += cost_calc.calculate_cost(src_peer_id, dst_peer_id);
}
if cost <= min_cost {
min_cost = cost;
min_path = path.clone();
}
}
next_hop_map.insert(
*graph.node_weight(*node_idx).unwrap(),
(*graph.node_weight(min_path[1]).unwrap(), *cost as i32),
);
}
next_hop_map
}
fn gen_next_hop_map_with_least_cost(
my_peer_id: PeerId,
graph: &PeerGraph,
idx_map: &PeerIdToNodexIdxMap,
) -> NextHopMap {
let next_hop_map = NextHopMap::new();
for item in idx_map.iter() {
if *item.key() == my_peer_id {
continue;
}
let dst_peer_node_idx = *item.value();
let Some((cost, path)) = astar::astar(
graph,
*idx_map.get(&my_peer_id).unwrap(),
|node_idx| node_idx == dst_peer_node_idx,
|e| *e.weight(),
|_| 0,
) else {
continue;
};
next_hop_map.insert(*item.key(), (*graph.node_weight(path[1]).unwrap(), cost));
}
next_hop_map
}
fn build_from_synced_info<T: RouteCostCalculatorInterface>(
&self,
my_peer_id: PeerId,
synced_info: &SyncedRouteInfo,
policy: NextHopPolicy,
mut cost_calc: T,
) {
// build peer_infos
@@ -453,21 +540,20 @@ impl RouteTable {
// build next hop map
self.next_hop_map.clear();
self.next_hop_map.insert(my_peer_id, (my_peer_id, 0));
for item in self.peer_infos.iter() {
let peer_id = *item.key();
if peer_id == my_peer_id {
continue;
}
let path =
Self::find_path_with_least_cost(my_peer_id, peer_id, synced_info, &mut cost_calc);
if let Some(path) = path {
assert!(path.len() >= 2);
self.next_hop_map
.insert(peer_id, (path[1], (path.len() - 1) as i32));
}
let (graph, idx_map) = Self::build_peer_graph_from_synced_info(
self.peer_infos.iter().map(|x| *x.key()).collect(),
&synced_info,
&mut cost_calc,
);
let next_hop_map = if matches!(policy, NextHopPolicy::LeastHop) {
Self::gen_next_hop_map_with_least_hop(my_peer_id, &graph, &idx_map, &mut cost_calc)
} else {
Self::gen_next_hop_map_with_least_cost(my_peer_id, &graph, &idx_map)
};
for item in next_hop_map.iter() {
self.next_hop_map.insert(*item.key(), *item.value());
}
// build graph
// build ipv4_peer_id_map, cidr_peer_id_map
self.ipv4_peer_id_map.clear();
@@ -695,21 +781,20 @@ impl PeerRouteServiceImpl {
}
fn update_route_table(&self) {
let mut calc_locked = self.cost_calculator.lock().unwrap();
calc_locked.as_mut().unwrap().begin_update();
self.route_table.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
DefaultRouteCostCalculator::default(),
NextHopPolicy::LeastHop,
calc_locked.as_mut().unwrap(),
);
let mut calc_locked = self.cost_calculator.lock().unwrap();
if calc_locked.is_none() {
return;
}
calc_locked.as_mut().unwrap().begin_update();
self.route_table_with_cost.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
NextHopPolicy::LeastCost,
calc_locked.as_mut().unwrap(),
);
calc_locked.as_mut().unwrap().end_update();
@@ -1710,17 +1795,21 @@ mod tests {
let p_a = create_mock_pmgr().await;
let p_b = create_mock_pmgr().await;
let p_c = create_mock_pmgr().await;
let p_d = create_mock_pmgr().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_c.clone(), p_b.clone()).await;
connect_peer_manager(p_a.clone(), p_c.clone()).await;
connect_peer_manager(p_d.clone(), p_b.clone()).await;
connect_peer_manager(p_d.clone(), p_c.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
let _r_a = create_mock_route(p_a.clone()).await;
let _r_b = create_mock_route(p_b.clone()).await;
let r_c = create_mock_route(p_c.clone()).await;
let _r_c = create_mock_route(p_c.clone()).await;
let r_d = create_mock_route(p_d.clone()).await;
// in normal mode, packet from p_c should directly forward to p_a
wait_for_condition(
|| async { r_c.get_next_hop(p_a.my_peer_id()).await == Some(p_a.my_peer_id()) },
|| async { r_d.get_next_hop(p_a.my_peer_id()).await != None },
Duration::from_secs(5),
)
.await;
@@ -1729,29 +1818,57 @@ mod tests {
p_a_peer_id: PeerId,
p_b_peer_id: PeerId,
p_c_peer_id: PeerId,
p_d_peer_id: PeerId,
}
impl RouteCostCalculatorInterface for TestCostCalculator {
fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 {
if src == self.p_c_peer_id && dst == self.p_a_peer_id {
if src == self.p_d_peer_id && dst == self.p_b_peer_id {
return 100;
}
if src == self.p_d_peer_id && dst == self.p_c_peer_id {
return 1;
}
if src == self.p_c_peer_id && dst == self.p_a_peer_id {
return 101;
}
if src == self.p_b_peer_id && dst == self.p_a_peer_id {
return 1;
}
if src == self.p_c_peer_id && dst == self.p_b_peer_id {
return 2;
}
1
}
}
r_c.set_route_cost_fn(Box::new(TestCostCalculator {
r_d.set_route_cost_fn(Box::new(TestCostCalculator {
p_a_peer_id: p_a.my_peer_id(),
p_b_peer_id: p_b.my_peer_id(),
p_c_peer_id: p_c.my_peer_id(),
p_d_peer_id: p_d.my_peer_id(),
}))
.await;
// after set cost, packet from p_c should forward to p_b first
wait_for_condition(
|| async {
r_c.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost)
r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost)
.await
== Some(p_c.my_peer_id())
},
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async {
r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastHop)
.await
== Some(p_b.my_peer_id())
},