fix firewall rule not specify interface (#1407)

This commit is contained in:
Sijie.Sun
2025-09-25 00:11:26 +08:00
committed by GitHub
parent 4445916ba7
commit 7035a3fef4
2 changed files with 38 additions and 72 deletions

View File

@@ -237,6 +237,7 @@ windows = { version = "0.52.0", features = [
"Win32_System_Com",
"Win32_Networking",
"Win32_System_Ole",
"Win32_System_Variant",
"Win32_Networking_WinSock",
"Win32_System_IO",
] }

View File

@@ -1,4 +1,4 @@
use std::{io, net::SocketAddr, os::windows::io::AsRawSocket};
use std::{io, mem::ManuallyDrop, net::SocketAddr, os::windows::io::AsRawSocket};
use anyhow::Context;
use network_interface::NetworkInterfaceConfig;
@@ -18,6 +18,8 @@ use windows::{
System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
},
System::Ole::{SafeArrayCreateVector, SafeArrayPutElement},
System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT},
},
};
@@ -247,20 +249,20 @@ pub fn add_interface_to_firewall_allowlist(interface_name: &str) -> anyhow::Resu
);
// Create rules for each protocol type
add_protocol_firewall_rules(&policy, interface_name, "TCP", 6)?; // TCP protocol number 6
add_protocol_firewall_rules(&policy, interface_name, "TCP", Some(6))?; // TCP protocol number 6
tracing::debug!("Added TCP firewall rules for interface: {}", interface_name);
add_protocol_firewall_rules(&policy, interface_name, "UDP", 17)?; // UDP protocol number 17
add_protocol_firewall_rules(&policy, interface_name, "UDP", Some(17))?; // UDP protocol number 17
tracing::debug!("Added UDP firewall rules for interface: {}", interface_name);
add_protocol_firewall_rules(&policy, interface_name, "ICMP", 1)?; // ICMP protocol number 1
add_protocol_firewall_rules(&policy, interface_name, "ICMP", Some(1))?; // ICMP protocol number 1
tracing::debug!(
"Added ICMP firewall rules for interface: {}",
interface_name
);
// Add fallback rules for all protocols
add_all_protocols_firewall_rules(&policy, interface_name)?;
add_protocol_firewall_rules(&policy, interface_name, "ALL", None)?;
tracing::debug!(
"Added fallback all-protocols rules for interface: {}",
interface_name
@@ -279,7 +281,7 @@ fn add_protocol_firewall_rules(
policy: &INetFwPolicy2,
interface_name: &str,
protocol_name: &str,
protocol_number: i32,
protocol_number: Option<i32>,
) -> anyhow::Result<()> {
// Create rules for both inbound and outbound traffic
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
@@ -307,7 +309,9 @@ fn add_protocol_firewall_rules(
unsafe {
rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?;
rule.SetProtocol(protocol_number)?;
if let Some(protocol_number) = protocol_number {
rule.SetProtocol(protocol_number)?;
}
rule.SetAction(NET_FW_ACTION_ALLOW)?;
if is_inbound {
@@ -322,61 +326,35 @@ fn add_protocol_firewall_rules(
)?;
rule.SetGrouping(&BSTR::from("EasyTier"))?;
// Get rule collection and add new rule
let rules = policy.Rules()?;
rules.Remove(&name_bstr)?; // Remove existing rule with same name first
rules.Add(&rule)?;
}
}
// Set the interface for this rule to apply to the specific network interface
// According to Microsoft docs, interfaces should be represented by their friendly name
// We need to create a SAFEARRAY of VARIANT strings containing the interface name
let interface_bstr = BSTR::from(interface_name);
Ok(())
}
/// Add fallback rules for all protocols
fn add_all_protocols_firewall_rules(
policy: &INetFwPolicy2,
interface_name: &str,
) -> anyhow::Result<()> {
// Create rules for both inbound and outbound traffic
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
// Create firewall rule instance
let rule: INetFwRule = unsafe {
CoCreateInstance(
&windows::Win32::NetworkManagement::WindowsFirewall::NetFwRule,
None,
CLSCTX_ALL,
)
}?;
let rule_name = format!(
"EasyTier {} - All Protocols ({})",
interface_name, direction_name
);
let description = format!(
"Allow all protocol traffic on EasyTier interface {}",
interface_name
);
let name_bstr = BSTR::from(&rule_name);
let desc_bstr = BSTR::from(&description);
unsafe {
rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?;
// Don't set protocol - allows all protocols by default
rule.SetAction(NET_FW_ACTION_ALLOW)?;
if is_inbound {
rule.SetDirection(NET_FW_RULE_DIR_IN)?;
} else {
rule.SetDirection(NET_FW_RULE_DIR_OUT)?;
// Create a SAFEARRAY containing one interface name
let interface_array = SafeArrayCreateVector(VT_VARIANT, 0, 1);
if interface_array.is_null() {
return Err(anyhow::anyhow!("Failed to create SAFEARRAY"));
}
rule.SetEnabled(windows::Win32::Foundation::VARIANT_TRUE)?;
rule.SetProfiles(
NET_FW_PROFILE2_PRIVATE.0 | NET_FW_PROFILE2_PUBLIC.0 | NET_FW_PROFILE2_DOMAIN.0,
let index = 0i32;
let mut variant_interface = VARIANT::default();
(*variant_interface.Anonymous.Anonymous).vt = VT_BSTR;
(*variant_interface.Anonymous.Anonymous).Anonymous.bstrVal =
ManuallyDrop::new(interface_bstr);
SafeArrayPutElement(
interface_array,
&index as *const _ as *const i32,
&variant_interface as *const _ as *const std::ffi::c_void,
)?;
rule.SetGrouping(&BSTR::from("EasyTier"))?;
// Create the VARIANT that contains the SAFEARRAY
let mut interface_variant = VARIANT::default();
(*interface_variant.Anonymous.Anonymous).vt = VARENUM(VT_ARRAY.0 | VT_VARIANT.0);
(*interface_variant.Anonymous.Anonymous).Anonymous.parray = interface_array;
rule.SetInterfaces(interface_variant)?;
// Get rule collection and add new rule
let rules = policy.Rules()?;
@@ -402,8 +380,7 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
let rules = unsafe { policy.Rules()? };
// Remove protocol-specific rules
for protocol_name in ["TCP", "UDP", "ICMP"] {
for protocol_name in ["TCP", "UDP", "ICMP", "ALL"] {
for direction in ["Inbound", "Outbound"] {
let rule_name = format!(
"EasyTier {} - {} Protocol ({})",
@@ -416,18 +393,6 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
}
}
// Remove fallback protocol rules
for direction in ["Inbound", "Outbound"] {
let rule_name = format!(
"EasyTier {} - All Protocols ({})",
interface_name, direction
);
let name_bstr = BSTR::from(&rule_name);
unsafe {
let _ = rules.Remove(&name_bstr); // Ignore errors, rule might not exist
}
}
Ok(())
}