fix firewall rule not specify interface (#1407)

This commit is contained in:
Sijie.Sun
2025-09-25 00:11:26 +08:00
committed by GitHub
parent 4445916ba7
commit 7035a3fef4
2 changed files with 38 additions and 72 deletions

View File

@@ -237,6 +237,7 @@ windows = { version = "0.52.0", features = [
"Win32_System_Com", "Win32_System_Com",
"Win32_Networking", "Win32_Networking",
"Win32_System_Ole", "Win32_System_Ole",
"Win32_System_Variant",
"Win32_Networking_WinSock", "Win32_Networking_WinSock",
"Win32_System_IO", "Win32_System_IO",
] } ] }

View File

@@ -1,4 +1,4 @@
use std::{io, net::SocketAddr, os::windows::io::AsRawSocket}; use std::{io, mem::ManuallyDrop, net::SocketAddr, os::windows::io::AsRawSocket};
use anyhow::Context; use anyhow::Context;
use network_interface::NetworkInterfaceConfig; use network_interface::NetworkInterfaceConfig;
@@ -18,6 +18,8 @@ use windows::{
System::Com::{ System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
}, },
System::Ole::{SafeArrayCreateVector, SafeArrayPutElement},
System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT},
}, },
}; };
@@ -247,20 +249,20 @@ pub fn add_interface_to_firewall_allowlist(interface_name: &str) -> anyhow::Resu
); );
// Create rules for each protocol type // Create rules for each protocol type
add_protocol_firewall_rules(&policy, interface_name, "TCP", 6)?; // TCP protocol number 6 add_protocol_firewall_rules(&policy, interface_name, "TCP", Some(6))?; // TCP protocol number 6
tracing::debug!("Added TCP firewall rules for interface: {}", interface_name); tracing::debug!("Added TCP firewall rules for interface: {}", interface_name);
add_protocol_firewall_rules(&policy, interface_name, "UDP", 17)?; // UDP protocol number 17 add_protocol_firewall_rules(&policy, interface_name, "UDP", Some(17))?; // UDP protocol number 17
tracing::debug!("Added UDP firewall rules for interface: {}", interface_name); tracing::debug!("Added UDP firewall rules for interface: {}", interface_name);
add_protocol_firewall_rules(&policy, interface_name, "ICMP", 1)?; // ICMP protocol number 1 add_protocol_firewall_rules(&policy, interface_name, "ICMP", Some(1))?; // ICMP protocol number 1
tracing::debug!( tracing::debug!(
"Added ICMP firewall rules for interface: {}", "Added ICMP firewall rules for interface: {}",
interface_name interface_name
); );
// Add fallback rules for all protocols // Add fallback rules for all protocols
add_all_protocols_firewall_rules(&policy, interface_name)?; add_protocol_firewall_rules(&policy, interface_name, "ALL", None)?;
tracing::debug!( tracing::debug!(
"Added fallback all-protocols rules for interface: {}", "Added fallback all-protocols rules for interface: {}",
interface_name interface_name
@@ -279,7 +281,7 @@ fn add_protocol_firewall_rules(
policy: &INetFwPolicy2, policy: &INetFwPolicy2,
interface_name: &str, interface_name: &str,
protocol_name: &str, protocol_name: &str,
protocol_number: i32, protocol_number: Option<i32>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create rules for both inbound and outbound traffic // Create rules for both inbound and outbound traffic
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] { for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
@@ -307,7 +309,9 @@ fn add_protocol_firewall_rules(
unsafe { unsafe {
rule.SetName(&name_bstr)?; rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?; rule.SetDescription(&desc_bstr)?;
if let Some(protocol_number) = protocol_number {
rule.SetProtocol(protocol_number)?; rule.SetProtocol(protocol_number)?;
}
rule.SetAction(NET_FW_ACTION_ALLOW)?; rule.SetAction(NET_FW_ACTION_ALLOW)?;
if is_inbound { if is_inbound {
@@ -322,61 +326,35 @@ fn add_protocol_firewall_rules(
)?; )?;
rule.SetGrouping(&BSTR::from("EasyTier"))?; rule.SetGrouping(&BSTR::from("EasyTier"))?;
// Get rule collection and add new rule // Set the interface for this rule to apply to the specific network interface
let rules = policy.Rules()?; // According to Microsoft docs, interfaces should be represented by their friendly name
rules.Remove(&name_bstr)?; // Remove existing rule with same name first // We need to create a SAFEARRAY of VARIANT strings containing the interface name
rules.Add(&rule)?; let interface_bstr = BSTR::from(interface_name);
}
// Create a SAFEARRAY containing one interface name
let interface_array = SafeArrayCreateVector(VT_VARIANT, 0, 1);
if interface_array.is_null() {
return Err(anyhow::anyhow!("Failed to create SAFEARRAY"));
} }
Ok(()) let index = 0i32;
} let mut variant_interface = VARIANT::default();
(*variant_interface.Anonymous.Anonymous).vt = VT_BSTR;
(*variant_interface.Anonymous.Anonymous).Anonymous.bstrVal =
ManuallyDrop::new(interface_bstr);
/// Add fallback rules for all protocols SafeArrayPutElement(
fn add_all_protocols_firewall_rules( interface_array,
policy: &INetFwPolicy2, &index as *const _ as *const i32,
interface_name: &str, &variant_interface as *const _ as *const std::ffi::c_void,
) -> anyhow::Result<()> {
// Create rules for both inbound and outbound traffic
for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] {
// Create firewall rule instance
let rule: INetFwRule = unsafe {
CoCreateInstance(
&windows::Win32::NetworkManagement::WindowsFirewall::NetFwRule,
None,
CLSCTX_ALL,
)
}?;
let rule_name = format!(
"EasyTier {} - All Protocols ({})",
interface_name, direction_name
);
let description = format!(
"Allow all protocol traffic on EasyTier interface {}",
interface_name
);
let name_bstr = BSTR::from(&rule_name);
let desc_bstr = BSTR::from(&description);
unsafe {
rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?;
// Don't set protocol - allows all protocols by default
rule.SetAction(NET_FW_ACTION_ALLOW)?;
if is_inbound {
rule.SetDirection(NET_FW_RULE_DIR_IN)?;
} else {
rule.SetDirection(NET_FW_RULE_DIR_OUT)?;
}
rule.SetEnabled(windows::Win32::Foundation::VARIANT_TRUE)?;
rule.SetProfiles(
NET_FW_PROFILE2_PRIVATE.0 | NET_FW_PROFILE2_PUBLIC.0 | NET_FW_PROFILE2_DOMAIN.0,
)?; )?;
rule.SetGrouping(&BSTR::from("EasyTier"))?;
// Create the VARIANT that contains the SAFEARRAY
let mut interface_variant = VARIANT::default();
(*interface_variant.Anonymous.Anonymous).vt = VARENUM(VT_ARRAY.0 | VT_VARIANT.0);
(*interface_variant.Anonymous.Anonymous).Anonymous.parray = interface_array;
rule.SetInterfaces(interface_variant)?;
// Get rule collection and add new rule // Get rule collection and add new rule
let rules = policy.Rules()?; let rules = policy.Rules()?;
@@ -402,8 +380,7 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
let rules = unsafe { policy.Rules()? }; let rules = unsafe { policy.Rules()? };
// Remove protocol-specific rules for protocol_name in ["TCP", "UDP", "ICMP", "ALL"] {
for protocol_name in ["TCP", "UDP", "ICMP"] {
for direction in ["Inbound", "Outbound"] { for direction in ["Inbound", "Outbound"] {
let rule_name = format!( let rule_name = format!(
"EasyTier {} - {} Protocol ({})", "EasyTier {} - {} Protocol ({})",
@@ -416,18 +393,6 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<(
} }
} }
// Remove fallback protocol rules
for direction in ["Inbound", "Outbound"] {
let rule_name = format!(
"EasyTier {} - All Protocols ({})",
interface_name, direction
);
let name_bstr = BSTR::from(&rule_name);
unsafe {
let _ = rules.Remove(&name_bstr); // Ignore errors, rule might not exist
}
}
Ok(()) Ok(())
} }