use std::{ any::Any, net::{IpAddr, SocketAddr}, pin::Pin, sync::{Arc, Mutex}, task::{ready, Poll}, }; use futures::{stream::FuturesUnordered, Future, Sink, Stream}; use network_interface::NetworkInterfaceConfig as _; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio_stream::StreamExt; use tokio_util::io::poll_write_buf; use zerocopy::FromBytes as _; use crate::{ rpc::TunnelInfo, tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE}, }; use super::{ buf::BufList, packet_def::{TCPTunnelHeader, ZCPacketType, TCP_TUNNEL_HEADER_SIZE}, SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, }; pub struct TunnelWrapper { reader: Arc>>, writer: Arc>>, info: Option, associate_data: Option>, } impl TunnelWrapper { pub fn new(reader: R, writer: W, info: Option) -> Self { Self::new_with_associate_data(reader, writer, info, None) } pub fn new_with_associate_data( reader: R, writer: W, info: Option, associate_data: Option>, ) -> Self { TunnelWrapper { reader: Arc::new(Mutex::new(Some(reader))), writer: Arc::new(Mutex::new(Some(writer))), info, associate_data, } } } impl Tunnel for TunnelWrapper where R: ZCPacketStream + Send + 'static, W: ZCPacketSink + Send + 'static, { fn split(&self) -> (Pin>, Pin>) { let reader = self.reader.lock().unwrap().take().unwrap(); let writer = self.writer.lock().unwrap().take().unwrap(); (Box::pin(reader), Box::pin(writer)) } fn info(&self) -> Option { self.info.clone() } } // a length delimited codec for async reader pin_project! { pub struct FramedReader { #[pin] reader: R, buf: BytesMut, state: FrameReaderState, max_packet_size: usize, associate_data: Option>, } } // usize means the size remaining to read enum FrameReaderState { ReadingHeader(usize), ReadingBody(usize), } impl FramedReader { pub fn new(reader: R, max_packet_size: usize) -> Self { Self::new_with_associate_data(reader, max_packet_size, None) } pub fn new_with_associate_data( reader: R, max_packet_size: usize, associate_data: Option>, ) -> Self { FramedReader { reader, buf: BytesMut::with_capacity(max_packet_size), state: FrameReaderState::ReadingHeader(4), max_packet_size, associate_data, } } fn extract_one_packet(buf: &mut BytesMut) -> Option { if buf.len() < TCP_TUNNEL_HEADER_SIZE { // header is not complete return None; } let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap(); let body_len = header.len.get() as usize; if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len { // body is not complete return None; } // extract one packet let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len); Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)) } } impl Stream for FramedReader where R: AsyncRead + Send + 'static + Unpin, { type Item = StreamItem; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let mut self_mut = self.project(); loop { while let Some(packet) = Self::extract_one_packet(self_mut.buf) { return Poll::Ready(Some(Ok(packet))); } reserve_buf( &mut self_mut.buf, *self_mut.max_packet_size, *self_mut.max_packet_size * 32, ); let cap = self_mut.buf.capacity() - self_mut.buf.len(); let buf = self_mut.buf.chunk_mut().as_mut_ptr(); let buf = unsafe { std::slice::from_raw_parts_mut(buf, cap) }; let mut buf = ReadBuf::new(buf); let ret = ready!(self_mut.reader.as_mut().poll_read(cx, &mut buf)); let len = buf.filled().len(); unsafe { self_mut.buf.advance_mut(len) }; match ret { Ok(_) => { if len == 0 { return Poll::Ready(None); } } Err(e) => { return Poll::Ready(Some(Err(TunnelError::IOError(e)))); } } } } } pub trait ZCPacketToBytes { fn into_bytes(&self, zc_packet: ZCPacket) -> Result; } pub struct TcpZCPacketToBytes; impl ZCPacketToBytes for TcpZCPacketToBytes { fn into_bytes(&self, item: ZCPacket) -> Result { let mut item = item.convert_type(ZCPacketType::TCP); let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len(); let Some(header) = item.mut_tcp_tunnel_header() else { return Err(TunnelError::InvalidPacket("packet too short".to_string())); }; header.len.set(tcp_len.try_into().unwrap()); Ok(item.into_bytes()) } } pin_project! { pub struct FramedWriter { #[pin] writer: W, sending_bufs: BufList, associate_data: Option>, converter: C, } } impl FramedWriter { fn max_buffer_count(&self) -> usize { 64 } } impl FramedWriter { pub fn new(writer: W) -> Self { Self::new_with_associate_data(writer, None) } pub fn new_with_associate_data( writer: W, associate_data: Option>, ) -> Self { FramedWriter { writer, sending_bufs: BufList::new(), associate_data, converter: TcpZCPacketToBytes {}, } } } impl FramedWriter { pub fn new_with_converter(writer: W, converter: C) -> Self { Self::new_with_converter_and_associate_data(writer, converter, None) } pub fn new_with_converter_and_associate_data( writer: W, converter: C, associate_data: Option>, ) -> Self { FramedWriter { writer, sending_bufs: BufList::new(), associate_data, converter, } } } impl Sink for FramedWriter where W: AsyncWrite + Send + 'static, C: ZCPacketToBytes + Send + 'static, { type Error = TunnelError; fn poll_ready( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let max_buffer_count = self.max_buffer_count(); if self.sending_bufs.bufs_cnt() >= max_buffer_count { self.as_mut().poll_flush(cx) } else { tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send"); Poll::Ready(Ok(())) } } fn start_send(self: Pin<&mut Self>, item: ZCPacket) -> Result<(), Self::Error> { let pinned = self.project(); pinned.sending_bufs.push(pinned.converter.into_bytes(item)?); Ok(()) } fn poll_flush( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let mut pinned = self.project(); let mut remaining = pinned.sending_bufs.remaining(); while remaining != 0 { let n = ready!(poll_write_buf( pinned.writer.as_mut(), cx, pinned.sending_bufs ))?; if n == 0 { return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new( std::io::ErrorKind::WriteZero, "failed to \ write frame to transport", )))); } remaining -= n; } tracing::trace!(?remaining, "flushed"); // Try flushing the underlying IO ready!(pinned.writer.poll_flush(cx))?; Poll::Ready(Ok(())) } fn poll_close( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; ready!(self.project().writer.poll_shutdown(cx))?; Poll::Ready(Ok(())) } } pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option { if local_ip.is_unspecified() || local_ip.is_multicast() { return None; } let ifaces = network_interface::NetworkInterface::show().ok()?; for iface in ifaces { for addr in iface.addr { if addr.ip() == *local_ip { return Some(iface.name); } } } tracing::error!(?local_ip, "can not find interface name by ip"); None } pub(crate) fn setup_sokcet2_ext( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, bind_dev: Option, ) -> Result<(), TunnelError> { #[cfg(target_os = "windows")] { let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM); crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?; } if bind_addr.is_ipv6() { socket2_socket.set_only_v6(true)?; } socket2_socket.set_nonblocking(true)?; socket2_socket.set_reuse_address(true)?; socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?; // #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] // socket2_socket.set_reuse_port(true)?; if bind_addr.ip().is_unspecified() { return Ok(()); } // linux/mac does not use interface of bind_addr to send packet, so we need to bind device // win can handle this with bind correctly #[cfg(any(target_os = "ios", target_os = "macos"))] if let Some(dev_name) = bind_dev { // use IP_BOUND_IF to bind device unsafe { let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8); tracing::warn!(?dev_idx, ?dev_name, "bind device"); socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?; tracing::warn!(?dev_idx, ?dev_name, "bind device doen"); } } #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] if let Some(dev_name) = bind_dev { tracing::trace!(dev_name = ?dev_name, "bind device"); socket2_socket.bind_device(Some(dev_name.as_bytes()))?; } Ok(()) } pub(crate) async fn wait_for_connect_futures( mut futures: FuturesUnordered, ) -> Result where Fut: Future> + Send + Sync, E: std::error::Error + Into + Send + Sync + 'static, { // return last error let mut last_err = None; while let Some(ret) = futures.next().await { if let Err(e) = ret { last_err = Some(e.into()); } else { return ret.map_err(|e| e.into()); } } Err(last_err.unwrap_or(TunnelError::Shutdown)) } pub(crate) fn setup_sokcet2( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, ) -> Result<(), TunnelError> { setup_sokcet2_ext( socket2_socket, bind_addr, super::common::get_interface_name_by_ip(&bind_addr.ip()), ) } pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) { if buf.capacity() < min_size { buf.reserve(max_size); } } pub mod tests { use std::time::Instant; use futures::{Future, SinkExt, StreamExt, TryStreamExt}; use tokio_util::bytes::{BufMut, Bytes, BytesMut}; use crate::{ common::netns::NetNS, tunnel::{packet_def::ZCPacket, TunnelConnector, TunnelListener}, }; pub async fn _tunnel_echo_server(tunnel: Box, once: bool) { let (mut recv, mut send) = tunnel.split(); if !once { while let Some(item) = recv.next().await { let Ok(msg) = item else { continue; }; if let Err(_) = send.send(msg).await { break; } } } else { let Some(ret) = recv.next().await else { assert!(false, "recv error"); return; }; if ret.is_err() { tracing::debug!(?ret, "recv error"); return; } let res = ret.unwrap(); tracing::debug!(?res, "recv a msg, try echo back"); send.send(res).await.unwrap(); } let _ = send.flush().await; let _ = send.close().await; tracing::warn!("echo server exit..."); } pub(crate) async fn _tunnel_pingpong(listener: L, connector: C) where L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, { _tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await } pub(crate) async fn _tunnel_pingpong_netns( mut listener: L, mut connector: C, l_netns: NetNS, c_netns: NetNS, ) where L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, { l_netns .run_async(|| async { listener.listen().await.unwrap(); }) .await; let lis = tokio::spawn(async move { let ret = listener.accept().await.unwrap(); println!("accept: {:?}", ret.info()); assert_eq!( ret.info().unwrap().local_addr, listener.local_url().to_string() ); _tunnel_echo_server(ret, false).await }); let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap(); println!("connect: {:?}", tunnel.info()); assert_eq!( tunnel.info().unwrap().remote_addr, connector.remote_url().to_string() ); let (mut recv, mut send) = tunnel.split(); send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes())) .await .unwrap(); let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next()) .await .unwrap() .unwrap() .unwrap(); println!("echo back: {:?}", ret); assert_eq!(ret.payload(), Bytes::from("12345678abcdefg")); send.close().await.unwrap(); if ["udp", "wg"].contains(&connector.remote_url().scheme()) { lis.abort(); } else { // lis should finish in 1 second let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await; assert!(ret.is_ok()); } } pub(crate) async fn _tunnel_bench(mut listener: L, mut connector: C) where L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, { listener.listen().await.unwrap(); let lis = tokio::spawn(async move { let ret = listener.accept().await.unwrap(); _tunnel_echo_server(ret, false).await }); let tunnel = connector.connect().await.unwrap(); let (recv, mut send) = tunnel.split(); // prepare a 4k buffer with random data let mut send_buf = BytesMut::new(); for _ in 0..64 { send_buf.put_i128(rand::random::()); } let r = tokio::spawn(async move { let now = Instant::now(); let count = recv .try_fold(0usize, |mut ret, _| async move { ret += 1; Ok(ret) }) .await .unwrap(); println!( "bps: {}", (count / 1024) * 4 / now.elapsed().as_secs() as usize ); }); let now = Instant::now(); while now.elapsed().as_secs() < 10 { // send.feed(item) let item = ZCPacket::new_with_payload(send_buf.as_ref()); let _ = send.feed(item).await.unwrap(); } send.close().await.unwrap(); drop(send); drop(connector); drop(tunnel); tracing::warn!("wait for recv to finish..."); let _ = tokio::join!(r); lis.abort(); let _ = tokio::join!(lis); } pub fn enable_log() { let filter = tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into()) .from_env() .unwrap() .add_directive("tarpc=error".parse().unwrap()); tracing_subscriber::fmt::fmt() .pretty() .with_env_filter(filter) .init(); } pub async fn wait_for_condition(mut condition: F, timeout: std::time::Duration) -> () where F: FnMut() -> FRet + Send, FRet: Future, { let now = std::time::Instant::now(); while now.elapsed() < timeout { if condition().await { return; } tokio::time::sleep(std::time::Duration::from_millis(50)).await; } assert!(condition().await, "Timeout") } }