bind socket to device on macos (#9)

bind socket to device on macos
This commit is contained in:
Sijie.Sun
2024-01-31 21:43:23 +08:00
committed by GitHub
parent a0e59f5c56
commit 95a52a4b5c
3 changed files with 55 additions and 26 deletions

View File

@@ -1,6 +1,6 @@
use std::{ use std::{
collections::VecDeque, collections::VecDeque,
net::IpAddr, net::{IpAddr, SocketAddr},
sync::Arc, sync::Arc,
task::{ready, Context, Poll}, task::{ready, Context, Poll},
}; };
@@ -269,6 +269,39 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None None
} }
pub(crate) fn setup_sokcet2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
) -> Result<(), TunnelError> {
socket2_socket.set_nonblocking(true)?;
socket2_socket.set_reuse_address(true)?;
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
socket2_socket.set_reuse_port(true)?;
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
// win can handle this with bind correctly
#[cfg(any(target_os = "ios", target_os = "macos"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
// use IP_BOUND_IF to bind device
unsafe {
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
tracing::warn!(?dev_idx, ?dev_name, "bind device");
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
}
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
}
Ok(())
}
pub mod tests { pub mod tests {
use std::time::Instant; use std::time::Instant;

View File

@@ -5,6 +5,8 @@ use futures::{stream::FuturesUnordered, StreamExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use crate::tunnels::common::setup_sokcet2;
use super::{ use super::{
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener, check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
}; };
@@ -112,32 +114,21 @@ impl TcpTunnelConnector {
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into()); return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
} }
async fn connect_with_custom_bind( async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
&mut self,
is_ipv4: bool,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new(); let mut futures = FuturesUnordered::new();
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?; let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
for bind_addr in self.bind_addrs.iter() { for bind_addr in self.bind_addrs.iter() {
let socket = if is_ipv4 { tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] let socket2_socket = socket2::Socket::new(
socket.set_reuseport(true)?; socket2::Domain::for_address(dst_addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, bind_addr)?;
socket.bind(*bind_addr)?; let socket = TcpSocket::from_std_stream(socket2_socket.into());
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
futures.push(socket.connect(dst_addr.clone())); futures.push(socket.connect(dst_addr.clone()));
} }
@@ -156,10 +147,8 @@ impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() { if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await self.connect_with_default_bind().await
} else if self.bind_addrs[0].is_ipv4() {
self.connect_with_custom_bind(true).await
} else { } else {
self.connect_with_custom_bind(false).await self.connect_with_custom_bind().await
} }
} }

View File

@@ -20,7 +20,7 @@ use crate::{
use super::{ use super::{
codec::BytesCodec, codec::BytesCodec,
common::{FramedTunnel, TunnelWithCustomInfo}, common::{setup_sokcet2, FramedTunnel, TunnelWithCustomInfo},
ring_tunnel::create_ring_tunnel_pair, ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelListener, DatagramSink, DatagramStream, Tunnel, TunnelListener,
}; };
@@ -269,7 +269,14 @@ impl UdpTunnelListener {
impl TunnelListener for UdpTunnelListener { impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> { async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?; let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
self.socket = Some(Arc::new(UdpSocket::bind(addr).await?));
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &addr)?;
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
let socket = self.socket.as_ref().unwrap().clone(); let socket = self.socket.as_ref().unwrap().clone();
let forward_tasks = self.forward_tasks.clone(); let forward_tasks = self.forward_tasks.clone();