mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-14 13:47:24 +08:00
introduce websocket tunnel
This commit is contained in:
@@ -431,7 +431,14 @@ pub mod tests {
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
if !once {
|
||||
recv.forward(send).await.unwrap();
|
||||
while let Some(item) = recv.next().await {
|
||||
let Ok(msg) = item else {
|
||||
continue;
|
||||
};
|
||||
if let Err(_) = send.send(msg).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let Some(ret) = recv.next().await else {
|
||||
assert!(false, "recv error");
|
||||
@@ -447,6 +454,8 @@ pub mod tests {
|
||||
tracing::debug!(?res, "recv a msg, try echo back");
|
||||
send.send(res).await.unwrap();
|
||||
}
|
||||
let _ = send.flush().await;
|
||||
let _ = send.close().await;
|
||||
|
||||
tracing::warn!("echo server exit...");
|
||||
}
|
||||
@@ -506,7 +515,7 @@ pub mod tests {
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
|
||||
|
||||
drop(send);
|
||||
send.close().await.unwrap();
|
||||
|
||||
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||
lis.abort();
|
||||
@@ -562,6 +571,7 @@ pub mod tests {
|
||||
let _ = send.feed(item).await.unwrap();
|
||||
}
|
||||
|
||||
send.close().await.unwrap();
|
||||
drop(send);
|
||||
drop(connector);
|
||||
drop(tunnel);
|
||||
@@ -576,7 +586,7 @@ pub mod tests {
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
|
||||
86
easytier/src/tunnel/insecure_tls.rs
Normal file
86
easytier/src/tunnel/insecure_tls.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
|
||||
|
||||
/// Dummy certificate verifier that treats any certificate as valid.
|
||||
/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing.
|
||||
#[derive(Debug)]
|
||||
struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
|
||||
|
||||
impl SkipServerVerification {
|
||||
fn new(provider: Arc<rustls::crypto::CryptoProvider>) -> Arc<Self> {
|
||||
Arc::new(Self(provider))
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls12_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&self.0.signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &CertificateDer<'_>,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
rustls::crypto::verify_tls13_signature(
|
||||
message,
|
||||
cert,
|
||||
dss,
|
||||
&self.0.signature_verification_algorithms,
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
self.0.signature_verification_algorithms.supported_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_crypto_provider() {
|
||||
let _ =
|
||||
rustls::crypto::CryptoProvider::install_default(rustls::crypto::ring::default_provider());
|
||||
}
|
||||
|
||||
pub fn get_insecure_tls_client_config() -> rustls::ClientConfig {
|
||||
init_crypto_provider();
|
||||
let provider = rustls::crypto::CryptoProvider::get_default().unwrap();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(SkipServerVerification::new(provider.clone()))
|
||||
.with_no_client_auth();
|
||||
config.enable_sni = false;
|
||||
config.enable_early_data = false;
|
||||
config
|
||||
}
|
||||
|
||||
pub fn get_insecure_tls_cert<'a>() -> (Vec<CertificateDer<'a>>, PrivateKeyDer<'a>) {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
|
||||
let cert_der = cert.serialize_der().unwrap();
|
||||
let priv_key = cert.serialize_private_key_der();
|
||||
let priv_key = rustls::pki_types::PrivatePkcs8KeyDer::from(priv_key);
|
||||
let cert_chain = vec![cert_der.clone().into()];
|
||||
|
||||
(cert_chain, priv_key.into())
|
||||
}
|
||||
@@ -28,6 +28,12 @@ pub mod wireguard;
|
||||
#[cfg(feature = "quic")]
|
||||
pub mod quic;
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
pub mod websocket;
|
||||
|
||||
#[cfg(any(feature = "quic", feature = "websocket"))]
|
||||
pub mod insecure_tls;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("io error")]
|
||||
@@ -62,6 +68,10 @@ pub enum TunnelError {
|
||||
#[error("no dns record found")]
|
||||
NoDnsRecordFound(IpVersion),
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
#[error("websocket error: {0}")]
|
||||
WebSocketError(#[from] tokio_websockets::Error),
|
||||
|
||||
#[error("tunnel error: {0}")]
|
||||
TunError(String),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::{pin::Pin, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::{task::JoinHandle, time::timeout};
|
||||
|
||||
use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream};
|
||||
|
||||
@@ -42,6 +42,8 @@ impl<T: Tunnel> MpscTunnel<T> {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let close_ret = timeout(Duration::from_secs(5), sink.close()).await;
|
||||
tracing::warn!(?close_ret, "mpsc close sink");
|
||||
});
|
||||
|
||||
Self {
|
||||
|
||||
@@ -114,6 +114,7 @@ pub struct ZCPacketOffsets {
|
||||
pub tcp_tunnel_header_offset: usize,
|
||||
pub udp_tunnel_header_offset: usize,
|
||||
pub wg_tunnel_header_offset: usize,
|
||||
pub dummy_tunnel_header_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
@@ -126,6 +127,8 @@ pub enum ZCPacketType {
|
||||
WG,
|
||||
// received from local tun device, should reserve header space for tcp or udp tunnel
|
||||
NIC,
|
||||
// tunnel without header
|
||||
DummyTunnel,
|
||||
}
|
||||
|
||||
const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max(
|
||||
@@ -158,6 +161,7 @@ impl ZCPacketType {
|
||||
TCP_TUNNEL_HEADER_SIZE,
|
||||
WG_TUNNEL_HEADER_SIZE,
|
||||
),
|
||||
dummy_tunnel_header_offset: get_converted_offset(TCP_TUNNEL_HEADER_SIZE, 0),
|
||||
},
|
||||
ZCPacketType::UDP => ZCPacketOffsets {
|
||||
payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||
@@ -171,6 +175,7 @@ impl ZCPacketType {
|
||||
UDP_TUNNEL_HEADER_SIZE,
|
||||
WG_TUNNEL_HEADER_SIZE,
|
||||
),
|
||||
dummy_tunnel_header_offset: get_converted_offset(UDP_TUNNEL_HEADER_SIZE, 0),
|
||||
},
|
||||
ZCPacketType::WG => ZCPacketOffsets {
|
||||
payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||
@@ -184,6 +189,7 @@ impl ZCPacketType {
|
||||
UDP_TUNNEL_HEADER_SIZE,
|
||||
),
|
||||
wg_tunnel_header_offset: 0,
|
||||
dummy_tunnel_header_offset: get_converted_offset(WG_TUNNEL_HEADER_SIZE, 0),
|
||||
},
|
||||
ZCPacketType::NIC => ZCPacketOffsets {
|
||||
payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET,
|
||||
@@ -198,6 +204,16 @@ impl ZCPacketType {
|
||||
wg_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE
|
||||
- WG_TUNNEL_HEADER_SIZE,
|
||||
dummy_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE,
|
||||
},
|
||||
ZCPacketType::DummyTunnel => ZCPacketOffsets {
|
||||
payload_offset: PEER_MANAGER_HEADER_SIZE,
|
||||
peer_manager_header_offset: 0,
|
||||
tcp_tunnel_header_offset: get_converted_offset(0, TCP_TUNNEL_HEADER_SIZE),
|
||||
udp_tunnel_header_offset: get_converted_offset(0, UDP_TUNNEL_HEADER_SIZE),
|
||||
wg_tunnel_header_offset: get_converted_offset(0, WG_TUNNEL_HEADER_SIZE),
|
||||
dummy_tunnel_header_offset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -349,13 +365,21 @@ impl ZCPacket {
|
||||
hdr.len.set(payload_len as u32);
|
||||
}
|
||||
|
||||
fn tunnel_payload(&self) -> &[u8] {
|
||||
pub fn tunnel_payload(&self) -> &[u8] {
|
||||
&self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.peer_manager_header_offset..]
|
||||
}
|
||||
|
||||
pub fn tunnel_payload_bytes(mut self) -> BytesMut {
|
||||
self.inner.split_off(
|
||||
self.packet_type
|
||||
.get_packet_offsets()
|
||||
.peer_manager_header_offset,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn convert_type(mut self, target_packet_type: ZCPacketType) -> Self {
|
||||
if target_packet_type == self.packet_type {
|
||||
return self;
|
||||
@@ -377,6 +401,11 @@ impl ZCPacket {
|
||||
.get_packet_offsets()
|
||||
.wg_tunnel_header_offset
|
||||
}
|
||||
ZCPacketType::DummyTunnel => {
|
||||
self.packet_type
|
||||
.get_packet_offsets()
|
||||
.dummy_tunnel_header_offset
|
||||
}
|
||||
ZCPacketType::NIC => unreachable!(),
|
||||
};
|
||||
|
||||
|
||||
@@ -12,44 +12,18 @@ use crate::{
|
||||
},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use quinn::{ClientConfig, Connection, Endpoint, ServerConfig};
|
||||
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr, IpVersion, Tunnel, TunnelConnector, TunnelError,
|
||||
TunnelListener,
|
||||
check_scheme_and_get_socket_addr,
|
||||
insecure_tls::{get_insecure_tls_cert, get_insecure_tls_client_config},
|
||||
IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||
};
|
||||
|
||||
/// Dummy certificate verifier that treats any certificate as valid.
|
||||
/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing.
|
||||
struct SkipServerVerification;
|
||||
|
||||
impl SkipServerVerification {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::client::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::Certificate,
|
||||
_intermediates: &[rustls::Certificate],
|
||||
_server_name: &rustls::ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: std::time::SystemTime,
|
||||
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_client() -> ClientConfig {
|
||||
let crypto = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(SkipServerVerification::new())
|
||||
.with_no_client_auth();
|
||||
|
||||
ClientConfig::new(Arc::new(crypto))
|
||||
ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(get_insecure_tls_client_config()).unwrap(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
|
||||
@@ -68,18 +42,14 @@ pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec<u8>)
|
||||
|
||||
/// Returns default server configuration along with its certificate.
|
||||
fn configure_server() -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
|
||||
let cert_der = cert.serialize_der().unwrap();
|
||||
let priv_key = cert.serialize_private_key_der();
|
||||
let priv_key = rustls::PrivateKey(priv_key);
|
||||
let cert_chain = vec![rustls::Certificate(cert_der.clone())];
|
||||
let (certs, key) = get_insecure_tls_cert();
|
||||
|
||||
let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
|
||||
let mut server_config = ServerConfig::with_single_cert(certs.clone(), key.into())?;
|
||||
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
|
||||
transport_config.max_concurrent_uni_streams(10_u8.into());
|
||||
transport_config.max_concurrent_bidi_streams(10_u8.into());
|
||||
|
||||
Ok((server_config, cert_der))
|
||||
Ok((server_config, certs[0].to_vec()))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
|
||||
262
easytier/src/tunnel/websocket.rs
Normal file
262
easytier/src/tunnel/websocket.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_websockets::{ClientBuilder, Limits, Message};
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
|
||||
|
||||
use super::{
|
||||
common::{setup_sokcet2, TunnelWrapper},
|
||||
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
|
||||
packet_def::{ZCPacket, ZCPacketType},
|
||||
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||
};
|
||||
|
||||
fn is_wss(addr: &url::Url) -> Result<bool, TunnelError> {
|
||||
match addr.scheme() {
|
||||
"ws" => Ok(false),
|
||||
"wss" => Ok(true),
|
||||
_ => Err(TunnelError::InvalidProtocol(addr.scheme().to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn sink_from_zc_packet<E>(msg: ZCPacket) -> Result<Message, E> {
|
||||
Ok(Message::binary(msg.tunnel_payload_bytes().freeze()))
|
||||
}
|
||||
|
||||
async fn map_from_ws_message(
|
||||
msg: Result<Message, tokio_websockets::Error>,
|
||||
) -> Option<Result<ZCPacket, TunnelError>> {
|
||||
if msg.is_err() {
|
||||
tracing::error!(?msg, "recv from websocket error");
|
||||
return Some(Err(TunnelError::WebSocketError(msg.unwrap_err())));
|
||||
}
|
||||
|
||||
let msg = msg.unwrap();
|
||||
if msg.is_close() {
|
||||
tracing::warn!("recv close message from websocket");
|
||||
return None;
|
||||
}
|
||||
|
||||
if !msg.is_binary() {
|
||||
let msg = format!("{:?}", msg);
|
||||
tracing::error!(?msg, "Invalid packet");
|
||||
return Some(Err(TunnelError::InvalidPacket(msg)));
|
||||
}
|
||||
|
||||
Some(Ok(ZCPacket::new_from_buf(
|
||||
BytesMut::from(msg.into_payload().as_bytes()),
|
||||
ZCPacketType::DummyTunnel,
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WSTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl WSTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
WSTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_accept(&mut self, stream: TcpStream) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: self.addr.scheme().to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(
|
||||
&stream.peer_addr()?.to_string(),
|
||||
self.addr.scheme().to_string().as_str(),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
|
||||
let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited());
|
||||
|
||||
let ret: Box<dyn Tunnel> = if is_wss(&self.addr)? {
|
||||
init_crypto_provider();
|
||||
let (certs, key) = get_insecure_tls_cert();
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.with_context(|| "Failed to create server config")?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
|
||||
let stream = acceptor.accept(stream).await?;
|
||||
let (write, read) = server_bulder.accept(stream).await?.split();
|
||||
|
||||
Box::new(TunnelWrapper::new(
|
||||
read.filter_map(move |msg| map_from_ws_message(msg)),
|
||||
write.with(move |msg| sink_from_zc_packet(msg)),
|
||||
Some(info),
|
||||
))
|
||||
} else {
|
||||
let (write, read) = server_bulder.accept(stream).await?.split();
|
||||
Box::new(TunnelWrapper::new(
|
||||
read.filter_map(move |msg| map_from_ws_message(msg)),
|
||||
write.with(move |msg| sink_from_zc_packet(msg)),
|
||||
Some(info),
|
||||
))
|
||||
};
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for WSTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both)?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
self.addr
|
||||
.set_port(Some(socket.local_addr()?.port()))
|
||||
.unwrap();
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
loop {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
// only fail on tcp accept error
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
match self.try_accept(stream).await {
|
||||
Ok(tunnel) => return Ok(tunnel),
|
||||
Err(e) => {
|
||||
tracing::error!(?e, ?self, "Failed to accept ws/wss tunnel");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WSTunnelConnector {
|
||||
addr: url::Url,
|
||||
ip_version: IpVersion,
|
||||
}
|
||||
|
||||
impl WSTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
WSTunnelConnector {
|
||||
addr,
|
||||
ip_version: IpVersion::Both,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelConnector for WSTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let is_wss = is_wss(&self.addr)?;
|
||||
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?;
|
||||
let local_addr = if addr.is_ipv4() {
|
||||
"0.0.0.0:0"
|
||||
} else {
|
||||
"[::]:0"
|
||||
};
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: self.addr.scheme().to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(
|
||||
&local_addr.to_string(),
|
||||
self.addr.scheme().to_string().as_str(),
|
||||
)
|
||||
.into(),
|
||||
remote_addr: self.addr.to_string(),
|
||||
};
|
||||
|
||||
let connector =
|
||||
tokio_websockets::Connector::Rustls(Arc::new(get_insecure_tls_client_config()).into());
|
||||
let mut client_builder =
|
||||
ClientBuilder::from_uri(http::Uri::try_from(self.addr.to_string()).unwrap());
|
||||
if is_wss {
|
||||
init_crypto_provider();
|
||||
client_builder = client_builder.connector(&connector);
|
||||
}
|
||||
|
||||
let (client, _) = client_builder.connect().await?;
|
||||
|
||||
let (write, read) = client.split();
|
||||
let read = read.filter_map(move |msg| map_from_ws_message(msg));
|
||||
let write = write.with(move |msg| sink_from_zc_packet(msg));
|
||||
|
||||
Ok(Box::new(TunnelWrapper::new(read, write, Some(info))))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_ip_version(&mut self, ip_version: IpVersion) {
|
||||
self.ip_version = ip_version;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use crate::tunnel::common::tests::_tunnel_pingpong;
|
||||
use crate::tunnel::websocket::{WSTunnelConnector, WSTunnelListener};
|
||||
use crate::tunnel::{TunnelConnector, TunnelListener};
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
async fn ws_pingpong(#[values("ws", "wss")] proto: &str) {
|
||||
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
|
||||
let connector =
|
||||
WSTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
// TODO: tokio-websockets cannot correctly handle close, benchmark case is disabled
|
||||
// #[rstest::rstest]
|
||||
// #[tokio::test]
|
||||
// #[serial_test::serial]
|
||||
// async fn ws_bench(#[values("ws", "wss")] proto: &str) {
|
||||
// enable_log();
|
||||
// let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
|
||||
// let connector =
|
||||
// WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
|
||||
// _tunnel_bench(listener, connector).await
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn ws_accept_wss() {
|
||||
let mut listener = WSTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
|
||||
listener.listen().await.unwrap();
|
||||
let j = tokio::spawn(async move {
|
||||
let _ = listener.accept().await;
|
||||
});
|
||||
|
||||
let mut connector = WSTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
|
||||
connector.connect().await.unwrap_err();
|
||||
|
||||
let mut connector = WSTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
|
||||
connector.connect().await.unwrap();
|
||||
|
||||
j.abort();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user