diff --git a/Cargo.toml b/Cargo.toml index ceb5093..371d4f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,7 @@ async-stream = "0.3.5" async-trait = "0.1.74" dashmap = "5.5.3" -timedmap = "1.0.1" +timedmap = "=1.0.1" # for tap device tun = { version = "0.6.1", features = ["async"] } @@ -112,6 +112,9 @@ network-interface = "1.1.1" # for ospf route pathfinding = "4.9.1" +# for encryption +boringtun = { version = "0.6.0" } + # for cli tabled = "0.15.*" humansize = "2.1.3" @@ -135,6 +138,7 @@ zip = "0.6.6" [dev-dependencies] serial_test = "3.0.0" +rstest = "0.18.2" [profile.dev] panic = "abort" diff --git a/README.md b/README.md index e8b684d..220138a 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,12 @@ [简体中文](/README_CN.md) | [English](/README.md) - EasyTier is a simple, plug-and-play, decentralized VPN networking solution implemented with the Rust language and Tokio framework. + EasyTier is a simple, safe and decentralized VPN networking solution implemented with the Rust language and Tokio framework. ## Features - **Decentralized**: No need to rely on centralized services, nodes are equal and independent. +- **Safe**: Use WireGuard protocol to encrypt data. - **Cross-platform**: Supports MacOS/Linux/Windows, will support IOS and Android in the future. The executable file is statically linked, making deployment simple. - **Networking without public IP**: Supports networking using shared public nodes, refer to [Configuration Guide](#Networking-without-public-IP) - **NAT traversal**: Supports UDP-based NAT traversal, able to establish stable connections even in complex network environments. diff --git a/README_CN.md b/README_CN.md index 865b199..350c741 100644 --- a/README_CN.md +++ b/README_CN.md @@ -7,11 +7,12 @@ [简体中文](/README_CN.md) | [English](/README.md) -一个简单、即插即用、去中心化的内网穿透 VPN 组网方案,使用 Rust 语言和 Tokio 框架实现。 +一个简单、安全、去中心化的内网穿透 VPN 组网方案,使用 Rust 语言和 Tokio 框架实现。 ## 特点 - **去中心化**:无需依赖中心化服务,节点平等且独立。 +- **安全**:支持利用 WireGuard 加密通信。 - **跨平台**:支持 MacOS/Linux/Windows,未来将支持 IOS 和 Android。可执行文件静态链接,部署简单。 - **无公网 IP 组网**:支持利用共享的公网节点组网,可参考 [配置指南](#无公网IP组网) - **NAT 穿透**:支持基于 UDP 的 NAT 穿透,即使在复杂的网络环境下也能建立稳定的连接。 @@ -178,7 +179,7 @@ sudo easytier-core --ipv4 10.144.144.2 --network-name abc --network-secret abc - # 路线图 - [ ] 完善文档和用户指南。 -- [ ] 支持加密、TCP 打洞等特性。 +- [ ] 支持 TCP 打洞等特性。 - [ ] 支持 Android、IOS 等移动平台。 - [ ] 支持 Web 配置管理。 diff --git a/src/connector/direct.rs b/src/connector/direct.rs index e9843b5..291daad 100644 --- a/src/connector/direct.rs +++ b/src/connector/direct.rs @@ -163,7 +163,7 @@ impl DirectConnectorManager { return Err(Error::UrlInBlacklist); } - let connector = create_connector_by_url(&addr, data.global_ctx.get_ip_collector()).await?; + let connector = create_connector_by_url(&addr, &data.global_ctx).await?; let (peer_id, conn_id) = timeout( std::time::Duration::from_secs(5), data.peer_manager.try_connect(connector), diff --git a/src/connector/manual.rs b/src/connector/manual.rs index 49538e0..f8f6199 100644 --- a/src/connector/manual.rs +++ b/src/connector/manual.rs @@ -90,7 +90,7 @@ impl ManualConnectorManager { } pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> { - self.add_connector(create_connector_by_url(url, self.global_ctx.get_ip_collector()).await?); + self.add_connector(create_connector_by_url(url, &self.global_ctx).await?); Ok(()) } diff --git a/src/connector/mod.rs b/src/connector/mod.rs index f398d91..af275d6 100644 --- a/src/connector/mod.rs +++ b/src/connector/mod.rs @@ -4,10 +4,13 @@ use std::{ }; use crate::{ - common::{error::Error, network::IPCollector}, + common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, tunnels::{ - ring_tunnel::RingTunnelConnector, tcp_tunnel::TcpTunnelConnector, - udp_tunnel::UdpTunnelConnector, TunnelConnector, + ring_tunnel::RingTunnelConnector, + tcp_tunnel::TcpTunnelConnector, + udp_tunnel::UdpTunnelConnector, + wireguard::{WgConfig, WgTunnelConnector}, + TunnelConnector, }, }; @@ -41,7 +44,7 @@ async fn set_bind_addr_for_peer_connector( pub async fn create_connector_by_url( url: &str, - ip_collector: Arc, + global_ctx: &ArcGlobalCtx, ) -> Result, Error> { let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?; match url.scheme() { @@ -49,16 +52,24 @@ pub async fn create_connector_by_url( let dst_addr = crate::tunnels::check_scheme_and_get_socket_addr::(&url, "tcp")?; let mut connector = TcpTunnelConnector::new(url); - set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector) - .await; + set_bind_addr_for_peer_connector( + &mut connector, + dst_addr.is_ipv4(), + &global_ctx.get_ip_collector(), + ) + .await; return Ok(Box::new(connector)); } "udp" => { let dst_addr = crate::tunnels::check_scheme_and_get_socket_addr::(&url, "udp")?; let mut connector = UdpTunnelConnector::new(url); - set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector) - .await; + set_bind_addr_for_peer_connector( + &mut connector, + dst_addr.is_ipv4(), + &global_ctx.get_ip_collector(), + ) + .await; return Ok(Box::new(connector)); } "ring" => { @@ -66,6 +77,14 @@ pub async fn create_connector_by_url( let connector = RingTunnelConnector::new(url); return Ok(Box::new(connector)); } + "wg" => { + crate::tunnels::check_scheme_and_get_socket_addr::(&url, "wg")?; + let nid = global_ctx.get_network_identity(); + let wg_config = + WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); + let connector = WgTunnelConnector::new(url, wg_config); + return Ok(Box::new(connector)); + } _ => { return Err(Error::InvalidUrl(url.into())); } diff --git a/src/easytier-core.rs b/src/easytier-core.rs index ec19d5e..8503242 100644 --- a/src/easytier-core.rs +++ b/src/easytier-core.rs @@ -73,7 +73,8 @@ struct Cli { #[arg(short, long, help = "listeners to accept connections, pass '' to avoid listening.", default_values_t = ["tcp://0.0.0.0:11010".to_string(), - "udp://0.0.0.0:11010".to_string()])] + "udp://0.0.0.0:11010".to_string(), + "wg://0.0.0.0:11011".to_string()])] listeners: Vec, /// specify the linux network namespace, default is the root namespace diff --git a/src/instance/listeners.rs b/src/instance/listeners.rs index 36a725c..5886022 100644 --- a/src/instance/listeners.rs +++ b/src/instance/listeners.rs @@ -12,8 +12,11 @@ use crate::{ }, peers::peer_manager::PeerManager, tunnels::{ - ring_tunnel::RingTunnelListener, tcp_tunnel::TcpTunnelListener, - udp_tunnel::UdpTunnelListener, Tunnel, TunnelListener, + ring_tunnel::RingTunnelListener, + tcp_tunnel::TcpTunnelListener, + udp_tunnel::UdpTunnelListener, + wireguard::{WgConfig, WgTunnelListener}, + Tunnel, TunnelListener, }, }; @@ -66,6 +69,13 @@ impl ListenerManage "udp" => { self.add_listener(UdpTunnelListener::new(l.clone())).await?; } + "wg" => { + let nid = self.global_ctx.get_network_identity(); + let wg_config = + WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); + self.add_listener(WgTunnelListener::new(l.clone(), wg_config)) + .await?; + } _ => { log::warn!("unsupported listener uri: {}", l); } diff --git a/src/tests/three_node.rs b/src/tests/three_node.rs index 47fdc99..acf91cf 100644 --- a/src/tests/three_node.rs +++ b/src/tests/three_node.rs @@ -19,6 +19,7 @@ use crate::{ ring_tunnel::RingTunnelConnector, tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener}, udp_tunnel::{UdpTunnelConnector, UdpTunnelListener}, + wireguard::{WgConfig, WgTunnelConnector}, }, }; @@ -50,6 +51,7 @@ pub fn get_inst_config(inst_name: &str, ns: Option<&str>, ipv4: &str) -> TomlCon config.set_listeners(vec![ "tcp://0.0.0.0:11010".parse().unwrap(), "udp://0.0.0.0:11010".parse().unwrap(), + "wg://0.0.0.0:11011".parse().unwrap(), ]); config } @@ -72,12 +74,22 @@ pub async fn init_three_node(proto: &str) -> Vec { .add_connector(TcpTunnelConnector::new( "tcp://10.1.1.1:11010".parse().unwrap(), )); - } else { + } else if proto == "udp" { inst2 .get_conn_manager() .add_connector(UdpTunnelConnector::new( "udp://10.1.1.1:11010".parse().unwrap(), )); + } else if proto == "wg" { + inst2 + .get_conn_manager() + .add_connector(WgTunnelConnector::new( + "wg://10.1.1.1:11011".parse().unwrap(), + WgConfig::new_from_network_identity( + &inst1.get_global_ctx().get_network_identity().network_name, + &inst1.get_global_ctx().get_network_identity().network_secret, + ), + )); } inst2 @@ -101,10 +113,11 @@ pub async fn init_three_node(proto: &str) -> Vec { vec![inst1, inst2, inst3] } +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn basic_three_node_test_tcp() { - let insts = init_three_node("tcp").await; +pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + let insts = init_three_node(proto).await; check_route( "10.144.144.2", @@ -119,28 +132,11 @@ pub async fn basic_three_node_test_tcp() { ); } +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn basic_three_node_test_udp() { - let insts = init_three_node("udp").await; - - check_route( - "10.144.144.2", - insts[1].peer_id(), - insts[0].get_peer_manager().list_routes().await, - ); - - check_route( - "10.144.144.3", - insts[2].peer_id(), - insts[0].get_peer_manager().list_routes().await, - ); -} - -#[tokio::test] -#[serial_test::serial] -pub async fn tcp_proxy_three_node_test() { - let insts = init_three_node("tcp").await; +pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + let insts = init_three_node(proto).await; insts[2] .get_global_ctx() @@ -171,10 +167,11 @@ pub async fn tcp_proxy_three_node_test() { .await; } +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn icmp_proxy_three_node_test() { - let insts = init_three_node("tcp").await; +pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + let insts = init_three_node(proto).await; insts[2] .get_global_ctx() @@ -205,34 +202,71 @@ pub async fn icmp_proxy_three_node_test() { assert_eq!(code.code().unwrap(), 0); } +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn proxy_three_node_disconnect_test() { +pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) { + let insts = init_three_node(proto).await; let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4")); - inst4 - .get_conn_manager() - .add_connector(TcpTunnelConnector::new( - "tcp://10.1.2.3:11010".parse().unwrap(), - )); + if proto == "tcp" { + inst4 + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.2.3:11010".parse().unwrap(), + )); + } else if proto == "wg" { + inst4 + .get_conn_manager() + .add_connector(WgTunnelConnector::new( + "wg://10.1.2.3:11011".parse().unwrap(), + WgConfig::new_from_network_identity( + &inst4.get_global_ctx().get_network_identity().network_name, + &inst4.get_global_ctx().get_network_identity().network_secret, + ), + )); + } else { + unreachable!("not support"); + } inst4.run().await.unwrap(); - tokio::spawn(async { - loop { - tokio::time::sleep(tokio::time::Duration::from_secs(6)).await; + let task = tokio::spawn(async move { + for _ in 1..=2 { + tokio::time::sleep(tokio::time::Duration::from_secs(8)).await; + // inst4 should be in inst1's route list + let routes = insts[0].get_peer_manager().list_routes().await; + assert!( + routes + .iter() + .find(|r| r.peer_id == inst4.peer_id()) + .is_some(), + "inst4 should be in inst1's route list, {:?}", + routes + ); + set_link_status("net_d", false); - tokio::time::sleep(tokio::time::Duration::from_secs(6)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(8)).await; + let routes = insts[0].get_peer_manager().list_routes().await; + assert!( + routes + .iter() + .find(|r| r.peer_id == inst4.peer_id()) + .is_none(), + "inst4 should not be in inst1's route list, {:?}", + routes + ); set_link_status("net_d", true); } }); - // TODO: add some traffic here, also should check route & peer list - tokio::time::sleep(tokio::time::Duration::from_secs(35)).await; + let (ret,) = tokio::join!(task); + assert!(ret.is_ok()); } +#[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn udp_proxy_three_node_test() { - let insts = init_three_node("tcp").await; +pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + let insts = init_three_node(proto).await; insts[2] .get_global_ctx() diff --git a/src/tunnels/common.rs b/src/tunnels/common.rs index ab2b931..512c2f7 100644 --- a/src/tunnels/common.rs +++ b/src/tunnels/common.rs @@ -402,7 +402,7 @@ pub mod tests { close_tunnel(&tunnel).await.unwrap(); - if connector.remote_url().scheme() == "udp" { + if ["udp", "wg"].contains(&connector.remote_url().scheme()) { lis.abort(); } else { // lis should finish in 1 second diff --git a/src/tunnels/mod.rs b/src/tunnels/mod.rs index b064842..da31a6f 100644 --- a/src/tunnels/mod.rs +++ b/src/tunnels/mod.rs @@ -5,6 +5,7 @@ pub mod stats; pub mod tcp_tunnel; pub mod tunnel_filter; pub mod udp_tunnel; +pub mod wireguard; use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc}; diff --git a/src/tunnels/udp_tunnel.rs b/src/tunnels/udp_tunnel.rs index 3d49826..474ea7c 100644 --- a/src/tunnels/udp_tunnel.rs +++ b/src/tunnels/udp_tunnel.rs @@ -204,12 +204,12 @@ fn get_tunnel_from_socket( ) } -struct StreamSinkPair( - Pin>, - Pin>, - u32, +pub(crate) struct StreamSinkPair( + pub Pin>, + pub Pin>, + pub u32, ); -type ArcStreamSinkPair = Arc>; +pub(crate) type ArcStreamSinkPair = Arc>; pub struct UdpTunnelListener { addr: url::Url, diff --git a/src/tunnels/wireguard.rs b/src/tunnels/wireguard.rs new file mode 100644 index 0000000..67a2ba0 --- /dev/null +++ b/src/tunnels/wireguard.rs @@ -0,0 +1,666 @@ +use std::{ + collections::hash_map::DefaultHasher, + fmt::{Debug, Formatter}, + hash::Hasher, + net::SocketAddr, + pin::Pin, + sync::Arc, + time::Duration, +}; + +use anyhow::Context; +use async_recursion::async_recursion; +use async_trait::async_trait; +use boringtun::{ + noise::{errors::WireGuardError, Tunn, TunnResult}, + x25519::{PublicKey, StaticSecret}, +}; +use dashmap::DashMap; +use futures::{SinkExt, StreamExt}; +use rand::RngCore; +use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; + +use crate::{ + rpc::TunnelInfo, + tunnels::{build_url_from_socket_addr, common::TunnelWithCustomInfo}, +}; + +use super::{ + check_scheme_and_get_socket_addr, common::setup_sokcet2, ring_tunnel::create_ring_tunnel_pair, + DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, +}; + +const MAX_PACKET: usize = 4096; + +#[derive(Clone)] +pub struct WgConfig { + my_secret_key: StaticSecret, + my_public_key: PublicKey, + + peer_public_key: PublicKey, +} + +impl WgConfig { + pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self { + let mut my_sec = [0u8; 32]; + let mut hasher = DefaultHasher::new(); + hasher.write(network_name.as_bytes()); + hasher.write(network_secret.as_bytes()); + my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..8]); + my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..16]); + my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..24]); + my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes()); + + let my_secret_key = StaticSecret::from(my_sec); + let my_public_key = PublicKey::from(&my_secret_key); + let peer_public_key = my_public_key.clone(); + + WgConfig { + my_secret_key, + my_public_key, + peer_public_key, + } + } +} + +#[derive(Clone)] +struct WgPeerData { + udp: Arc, // only for send + endpoint: SocketAddr, + tunn: Arc>, + sink: Arc>>>, + stream: Arc>>>, +} + +impl Debug for WgPeerData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WgPeerData") + .field("endpoint", &self.endpoint) + .field("local", &self.udp.local_addr()) + .finish() + } +} + +impl WgPeerData { + #[tracing::instrument] + async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> { + let mut send_buf = [0u8; MAX_PACKET]; + let encapsulate_result = { + let mut peer = self.tunn.lock().await; + peer.encapsulate(&packet, &mut send_buf) + }; + + tracing::info!( + ?encapsulate_result, + "Received {} bytes from me", + packet.len() + ); + + match encapsulate_result { + TunnResult::WriteToNetwork(packet) => { + self.udp + .send_to(packet, self.endpoint) + .await + .context("Failed to send encrypted IP packet to WireGuard endpoint.")?; + tracing::debug!( + "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", + packet.len() + ); + } + TunnResult::Err(e) => { + tracing::error!("Failed to encapsulate IP packet: {:?}", e); + } + TunnResult::Done => { + // Ignored + } + other => { + tracing::error!( + "Unexpected WireGuard state during encapsulation: {:?}", + other + ); + } + }; + Ok(()) + } + + /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, + /// decapsulates them, and dispatches newly received IP packets. + #[tracing::instrument] + pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) { + let mut send_buf = [0u8; MAX_PACKET]; + let data = &recv_buf[..]; + let decapsulate_result = { + let mut peer = self.tunn.lock().await; + peer.decapsulate(None, data, &mut send_buf) + }; + + tracing::debug!("Decapsulation result: {:?}", decapsulate_result); + + match decapsulate_result { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + return; + } + }; + let mut peer = self.tunn.lock().await; + loop { + let mut send_buf = [0u8; MAX_PACKET]; + match peer.decapsulate(None, &[], &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + break; + } + }; + } + _ => { + break; + } + } + } + } + TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { + tracing::debug!( + "WireGuard endpoint sent IP packet of {} bytes", + packet.len() + ); + let ret = self + .sink + .lock() + .await + .send( + WgPeer::remove_ip_header(packet, packet[0] >> 4 == 4) + .to_vec() + .into(), + ) + .await; + if ret.is_err() { + tracing::error!("Failed to send packet to tunnel: {:?}", ret); + } + } + _ => { + tracing::warn!( + "Unexpected WireGuard state during decapsulation: {:?}", + decapsulate_result + ); + } + } + } + + #[tracing::instrument] + #[async_recursion] + async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () { + match result { + TunnResult::WriteToNetwork(packet) => { + tracing::debug!( + "Sending routine packet of {} bytes to WireGuard endpoint", + packet.len() + ); + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!( + "Failed to send routine packet to WireGuard endpoint: {:?}", + e + ); + } + }; + } + TunnResult::Err(WireGuardError::ConnectionExpired) => { + tracing::warn!("Wireguard handshake has expired!"); + + let mut buf = vec![0u8; MAX_PACKET]; + let result = self + .tunn + .lock() + .await + .format_handshake_initiation(&mut buf[..], false); + + self.handle_routine_tun_result(result).await + } + TunnResult::Err(e) => { + tracing::error!( + "Failed to prepare routine packet for WireGuard endpoint: {:?}", + e + ); + } + TunnResult::Done => { + // Sleep for a bit + tokio::time::sleep(Duration::from_millis(1)).await; + } + other => { + tracing::warn!("Unexpected WireGuard routine task state: {:?}", other); + } + }; + } + + /// WireGuard Routine task. Handles Handshake, keep-alive, etc. + pub async fn routine_task(self) { + loop { + let mut send_buf = [0u8; MAX_PACKET]; + let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) }; + self.handle_routine_tun_result(tun_result).await; + } + } +} + +struct WgPeer { + udp: Arc, // only for send + config: WgConfig, + endpoint: SocketAddr, + + data: Option, + tasks: JoinSet<()>, + + access_time: std::time::Instant, +} + +impl WgPeer { + fn new(udp: Arc, config: WgConfig, endpoint: SocketAddr) -> Self { + WgPeer { + udp, + config, + endpoint, + + data: None, + tasks: JoinSet::new(), + + access_time: std::time::Instant::now(), + } + } + + fn add_ip_header(packet: &[u8]) -> Vec { + let mut ret = vec![0u8; packet.len() + 20]; + let ip_header = ret.as_mut_slice(); + ip_header[0] = 0x45; + ip_header[1] = 0; + ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes()); + ip_header[4..6].copy_from_slice(&0u16.to_be_bytes()); + ip_header[6..8].copy_from_slice(&0u16.to_be_bytes()); + ip_header[8] = 64; + ip_header[9] = 0; + ip_header[10..12].copy_from_slice(&0u16.to_be_bytes()); + ip_header[12..16].copy_from_slice(&0u32.to_be_bytes()); + ip_header[16..20].copy_from_slice(&0u32.to_be_bytes()); + ip_header[20..].copy_from_slice(packet); + ret + } + + fn remove_ip_header(packet: &[u8], is_v4: bool) -> &[u8] { + if is_v4 { + return &packet[20..]; + } else { + return &packet[40..]; + } + } + + async fn handle_packet_from_me(data: WgPeerData) { + while let Some(Ok(packet)) = data.stream.lock().await.next().await { + let ret = data + .handle_one_packet_from_me(&Self::add_ip_header(&packet)) + .await; + if let Err(e) = ret { + tracing::error!("Failed to handle packet from me: {}", e); + } + } + } + + async fn handle_packet_from_peer(&mut self, packet: &[u8]) { + self.access_time = std::time::Instant::now(); + tracing::info!("Received {} bytes from peer", packet.len()); + let data = self.data.as_ref().unwrap(); + data.handle_one_packet_from_peer(packet).await; + } + + fn start_and_get_tunnel(&mut self) -> Box { + let (stunnel, ctunnel) = create_ring_tunnel_pair(); + + let data = WgPeerData { + udp: self.udp.clone(), + endpoint: self.endpoint, + tunn: Arc::new(Mutex::new( + Tunn::new( + self.config.my_secret_key.clone(), + self.config.peer_public_key.clone(), + None, + None, + rand::thread_rng().next_u32(), + None, + ) + .unwrap(), + )), + sink: Arc::new(Mutex::new(stunnel.pin_sink())), + stream: Arc::new(Mutex::new(stunnel.pin_stream())), + }; + + self.data = Some(data.clone()); + self.tasks.spawn(Self::handle_packet_from_me(data.clone())); + self.tasks.spawn(data.routine_task()); + + ctunnel + } +} + +type ConnSender = tokio::sync::mpsc::UnboundedSender>; +type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver>; + +pub struct WgTunnelListener { + addr: url::Url, + config: WgConfig, + + udp: Option>, + conn_recv: ConnReceiver, + conn_send: Option, + + tasks: JoinSet<()>, +} + +impl WgTunnelListener { + pub fn new(addr: url::Url, config: WgConfig) -> Self { + let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel(); + WgTunnelListener { + addr, + config, + + udp: None, + conn_recv, + conn_send: Some(conn_send), + + tasks: JoinSet::new(), + } + } + + fn get_udp_socket(&self) -> Arc { + self.udp.as_ref().unwrap().clone() + } + + async fn handle_udp_incoming( + socket: Arc, + config: WgConfig, + conn_sender: ConnSender, + ) { + let mut tasks = JoinSet::new(); + let peer_map: Arc> = Arc::new(DashMap::new()); + + let peer_map_clone = peer_map.clone(); + tasks.spawn(async move { + loop { + peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600); + tokio::time::sleep(Duration::from_secs(60)).await; + } + }); + + let mut buf = [0u8; 4096]; + loop { + let Ok((n, addr)) = socket.recv_from(&mut buf).await else { + tracing::error!("Failed to receive from UDP socket"); + break; + }; + + let data = &buf[..n]; + tracing::info!("Received {} bytes from {}", n, addr); + + if !peer_map.contains_key(&addr) { + tracing::info!("New peer: {}", addr); + let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone()); + let tunnel = Box::new(TunnelWithCustomInfo::new( + wg.start_and_get_tunnel(), + TunnelInfo { + tunnel_type: "wg".to_owned(), + local_addr: build_url_from_socket_addr( + &socket.local_addr().unwrap().to_string(), + "wg", + ) + .into(), + remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(), + }, + )); + if let Err(e) = conn_sender.send(tunnel) { + tracing::error!("Failed to send tunnel to conn_sender: {}", e); + } + peer_map.insert(addr, wg); + } + + let mut peer = peer_map.get_mut(&addr).unwrap(); + peer.handle_packet_from_peer(data).await; + } + } +} + +#[async_trait] +impl TunnelListener for WgTunnelListener { + async fn listen(&mut self) -> Result<(), super::TunnelError> { + let addr = check_scheme_and_get_socket_addr::(&self.addr, "wg")?; + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + setup_sokcet2(&socket2_socket, &addr)?; + self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); + self.tasks.spawn(Self::handle_udp_incoming( + self.get_udp_socket(), + self.config.clone(), + self.conn_send.take().unwrap(), + )); + + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + while let Some(tunnel) = self.conn_recv.recv().await { + tracing::info!(?tunnel, "Accepted tunnel"); + return Ok(tunnel); + } + Err(TunnelError::CommonError( + "Failed to accept tunnel".to_string(), + )) + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +pub struct WgClientTunnel { + wg_peer: WgPeer, + tunnel: Box, + info: TunnelInfo, +} + +impl Tunnel for WgClientTunnel { + fn stream(&self) -> Box { + self.tunnel.stream() + } + + fn sink(&self) -> Box { + self.tunnel.sink() + } + + fn info(&self) -> Option { + Some(self.info.clone()) + } +} + +#[derive(Clone)] +pub struct WgTunnelConnector { + addr: url::Url, + config: WgConfig, + udp: Option>, +} + +impl Debug for WgTunnelConnector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WgTunnelConnector") + .field("addr", &self.addr) + .field("udp", &self.udp) + .finish() + } +} + +impl WgTunnelConnector { + pub fn new(addr: url::Url, config: WgConfig) -> Self { + WgTunnelConnector { + addr, + config, + udp: None, + } + } + + fn create_handshake_init(tun: &mut Tunn) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_init = tun.format_handshake_initiation(&mut dst, false); + assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); + let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { + sent + } else { + unreachable!(); + }; + + handshake_init.into() + } + + fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + assert!( + matches!(keepalive, TunnResult::WriteToNetwork(_)), + "Failed to parse handshake response, {:?}", + keepalive + ); + + let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { + sent + } else { + unreachable!(); + }; + + keepalive.into() + } +} + +#[async_trait] +impl super::TunnelConnector for WgTunnelConnector { + #[tracing::instrument] + async fn connect(&mut self) -> Result, super::TunnelError> { + let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "wg")?; + tracing::warn!("wg connect: {:?}", self.addr); + let udp = UdpSocket::bind("0.0.0.0:0").await?; + let local_addr = udp.local_addr().unwrap().to_string(); + + let mut my_tun = Tunn::new( + self.config.my_secret_key.clone(), + self.config.peer_public_key.clone(), + None, + None, + rand::thread_rng().next_u32(), + None, + ) + .unwrap(); + + let init = Self::create_handshake_init(&mut my_tun); + udp.send_to(&init, addr).await?; + + let mut buf = [0u8; MAX_PACKET]; + let (n, _) = udp.recv_from(&mut buf).await.unwrap(); + let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]); + udp.send_to(&keepalive, addr).await?; + + let mut wg_peer = WgPeer::new(Arc::new(udp), self.config.clone(), addr); + let tunnel = wg_peer.start_and_get_tunnel(); + + let data = wg_peer.data.as_ref().unwrap().clone(); + wg_peer.tasks.spawn(async move { + loop { + let mut buf = [0u8; MAX_PACKET]; + let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); + if recv_addr != addr { + continue; + } + data.handle_one_packet_from_peer(&buf[..n]).await; + } + }); + + let ret = Box::new(WgClientTunnel { + wg_peer, + tunnel, + info: TunnelInfo { + tunnel_type: "wg".to_owned(), + local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(), + remote_addr: self.remote_url().into(), + }, + }); + + Ok(ret) + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } +} + +#[cfg(test)] +pub mod tests { + use boringtun::*; + + use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; + use crate::tunnels::wireguard::*; + + pub fn enable_log() { + let filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into()) + .from_env() + .unwrap() + .add_directive("tarpc=error".parse().unwrap()); + tracing_subscriber::fmt::fmt() + .pretty() + .with_env_filter(filter) + .init(); + } + + pub fn create_wg_config() -> (WgConfig, WgConfig) { + let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); + let my_public_key = x25519::PublicKey::from(&my_secret_key); + + let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); + let their_public_key = x25519::PublicKey::from(&their_secret_key); + + let server_cfg = WgConfig { + my_secret_key: my_secret_key.clone(), + my_public_key, + peer_public_key: their_public_key.clone(), + }; + + let client_cfg = WgConfig { + my_secret_key: their_secret_key, + my_public_key: their_public_key, + peer_public_key: my_public_key, + }; + + (server_cfg, client_cfg) + } + + #[tokio::test] + async fn test_wg() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg); + let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn udp_bench() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg); + let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg); + _tunnel_bench(listener, connector).await + } +}