mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-14 21:57:24 +08:00
support respond stun request in udp tunnel (#484)
we can use this to help the hole punching. (getting public mapped address stablely)
This commit is contained in:
@@ -132,7 +132,7 @@ impl StunClient {
|
|||||||
async fn wait_stun_response<'a, const N: usize>(
|
async fn wait_stun_response<'a, const N: usize>(
|
||||||
&self,
|
&self,
|
||||||
buf: &'a mut [u8; N],
|
buf: &'a mut [u8; N],
|
||||||
tids: &Vec<u128>,
|
tids: &Vec<u32>,
|
||||||
expected_ip_changed: bool,
|
expected_ip_changed: bool,
|
||||||
expected_port_changed: bool,
|
expected_port_changed: bool,
|
||||||
stun_host: &SocketAddr,
|
stun_host: &SocketAddr,
|
||||||
@@ -170,7 +170,7 @@ impl StunClient {
|
|||||||
|
|
||||||
if msg.class() != MessageClass::SuccessResponse
|
if msg.class() != MessageClass::SuccessResponse
|
||||||
|| msg.method() != BINDING
|
|| msg.method() != BINDING
|
||||||
|| !tids.contains(&tid_to_u128(&msg.transaction_id()))
|
|| !tids.contains(&tid_to_u32(&msg.transaction_id()))
|
||||||
{
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -239,7 +239,7 @@ impl StunClient {
|
|||||||
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
||||||
|
|
||||||
let mut message =
|
let mut message =
|
||||||
Message::<Attribute>::new(MessageClass::Request, BINDING, u128_to_tid(tid as u128));
|
Message::<Attribute>::new(MessageClass::Request, BINDING, u32_to_tid(tid));
|
||||||
message.add_attribute(ChangeRequest::new(change_ip, change_port));
|
message.add_attribute(ChangeRequest::new(change_ip, change_port));
|
||||||
|
|
||||||
// Encodes the message
|
// Encodes the message
|
||||||
@@ -247,7 +247,7 @@ impl StunClient {
|
|||||||
let msg = encoder
|
let msg = encoder
|
||||||
.encode_into_bytes(message.clone())
|
.encode_into_bytes(message.clone())
|
||||||
.with_context(|| "encode stun message")?;
|
.with_context(|| "encode stun message")?;
|
||||||
tids.push(tid as u128);
|
tids.push(tid);
|
||||||
tracing::trace!(?message, ?msg, tid, "send stun request");
|
tracing::trace!(?message, ?msg, tid, "send stun request");
|
||||||
self.socket
|
self.socket
|
||||||
.send_to(msg.as_slice().into(), &stun_host)
|
.send_to(msg.as_slice().into(), &stun_host)
|
||||||
@@ -818,6 +818,8 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::tunnel::{udp::UdpTunnelListener, TunnelListener};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -836,4 +838,30 @@ mod tests {
|
|||||||
let port_mapping = collector.get_udp_port_mapping(3000).await;
|
let port_mapping = collector.get_udp_port_mapping(3000).await;
|
||||||
println!("{:#?}", port_mapping);
|
println!("{:#?}", port_mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_internal_stun_server() {
|
||||||
|
let mut udp_server1 = UdpTunnelListener::new("udp://0.0.0.0:55555".parse().unwrap());
|
||||||
|
let mut udp_server2 = UdpTunnelListener::new("udp://0.0.0.0:55535".parse().unwrap());
|
||||||
|
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
tasks.spawn(async move {
|
||||||
|
udp_server1.listen().await.unwrap();
|
||||||
|
loop {
|
||||||
|
udp_server1.accept().await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
tasks.spawn(async move {
|
||||||
|
udp_server2.listen().await.unwrap();
|
||||||
|
loop {
|
||||||
|
udp_server2.accept().await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stun_servers = vec!["127.0.0.1:55555".to_string(), "127.0.0.1:55535".to_string()];
|
||||||
|
let detector = UdpNatTypeDetector::new(stun_servers, 1);
|
||||||
|
let ret = detector.detect_nat_type(0).await;
|
||||||
|
println!("{:#?}, {:?}", ret, ret.as_ref().unwrap().nat_type());
|
||||||
|
assert_eq!(ret.unwrap().nat_type(), NatType::PortRestricted);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -267,17 +267,18 @@ impl_encode!(ChangeRequestEncoder, ChangeRequest, |item: Self::Item| {
|
|||||||
((ip << 1 | port) << 1) as u32
|
((ip << 1 | port) << 1) as u32
|
||||||
});
|
});
|
||||||
|
|
||||||
pub fn tid_to_u128(tid: &TransactionId) -> u128 {
|
pub fn tid_to_u32(tid: &TransactionId) -> u32 {
|
||||||
let mut tid_buf = [0u8; 16];
|
let mut tid_buf = [0u8; 4];
|
||||||
// copy bytes from msg_tid to tid_buf
|
// copy bytes from msg_tid to tid_buf
|
||||||
tid_buf[..tid.as_bytes().len()].copy_from_slice(tid.as_bytes());
|
tid_buf[..].copy_from_slice(&tid.as_bytes()[8..12]);
|
||||||
u128::from_le_bytes(tid_buf)
|
u32::from_le_bytes(tid_buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn u128_to_tid(tid: u128) -> TransactionId {
|
pub fn u32_to_tid(tid: u32) -> TransactionId {
|
||||||
let tid_buf = tid.to_le_bytes();
|
let tid_buf = tid.to_le_bytes();
|
||||||
let mut tid_arr = [0u8; 12];
|
let mut tid_arr = [0u8; 12];
|
||||||
tid_arr.copy_from_slice(&tid_buf[..12]);
|
tid_arr[..4].copy_from_slice(&0xdeadbeefu32.to_be_bytes());
|
||||||
|
tid_arr[8..12].copy_from_slice(&tid_buf);
|
||||||
TransactionId::new(tid_arr)
|
TransactionId::new(tid_arr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ use std::{
|
|||||||
sync::{Arc, Weak},
|
sync::{Arc, Weak},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use futures::{stream::FuturesUnordered, StreamExt};
|
use futures::{stream::FuturesUnordered, StreamExt};
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::{Rng, SeedableRng};
|
||||||
|
use zerocopy::AsBytes;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
@@ -95,7 +97,60 @@ pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
|
fn is_stun_packet(b: &[u8]) -> bool {
|
||||||
|
// stun has following pattern:
|
||||||
|
// 1. first two bits are 0b00
|
||||||
|
// 2. magic cookie between 32-64 bits: 0x2112A442
|
||||||
|
b[4..8] == [0x21, 0x12, 0xA4, 0x42] && b[0] & 0xC0 == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn respond_stun_packet(
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
addr: SocketAddr,
|
||||||
|
req_buf: Vec<u8>,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
use crate::common::stun_codec_ext::*;
|
||||||
|
use bytecodec::DecodeExt as _;
|
||||||
|
use bytecodec::EncodeExt as _;
|
||||||
|
use stun_codec::rfc5389::attributes::MappedAddress;
|
||||||
|
use stun_codec::rfc5389::methods::BINDING;
|
||||||
|
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
|
||||||
|
|
||||||
|
let mut decoder = MessageDecoder::<Attribute>::new();
|
||||||
|
let req_msg = decoder
|
||||||
|
.decode_from_bytes(&req_buf)
|
||||||
|
.map_err(|e| anyhow::anyhow!("stun decode error: {:?}", e))?
|
||||||
|
.map_err(|e| anyhow::anyhow!("stun decode broken message error: {:?}", e))?;
|
||||||
|
|
||||||
|
let tid = req_msg.transaction_id();
|
||||||
|
// we only respond easytier stun req, whose tid has 0xdeadbeef prefix
|
||||||
|
if tid.as_bytes()[0..4] != [0xde, 0xad, 0xbe, 0xef] {
|
||||||
|
anyhow::bail!("stun req tid not from easytier");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut resp_msg = Message::<Attribute>::new(
|
||||||
|
MessageClass::SuccessResponse,
|
||||||
|
BINDING,
|
||||||
|
// we discard the prefix, make sure our implementation is not compatible with other stun client
|
||||||
|
u32_to_tid(tid_to_u32(&tid)),
|
||||||
|
);
|
||||||
|
resp_msg.add_attribute(Attribute::MappedAddress(MappedAddress::new(addr.clone())));
|
||||||
|
|
||||||
|
let mut encoder = MessageEncoder::new();
|
||||||
|
let rsp_buf = encoder
|
||||||
|
.encode_into_bytes(resp_msg.clone())
|
||||||
|
.map_err(|e| anyhow::anyhow!("stun encode error: {:?}", e))?;
|
||||||
|
|
||||||
|
socket
|
||||||
|
.send_to(&rsp_buf, addr.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "send stun response error")?;
|
||||||
|
|
||||||
|
tracing::debug!(?addr, ?req_msg, "udp respond stun packet done");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_zcpacket_from_buf(buf: BytesMut, allow_stun: bool) -> Result<ZCPacket, TunnelError> {
|
||||||
let dg_size = buf.len();
|
let dg_size = buf.len();
|
||||||
if dg_size < UDP_TUNNEL_HEADER_SIZE {
|
if dg_size < UDP_TUNNEL_HEADER_SIZE {
|
||||||
return Err(TunnelError::InvalidPacket(format!(
|
return Err(TunnelError::InvalidPacket(format!(
|
||||||
@@ -104,6 +159,10 @@ fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if allow_stun && is_stun_packet(&buf[..UDP_TUNNEL_HEADER_SIZE]) {
|
||||||
|
return Ok(ZCPacket::new_from_buf(buf, ZCPacketType::UDP));
|
||||||
|
}
|
||||||
|
|
||||||
let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
||||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
let payload_len = header.len.get() as usize;
|
let payload_len = header.len.get() as usize;
|
||||||
@@ -154,7 +213,7 @@ async fn forward_from_ring_to_udp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn udp_recv_from_socket_forward_task<F>(socket: Arc<UdpSocket>, mut f: F)
|
async fn udp_recv_from_socket_forward_task<F>(socket: Arc<UdpSocket>, allow_stun: bool, mut f: F)
|
||||||
where
|
where
|
||||||
F: FnMut(ZCPacket, SocketAddr) -> (),
|
F: FnMut(ZCPacket, SocketAddr) -> (),
|
||||||
{
|
{
|
||||||
@@ -175,7 +234,7 @@ where
|
|||||||
dg_size
|
dg_size
|
||||||
);
|
);
|
||||||
|
|
||||||
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
|
let zc_packet = match get_zcpacket_from_buf(buf.split(), allow_stun) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(?e, "udp get zc packet from buf error");
|
tracing::warn!(?e, "udp get zc packet from buf error");
|
||||||
@@ -337,6 +396,16 @@ impl UdpTunnelListenerData {
|
|||||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
if header.msg_type == UdpPacketType::Syn as u8 {
|
if header.msg_type == UdpPacketType::Syn as u8 {
|
||||||
tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet));
|
tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet));
|
||||||
|
} else if is_stun_packet(header.as_bytes()) {
|
||||||
|
// ignore stun packet
|
||||||
|
tracing::debug!("udp forward packet ignore stun packet");
|
||||||
|
let socket = self.socket.as_ref().unwrap().clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let ret = respond_stun_packet(socket, addr, zc_packet.inner().to_vec()).await;
|
||||||
|
if let Err(e) = ret {
|
||||||
|
tracing::error!(?e, "udp respond stun packet error");
|
||||||
|
}
|
||||||
|
});
|
||||||
} else if header.msg_type != UdpPacketType::HolePunch as u8 {
|
} else if header.msg_type != UdpPacketType::HolePunch as u8 {
|
||||||
let Some(mut conn) = self.sock_map.get_mut(&addr) else {
|
let Some(mut conn) = self.sock_map.get_mut(&addr) else {
|
||||||
tracing::trace!(?header, "udp forward packet error, connection not found");
|
tracing::trace!(?header, "udp forward packet error, connection not found");
|
||||||
@@ -350,7 +419,7 @@ impl UdpTunnelListenerData {
|
|||||||
|
|
||||||
async fn do_forward_task(self: Self) {
|
async fn do_forward_task(self: Self) {
|
||||||
let socket = self.socket.as_ref().unwrap().clone();
|
let socket = self.socket.as_ref().unwrap().clone();
|
||||||
udp_recv_from_socket_forward_task(socket, |zc_packet, addr| {
|
udp_recv_from_socket_forward_task(socket, true, |zc_packet, addr| {
|
||||||
self.do_forward_one_packet_to_conn(zc_packet, addr);
|
self.do_forward_one_packet_to_conn(zc_packet, addr);
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
@@ -501,7 +570,7 @@ impl UdpTunnelConnector {
|
|||||||
socket.recv_buf_from(&mut buf),
|
socket.recv_buf_from(&mut buf),
|
||||||
)
|
)
|
||||||
.await??;
|
.await??;
|
||||||
let zc_packet = get_zcpacket_from_buf(buf.split())?;
|
let zc_packet = get_zcpacket_from_buf(buf.split(), false)?;
|
||||||
if recv_addr != addr {
|
if recv_addr != addr {
|
||||||
tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match");
|
tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match");
|
||||||
}
|
}
|
||||||
@@ -588,7 +657,7 @@ impl UdpTunnelConnector {
|
|||||||
tracing::debug!("connector udp close event");
|
tracing::debug!("connector udp close event");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
_ = udp_recv_from_socket_forward_task(socket_clone, |zc_packet, addr| {
|
_ = udp_recv_from_socket_forward_task(socket_clone,false, |zc_packet, addr| {
|
||||||
tracing::debug!(?addr, "connector udp forward task done");
|
tracing::debug!(?addr, "connector udp forward task done");
|
||||||
if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) {
|
if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) {
|
||||||
tracing::trace!(?e, ?addr, "udp forward packet error");
|
tracing::trace!(?e, ?addr, "udp forward packet error");
|
||||||
|
|||||||
Reference in New Issue
Block a user