Perf improve (#59)

* improve perf

* fix forward
This commit is contained in:
Sijie.Sun
2024-04-26 23:02:07 +08:00
committed by GitHub
parent 096af6aa45
commit 69651ae3fd
16 changed files with 370 additions and 162 deletions

View File

@@ -3,11 +3,7 @@ use std::sync::Arc;
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use tokio::{
select,
sync::{mpsc, Mutex},
task::JoinHandle,
};
use tokio::{select, sync::mpsc, task::JoinHandle};
use tracing::Instrument;
@@ -25,7 +21,7 @@ use crate::{
tunnel::packet_def::ZCPacket,
};
type ArcPeerConn = Arc<Mutex<PeerConn>>;
type ArcPeerConn = Arc<PeerConn>;
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
pub struct Peer {
@@ -73,7 +69,7 @@ impl Peer {
if let Some((_, conn)) = conns_copy.remove(&ret) {
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
conn.lock().await.get_conn_info(),
conn.get_conn_info(),
));
}
}
@@ -108,12 +104,11 @@ impl Peer {
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
conn.set_close_event_sender(self.close_event_sender.clone());
conn.start_recv_loop(self.packet_recv_chan.clone());
conn.start_recv_loop(self.packet_recv_chan.clone()).await;
conn.start_pingpong();
self.global_ctx
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
self.conns
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
self.conns.insert(conn.get_conn_id(), Arc::new(conn));
}
async fn select_conn(&self) -> Option<ArcPeerConn> {
@@ -128,7 +123,7 @@ impl Peer {
}
let conn = conn.unwrap().clone();
self.default_conn_id.store(conn.lock().await.get_conn_id());
self.default_conn_id.store(conn.get_conn_id());
Some(conn)
}
@@ -136,10 +131,7 @@ impl Peer {
let Some(conn) = self.select_conn().await else {
return Err(Error::PeerNoConnectionError(self.peer_node_id));
};
let conn_clone = conn.clone();
drop(conn);
conn_clone.lock().await.send_msg(msg).await?;
conn.send_msg(msg).await?;
Ok(())
}
@@ -162,7 +154,7 @@ impl Peer {
let mut ret = Vec::new();
for conn in conns {
ret.push(conn.lock().await.get_conn_info());
ret.push(conn.get_conn_info());
}
ret
}

View File

@@ -279,19 +279,10 @@ impl PeerManager {
let from_peer_id = hdr.from_peer_id.get();
let to_peer_id = hdr.to_peer_id.get();
if to_peer_id != my_peer_id {
log::trace!(
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
to_peer_id,
my_peer_id
);
tracing::trace!(?to_peer_id, ?my_peer_id, "need forward");
let ret = peers.send_msg(ret, to_peer_id).await;
if ret.is_err() {
log::error!(
"forward packet error: {:?}, dst: {:?}, from: {:?}",
ret,
to_peer_id,
from_peer_id
);
tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error");
}
} else {
let mut processed = false;
@@ -516,15 +507,11 @@ impl PeerManager {
msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8);
if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await {
if let Err(e) = self.peers.send_msg_directly(msg.clone(), gateway).await {
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
errs.push(e);
}
} else if self.foreign_network_client.has_next_hop(*peer_id) {
if let Err(e) = self
.foreign_network_client
.send_msg(msg.clone(), *peer_id)
.await
{
if let Err(e) = self.foreign_network_client.send_msg(msg, *peer_id).await {
errs.push(e);
}
}
@@ -622,12 +609,23 @@ impl PeerManager {
#[cfg(test)]
mod tests {
use std::{fmt::Debug, sync::Arc};
use crate::{
connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
connector::{
create_connector_by_url, udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
},
instance::listeners::get_listener_by_url,
peers::{
peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient},
tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
},
rpc::NatType,
tunnel::{TunnelConnector, TunnelListener},
};
use super::PeerManager;
#[tokio::test]
async fn drop_peer_manager() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
@@ -659,4 +657,98 @@ mod tests {
)
.await;
}
async fn connect_peer_manager_with<C: TunnelConnector + Debug + 'static, L: TunnelListener>(
client_mgr: Arc<PeerManager>,
server_mgr: &Arc<PeerManager>,
mut client: C,
server: &mut L,
) {
server.listen().await.unwrap();
tokio::spawn(async move {
client.set_bind_addrs(vec![]);
client_mgr.try_connect(client).await.unwrap();
});
server_mgr
.add_client_tunnel(server.accept().await.unwrap())
.await
.unwrap();
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial(forward_packet_test)]
async fn forward_packet(
#[values("tcp", "udp", "wg", "quic")] proto1: &str,
#[values("tcp", "udp", "wg", "quic")] proto2: &str,
) {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_a.get_peer_rpc_mgr().run_service(
100,
MockService {
prefix: "hello a".to_owned(),
}
.serve(),
);
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_c.get_peer_rpc_mgr().run_service(
100,
MockService {
prefix: "hello c".to_owned(),
}
.serve(),
);
let mut listener1 = get_listener_by_url(
&format!("{}://0.0.0.0:31013", proto1).parse().unwrap(),
peer_mgr_b.get_global_ctx(),
)
.unwrap();
let connector1 = create_connector_by_url(
format!("{}://127.0.0.1:31013", proto1).as_str(),
&peer_mgr_a.get_global_ctx(),
)
.await
.unwrap();
connect_peer_manager_with(peer_mgr_a.clone(), &peer_mgr_b, connector1, &mut listener1)
.await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
let mut listener2 = get_listener_by_url(
&format!("{}://0.0.0.0:31014", proto2).parse().unwrap(),
peer_mgr_c.get_global_ctx(),
)
.unwrap();
let connector2 = create_connector_by_url(
format!("{}://127.0.0.1:31014", proto2).as_str(),
&peer_mgr_b.get_global_ctx(),
)
.await
.unwrap();
connect_peer_manager_with(peer_mgr_b.clone(), &peer_mgr_c, connector2, &mut listener2)
.await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let ret = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(100, peer_mgr_c.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await
.unwrap();
assert_eq!(ret, "hello c abc");
}
}

View File

@@ -389,7 +389,7 @@ impl PeerRpcManager {
}
#[cfg(test)]
mod tests {
pub mod tests {
use std::{pin::Pin, sync::Arc};
use futures::{SinkExt, StreamExt};
@@ -415,8 +415,8 @@ mod tests {
}
#[derive(Clone)]
struct MockService {
prefix: String,
pub struct MockService {
pub prefix: String,
}
#[tarpc::server]

View File

@@ -13,7 +13,7 @@ use futures::{SinkExt, StreamExt, TryFutureExt};
use prost::Message;
use tokio::{
sync::{broadcast, mpsc},
sync::{broadcast, mpsc, Mutex},
task::JoinSet,
time::{timeout, Duration},
};
@@ -52,9 +52,9 @@ pub struct PeerConn {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
tunnel: Box<dyn Any + Send + 'static>,
tunnel: Arc<Mutex<Box<dyn Any + Send + 'static>>>,
sink: MpscTunnelSender,
recv: Option<Pin<Box<dyn ZCPacketStream>>>,
recv: Arc<Mutex<Option<Pin<Box<dyn ZCPacketStream>>>>>,
tunnel_info: Option<TunnelInfo>,
tasks: JoinSet<Result<(), TunnelError>>,
@@ -98,9 +98,9 @@ impl PeerConn {
my_peer_id,
global_ctx,
tunnel: Box::new(mpsc_tunnel),
tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))),
sink,
recv: Some(recv),
recv: Arc::new(Mutex::new(Some(recv))),
tunnel_info,
tasks: JoinSet::new(),
@@ -121,7 +121,8 @@ impl PeerConn {
}
async fn wait_handshake(&mut self) -> Result<HandshakeRequest, Error> {
let recv = self.recv.as_mut().unwrap();
let mut locked = self.recv.lock().await;
let recv = locked.as_mut().unwrap();
let Some(rsp) = recv.next().await else {
return Err(Error::WaitRespError(
"conn closed during wait handshake response".to_owned(),
@@ -199,8 +200,8 @@ impl PeerConn {
self.info.is_some()
}
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.recv.take().unwrap();
pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.recv.lock().await.take().unwrap();
let sink = self.sink.clone();
let mut sender = PollSender::new(packet_recv_chan.clone());
let close_event_sender = self.close_event_sender.clone().unwrap();
@@ -286,7 +287,7 @@ impl PeerConn {
});
}
pub async fn send_msg(&mut self, msg: ZCPacket) -> Result<(), Error> {
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
Ok(self.sink.send(msg).await?)
}
@@ -398,7 +399,9 @@ mod tests {
);
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
s_peer
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
.await;
assert!(c_ret.is_ok());
assert!(s_ret.is_ok());
@@ -406,7 +409,9 @@ mod tests {
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
c_peer.set_close_event_sender(close_send);
c_peer.start_pingpong();
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
c_peer
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
.await;
// wait 5s, conn should not be disconnected
tokio::time::sleep(Duration::from_secs(15)).await;