diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 9c938d8..3b63896 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -237,6 +237,7 @@ windows = { version = "0.52.0", features = [ "Win32_System_Com", "Win32_Networking", "Win32_System_Ole", + "Win32_System_Variant", "Win32_Networking_WinSock", "Win32_System_IO", ] } diff --git a/easytier/src/arch/windows.rs b/easytier/src/arch/windows.rs index 4dae5c8..9d97ecf 100644 --- a/easytier/src/arch/windows.rs +++ b/easytier/src/arch/windows.rs @@ -1,4 +1,4 @@ -use std::{io, net::SocketAddr, os::windows::io::AsRawSocket}; +use std::{io, mem::ManuallyDrop, net::SocketAddr, os::windows::io::AsRawSocket}; use anyhow::Context; use network_interface::NetworkInterfaceConfig; @@ -18,6 +18,8 @@ use windows::{ System::Com::{ CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, }, + System::Ole::{SafeArrayCreateVector, SafeArrayPutElement}, + System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT}, }, }; @@ -247,20 +249,20 @@ pub fn add_interface_to_firewall_allowlist(interface_name: &str) -> anyhow::Resu ); // Create rules for each protocol type - add_protocol_firewall_rules(&policy, interface_name, "TCP", 6)?; // TCP protocol number 6 + add_protocol_firewall_rules(&policy, interface_name, "TCP", Some(6))?; // TCP protocol number 6 tracing::debug!("Added TCP firewall rules for interface: {}", interface_name); - add_protocol_firewall_rules(&policy, interface_name, "UDP", 17)?; // UDP protocol number 17 + add_protocol_firewall_rules(&policy, interface_name, "UDP", Some(17))?; // UDP protocol number 17 tracing::debug!("Added UDP firewall rules for interface: {}", interface_name); - add_protocol_firewall_rules(&policy, interface_name, "ICMP", 1)?; // ICMP protocol number 1 + add_protocol_firewall_rules(&policy, interface_name, "ICMP", Some(1))?; // ICMP protocol number 1 tracing::debug!( "Added ICMP firewall rules for interface: {}", interface_name ); // Add fallback rules for all protocols - add_all_protocols_firewall_rules(&policy, interface_name)?; + add_protocol_firewall_rules(&policy, interface_name, "ALL", None)?; tracing::debug!( "Added fallback all-protocols rules for interface: {}", interface_name @@ -279,7 +281,7 @@ fn add_protocol_firewall_rules( policy: &INetFwPolicy2, interface_name: &str, protocol_name: &str, - protocol_number: i32, + protocol_number: Option, ) -> anyhow::Result<()> { // Create rules for both inbound and outbound traffic for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] { @@ -307,7 +309,9 @@ fn add_protocol_firewall_rules( unsafe { rule.SetName(&name_bstr)?; rule.SetDescription(&desc_bstr)?; - rule.SetProtocol(protocol_number)?; + if let Some(protocol_number) = protocol_number { + rule.SetProtocol(protocol_number)?; + } rule.SetAction(NET_FW_ACTION_ALLOW)?; if is_inbound { @@ -322,61 +326,35 @@ fn add_protocol_firewall_rules( )?; rule.SetGrouping(&BSTR::from("EasyTier"))?; - // Get rule collection and add new rule - let rules = policy.Rules()?; - rules.Remove(&name_bstr)?; // Remove existing rule with same name first - rules.Add(&rule)?; - } - } + // Set the interface for this rule to apply to the specific network interface + // According to Microsoft docs, interfaces should be represented by their friendly name + // We need to create a SAFEARRAY of VARIANT strings containing the interface name + let interface_bstr = BSTR::from(interface_name); - Ok(()) -} - -/// Add fallback rules for all protocols -fn add_all_protocols_firewall_rules( - policy: &INetFwPolicy2, - interface_name: &str, -) -> anyhow::Result<()> { - // Create rules for both inbound and outbound traffic - for (is_inbound, direction_name) in [(true, "Inbound"), (false, "Outbound")] { - // Create firewall rule instance - let rule: INetFwRule = unsafe { - CoCreateInstance( - &windows::Win32::NetworkManagement::WindowsFirewall::NetFwRule, - None, - CLSCTX_ALL, - ) - }?; - - let rule_name = format!( - "EasyTier {} - All Protocols ({})", - interface_name, direction_name - ); - let description = format!( - "Allow all protocol traffic on EasyTier interface {}", - interface_name - ); - - let name_bstr = BSTR::from(&rule_name); - let desc_bstr = BSTR::from(&description); - - unsafe { - rule.SetName(&name_bstr)?; - rule.SetDescription(&desc_bstr)?; - // Don't set protocol - allows all protocols by default - rule.SetAction(NET_FW_ACTION_ALLOW)?; - - if is_inbound { - rule.SetDirection(NET_FW_RULE_DIR_IN)?; - } else { - rule.SetDirection(NET_FW_RULE_DIR_OUT)?; + // Create a SAFEARRAY containing one interface name + let interface_array = SafeArrayCreateVector(VT_VARIANT, 0, 1); + if interface_array.is_null() { + return Err(anyhow::anyhow!("Failed to create SAFEARRAY")); } - rule.SetEnabled(windows::Win32::Foundation::VARIANT_TRUE)?; - rule.SetProfiles( - NET_FW_PROFILE2_PRIVATE.0 | NET_FW_PROFILE2_PUBLIC.0 | NET_FW_PROFILE2_DOMAIN.0, + let index = 0i32; + let mut variant_interface = VARIANT::default(); + (*variant_interface.Anonymous.Anonymous).vt = VT_BSTR; + (*variant_interface.Anonymous.Anonymous).Anonymous.bstrVal = + ManuallyDrop::new(interface_bstr); + + SafeArrayPutElement( + interface_array, + &index as *const _ as *const i32, + &variant_interface as *const _ as *const std::ffi::c_void, )?; - rule.SetGrouping(&BSTR::from("EasyTier"))?; + + // Create the VARIANT that contains the SAFEARRAY + let mut interface_variant = VARIANT::default(); + (*interface_variant.Anonymous.Anonymous).vt = VARENUM(VT_ARRAY.0 | VT_VARIANT.0); + (*interface_variant.Anonymous.Anonymous).Anonymous.parray = interface_array; + + rule.SetInterfaces(interface_variant)?; // Get rule collection and add new rule let rules = policy.Rules()?; @@ -402,8 +380,7 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<( let rules = unsafe { policy.Rules()? }; - // Remove protocol-specific rules - for protocol_name in ["TCP", "UDP", "ICMP"] { + for protocol_name in ["TCP", "UDP", "ICMP", "ALL"] { for direction in ["Inbound", "Outbound"] { let rule_name = format!( "EasyTier {} - {} Protocol ({})", @@ -416,18 +393,6 @@ pub fn remove_interface_firewall_rules(interface_name: &str) -> anyhow::Result<( } } - // Remove fallback protocol rules - for direction in ["Inbound", "Outbound"] { - let rule_name = format!( - "EasyTier {} - All Protocols ({})", - interface_name, direction - ); - let name_bstr = BSTR::from(&rule_name); - unsafe { - let _ = rules.Remove(&name_bstr); // Ignore errors, rule might not exist - } - } - Ok(()) }