fix dns query (#864)

1. dns resolver should be global unique so dns cache can work. avoid dns query influence hole punching.
2. when system dns failed, fallback to hickory dns.
This commit is contained in:
Sijie.Sun
2025-05-23 10:34:28 +08:00
committed by GitHub
parent 83d1ecc4da
commit 5a2fd4465c
14 changed files with 201 additions and 110 deletions

134
easytier/src/common/dns.rs Normal file
View File

@@ -0,0 +1,134 @@
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use anyhow::Context;
use hickory_proto::runtime::TokioRuntimeProvider;
use hickory_proto::xfer::Protocol;
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, ResolverConfig, ResolverOpts};
use hickory_resolver::name_server::{GenericConnector, TokioConnectionProvider};
use hickory_resolver::system_conf::read_system_conf;
use hickory_resolver::{Resolver, TokioResolver};
use once_cell::sync::Lazy;
use tokio::net::lookup_host;
use super::error::Error;
pub fn get_default_resolver_config() -> ResolverConfig {
let mut default_resolve_config = ResolverConfig::new();
default_resolve_config.add_name_server(NameServerConfig::new(
"223.5.5.5:53".parse().unwrap(),
Protocol::Udp,
));
default_resolve_config.add_name_server(NameServerConfig::new(
"180.184.1.1:53".parse().unwrap(),
Protocol::Udp,
));
default_resolve_config
}
pub static ALLOW_USE_SYSTEM_DNS_RESOLVER: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(true));
pub static RESOLVER: Lazy<Arc<Resolver<GenericConnector<TokioRuntimeProvider>>>> =
Lazy::new(|| {
let system_cfg = read_system_conf();
let mut cfg = get_default_resolver_config();
let mut opt = ResolverOpts::default();
if let Ok(s) = system_cfg {
for ns in s.0.name_servers() {
cfg.add_name_server(ns.clone());
}
opt = s.1;
}
opt.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
.with_options(opt);
Arc::new(builder.build())
});
pub async fn resolve_txt_record(domain_name: &str) -> Result<String, Error> {
let r = RESOLVER.clone();
let response = r.txt_lookup(domain_name).await.with_context(|| {
format!(
"txt_lookup failed, domain_name: {}",
domain_name.to_string()
)
})?;
let txt_record = response.iter().next().with_context(|| {
format!(
"no txt record found, domain_name: {}",
domain_name.to_string()
)
})?;
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
tracing::info!(?txt_data, ?domain_name, "get txt record");
Ok(txt_data.to_string())
}
pub async fn socket_addrs(
url: &url::Url,
default_port_number: impl Fn() -> Option<u16>,
) -> Result<Vec<SocketAddr>, Error> {
let host = url.host_str().ok_or(Error::InvalidUrl(url.to_string()))?;
let port = url
.port()
.or_else(default_port_number)
.ok_or(Error::InvalidUrl(url.to_string()))?;
// if host is an ip address, return it directly
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return Ok(vec![SocketAddr::new(ip, port)]);
}
if ALLOW_USE_SYSTEM_DNS_RESOLVER.load(std::sync::atomic::Ordering::Relaxed) {
let socket_addr = format!("{}:{}", host, port);
match lookup_host(socket_addr).await {
Ok(a) => {
let a = a.collect();
tracing::debug!(?a, "system dns lookup done");
return Ok(a);
}
Err(e) => {
tracing::error!(?e, "system dns lookup failed");
}
}
}
// use hickory_resolver
let ret = RESOLVER.lookup_ip(host).await.with_context(|| {
format!(
"hickory dns lookup_ip failed, host: {}, port: {}",
host, port
)
})?;
Ok(ret
.iter()
.map(|ip| SocketAddr::new(ip, port))
.collect::<Vec<_>>())
}
#[cfg(test)]
mod tests {
use crate::defer;
use super::*;
#[tokio::test]
async fn test_socket_addrs() {
let url = url::Url::parse("tcp://public.easytier.cn:80").unwrap();
let addrs = socket_addrs(&url, || Some(80)).await.unwrap();
assert_eq!(2, addrs.len(), "addrs: {:?}", addrs);
println!("addrs: {:?}", addrs);
ALLOW_USE_SYSTEM_DNS_RESOLVER.store(false, std::sync::atomic::Ordering::Relaxed);
defer!(
ALLOW_USE_SYSTEM_DNS_RESOLVER.store(true, std::sync::atomic::Ordering::Relaxed);
);
let addrs = socket_addrs(&url, || Some(80)).await.unwrap();
assert_eq!(2, addrs.len(), "addrs: {:?}", addrs);
println!("addrs2: {:?}", addrs);
}
}

View File

@@ -11,6 +11,7 @@ pub mod compressor;
pub mod config;
pub mod constants;
pub mod defer;
pub mod dns;
pub mod error;
pub mod global_ctx;
pub mod ifcfg;

View File

@@ -8,10 +8,6 @@ use crate::proto::common::{NatType, StunInfo};
use anyhow::Context;
use chrono::Local;
use crossbeam::atomic::AtomicCell;
use hickory_proto::xfer::Protocol;
use hickory_resolver::config::{NameServerConfig, ResolverConfig};
use hickory_resolver::name_server::TokioConnectionProvider;
use hickory_resolver::TokioResolver;
use rand::seq::IteratorRandom;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::{broadcast, Mutex};
@@ -24,45 +20,9 @@ use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
use crate::common::error::Error;
use super::dns::resolve_txt_record;
use super::stun_codec_ext::*;
pub fn get_default_resolver_config() -> ResolverConfig {
let mut default_resolve_config = ResolverConfig::new();
default_resolve_config.add_name_server(NameServerConfig::new(
"223.5.5.5:53".parse().unwrap(),
Protocol::Udp,
));
default_resolve_config.add_name_server(NameServerConfig::new(
"180.184.1.1:53".parse().unwrap(),
Protocol::Udp,
));
default_resolve_config
}
pub async fn resolve_txt_record(
domain_name: &str,
resolver: &TokioResolver,
) -> Result<String, Error> {
let response = resolver.txt_lookup(domain_name).await.with_context(|| {
format!(
"txt_lookup failed, domain_name: {}",
domain_name.to_string()
)
})?;
let txt_record = response.iter().next().with_context(|| {
format!(
"no txt record found, domain_name: {}",
domain_name.to_string()
)
})?;
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
tracing::info!(?txt_data, ?domain_name, "get txt record");
Ok(txt_data.to_string())
}
struct HostResolverIter {
hostnames: Vec<String>,
ips: Vec<SocketAddr>,
@@ -81,13 +41,7 @@ impl HostResolverIter {
}
async fn get_txt_record(domain_name: &str) -> Result<Vec<String>, Error> {
let resolver = TokioResolver::builder_tokio()
.unwrap_or(TokioResolver::builder_with_config(
get_default_resolver_config(),
TokioConnectionProvider::default(),
))
.build();
let txt_data = resolve_txt_record(domain_name, &resolver).await?;
let txt_data = resolve_txt_record(domain_name).await?;
Ok(txt_data.split(" ").map(|x| x.to_string()).collect())
}

View File

@@ -2,17 +2,15 @@ use std::{net::SocketAddr, sync::Arc};
use crate::{
common::{
dns::{resolve_txt_record, RESOLVER},
error::Error,
global_ctx::ArcGlobalCtx,
stun::{get_default_resolver_config, resolve_txt_record},
},
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET},
};
use anyhow::Context;
use dashmap::DashSet;
use hickory_resolver::{
name_server::TokioConnectionProvider, proto::rr::rdata::SRV, TokioResolver,
};
use hickory_resolver::proto::rr::rdata::SRV;
use rand::{seq::SliceRandom, Rng as _};
use crate::proto::common::TunnelInfo;
@@ -58,14 +56,7 @@ impl DNSTunnelConnector {
&self,
domain_name: &str,
) -> Result<Box<dyn TunnelConnector>, Error> {
let resolver = TokioResolver::builder_tokio()
.unwrap_or(TokioResolver::builder_with_config(
get_default_resolver_config(),
TokioConnectionProvider::default(),
))
.build();
let txt_data = resolve_txt_record(domain_name, &resolver)
let txt_data = resolve_txt_record(domain_name)
.await
.with_context(|| format!("resolve txt record failed, domain_name: {}", domain_name))?;
@@ -120,13 +111,6 @@ impl DNSTunnelConnector {
) -> Result<Box<dyn TunnelConnector>, Error> {
tracing::info!("handle_srv_record: {}", domain_name);
let resolver = TokioResolver::builder_tokio()
.unwrap_or(TokioResolver::builder_with_config(
get_default_resolver_config(),
TokioConnectionProvider::default(),
))
.build();
let srv_domains = PROTO_PORT_OFFSET
.iter()
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name}
@@ -136,7 +120,7 @@ impl DNSTunnelConnector {
let srv_lookup_tasks = srv_domains
.iter()
.map(|(srv_domain, protocol)| {
let resolver = resolver.clone();
let resolver = RESOLVER.clone();
let responses = responses.clone();
async move {
let response = resolver.srv_lookup(srv_domain).await.with_context(|| {

View File

@@ -60,7 +60,8 @@ pub async fn create_connector_by_url(
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() {
"tcp" => {
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version)?;
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version).await?;
let mut connector = TcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
@@ -73,7 +74,8 @@ pub async fn create_connector_by_url(
Box::new(connector)
}
"udp" => {
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version)?;
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version).await?;
let mut connector = UdpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
@@ -90,14 +92,14 @@ pub async fn create_connector_by_url(
Box::new(connector)
}
"ring" => {
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both)?;
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both).await?;
let connector = RingTunnelConnector::new(url);
Box::new(connector)
}
#[cfg(feature = "quic")]
"quic" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version)?;
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version).await?;
let mut connector = QUICTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
@@ -111,7 +113,8 @@ pub async fn create_connector_by_url(
}
#[cfg(feature = "wireguard")]
"wg" => {
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version)?;
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version).await?;
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
@@ -131,7 +134,7 @@ pub async fn create_connector_by_url(
#[cfg(feature = "websocket")]
"ws" | "wss" => {
use crate::tunnel::FromUrl;
let dst_addr = SocketAddr::from_url(url.clone(), ip_version)?;
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(

View File

@@ -19,7 +19,7 @@ use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::task::JoinSet;
use crate::common::stun::get_default_resolver_config;
use crate::common::dns::get_default_resolver_config;
use super::config::{GeneralConfig, Record, RunConfig};

View File

@@ -401,8 +401,8 @@ pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
E: std::error::Error + Into<TunnelError> + Send + Sync + 'static,
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;

View File

@@ -8,6 +8,7 @@ use std::fmt::Debug;
use tokio::time::error::Elapsed;
use crate::common::dns::socket_addrs;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket;
@@ -169,13 +170,14 @@ impl std::fmt::Debug for dyn TunnelListener {
}
}
#[async_trait::async_trait]
pub(crate) trait FromUrl {
fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError>
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError>
where
Self: Sized;
}
pub(crate) fn check_scheme_and_get_socket_addr_ext<T>(
pub(crate) async fn check_scheme_and_get_socket_addr_ext<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
@@ -187,10 +189,10 @@ where
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone(), ip_version)?)
Ok(T::from_url(url.clone(), ip_version).await?)
}
pub(crate) fn check_scheme_and_get_socket_addr<T>(
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
@@ -202,7 +204,7 @@ where
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone(), ip_version)?)
Ok(T::from_url(url.clone(), ip_version).await?)
}
fn default_port(scheme: &str) -> Option<u16> {
@@ -217,9 +219,17 @@ fn default_port(scheme: &str) -> Option<u16> {
}
}
#[async_trait::async_trait]
impl FromUrl for SocketAddr {
fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = url.socket_addrs(|| default_port(url.scheme()))?;
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = socket_addrs(&url, || default_port(url.scheme()))
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
let addrs = addrs
.into_iter()
@@ -239,8 +249,9 @@ impl FromUrl for SocketAddr {
}
}
#[async_trait::async_trait]
impl FromUrl for uuid::Uuid {
fn from_url(url: url::Url, _ip_version: IpVersion) -> Result<Self, TunnelError> {
async fn from_url(url: url::Url, _ip_version: IpVersion) -> Result<Self, TunnelError> {
let o = url.host_str().unwrap();
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
Ok(o)

View File

@@ -85,7 +85,8 @@ impl QUICTunnelListener {
impl TunnelListener for QUICTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)
.await?;
let (endpoint, server_cert) = make_server_endpoint(addr).unwrap();
self.endpoint = Some(endpoint);
self.server_cert = Some(server_cert);
@@ -149,11 +150,9 @@ impl QUICTunnelConnector {
#[async_trait::async_trait]
impl TunnelConnector for QUICTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr = check_scheme_and_get_socket_addr_ext::<SocketAddr>(
&self.addr,
"quic",
self.ip_version,
)?;
let addr =
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "quic", self.ip_version)
.await?;
let local_addr = if addr.is_ipv4() {
"0.0.0.0:0"
} else {

View File

@@ -230,12 +230,13 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
}
impl RingTunnelListener {
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
async fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(
&self.listerner_addr,
"ring",
super::IpVersion::Both,
)
.await
}
}
@@ -246,13 +247,13 @@ impl TunnelListener for RingTunnelListener {
CONNECTION_MAP
.lock()
.await
.insert(self.get_addr()?, self.conn_sender.clone());
.insert(self.get_addr().await?, self.conn_sender.clone());
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr);
let my_addr = self.get_addr()?;
let my_addr = self.get_addr().await?;
if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr {
tracing::info!("accept new conn of key: {}", self.listerner_addr);
@@ -292,7 +293,8 @@ impl TunnelConnector for RingTunnelConnector {
&self.remote_addr,
"ring",
super::IpVersion::Both,
)?;
)
.await?;
let entry = CONNECTION_MAP
.lock()
.await

View File

@@ -59,7 +59,8 @@ impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)
.await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -190,7 +191,8 @@ impl TcpTunnelConnector {
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "tcp", self.ip_version)?;
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "tcp", self.ip_version)
.await?;
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
} else {

View File

@@ -477,7 +477,8 @@ impl TunnelListener for UdpTunnelListener {
&self.addr,
"udp",
IpVersion::Both,
)?;
)
.await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -781,7 +782,8 @@ impl super::TunnelConnector for UdpTunnelConnector {
&self.addr,
"udp",
self.ip_version,
)?;
)
.await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {
@@ -963,6 +965,7 @@ mod tests {
"udp",
IpVersion::Both,
)
.await
.unwrap();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),

View File

@@ -121,7 +121,7 @@ impl WSTunnelListener {
#[async_trait::async_trait]
impl TunnelListener for WSTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both)?;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
@@ -182,7 +182,7 @@ impl WSTunnelConnector {
tcp_socket: TcpSocket,
) -> Result<Box<dyn Tunnel>, TunnelError> {
let is_wss = is_wss(&addr)?;
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version)?;
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version).await?;
let domain = addr.domain();
let host = socket_addr.ip();
let stream = tcp_socket.connect(socket_addr).await?;
@@ -205,12 +205,8 @@ impl WSTunnelConnector {
let tls_conn =
tokio_rustls::TlsConnector::from(Arc::new(get_insecure_tls_client_config()));
let domain_or_ip = match domain {
None => {
host.to_string()
}
Some(domain) => {
domain.to_string()
}
None => host.to_string(),
Some(domain) => domain.to_string(),
};
let stream = tls_conn
.connect(domain_or_ip.try_into().unwrap(), stream)
@@ -274,7 +270,7 @@ impl WSTunnelConnector {
#[async_trait::async_trait]
impl TunnelConnector for WSTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?;
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {

View File

@@ -548,7 +548,8 @@ impl WgTunnelListener {
impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)
.await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
@@ -705,7 +706,8 @@ impl super::TunnelConnector for WgTunnelConnector {
&self.addr,
"wg",
self.ip_version,
)?;
)
.await?;
if addr.is_ipv6() {
return self.connect_with_ipv6(addr).await;