mirror of
https://mirror.suhoan.cn/https://github.com/EasyTier/EasyTier.git
synced 2025-12-14 21:57:24 +08:00
fix firewall rule not specify interface (#1407)
This commit is contained in:
@@ -237,6 +237,7 @@ windows = { version = "0.52.0", features = [
|
|||||||
"Win32_System_Com",
|
"Win32_System_Com",
|
||||||
"Win32_Networking",
|
"Win32_Networking",
|
||||||
"Win32_System_Ole",
|
"Win32_System_Ole",
|
||||||
|
"Win32_System_Variant",
|
||||||
"Win32_Networking_WinSock",
|
"Win32_Networking_WinSock",
|
||||||
"Win32_System_IO",
|
"Win32_System_IO",
|
||||||
] }
|
] }
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use std::{io, net::SocketAddr, os::windows::io::AsRawSocket};
|
use std::{io, mem::ManuallyDrop, net::SocketAddr, os::windows::io::AsRawSocket};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use network_interface::NetworkInterfaceConfig;
|
use network_interface::NetworkInterfaceConfig;
|
||||||
@@ -18,6 +18,8 @@ use windows::{
|
|||||||
System::Com::{
|
System::Com::{
|
||||||
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
|
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
|
||||||
},
|
},
|
||||||
|
System::Ole::{SafeArrayCreateVector, SafeArrayPutElement},
|
||||||
|
System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -247,20 +249,20 @@ pub fn add_interface_to_firewall_allowlist(interface_name: &str) -> anyhow::Resu
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Create rules for each protocol type
|
// Create rules for each protocol type
|
||||||
add_protocol_firewall_rules(&policy, interface_name, "TCP", 6)?; // TCP protocol number 6
|
add_protocol_firewall_rules(&policy, interface_name, "TCP", Some(6))?; // TCP protocol number 6
|
||||||
tracing::debug!("Added TCP firewall rules for interface: {}", interface_name);
|
tracing::debug!("Added TCP firewall rules for interface: {}", interface_name);
|
||||||
|
|
||||||
add_protocol_firewall_rules(&policy, interface_name, "UDP", 17)?; // UDP protocol number 17
|
add_protocol_firewall_rules(&policy, interface_name, "UDP", Some(17))?; // UDP protocol number 17
|
||||||
tracing::debug!("Added UDP firewall rules for interface: {}", interface_name);
|
tracing::debug!("Added UDP firewall rules for interface: {}", interface_name);
|
||||||
|
|
||||||
add_protocol_firewall_rules(&policy, interface_name, "ICMP", 1)?; // ICMP protocol number 1
|
add_protocol_firewall_rules(&policy, interface_name, "ICMP", Some(1))?; // ICMP protocol number 1
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Added ICMP firewall rules for interface: {}",
|
"Added ICMP firewall rules for interface: {}",
|
||||||
interface_name
|
interface_name
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add fallback rules for all protocols
|
// Add fallback rules for all protocols
|
||||||
add_all_protocols_firewall_rules(&policy, interface_name)?;
|
add_protocol_firewall_rules(&policy, interface_name, "ALL", None)?;
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Added fallback all-protocols rules for interface: {}",
|
"Added fallback all-protocols rules for interface: {}",
|
||||||
interface_name
|
interface_name
|
||||||
@@ -279,7 +281,7 @@ fn add_protocol_firewall_rules(
|
|||||||
policy: &INetFwPolicy2,
|
policy: &INetFwPolicy2,
|
||||||
interface_name: &str,
|
interface_name: &str,
|
||||||
protocol_name: &str,
|
protocol_name: &str,
|
||||||
protocol_number: i32,
|
protocol_number: Option<i32>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// Create rules for both inbound and outbound traffic
|
// Create rules for both inbound and outbound traffic
|
||||||
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
|
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
|
||||||
@@ -307,7 +309,9 @@ fn add_protocol_firewall_rules(
|
|||||||
unsafe {
|
unsafe {
|
||||||
rule.SetName(&name_bstr)?;
|
rule.SetName(&name_bstr)?;
|
||||||
rule.SetDescription(&desc_bstr)?;
|
rule.SetDescription(&desc_bstr)?;
|
||||||
|
if let Some(protocol_number) = protocol_number {
|
||||||
rule.SetProtocol(protocol_number)?;
|
rule.SetProtocol(protocol_number)?;
|
||||||
|
}
|
||||||
rule.SetAction(NET_FW_ACTION_ALLOW)?;
|
rule.SetAction(NET_FW_ACTION_ALLOW)?;
|
||||||
|
|
||||||
if is_inbound {
|
if is_inbound {
|
||||||
@@ -322,61 +326,35 @@ fn add_protocol_firewall_rules(
|
|||||||
)?;
|
)?;
|
||||||
rule.SetGrouping(&BSTR::from("EasyTier"))?;
|
rule.SetGrouping(&BSTR::from("EasyTier"))?;
|
||||||
|
|
||||||
// Get rule collection and add new rule
|
// Set the interface for this rule to apply to the specific network interface
|
||||||
let rules = policy.Rules()?;
|
// According to Microsoft docs, interfaces should be represented by their friendly name
|
||||||
rules.Remove(&name_bstr)?; // Remove existing rule with same name first
|
// We need to create a SAFEARRAY of VARIANT strings containing the interface name
|
||||||
rules.Add(&rule)?;
|
let interface_bstr = BSTR::from(interface_name);
|
||||||
}
|
|
||||||
|
// Create a SAFEARRAY containing one interface name
|
||||||
|
let interface_array = SafeArrayCreateVector(VT_VARIANT, 0, 1);
|
||||||
|
if interface_array.is_null() {
|
||||||
|
return Err(anyhow::anyhow!("Failed to create SAFEARRAY"));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
let index = 0i32;
|
||||||
}
|
let mut variant_interface = VARIANT::default();
|
||||||
|
(*variant_interface.Anonymous.Anonymous).vt = VT_BSTR;
|
||||||
|
(*variant_interface.Anonymous.Anonymous).Anonymous.bstrVal =
|
||||||
|
ManuallyDrop::new(interface_bstr);
|
||||||
|
|
||||||
/// Add fallback rules for all protocols
|
SafeArrayPutElement(
|
||||||
fn add_all_protocols_firewall_rules(
|
interface_array,
|
||||||
policy: &INetFwPolicy2,
|
&index as *const _ as *const i32,
|
||||||
interface_name: &str,
|
&variant_interface as *const _ as *const std::ffi::c_void,
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
// Create rules for both inbound and outbound traffic
|
|
||||||
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
|
|
||||||
// Create firewall rule instance
|
|
||||||
let rule: INetFwRule = unsafe {
|
|
||||||
CoCreateInstance(
|
|
||||||
&windows::Win32::NetworkManagement::WindowsFirewall::NetFwRule,
|
|
||||||
None,
|
|
||||||
CLSCTX_ALL,
|
|
||||||
)
|
|
||||||
}?;
|
|
||||||
|
|
||||||
let rule_name = format!(
|
|
||||||
"EasyTier {} - All Protocols ({})",
|
|
||||||
interface_name, direction_name
|
|
||||||
);
|
|
||||||
let description = format!(
|
|
||||||
"Allow all protocol traffic on EasyTier interface {}",
|
|
||||||
interface_name
|
|
||||||
);
|
|
||||||
|
|
||||||
let name_bstr = BSTR::from(&rule_name);
|
|
||||||
let desc_bstr = BSTR::from(&description);
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
rule.SetName(&name_bstr)?;
|
|
||||||
rule.SetDescription(&desc_bstr)?;
|
|
||||||
// Don't set protocol - allows all protocols by default
|
|
||||||
rule.SetAction(NET_FW_ACTION_ALLOW)?;
|
|
||||||
|
|
||||||
if is_inbound {
|
|
||||||
rule.SetDirection(NET_FW_RULE_DIR_IN)?;
|
|
||||||
} else {
|
|
||||||
rule.SetDirection(NET_FW_RULE_DIR_OUT)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
rule.SetEnabled(windows::Win32::Foundation::VARIANT_TRUE)?;
|
|
||||||
rule.SetProfiles(
|
|
||||||
NET_FW_PROFILE2_PRIVATE.0 | NET_FW_PROFILE2_PUBLIC.0 | NET_FW_PROFILE2_DOMAIN.0,
|
|
||||||
)?;
|
)?;
|
||||||
rule.SetGrouping(&BSTR::from("EasyTier"))?;
|
|
||||||
|
// Create the VARIANT that contains the SAFEARRAY
|
||||||
|
let mut interface_variant = VARIANT::default();
|
||||||
|
(*interface_variant.Anonymous.Anonymous).vt = VARENUM(VT_ARRAY.0 | VT_VARIANT.0);
|
||||||
|
(*interface_variant.Anonymous.Anonymous).Anonymous.parray = interface_array;
|
||||||
|
|
||||||
|
rule.SetInterfaces(interface_variant)?;
|
||||||
|
|
||||||
// Get rule collection and add new rule
|
// Get rule collection and add new rule
|
||||||
let rules = policy.Rules()?;
|
let rules = policy.Rules()?;
|
||||||
@@ -402,8 +380,7 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
|
|||||||
|
|
||||||
let rules = unsafe { policy.Rules()? };
|
let rules = unsafe { policy.Rules()? };
|
||||||
|
|
||||||
// Remove protocol-specific rules
|
for protocol_name in ["TCP", "UDP", "ICMP", "ALL"] {
|
||||||
for protocol_name in ["TCP", "UDP", "ICMP"] {
|
|
||||||
for direction in ["Inbound", "Outbound"] {
|
for direction in ["Inbound", "Outbound"] {
|
||||||
let rule_name = format!(
|
let rule_name = format!(
|
||||||
"EasyTier {} - {} Protocol ({})",
|
"EasyTier {} - {} Protocol ({})",
|
||||||
@@ -416,18 +393,6 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove fallback protocol rules
|
|
||||||
for direction in ["Inbound", "Outbound"] {
|
|
||||||
let rule_name = format!(
|
|
||||||
"EasyTier {} - All Protocols ({})",
|
|
||||||
interface_name, direction
|
|
||||||
);
|
|
||||||
let name_bstr = BSTR::from(&rule_name);
|
|
||||||
unsafe {
|
|
||||||
let _ = rules.Remove(&name_bstr); // Ignore errors, rule might not exist
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user