mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-13 21:27:25 +08:00
zero copy tunnel (#55)
make tunnel zero copy, for better performance. remove most of the locks in io path. introduce quic tunnel prepare for encryption
This commit is contained in:
539
easytier/src/tunnel/common.rs
Normal file
539
easytier/src/tunnel/common.rs
Normal file
@@ -0,0 +1,539 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
task::{ready, Poll},
|
||||
};
|
||||
|
||||
use futures::{stream::FuturesUnordered, Future, Sink, Stream};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::{poll_read_buf, poll_write_buf};
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE},
|
||||
};
|
||||
|
||||
use super::{
|
||||
buf::BufList,
|
||||
packet_def::{TCPTunnelHeader, ZCPacketType, TCP_TUNNEL_HEADER_SIZE},
|
||||
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
||||
};
|
||||
|
||||
pub struct TunnelWrapper<R, W> {
|
||||
reader: Arc<Mutex<Option<R>>>,
|
||||
writer: Arc<Mutex<Option<W>>>,
|
||||
info: Option<TunnelInfo>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
|
||||
impl<R, W> TunnelWrapper<R, W> {
|
||||
pub fn new(reader: R, writer: W, info: Option<TunnelInfo>) -> Self {
|
||||
Self::new_with_associate_data(reader, writer, info, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
reader: R,
|
||||
writer: W,
|
||||
info: Option<TunnelInfo>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
TunnelWrapper {
|
||||
reader: Arc::new(Mutex::new(Some(reader))),
|
||||
writer: Arc::new(Mutex::new(Some(writer))),
|
||||
info,
|
||||
associate_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, W> Tunnel for TunnelWrapper<R, W>
|
||||
where
|
||||
R: ZCPacketStream + Send + 'static,
|
||||
W: ZCPacketSink + Send + 'static,
|
||||
{
|
||||
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
||||
let reader = self.reader.lock().unwrap().take().unwrap();
|
||||
let writer = self.writer.lock().unwrap().take().unwrap();
|
||||
(Box::pin(reader), Box::pin(writer))
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.info.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// a length delimited codec for async reader
|
||||
pin_project! {
|
||||
pub struct FramedReader<R> {
|
||||
#[pin]
|
||||
reader: R,
|
||||
buf: BytesMut,
|
||||
state: FrameReaderState,
|
||||
max_packet_size: usize,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
}
|
||||
|
||||
// usize means the size remaining to read
|
||||
enum FrameReaderState {
|
||||
ReadingHeader(usize),
|
||||
ReadingBody(usize),
|
||||
}
|
||||
|
||||
impl<R> FramedReader<R> {
|
||||
pub fn new(reader: R, max_packet_size: usize) -> Self {
|
||||
Self::new_with_associate_data(reader, max_packet_size, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
reader: R,
|
||||
max_packet_size: usize,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
FramedReader {
|
||||
reader,
|
||||
buf: BytesMut::with_capacity(max_packet_size),
|
||||
state: FrameReaderState::ReadingHeader(4),
|
||||
max_packet_size,
|
||||
associate_data,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_one_packet(buf: &mut BytesMut) -> Option<ZCPacket> {
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
|
||||
// header is not complete
|
||||
return None;
|
||||
}
|
||||
|
||||
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
|
||||
let body_len = header.len.get() as usize;
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
|
||||
// body is not complete
|
||||
return None;
|
||||
}
|
||||
|
||||
// extract one packet
|
||||
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
|
||||
Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for FramedReader<R>
|
||||
where
|
||||
R: AsyncRead + Send + 'static + Unpin,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
let mut self_mut = self.project();
|
||||
|
||||
loop {
|
||||
while let Some(packet) = Self::extract_one_packet(self_mut.buf) {
|
||||
return Poll::Ready(Some(Ok(packet)));
|
||||
}
|
||||
|
||||
reserve_buf(
|
||||
&mut self_mut.buf,
|
||||
*self_mut.max_packet_size,
|
||||
*self_mut.max_packet_size * 64,
|
||||
);
|
||||
|
||||
match ready!(poll_read_buf(
|
||||
self_mut.reader.as_mut(),
|
||||
cx,
|
||||
&mut self_mut.buf
|
||||
)) {
|
||||
Ok(size) => {
|
||||
if size == 0 {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Poll::Ready(Some(Err(TunnelError::IOError(e))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct FramedWriter<W> {
|
||||
#[pin]
|
||||
writer: W,
|
||||
sending_bufs: BufList<Bytes>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> FramedWriter<W> {
|
||||
pub fn new(writer: W) -> Self {
|
||||
Self::new_with_associate_data(writer, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
writer: W,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
FramedWriter {
|
||||
writer,
|
||||
sending_bufs: BufList::new(),
|
||||
associate_data: associate_data,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_buffer_count(&self) -> usize {
|
||||
64
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> Sink<SinkItem> for FramedWriter<W>
|
||||
where
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
type Error = TunnelError;
|
||||
|
||||
fn poll_ready(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let max_buffer_count = self.max_buffer_count();
|
||||
if self.sending_bufs.bufs_cnt() >= max_buffer_count {
|
||||
self.as_mut().poll_flush(cx)
|
||||
} else {
|
||||
tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, mut item: ZCPacket) -> Result<(), Self::Error> {
|
||||
let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len();
|
||||
let Some(header) = item.mut_tcp_tunnel_header() else {
|
||||
return Err(TunnelError::InvalidPacket("packet too short".to_string()));
|
||||
};
|
||||
header.len.set(tcp_len.try_into().unwrap());
|
||||
|
||||
let item = item.into_bytes(ZCPacketType::TCP);
|
||||
self.project().sending_bufs.push(item);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let mut pinned = self.project();
|
||||
let mut remaining = pinned.sending_bufs.remaining();
|
||||
while remaining != 0 {
|
||||
let n = ready!(poll_write_buf(
|
||||
pinned.writer.as_mut(),
|
||||
cx,
|
||||
pinned.sending_bufs
|
||||
))?;
|
||||
if n == 0 {
|
||||
return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"failed to \
|
||||
write frame to transport",
|
||||
))));
|
||||
}
|
||||
remaining -= n;
|
||||
}
|
||||
|
||||
tracing::trace!(?remaining, "flushed");
|
||||
|
||||
// Try flushing the underlying IO
|
||||
ready!(pinned.writer.poll_flush(cx))?;
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
ready!(self.project().writer.poll_shutdown(cx))?;
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
||||
return None;
|
||||
}
|
||||
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
||||
for iface in ifaces {
|
||||
for addr in iface.addr {
|
||||
if addr.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(?local_ip, "can not find interface name by ip");
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
bind_dev: Option<String>,
|
||||
) -> Result<(), TunnelError> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
||||
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
socket2_socket.set_reuse_address(true)?;
|
||||
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
|
||||
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket2_socket.set_reuse_port(true)?;
|
||||
|
||||
if bind_addr.ip().is_unspecified() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// win can handle this with bind correctly
|
||||
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
// use IP_BOUND_IF to bind device
|
||||
unsafe {
|
||||
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
||||
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
|
||||
E: std::error::Error + Into<TunnelError> + Send + Sync + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_sokcet2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
if buf.capacity() < min_size {
|
||||
buf.reserve(max_size);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnel::{packet_def::ZCPacket, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
if !once {
|
||||
recv.forward(send).await.unwrap();
|
||||
} else {
|
||||
let Some(ret) = recv.next().await else {
|
||||
assert!(false, "recv error");
|
||||
return;
|
||||
};
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::debug!(?ret, "recv error");
|
||||
return;
|
||||
}
|
||||
|
||||
let res = ret.unwrap();
|
||||
tracing::debug!(?res, "recv a msg, try echo back");
|
||||
send.send(res).await.unwrap();
|
||||
}
|
||||
|
||||
tracing::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
|
||||
|
||||
drop(send);
|
||||
|
||||
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let (recv, mut send) = tunnel.split();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let r = tokio::spawn(async move {
|
||||
let now = Instant::now();
|
||||
let count = recv
|
||||
.try_fold(0usize, |mut ret, _| async move {
|
||||
ret += 1;
|
||||
Ok(ret)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"bps: {}",
|
||||
(count / 1024) * 4 / now.elapsed().as_secs() as usize
|
||||
);
|
||||
});
|
||||
|
||||
let now = Instant::now();
|
||||
while now.elapsed().as_secs() < 10 {
|
||||
// send.feed(item)
|
||||
let item = ZCPacket::new_with_payload(send_buf.as_ref());
|
||||
let _ = send.feed(item).await.unwrap();
|
||||
}
|
||||
|
||||
drop(send);
|
||||
drop(connector);
|
||||
drop(tunnel);
|
||||
|
||||
tracing::warn!("wait for recv to finish...");
|
||||
|
||||
let _ = tokio::join!(r);
|
||||
|
||||
lis.abort();
|
||||
let _ = tokio::join!(lis);
|
||||
}
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.pretty()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user