diff --git a/Cargo.toml b/Cargo.toml index 371d4f6..2f4ec61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,7 +91,7 @@ url = { version = "2.5", features = ["serde"] } byteorder = "1.5.0" # for proxy -cidr = "0.2.2" +cidr = { version = "0.2.2", features = ["serde"] } socket2 = "0.5.5" # for hole punching @@ -119,6 +119,8 @@ boringtun = { version = "0.6.0" } tabled = "0.15.*" humansize = "2.1.3" +base64 = "0.21.7" + [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.52", features = [ diff --git a/proto/cli.proto b/proto/cli.proto index 3001ea6..08fd85f 100644 --- a/proto/cli.proto +++ b/proto/cli.proto @@ -142,3 +142,18 @@ message GetGlobalPeerMapResponse { service PeerCenterRpc { rpc GetGlobalPeerMap (GetGlobalPeerMapRequest) returns (GetGlobalPeerMapResponse); } + +message VpnPortalInfo { + string vpn_type = 1; + string client_config = 2; + repeated string connected_clients = 3; +} + +message GetVpnPortalInfoRequest {} +message GetVpnPortalInfoResponse { + VpnPortalInfo vpn_portal_info = 1; +} + +service VpnPortalRpc { + rpc GetVpnPortalInfo (GetVpnPortalInfoRequest) returns (GetVpnPortalInfoResponse); +} diff --git a/src/common/config.rs b/src/common/config.rs index 9dd87ea..58a8bae 100644 --- a/src/common/config.rs +++ b/src/common/config.rs @@ -42,6 +42,9 @@ pub trait ConfigLoader: Send + Sync { fn get_rpc_portal(&self) -> Option; fn set_rpc_portal(&self, addr: SocketAddr); + fn get_vpn_portal_config(&self) -> Option; + fn set_vpn_portal_config(&self, config: VpnPortalConfig); + fn dump(&self) -> String; } @@ -87,6 +90,12 @@ pub struct ConsoleLoggerConfig { pub level: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct VpnPortalConfig { + pub client_cidr: cidr::Ipv4Cidr, + pub wireguard_listen: SocketAddr, +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] struct Config { netns: Option, @@ -103,6 +112,8 @@ struct Config { console_logger: Option, rpc_portal: Option, + + vpn_portal_config: Option, } #[derive(Debug, Clone)] @@ -314,6 +325,13 @@ impl ConfigLoader for TomlConfigLoader { self.config.lock().unwrap().rpc_portal = Some(addr); } + fn get_vpn_portal_config(&self) -> Option { + self.config.lock().unwrap().vpn_portal_config.clone() + } + fn set_vpn_portal_config(&self, config: VpnPortalConfig) { + self.config.lock().unwrap().vpn_portal_config = Some(config); + } + fn dump(&self) -> String { toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap() } diff --git a/src/common/global_ctx.rs b/src/common/global_ctx.rs index 8f60d92..3f71d7a 100644 --- a/src/common/global_ctx.rs +++ b/src/common/global_ctx.rs @@ -28,6 +28,9 @@ pub enum GlobalCtxEvent { Connecting(url::Url), ConnectError(String, String), // (dst, error message) + + VpnPortalClientConnected(String, String), // (portal, client ip) + VpnPortalClientDisconnected(String, String), // (portal, client ip) } type EventBus = tokio::sync::broadcast::Sender; @@ -192,6 +195,10 @@ impl GlobalCtx { pub fn add_running_listener(&self, url: url::Url) { self.running_listeners.lock().unwrap().push(url); } + + pub fn get_vpn_portal_cidr(&self) -> Option { + self.config.get_vpn_portal_config().map(|x| x.client_cidr) + } } #[cfg(test)] diff --git a/src/common/mod.rs b/src/common/mod.rs index 0aeeed1..d609ca5 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -54,7 +54,7 @@ pub fn join_joinset_background( } future::poll_fn(|cx| { - tracing::info!("try join joinset tasks"); + tracing::debug!("try join joinset tasks"); let Some(js) = js.upgrade() else { return std::task::Poll::Ready(()); }; diff --git a/src/common/stun.rs b/src/common/stun.rs index c5c0a0a..60fefcc 100644 --- a/src/common/stun.rs +++ b/src/common/stun.rs @@ -127,7 +127,7 @@ impl Stun { continue; }; - tracing::info!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg); + tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg); if msg.class() != MessageClass::SuccessResponse || msg.method() != BINDING @@ -194,7 +194,7 @@ impl Stun { changed_addr } - #[tracing::instrument(ret, err, level = Level::INFO)] + #[tracing::instrument(ret, err, level = Level::DEBUG)] pub async fn bind_request( &self, source_port: u16, @@ -250,7 +250,7 @@ impl Stun { real_port_changed, }; - tracing::info!( + tracing::debug!( ?stun_host, ?recv_addr, ?changed_socket_addr, @@ -300,7 +300,7 @@ impl UdpNatTypeDetector { let ret = stun.bind_request(source_port, true, true).await; if let Ok(resp) = ret { if !resp.real_ip_changed || !resp.real_port_changed { - tracing::info!( + tracing::debug!( ?server_ip, ?ret, "stun bind request return with unchanged ip and port" @@ -311,7 +311,7 @@ impl UdpNatTypeDetector { } ret_test2 = ret.ok(); ret_test3 = stun.bind_request(source_port, false, true).await.ok(); - tracing::info!(?ret_test3, "stun bind request with changed port"); + tracing::debug!(?ret_test3, "stun bind request with changed port"); succ = true; break; } @@ -320,7 +320,7 @@ impl UdpNatTypeDetector { return NatType::Unknown; } - tracing::info!( + tracing::debug!( ?ret_test1_1, ?ret_test1_2, ?ret_test2, diff --git a/src/easytier-cli.rs b/src/easytier-cli.rs index 3c1677e..b6b0662 100644 --- a/src/easytier-cli.rs +++ b/src/easytier-cli.rs @@ -3,6 +3,7 @@ use std::{net::SocketAddr, vec}; use clap::{command, Args, Parser, Subcommand}; +use rpc::vpn_portal_rpc_client::VpnPortalRpcClient; mod arch; mod common; @@ -38,6 +39,7 @@ enum SubCommand { Stun, Route, PeerCenter, + VpnPortal, } #[derive(Args, Debug)] @@ -216,6 +218,12 @@ impl CommandHandler { Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?) } + async fn get_vpn_portal_client( + &self, + ) -> Result, Error> { + Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?) + } + async fn list_peers(&self) -> Result { let mut client = self.get_peer_manager_client().await?; let request = tonic::Request::new(ListPeerRequest::default()); @@ -452,6 +460,18 @@ async fn main() -> Result<(), Error> { .to_string() ); } + SubCommand::VpnPortal => { + let mut vpn_portal_client = handler.get_vpn_portal_client().await?; + let resp = vpn_portal_client + .get_vpn_portal_info(GetVpnPortalInfoRequest::default()) + .await? + .into_inner() + .vpn_portal_info + .unwrap_or_default(); + println!("portal_name: {}\n", resp.vpn_type); + println!("client_config:{}", resp.client_config); + println!("connected_clients:\n{:#?}", resp.connected_clients); + } } Ok(()) diff --git a/src/easytier-core.rs b/src/easytier-core.rs index 8503242..ff384f2 100644 --- a/src/easytier-core.rs +++ b/src/easytier-core.rs @@ -17,9 +17,10 @@ mod peer_center; mod peers; mod rpc; mod tunnels; +mod vpn_portal; use common::{ - config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig}, + config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig}, get_logger_timer_rfc3339, }; use instance::instance::Instance; @@ -105,6 +106,14 @@ struct Cli { help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000" )] instance_id: Option, + + #[arg( + long, + help = "url that defines the vpn portal, allow other vpn clients to connect. +example: wg://0.0.0.0:11010/10.14.14.0/24, means the vpn portal is a wireguard server listening on vpn.example.com:11010, +and the vpn client is in network of 10.14.14.0/24" + )] + vpn_portal: Option, } impl From for TomlConfigLoader { @@ -197,6 +206,38 @@ impl From for TomlConfigLoader { }); } + if cli.vpn_portal.is_some() { + let url: url::Url = cli + .vpn_portal + .clone() + .unwrap() + .parse() + .with_context(|| { + format!( + "failed to parse vpn portal url: {}", + cli.vpn_portal.unwrap() + ) + }) + .unwrap(); + cfg.set_vpn_portal_config(VpnPortalConfig { + client_cidr: url.path()[1..] + .parse() + .with_context(|| { + format!("failed to parse vpn portal client cidr: {}", url.path()) + }) + .unwrap(), + wireguard_listen: format!("{}:{}", url.host_str().unwrap(), url.port().unwrap()) + .parse() + .with_context(|| { + format!( + "failed to parse vpn portal wireguard listen address: {}", + url.host_str().unwrap() + ) + }) + .unwrap(), + }); + } + cfg } } @@ -337,6 +378,20 @@ pub async fn main() { GlobalCtxEvent::ConnectError(dst, err) => { print_event(format!("connect to peer error. dst: {}, err: {}", dst, err)); } + + GlobalCtxEvent::VpnPortalClientConnected(portal, client_addr) => { + print_event(format!( + "vpn portal client connected. portal: {}, client_addr: {}", + portal, client_addr + )); + } + + GlobalCtxEvent::VpnPortalClientDisconnected(portal, client_addr) => { + print_event(format!( + "vpn portal client disconnected. portal: {}, client_addr: {}", + portal, client_addr + )); + } } } }); diff --git a/src/instance/instance.rs b/src/instance/instance.rs index d212ebf..f4e1329 100644 --- a/src/instance/instance.rs +++ b/src/instance/instance.rs @@ -1,6 +1,6 @@ use std::borrow::BorrowMut; use std::net::Ipv4Addr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Context; use futures::StreamExt; @@ -25,7 +25,10 @@ use crate::peer_center::instance::PeerCenterInstance; use crate::peers::peer_conn::PeerConnId; use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::rpc_service::PeerManagerRpcService; +use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc; +use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; use crate::tunnels::SinkItem; +use crate::vpn_portal::{self, VpnPortal}; use tokio_stream::wrappers::ReceiverStream; @@ -54,6 +57,8 @@ pub struct Instance { peer_center: Arc, + vpn_portal: Arc>>, + global_ctx: ArcGlobalCtx, } @@ -102,6 +107,8 @@ impl Instance { let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone())); + let vpn_portal_inst = vpn_portal::wireguard::WireGuard::default(); + Instance { inst_name: global_ctx.inst_name.clone(), id, @@ -122,6 +129,8 @@ impl Instance { peer_center, + vpn_portal: Arc::new(Mutex::new(Box::new(vpn_portal_inst))), + global_ctx, } } @@ -134,6 +143,7 @@ impl Instance { if let Some(ipv4) = Ipv4Packet::new(&ret) { if ipv4.get_version() != 4 { tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4); + return; } let dst_ipv4 = ipv4.get_destination(); tracing::trace!( @@ -270,6 +280,14 @@ impl Instance { self.add_initial_peers().await?; + if let Some(_) = self.global_ctx.get_vpn_portal_cidr() { + self.vpn_portal + .lock() + .await + .start(self.get_global_ctx(), self.get_peer_manager()) + .await?; + } + Ok(()) } @@ -304,6 +322,45 @@ impl Instance { self.peer_manager.my_peer_id() } + fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc { + struct VpnPortalRpcService { + peer_mgr: Weak, + vpn_portal: Weak>>, + } + + #[tonic::async_trait] + impl VpnPortalRpc for VpnPortalRpcService { + async fn get_vpn_portal_info( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + let Some(vpn_portal) = self.vpn_portal.upgrade() else { + return Err(tonic::Status::unavailable("vpn portal not available")); + }; + + let Some(peer_mgr) = self.peer_mgr.upgrade() else { + return Err(tonic::Status::unavailable("peer manager not available")); + }; + + let vpn_portal = vpn_portal.lock().await; + let ret = GetVpnPortalInfoResponse { + vpn_portal_info: Some(VpnPortalInfo { + vpn_type: vpn_portal.name(), + client_config: vpn_portal.dump_client_config(peer_mgr).await, + connected_clients: vpn_portal.list_clients().await, + }), + }; + + Ok(tonic::Response::new(ret)) + } + } + + VpnPortalRpcService { + peer_mgr: Arc::downgrade(&self.peer_manager), + vpn_portal: Arc::downgrade(&self.vpn_portal), + } + } + fn run_rpc_server(&mut self) -> Result<(), Box> { let Some(addr) = self.global_ctx.config.get_rpc_portal() else { tracing::info!("rpc server not enabled, because rpc_portal is not set."); @@ -313,6 +370,7 @@ impl Instance { let conn_manager = self.conn_manager.clone(); let net_ns = self.global_ctx.net_ns.clone(); let peer_center = self.peer_center.clone(); + let vpn_portal_rpc = self.get_vpn_portal_rpc_service(); self.tasks.spawn(async move { let _g = net_ns.guard(); @@ -332,6 +390,9 @@ impl Instance { peer_center.get_rpc_service(), ), ) + .add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new( + vpn_portal_rpc, + )) .serve(addr) .await .with_context(|| format!("rpc server failed. addr: {}", addr)) diff --git a/src/peer_center/instance.rs b/src/peer_center/instance.rs index 281eb3b..37f566a 100644 --- a/src/peer_center/instance.rs +++ b/src/peer_center/instance.rs @@ -242,9 +242,9 @@ impl PeerCenterInstance { for _ in 1..10 { peers = ctx.job_ctx.service.list_peers().await.into(); if peers == *ctx.job_ctx.last_report_peers.lock().await { - break; + return Ok(3000); } - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(2)).await; } *ctx.job_ctx.last_report_peers.lock().await = peers.clone(); diff --git a/src/peers/foreign_network_manager.rs b/src/peers/foreign_network_manager.rs index 0477193..1d959f5 100644 --- a/src/peers/foreign_network_manager.rs +++ b/src/peers/foreign_network_manager.rs @@ -189,7 +189,7 @@ impl ForeignNetworkManager { } pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> { - tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager"); + tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager"); let entry = self .data @@ -222,10 +222,11 @@ impl ForeignNetworkManager { let mut s = self.global_ctx.subscribe(); self.tasks.lock().await.spawn(async move { while let Ok(e) = s.recv().await { - tracing::warn!(?e, "global event"); if let GlobalCtxEvent::PeerRemoved(peer_id) = &e { + tracing::info!(?e, "remove peer from foreign network manager"); data.remove_peer(*peer_id); } else if let GlobalCtxEvent::PeerConnRemoved(..) = &e { + tracing::info!(?e, "clear no conn peer from foreign network manager"); data.clear_no_conn_peer(); } } diff --git a/src/peers/peer_ospf_route.rs b/src/peers/peer_ospf_route.rs index 40334cc..7670390 100644 --- a/src/peers/peer_ospf_route.rs +++ b/src/peers/peer_ospf_route.rs @@ -99,6 +99,7 @@ impl RoutePeerInfo { .get_proxy_cidrs() .iter() .map(|x| x.to_string()) + .chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string())) .collect(), hostname: global_ctx.get_hostname(), udp_stun_info: global_ctx @@ -385,6 +386,10 @@ impl RouteTable { self.next_hop_map.get(&dst_peer_id).map(|x| *x) } + fn peer_reachable(&self, peer_id: PeerId) -> bool { + self.next_hop_map.contains_key(&peer_id) + } + fn get_nat_type(&self, peer_id: PeerId) -> Option { self.peer_infos .get(&peer_id) @@ -407,10 +412,10 @@ impl RouteTable { // build next hop map self.next_hop_map.clear(); + self.next_hop_map.insert(my_peer_id, (my_peer_id, 0)); for item in self.peer_infos.iter() { let peer_id = *item.key(); if peer_id == my_peer_id { - self.next_hop_map.insert(peer_id, (peer_id, 0)); continue; } let Some(path) = pathfinding::prelude::bfs( @@ -617,8 +622,7 @@ impl PeerRouteServiceImpl { .synced_route_info .update_my_peer_info(self.my_peer_id, &self.global_ctx) { - self.update_cached_local_conn_bitmap(); - self.update_route_table(); + self.update_route_table_and_cached_local_conn_bitmap(); return true; } false @@ -631,8 +635,7 @@ impl PeerRouteServiceImpl { .update_my_conn_info(self.my_peer_id, connected_peers); if updated { - self.update_cached_local_conn_bitmap(); - self.update_route_table(); + self.update_route_table_and_cached_local_conn_bitmap(); } updated @@ -643,12 +646,27 @@ impl PeerRouteServiceImpl { .build_from_synced_info(self.my_peer_id, &self.synced_route_info); } - fn update_cached_local_conn_bitmap(&self) { + fn update_route_table_and_cached_local_conn_bitmap(&self) { + // update route table first because we want to filter out unreachable peers. + self.update_route_table(); + + // the conn_bitmap should contain complete list of directly connected peers. + // use union of dst peers can preserve this property. + let all_dst_peer_ids = self + .synced_route_info + .conn_map + .iter() + .map(|x| x.value().clone().0.into_iter()) + .flatten() + .collect::>(); + let all_peer_ids = self .synced_route_info .conn_map .iter() .map(|x| (*x.key(), x.value().1.get())) + // do not sync conn info of peers that are not reachable from any peer. + .filter(|p| all_dst_peer_ids.contains(&p.0) || self.route_table.peer_reachable(p.0)) .collect::>(); let mut conn_bitmap = RouteConnBitmap::new(); @@ -680,6 +698,12 @@ impl PeerRouteServiceImpl { { continue; } + + // do not send unreachable peer info to dst peer. + if !self.route_table.peer_reachable(*item.key()) { + continue; + } + route_infos.push(item.value().clone()); } @@ -867,8 +891,7 @@ impl RouteService for RouteSessionManager { session.update_dst_saved_conn_bitmap_version(conn_bitmap); } - service_impl.update_cached_local_conn_bitmap(); - service_impl.update_route_table(); + service_impl.update_route_table_and_cached_local_conn_bitmap(); tracing::debug!( "sync_route_info: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}", @@ -1012,7 +1035,7 @@ impl RouteSessionManager { .map(|x| *x) .collect::>(); - tracing::info!(?service_impl.my_peer_id, ?peers, ?session_peers, ?initiator_candidates, "maintain_sessions begin"); + tracing::debug!(?service_impl.my_peer_id, ?peers, ?session_peers, ?initiator_candidates, "maintain_sessions begin"); if initiator_candidates.is_empty() { next_sleep_ms = 1000; diff --git a/src/peers/peer_rip_route.rs b/src/peers/peer_rip_route.rs index 7e7ad92..26b5207 100644 --- a/src/peers/peer_rip_route.rs +++ b/src/peers/peer_rip_route.rs @@ -52,6 +52,7 @@ impl SyncPeerInfo { .get_proxy_cidrs() .iter() .map(|x| x.to_string()) + .chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string())) .collect(), hostname: global_ctx.get_hostname(), udp_stun_info: global_ctx diff --git a/src/tunnels/udp_tunnel.rs b/src/tunnels/udp_tunnel.rs index 474ea7c..22a2a45 100644 --- a/src/tunnels/udp_tunnel.rs +++ b/src/tunnels/udp_tunnel.rs @@ -28,7 +28,7 @@ use super::{ DatagramSink, DatagramStream, Tunnel, TunnelListener, }; -pub const UDP_DATA_MTU: usize = 2500; +pub const UDP_DATA_MTU: usize = 65000; #[derive(Archive, Deserialize, Serialize)] #[archive(compare(PartialEq), check_bytes)] @@ -123,7 +123,7 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option { } if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::warn!(?udp_packet, "udp magic not match"); + tracing::trace!(?udp_packet, "udp magic not match"); return None; } @@ -351,7 +351,7 @@ impl TunnelListener for UdpTunnelListener { }; if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::info!(?udp_packet, "udp magic not match"); + tracing::trace!(?udp_packet, "udp magic not match"); continue; } @@ -471,7 +471,7 @@ impl UdpTunnelConnector { }; if udp_packet.magic != UDP_PACKET_MAGIC { - tracing::info!(?udp_packet, "udp magic not match"); + tracing::trace!(?udp_packet, "udp magic not match"); return Err(super::TunnelError::ConnectError(format!( "udp connect error, magic not match. magic: {:?}", udp_packet.magic diff --git a/src/tunnels/wireguard.rs b/src/tunnels/wireguard.rs index 67a2ba0..4166895 100644 --- a/src/tunnels/wireguard.rs +++ b/src/tunnels/wireguard.rs @@ -30,14 +30,25 @@ use super::{ DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, }; -const MAX_PACKET: usize = 4096; +const MAX_PACKET: usize = 65500; + +#[derive(Debug, Clone)] +enum WgType { + // used by easytier peer, need remove/add ip header for in/out wg msg + InternalUse, + // used by wireguard peer, keep original ip header + ExternalUse, +} #[derive(Clone)] pub struct WgConfig { my_secret_key: StaticSecret, my_public_key: PublicKey, + peer_secret_key: StaticSecret, peer_public_key: PublicKey, + + wg_type: WgType, } impl WgConfig { @@ -56,14 +67,47 @@ impl WgConfig { let my_secret_key = StaticSecret::from(my_sec); let my_public_key = PublicKey::from(&my_secret_key); + let peer_secret_key = StaticSecret::from(my_sec); let peer_public_key = my_public_key.clone(); WgConfig { my_secret_key, my_public_key, + peer_secret_key, peer_public_key, + + wg_type: WgType::InternalUse, } } + + pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self { + let server_cfg = Self::new_from_network_identity("server", server_key_seed); + let client_cfg = Self::new_from_network_identity("client", client_key_seed); + Self { + my_secret_key: server_cfg.my_secret_key, + my_public_key: server_cfg.my_public_key, + peer_secret_key: client_cfg.my_secret_key, + peer_public_key: client_cfg.my_public_key, + + wg_type: WgType::ExternalUse, + } + } + + pub fn my_secret_key(&self) -> &[u8] { + self.my_secret_key.as_bytes() + } + + pub fn peer_secret_key(&self) -> &[u8] { + self.peer_secret_key.as_bytes() + } + + pub fn my_public_key(&self) -> &[u8] { + self.my_public_key.as_bytes() + } + + pub fn peer_public_key(&self) -> &[u8] { + self.peer_public_key.as_bytes() + } } #[derive(Clone)] @@ -73,6 +117,7 @@ struct WgPeerData { tunn: Arc>, sink: Arc>>>, stream: Arc>>>, + wg_type: WgType, } impl Debug for WgPeerData { @@ -88,12 +133,17 @@ impl WgPeerData { #[tracing::instrument] async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> { let mut send_buf = [0u8; MAX_PACKET]; + let encapsulate_result = { let mut peer = self.tunn.lock().await; - peer.encapsulate(&packet, &mut send_buf) + if matches!(self.wg_type, WgType::InternalUse) { + peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf) + } else { + peer.encapsulate(&packet, &mut send_buf) + } }; - tracing::info!( + tracing::trace!( ?encapsulate_result, "Received {} bytes from me", packet.len() @@ -177,9 +227,13 @@ impl WgPeerData { .lock() .await .send( - WgPeer::remove_ip_header(packet, packet[0] >> 4 == 4) - .to_vec() - .into(), + if matches!(self.wg_type, WgType::InternalUse) { + self.remove_ip_header(packet, packet[0] >> 4 == 4) + } else { + packet + } + .to_vec() + .into(), ) .await; if ret.is_err() { @@ -250,6 +304,31 @@ impl WgPeerData { self.handle_routine_tun_result(tun_result).await; } } + + fn add_ip_header(&self, packet: &[u8]) -> Vec { + let mut ret = vec![0u8; packet.len() + 20]; + let ip_header = ret.as_mut_slice(); + ip_header[0] = 0x45; + ip_header[1] = 0; + ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes()); + ip_header[4..6].copy_from_slice(&0u16.to_be_bytes()); + ip_header[6..8].copy_from_slice(&0u16.to_be_bytes()); + ip_header[8] = 64; + ip_header[9] = 0; + ip_header[10..12].copy_from_slice(&0u16.to_be_bytes()); + ip_header[12..16].copy_from_slice(&0u32.to_be_bytes()); + ip_header[16..20].copy_from_slice(&0u32.to_be_bytes()); + ip_header[20..].copy_from_slice(packet); + ret + } + + fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] { + if is_v4 { + return &packet[20..]; + } else { + return &packet[40..]; + } + } } struct WgPeer { @@ -277,36 +356,9 @@ impl WgPeer { } } - fn add_ip_header(packet: &[u8]) -> Vec { - let mut ret = vec![0u8; packet.len() + 20]; - let ip_header = ret.as_mut_slice(); - ip_header[0] = 0x45; - ip_header[1] = 0; - ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes()); - ip_header[4..6].copy_from_slice(&0u16.to_be_bytes()); - ip_header[6..8].copy_from_slice(&0u16.to_be_bytes()); - ip_header[8] = 64; - ip_header[9] = 0; - ip_header[10..12].copy_from_slice(&0u16.to_be_bytes()); - ip_header[12..16].copy_from_slice(&0u32.to_be_bytes()); - ip_header[16..20].copy_from_slice(&0u32.to_be_bytes()); - ip_header[20..].copy_from_slice(packet); - ret - } - - fn remove_ip_header(packet: &[u8], is_v4: bool) -> &[u8] { - if is_v4 { - return &packet[20..]; - } else { - return &packet[40..]; - } - } - async fn handle_packet_from_me(data: WgPeerData) { while let Some(Ok(packet)) = data.stream.lock().await.next().await { - let ret = data - .handle_one_packet_from_me(&Self::add_ip_header(&packet)) - .await; + let ret = data.handle_one_packet_from_me(&packet).await; if let Err(e) = ret { tracing::error!("Failed to handle packet from me: {}", e); } @@ -315,7 +367,7 @@ impl WgPeer { async fn handle_packet_from_peer(&mut self, packet: &[u8]) { self.access_time = std::time::Instant::now(); - tracing::info!("Received {} bytes from peer", packet.len()); + tracing::trace!("Received {} bytes from peer", packet.len()); let data = self.data.as_ref().unwrap(); data.handle_one_packet_from_peer(packet).await; } @@ -339,6 +391,7 @@ impl WgPeer { )), sink: Arc::new(Mutex::new(stunnel.pin_sink())), stream: Arc::new(Mutex::new(stunnel.pin_stream())), + wg_type: self.config.wg_type.clone(), }; self.data = Some(data.clone()); @@ -349,6 +402,17 @@ impl WgPeer { } } +impl Drop for WgPeer { + fn drop(&mut self) { + self.tasks.abort_all(); + if let Some(data) = self.data.clone() { + tokio::spawn(async move { + let _ = data.sink.lock().await.close().await; + }); + } + } +} + type ConnSender = tokio::sync::mpsc::UnboundedSender>; type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver>; @@ -406,7 +470,7 @@ impl WgTunnelListener { }; let data = &buf[..n]; - tracing::info!("Received {} bytes from {}", n, addr); + tracing::trace!("Received {} bytes from {}", n, addr); if !peer_map.contains_key(&addr) { tracing::info!("New peer: {}", addr); @@ -636,13 +700,17 @@ pub mod tests { let server_cfg = WgConfig { my_secret_key: my_secret_key.clone(), my_public_key, + peer_secret_key: their_secret_key.clone(), peer_public_key: their_public_key.clone(), + wg_type: WgType::InternalUse, }; let client_cfg = WgConfig { my_secret_key: their_secret_key, my_public_key: their_public_key, + peer_secret_key: my_secret_key, peer_public_key: my_public_key, + wg_type: WgType::InternalUse, }; (server_cfg, client_cfg) diff --git a/src/vpn_portal/mod.rs b/src/vpn_portal/mod.rs new file mode 100644 index 0000000..3a2c171 --- /dev/null +++ b/src/vpn_portal/mod.rs @@ -0,0 +1,24 @@ +// with vpn portal, user can use other vpn client to connect to easytier servers +// without installing easytier. +// these vpn client include: +// 1. wireguard +// 2. openvpn (TODO) +// 3. shadowsocks (TODO) + +use std::sync::Arc; + +use crate::{common::global_ctx::ArcGlobalCtx, peers::peer_manager::PeerManager}; + +pub mod wireguard; + +#[async_trait::async_trait] +pub trait VpnPortal: Send + Sync { + async fn start( + &mut self, + global_ctx: ArcGlobalCtx, + peer_mgr: Arc, + ) -> anyhow::Result<()>; + async fn dump_client_config(&self, peer_mgr: Arc) -> String; + fn name(&self) -> String; + async fn list_clients(&self) -> Vec; +} diff --git a/src/vpn_portal/wireguard.rs b/src/vpn_portal/wireguard.rs new file mode 100644 index 0000000..cb25471 --- /dev/null +++ b/src/vpn_portal/wireguard.rs @@ -0,0 +1,346 @@ +use std::{ + net::{Ipv4Addr, SocketAddr}, + pin::Pin, + sync::Arc, +}; + +use anyhow::Context; +use base64::{prelude::BASE64_STANDARD, Engine}; +use cidr::Ipv4Inet; +use dashmap::DashMap; +use futures::{SinkExt, StreamExt}; +use pnet::packet::ipv4::Ipv4Packet; +use tokio::{sync::Mutex, task::JoinSet}; +use tokio_util::bytes::Bytes; + +use crate::{ + common::{ + global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, + join_joinset_background, + }, + peers::{ + packet::{self, ArchivedPacket}, + peer_manager::PeerManager, + PeerPacketFilter, + }, + tunnels::{ + wireguard::{WgConfig, WgTunnelListener}, + DatagramSink, Tunnel, TunnelListener, + }, +}; + +use super::VpnPortal; + +type WgPeerIpTable = Arc>>; + +struct ClientEntry { + endpoint_addr: Option, + sink: Mutex>>, +} + +struct WireGuardImpl { + global_ctx: ArcGlobalCtx, + peer_mgr: Arc, + wg_config: WgConfig, + listenr_addr: SocketAddr, + + wg_peer_ip_table: WgPeerIpTable, + + tasks: Arc>>, +} + +impl WireGuardImpl { + fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc) -> Self { + let nid = global_ctx.get_network_identity(); + let key_seed = format!("{}{}", nid.network_name, nid.network_secret); + let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed); + + let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap(); + let listenr_addr = vpn_cfg.wireguard_listen; + + Self { + global_ctx, + peer_mgr, + wg_config, + listenr_addr, + wg_peer_ip_table: Arc::new(DashMap::new()), + tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), + } + } + + async fn handle_incoming_conn( + t: Box, + peer_mgr: Arc, + wg_peer_ip_table: WgPeerIpTable, + ) { + let mut s = t.pin_stream(); + let mut ip_registered = false; + + let info = t.info().unwrap_or_default(); + let remote_addr = info.remote_addr.clone(); + peer_mgr + .get_global_ctx() + .issue_event(GlobalCtxEvent::VpnPortalClientConnected( + info.local_addr, + info.remote_addr, + )); + + while let Some(Ok(msg)) = s.next().await { + let Some(i) = Ipv4Packet::new(&msg) else { + tracing::error!(?msg, "Failed to parse ipv4 packet"); + continue; + }; + if !ip_registered { + let client_entry = Arc::new(ClientEntry { + endpoint_addr: remote_addr.parse().ok(), + sink: Mutex::new(t.pin_sink()), + }); + wg_peer_ip_table.insert(i.get_source(), client_entry.clone()); + ip_registered = true; + } + tracing::trace!(?i, "Received from wg client"); + let _ = peer_mgr + .send_msg_ipv4(msg.clone(), i.get_destination()) + .await; + } + + let info = t.info().unwrap_or_default(); + peer_mgr + .get_global_ctx() + .issue_event(GlobalCtxEvent::VpnPortalClientDisconnected( + info.local_addr, + info.remote_addr, + )); + } + + async fn start_pipeline_processor(&self) { + struct PeerPacketFilterForVpnPortal { + wg_peer_ip_table: WgPeerIpTable, + } + + #[async_trait::async_trait] + impl PeerPacketFilter for PeerPacketFilterForVpnPortal { + async fn try_process_packet_from_peer( + &self, + packet: &ArchivedPacket, + _: &Bytes, + ) -> Option<()> { + if packet.packet_type != packet::PacketType::Data { + return None; + }; + + let payload_bytes = packet.payload.as_bytes(); + + let ipv4 = Ipv4Packet::new(payload_bytes)?; + if ipv4.get_version() != 4 { + return None; + } + + let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone(); + + tracing::trace!(?ipv4, "Packet filter for vpn portal"); + + let ret = entry + .sink + .lock() + .await + .send(Bytes::copy_from_slice(payload_bytes)) + .await; + + ret.ok() + } + } + + self.peer_mgr + .add_packet_process_pipeline(Box::new(PeerPacketFilterForVpnPortal { + wg_peer_ip_table: self.wg_peer_ip_table.clone(), + })) + .await; + } + + async fn start(&self) -> anyhow::Result<()> { + let mut l = WgTunnelListener::new( + format!("wg://{}", self.listenr_addr).parse().unwrap(), + self.wg_config.clone(), + ); + + l.listen() + .await + .with_context(|| "Failed to start wireguard listener for vpn portal")?; + + join_joinset_background(self.tasks.clone(), "wireguard".to_string()); + + let tasks = Arc::downgrade(&self.tasks.clone()); + let peer_mgr = self.peer_mgr.clone(); + let wg_peer_ip_table = self.wg_peer_ip_table.clone(); + self.tasks.lock().unwrap().spawn(async move { + while let Ok(t) = l.accept().await { + let Some(tasks) = tasks.upgrade() else { + break; + }; + tasks.lock().unwrap().spawn(Self::handle_incoming_conn( + t, + peer_mgr.clone(), + wg_peer_ip_table.clone(), + )); + } + }); + + self.start_pipeline_processor().await; + + Ok(()) + } +} + +#[derive(Default)] +pub struct WireGuard { + inner: Option, +} + +#[async_trait::async_trait] +impl VpnPortal for WireGuard { + async fn start( + &mut self, + global_ctx: ArcGlobalCtx, + peer_mgr: Arc, + ) -> anyhow::Result<()> { + assert!(self.inner.is_none()); + + let vpn_cfg = global_ctx.config.get_vpn_portal_config(); + if vpn_cfg.is_none() { + anyhow::bail!("vpn cfg is not set for wireguard vpn portal"); + } + + let inner = WireGuardImpl::new(global_ctx, peer_mgr); + inner.start().await?; + self.inner = Some(inner); + Ok(()) + } + + async fn dump_client_config(&self, peer_mgr: Arc) -> String { + let global_ctx = self.inner.as_ref().unwrap().global_ctx.clone(); + let routes = peer_mgr.list_routes().await; + let mut allow_ips = routes + .iter() + .map(|x| x.proxy_cidrs.iter().map(String::to_string)) + .flatten() + .collect::>(); + for ipv4 in routes.iter().map(|x| &x.ipv4_addr) { + let Ok(ipv4) = ipv4.parse() else { + continue; + }; + let inet = Ipv4Inet::new(ipv4, 24).unwrap(); + allow_ips.push(inet.network().to_string()); + break; + } + + let allow_ips = allow_ips + .into_iter() + .map(|x| x.to_string()) + .collect::>() + .join(","); + + let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap(); + let client_cidr = vpn_cfg.client_cidr; + + let cfg = self.inner.as_ref().unwrap().wg_config.clone(); + let cfg_str = format!( + r#" +[Interface] +PrivateKey = {peer_secret_key} +Address = {client_cidr} # should assign an ip from this cidr manually + +[Peer] +PublicKey = {my_public_key} +AllowedIPs = {allow_ips} +Endpoint = {listenr_addr} # should be the public ip of the vpn server +"#, + peer_secret_key = BASE64_STANDARD.encode(cfg.peer_secret_key()), + my_public_key = BASE64_STANDARD.encode(cfg.my_public_key()), + listenr_addr = self.inner.as_ref().unwrap().listenr_addr, + allow_ips = allow_ips, + client_cidr = client_cidr, + ); + + cfg_str + } + + fn name(&self) -> String { + "wireguard".to_string() + } + + async fn list_clients(&self) -> Vec { + self.inner + .as_ref() + .unwrap() + .wg_peer_ip_table + .iter() + .map(|x| { + x.value() + .endpoint_addr + .as_ref() + .map(|x| x.to_string()) + .unwrap_or_default() + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + use crate::{ + common::{ + config::{NetworkIdentity, VpnPortalConfig}, + global_ctx::tests::get_mock_global_ctx_with_network, + }, + connector::udp_hole_punch::tests::replace_stun_info_collector, + peers::{ + peer_manager::{PeerManager, RouteAlgoType}, + tests::wait_for_condition, + }, + rpc::NatType, + tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector}, + }; + + async fn portal_test() { + let (s, _r) = tokio::sync::mpsc::channel(1000); + let peer_mgr = Arc::new(PeerManager::new( + RouteAlgoType::Ospf, + get_mock_global_ctx_with_network(Some(NetworkIdentity { + network_name: "sijie".to_string(), + network_secret: "1919119".to_string(), + })), + s, + )); + replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown); + peer_mgr + .get_global_ctx() + .config + .set_vpn_portal_config(VpnPortalConfig { + wireguard_listen: "0.0.0.0:11021".parse().unwrap(), + client_cidr: "10.14.14.0/24".parse().unwrap(), + }); + peer_mgr.run().await.unwrap(); + let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap()); + let tunnel = pmgr_conn.connect().await; + peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap(); + wait_for_condition( + || async { + let routes = peer_mgr.list_routes().await; + println!("Routes: {:?}", routes); + routes.len() != 0 + }, + std::time::Duration::from_secs(10), + ) + .await; + + let mut wg = WireGuard::default(); + wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone()) + .await + .unwrap(); + } +}