diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index 36aafb7..301efcd 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -132,7 +132,7 @@ impl StunClient { async fn wait_stun_response<'a, const N: usize>( &self, buf: &'a mut [u8; N], - tids: &Vec, + tids: &Vec, expected_ip_changed: bool, expected_port_changed: bool, stun_host: &SocketAddr, @@ -170,7 +170,7 @@ impl StunClient { if msg.class() != MessageClass::SuccessResponse || msg.method() != BINDING - || !tids.contains(&tid_to_u128(&msg.transaction_id())) + || !tids.contains(&tid_to_u32(&msg.transaction_id())) { continue; } @@ -239,7 +239,7 @@ impl StunClient { unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) }; let mut message = - Message::::new(MessageClass::Request, BINDING, u128_to_tid(tid as u128)); + Message::::new(MessageClass::Request, BINDING, u32_to_tid(tid)); message.add_attribute(ChangeRequest::new(change_ip, change_port)); // Encodes the message @@ -247,7 +247,7 @@ impl StunClient { let msg = encoder .encode_into_bytes(message.clone()) .with_context(|| "encode stun message")?; - tids.push(tid as u128); + tids.push(tid); tracing::trace!(?message, ?msg, tid, "send stun request"); self.socket .send_to(msg.as_slice().into(), &stun_host) @@ -818,6 +818,8 @@ impl StunInfoCollectorTrait for MockStunInfoCollector { #[cfg(test)] mod tests { + use crate::tunnel::{udp::UdpTunnelListener, TunnelListener}; + use super::*; #[tokio::test] @@ -836,4 +838,30 @@ mod tests { let port_mapping = collector.get_udp_port_mapping(3000).await; println!("{:#?}", port_mapping); } + + #[tokio::test] + async fn test_internal_stun_server() { + let mut udp_server1 = UdpTunnelListener::new("udp://0.0.0.0:55555".parse().unwrap()); + let mut udp_server2 = UdpTunnelListener::new("udp://0.0.0.0:55535".parse().unwrap()); + + let mut tasks = JoinSet::new(); + tasks.spawn(async move { + udp_server1.listen().await.unwrap(); + loop { + udp_server1.accept().await.unwrap(); + } + }); + tasks.spawn(async move { + udp_server2.listen().await.unwrap(); + loop { + udp_server2.accept().await.unwrap(); + } + }); + + let stun_servers = vec!["127.0.0.1:55555".to_string(), "127.0.0.1:55535".to_string()]; + let detector = UdpNatTypeDetector::new(stun_servers, 1); + let ret = detector.detect_nat_type(0).await; + println!("{:#?}, {:?}", ret, ret.as_ref().unwrap().nat_type()); + assert_eq!(ret.unwrap().nat_type(), NatType::PortRestricted); + } } diff --git a/easytier/src/common/stun_codec_ext.rs b/easytier/src/common/stun_codec_ext.rs index 3ea28b5..c013995 100644 --- a/easytier/src/common/stun_codec_ext.rs +++ b/easytier/src/common/stun_codec_ext.rs @@ -267,17 +267,18 @@ impl_encode!(ChangeRequestEncoder, ChangeRequest, |item: Self::Item| { ((ip << 1 | port) << 1) as u32 }); -pub fn tid_to_u128(tid: &TransactionId) -> u128 { - let mut tid_buf = [0u8; 16]; +pub fn tid_to_u32(tid: &TransactionId) -> u32 { + let mut tid_buf = [0u8; 4]; // copy bytes from msg_tid to tid_buf - tid_buf[..tid.as_bytes().len()].copy_from_slice(tid.as_bytes()); - u128::from_le_bytes(tid_buf) + tid_buf[..].copy_from_slice(&tid.as_bytes()[8..12]); + u32::from_le_bytes(tid_buf) } -pub fn u128_to_tid(tid: u128) -> TransactionId { +pub fn u32_to_tid(tid: u32) -> TransactionId { let tid_buf = tid.to_le_bytes(); let mut tid_arr = [0u8; 12]; - tid_arr.copy_from_slice(&tid_buf[..12]); + tid_arr[..4].copy_from_slice(&0xdeadbeefu32.to_be_bytes()); + tid_arr[8..12].copy_from_slice(&tid_buf); TransactionId::new(tid_arr) } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 86c417b..9a60501 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -3,11 +3,13 @@ use std::{ sync::{Arc, Weak}, }; +use anyhow::Context; use async_trait::async_trait; use bytes::BytesMut; use dashmap::DashMap; use futures::{stream::FuturesUnordered, StreamExt}; use rand::{Rng, SeedableRng}; +use zerocopy::AsBytes; use std::net::SocketAddr; use tokio::{ @@ -95,7 +97,60 @@ pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket { ) } -fn get_zcpacket_from_buf(buf: BytesMut) -> Result { +fn is_stun_packet(b: &[u8]) -> bool { + // stun has following pattern: + // 1. first two bits are 0b00 + // 2. magic cookie between 32-64 bits: 0x2112A442 + b[4..8] == [0x21, 0x12, 0xA4, 0x42] && b[0] & 0xC0 == 0 +} + +async fn respond_stun_packet( + socket: Arc, + addr: SocketAddr, + req_buf: Vec, +) -> Result<(), anyhow::Error> { + use crate::common::stun_codec_ext::*; + use bytecodec::DecodeExt as _; + use bytecodec::EncodeExt as _; + use stun_codec::rfc5389::attributes::MappedAddress; + use stun_codec::rfc5389::methods::BINDING; + use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder}; + + let mut decoder = MessageDecoder::::new(); + let req_msg = decoder + .decode_from_bytes(&req_buf) + .map_err(|e| anyhow::anyhow!("stun decode error: {:?}", e))? + .map_err(|e| anyhow::anyhow!("stun decode broken message error: {:?}", e))?; + + let tid = req_msg.transaction_id(); + // we only respond easytier stun req, whose tid has 0xdeadbeef prefix + if tid.as_bytes()[0..4] != [0xde, 0xad, 0xbe, 0xef] { + anyhow::bail!("stun req tid not from easytier"); + } + + let mut resp_msg = Message::::new( + MessageClass::SuccessResponse, + BINDING, + // we discard the prefix, make sure our implementation is not compatible with other stun client + u32_to_tid(tid_to_u32(&tid)), + ); + resp_msg.add_attribute(Attribute::MappedAddress(MappedAddress::new(addr.clone()))); + + let mut encoder = MessageEncoder::new(); + let rsp_buf = encoder + .encode_into_bytes(resp_msg.clone()) + .map_err(|e| anyhow::anyhow!("stun encode error: {:?}", e))?; + + socket + .send_to(&rsp_buf, addr.clone()) + .await + .with_context(|| "send stun response error")?; + + tracing::debug!(?addr, ?req_msg, "udp respond stun packet done"); + Ok(()) +} + +fn get_zcpacket_from_buf(buf: BytesMut, allow_stun: bool) -> Result { let dg_size = buf.len(); if dg_size < UDP_TUNNEL_HEADER_SIZE { return Err(TunnelError::InvalidPacket(format!( @@ -104,6 +159,10 @@ fn get_zcpacket_from_buf(buf: BytesMut) -> Result { ))); } + if allow_stun && is_stun_packet(&buf[..UDP_TUNNEL_HEADER_SIZE]) { + return Ok(ZCPacket::new_from_buf(buf, ZCPacketType::UDP)); + } + let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP); let header = zc_packet.udp_tunnel_header().unwrap(); let payload_len = header.len.get() as usize; @@ -154,7 +213,7 @@ async fn forward_from_ring_to_udp( } } -async fn udp_recv_from_socket_forward_task(socket: Arc, mut f: F) +async fn udp_recv_from_socket_forward_task(socket: Arc, allow_stun: bool, mut f: F) where F: FnMut(ZCPacket, SocketAddr) -> (), { @@ -175,7 +234,7 @@ where dg_size ); - let zc_packet = match get_zcpacket_from_buf(buf.split()) { + let zc_packet = match get_zcpacket_from_buf(buf.split(), allow_stun) { Ok(v) => v, Err(e) => { tracing::warn!(?e, "udp get zc packet from buf error"); @@ -337,6 +396,16 @@ impl UdpTunnelListenerData { let header = zc_packet.udp_tunnel_header().unwrap(); if header.msg_type == UdpPacketType::Syn as u8 { tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet)); + } else if is_stun_packet(header.as_bytes()) { + // ignore stun packet + tracing::debug!("udp forward packet ignore stun packet"); + let socket = self.socket.as_ref().unwrap().clone(); + tokio::spawn(async move { + let ret = respond_stun_packet(socket, addr, zc_packet.inner().to_vec()).await; + if let Err(e) = ret { + tracing::error!(?e, "udp respond stun packet error"); + } + }); } else if header.msg_type != UdpPacketType::HolePunch as u8 { let Some(mut conn) = self.sock_map.get_mut(&addr) else { tracing::trace!(?header, "udp forward packet error, connection not found"); @@ -350,7 +419,7 @@ impl UdpTunnelListenerData { async fn do_forward_task(self: Self) { let socket = self.socket.as_ref().unwrap().clone(); - udp_recv_from_socket_forward_task(socket, |zc_packet, addr| { + udp_recv_from_socket_forward_task(socket, true, |zc_packet, addr| { self.do_forward_one_packet_to_conn(zc_packet, addr); }) .await; @@ -501,7 +570,7 @@ impl UdpTunnelConnector { socket.recv_buf_from(&mut buf), ) .await??; - let zc_packet = get_zcpacket_from_buf(buf.split())?; + let zc_packet = get_zcpacket_from_buf(buf.split(), false)?; if recv_addr != addr { tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match"); } @@ -588,7 +657,7 @@ impl UdpTunnelConnector { tracing::debug!("connector udp close event"); return; } - _ = udp_recv_from_socket_forward_task(socket_clone, |zc_packet, addr| { + _ = udp_recv_from_socket_forward_task(socket_clone,false, |zc_packet, addr| { tracing::debug!(?addr, "connector udp forward task done"); if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) { tracing::trace!(?e, ?addr, "udp forward packet error");