diff options
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 30 | ||||
| -rw-r--r-- | talpid-core/src/firewall/mod.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows/mod.rs | 604 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows/winfw/mod.rs | 328 | ||||
| -rw-r--r-- | talpid-core/src/firewall/windows/winfw/sys.rs | 277 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 18 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 85 | ||||
| -rw-r--r-- | talpid-openvpn/Cargo.toml | 2 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/objectpurger.cpp | 16 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/objectpurger.h | 1 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/winfw.cpp | 56 | ||||
| -rw-r--r-- | windows/winfw/src/winfw/winfw.h | 4 |
13 files changed, 862 insertions, 566 deletions
diff --git a/Cargo.lock b/Cargo.lock index 5aa2d3b30d..482657c481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5445,6 +5445,7 @@ dependencies = [ "triggered", "uuid", "widestring", + "winapi", "windows-sys 0.52.0", "winreg 0.51.0", ] diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index efdbea5cc6..c33ecf542b 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -83,6 +83,8 @@ use std::{ sync::{Arc, Weak}, time::Duration, }; +#[cfg(not(target_os = "android"))] +use talpid_core::tunnel_state_machine::BlockWhenDisconnected; use talpid_core::{ mpsc::Sender, split_tunnel, @@ -875,7 +877,9 @@ impl Daemon { tunnel_state_machine::InitialTunnelState { allow_lan: settings.allow_lan, #[cfg(not(target_os = "android"))] - block_when_disconnected: settings.block_when_disconnected, + block_when_disconnected: BlockWhenDisconnected::from( + settings.block_when_disconnected, + ), dns_config: dns::addresses_from_options(&settings.tunnel_options.dns_options), allowed_endpoint: access_mode_handler .get_current() @@ -2385,7 +2389,7 @@ impl Daemon { Ok(settings_changed) => { if settings_changed { self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected( - block_when_disconnected, + BlockWhenDisconnected::from(block_when_disconnected), oneshot_map(tx, |tx, ()| { Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response"); }), @@ -3087,7 +3091,7 @@ impl Daemon { { let (tx, _rx) = oneshot::channel(); self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected( - self.settings.block_when_disconnected, + BlockWhenDisconnected::from(self.settings.block_when_disconnected), tx, )); } @@ -3149,7 +3153,10 @@ impl Daemon { { log::debug!("Blocking firewall during shutdown"); let (tx, _rx) = oneshot::channel(); - self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx)); + self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected( + BlockWhenDisconnected::yes(), + tx, + )); } self.disconnect_tunnel(); @@ -3164,8 +3171,21 @@ impl Daemon { // without causing the service to be restarted. #[cfg(not(target_os = "android"))] if *self.target_state == TargetState::Secured { + let persist = if cfg!(target_os = "windows") { + // During app upgrades, as a safety measure, we make the firewall filters + // non-persistent. If the installation of the new version fails and + // the user is left in blocked state with no app, they can reboot + // to regain internet access. + self.settings.settings().block_when_disconnected + || self.settings.settings().auto_connect + } else { + true + }; let (tx, _rx) = oneshot::channel(); - self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx)); + self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected( + BlockWhenDisconnected::yes().persist(persist), + tx, + )); } self.target_state.lock(); diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index 053317f2a5..4255854624 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -326,4 +326,10 @@ impl Firewall { log::info!("Resetting firewall policy"); self.inner.reset_policy() } + + /// Sets whether the firewall should persist the blocking rules across a reboot. + #[cfg(target_os = "windows")] + pub fn persist(&mut self, persist: bool) { + self.inner.persist(persist); + } } diff --git a/talpid-core/src/firewall/windows/mod.rs b/talpid-core/src/firewall/windows/mod.rs index 3ef49938dd..13d3f5db90 100644 --- a/talpid-core/src/firewall/windows/mod.rs +++ b/talpid-core/src/firewall/windows/mod.rs @@ -1,20 +1,18 @@ -#![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. +use std::{net::IpAddr, sync::LazyLock}; -use crate::{dns::ResolvedDnsConfig, tunnel::TunnelMetadata}; - -use std::{ffi::CStr, io, net::IpAddr, ptr, sync::LazyLock}; - -use self::winfw::*; -use super::{FirewallArguments, FirewallPolicy, InitialFirewallState}; use talpid_types::{ net::{AllowedEndpoint, AllowedTunnelTraffic}, tunnel::FirewallPolicyError, ErrorExt, }; use widestring::WideCString; -use windows_sys::Win32::Globalization::{MultiByteToWideChar, CP_ACP}; + +use self::winfw::*; +use super::{FirewallArguments, FirewallPolicy, InitialFirewallState}; +use crate::{dns::ResolvedDnsConfig, tunnel::TunnelMetadata}; mod hyperv; +mod winfw; const HYPERV_LEAK_WARNING_MSG: &str = "Hyper-V (e.g. WSL machines) may leak in blocked states."; @@ -74,13 +72,19 @@ pub enum Error { ResettingPolicy(#[source] FirewallPolicyError), } -/// Timeout for acquiring the WFP transaction lock -const WINFW_TIMEOUT_SECONDS: u32 = 5; - -const LOGGING_CONTEXT: &[u8] = b"WinFw\0"; - /// The Windows implementation for the firewall. -pub struct Firewall(()); +pub struct Firewall { + /// If firewall rules should even if firewall module is shut down or dies. + /// + /// This should only very cautiously be turned off. + persist: bool, +} + +impl Default for Firewall { + fn default() -> Self { + Self { persist: true } + } +} impl Firewall { pub fn from_args(args: FirewallArguments) -> Result<Self, Error> { @@ -92,35 +96,16 @@ impl Firewall { } pub fn new() -> Result<Self, Error> { - unsafe { - WinFw_Initialize( - WINFW_TIMEOUT_SECONDS, - Some(log_sink), - LOGGING_CONTEXT.as_ptr(), - ) - .into_result()? - }; - + winfw::initialize()?; log::trace!("Successfully initialized windows firewall module"); - Ok(Firewall(())) + Ok(Firewall::default()) } fn initialize_blocked( allowed_endpoint: AllowedEndpoint, allow_lan: bool, ) -> Result<Self, Error> { - let cfg = &WinFwSettings::new(allow_lan); - let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint); - unsafe { - WinFw_InitializeBlocked( - WINFW_TIMEOUT_SECONDS, - cfg, - &allowed_endpoint.as_endpoint(), - Some(log_sink), - LOGGING_CONTEXT.as_ptr(), - ) - .into_result()? - }; + winfw::initialize_blocked(allowed_endpoint, allow_lan)?; log::trace!("Successfully initialized windows firewall module to a blocking state"); with_wmi_if_enabled(|wmi| { @@ -128,7 +113,7 @@ impl Firewall { consume_and_log_hyperv_err("Add block-all Hyper-V filter", result); }); - Ok(Firewall(())) + Ok(Firewall::default()) } pub fn apply_policy(&mut self, policy: FirewallPolicy) -> Result<(), Error> { @@ -150,8 +135,8 @@ impl Firewall { self.set_connecting_state( &peer_endpoint, cfg, - &tunnel, - &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(), + tunnel.as_ref(), + allowed_endpoint, &allowed_tunnel_traffic, ) } @@ -190,7 +175,7 @@ impl Firewall { } pub fn reset_policy(&mut self) -> Result<(), Error> { - unsafe { WinFw_Reset().into_result().map_err(Error::ResettingPolicy) }?; + winfw::reset().map_err(Error::ResettingPolicy)?; with_wmi_if_enabled(|wmi| { let result = hyperv::remove_blocking_hyperv_firewall_rules(wmi); @@ -200,113 +185,28 @@ impl Firewall { Ok(()) } + pub fn persist(&mut self, persist: bool) { + self.persist = persist; + } + fn set_connecting_state( &mut self, - endpoint: &AllowedEndpoint, + peer_endpoint: &AllowedEndpoint, winfw_settings: &WinFwSettings, - tunnel_metadata: &Option<TunnelMetadata>, - allowed_endpoint: &WinFwAllowedEndpoint<'_>, + tunnel_metadata: Option<&TunnelMetadata>, + allowed_endpoint: AllowedEndpoint, allowed_tunnel_traffic: &AllowedTunnelTraffic, ) -> Result<(), Error> { log::trace!("Applying 'connecting' firewall policy"); - let ip_str = widestring_ip(endpoint.endpoint.address.ip()); - let winfw_relay = WinFwEndpoint { - ip: ip_str.as_ptr(), - port: endpoint.endpoint.address.port(), - protocol: WinFwProt::from(endpoint.endpoint.protocol), - }; - - // SAFETY: `endpoint1_ip`, `endpoint2_ip`, `endpoint1`, `endpoint2`, `relay_client_wstrs` - // must not be dropped until `WinFw_ApplyPolicyConnecting` has returned. - - let relay_client_wstrs: Vec<_> = endpoint - .clients - .iter() - .map(WideCString::from_os_str_truncate) - .collect(); - let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs - .iter() - .map(|wstr| wstr.as_ptr()) - .collect(); - let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); - - let interface_wstr = tunnel_metadata - .as_ref() - .map(|metadata| WideCString::from_str_truncate(&metadata.interface)); - let interface_wstr_ptr = if let Some(ref wstr) = interface_wstr { - wstr.as_ptr() - } else { - ptr::null() - }; - - let mut endpoint1_ip = WideCString::new(); - let mut endpoint2_ip = WideCString::new(); - let (endpoint1, endpoint2) = match allowed_tunnel_traffic { - AllowedTunnelTraffic::One(endpoint) => { - endpoint1_ip = widestring_ip(endpoint.address.ip()); - ( - Some(WinFwEndpoint { - ip: endpoint1_ip.as_ptr(), - port: endpoint.address.port(), - protocol: WinFwProt::from(endpoint.protocol), - }), - None, - ) - } - AllowedTunnelTraffic::Two(endpoint1, endpoint2) => { - endpoint1_ip = widestring_ip(endpoint1.address.ip()); - let endpoint1 = Some(WinFwEndpoint { - ip: endpoint1_ip.as_ptr(), - port: endpoint1.address.port(), - protocol: WinFwProt::from(endpoint1.protocol), - }); - endpoint2_ip = widestring_ip(endpoint2.address.ip()); - let endpoint2 = Some(WinFwEndpoint { - ip: endpoint2_ip.as_ptr(), - port: endpoint2.address.port(), - protocol: WinFwProt::from(endpoint2.protocol), - }); - (endpoint1, endpoint2) - } - AllowedTunnelTraffic::None | AllowedTunnelTraffic::All => (None, None), - }; - - let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic { - type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic), - endpoint1: endpoint1 - .as_ref() - .map(|ep| ep as *const _) - .unwrap_or(ptr::null()), - endpoint2: endpoint2 - .as_ref() - .map(|ep| ep as *const _) - .unwrap_or(ptr::null()), - }; - - let res = unsafe { - WinFw_ApplyPolicyConnecting( - winfw_settings, - &winfw_relay, - relay_client_wstr_ptrs.as_ptr(), - relay_client_wstr_ptrs_len, - interface_wstr_ptr, - allowed_endpoint, - &allowed_tunnel_traffic, - ) - .into_result() - .map_err(Error::ApplyingConnectingPolicy) - }; - // SAFETY: All of these hold stack allocated memory which is pointed to by - // `allowed_tunnel_traffic` and must remain allocated until `WinFw_ApplyPolicyConnecting` - // has returned. - drop(endpoint1_ip); - drop(endpoint2_ip); - #[allow(clippy::drop_non_drop)] - drop(endpoint1); - #[allow(clippy::drop_non_drop)] - drop(endpoint2); - drop(relay_client_wstrs); - res + let tunnel_interface = tunnel_metadata.map(|metadata| metadata.interface.as_ref()); + winfw::apply_policy_connecting( + peer_endpoint, + winfw_settings, + tunnel_interface, + allowed_endpoint, + allowed_tunnel_traffic, + ) + .map_err(Error::ApplyingConnectingPolicy) } fn set_connected_state( @@ -317,69 +217,9 @@ impl Firewall { dns_config: &ResolvedDnsConfig, ) -> Result<(), Error> { log::trace!("Applying 'connected' firewall policy"); - let ip_str = widestring_ip(endpoint.endpoint.address.ip()); - - let tunnel_alias = WideCString::from_str_truncate(&tunnel_metadata.interface); - - // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay - let winfw_relay = WinFwEndpoint { - ip: ip_str.as_ptr(), - port: endpoint.endpoint.address.port(), - protocol: WinFwProt::from(endpoint.endpoint.protocol), - }; - - // SAFETY: `relay_client_wstrs` must not be dropped until `WinFw_ApplyPolicyConnected` has - // returned. - let relay_client_wstrs: Vec<_> = endpoint - .clients - .iter() - .map(WideCString::from_os_str_truncate) - .collect(); - let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs - .iter() - .map(|wstr| wstr.as_ptr()) - .collect(); - let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); - - let tunnel_dns_servers: Vec<WideCString> = dns_config - .tunnel_config() - .iter() - .cloned() - .map(widestring_ip) - .collect(); - let tunnel_dns_servers: Vec<*const u16> = - tunnel_dns_servers.iter().map(|ip| ip.as_ptr()).collect(); - let non_tunnel_dns_servers: Vec<WideCString> = dns_config - .non_tunnel_config() - .iter() - .cloned() - .map(widestring_ip) - .collect(); - let non_tunnel_dns_servers: Vec<*const u16> = non_tunnel_dns_servers - .iter() - .map(|ip| ip.as_ptr()) - .collect(); - - let result = unsafe { - WinFw_ApplyPolicyConnected( - winfw_settings, - &winfw_relay, - relay_client_wstr_ptrs.as_ptr(), - relay_client_wstr_ptrs_len, - tunnel_alias.as_ptr(), - tunnel_dns_servers.as_ptr(), - tunnel_dns_servers.len(), - non_tunnel_dns_servers.as_ptr(), - non_tunnel_dns_servers.len(), - ) - .into_result() + let tunnel_interface = &tunnel_metadata.interface; + winfw::apply_policy_connected(endpoint, winfw_settings, tunnel_interface, dns_config) .map_err(Error::ApplyingConnectedPolicy) - }; - - // SAFETY: `relay_client_wstrs` holds memory pointed to by pointers used in C++ and must - // not be dropped until after `WinFw_ApplyPolicyConnected` has returned. - drop(relay_client_wstrs); - result } fn set_blocked_state( @@ -388,35 +228,25 @@ impl Firewall { allowed_endpoint: Option<WinFwAllowedEndpointContainer>, ) -> Result<(), Error> { log::trace!("Applying 'blocked' firewall policy"); - let endpoint = allowed_endpoint - .as_ref() - .map(WinFwAllowedEndpointContainer::as_endpoint); - - unsafe { - WinFw_ApplyPolicyBlocked( - winfw_settings, - endpoint - .as_ref() - .map(|container| container as *const _) - .unwrap_or(ptr::null()), - ) - .into_result() + winfw::apply_policy_blocked(winfw_settings, allowed_endpoint) .map_err(Error::ApplyingBlockedPolicy) - } } } impl Drop for Firewall { fn drop(&mut self) { - if unsafe { - WinFw_Deinitialize(WinFwCleanupPolicy::ContinueBlocking) - .into_result() - .is_ok() - } { - log::trace!("Successfully deinitialized windows firewall module"); + // Deinitialize WinFW with or without persistent filters. + // All other filters should still remain intact. + let cleanup_policy = if self.persist { + WinFwCleanupPolicy::ContinueBlocking } else { - log::error!("Failed to deinitialize windows firewall module"); + WinFwCleanupPolicy::BlockingUntilReboot }; + + match winfw::deinit(cleanup_policy) { + Ok(()) => log::trace!("Successfully deinitialized windows firewall module"), + Err(_) => log::error!("Failed to deinitialize windows firewall module"), + } } } @@ -424,109 +254,6 @@ fn widestring_ip(ip: IpAddr) -> WideCString { WideCString::from_str_truncate(ip.to_string()) } -/// Logging callback implementation. -pub extern "system" fn log_sink( - level: log::Level, - msg: *const std::ffi::c_char, - context: *mut std::ffi::c_void, -) { - if msg.is_null() { - log::error!("Log message from FFI boundary is NULL"); - } else { - let target = if context.is_null() { - "UNKNOWN".into() - } else { - unsafe { CStr::from_ptr(context as *const _).to_string_lossy() } - }; - - let mb_string = unsafe { CStr::from_ptr(msg) }; - - let managed_msg = match multibyte_to_wide(mb_string, CP_ACP) { - Ok(wide_str) => String::from_utf16_lossy(&wide_str), - // Best effort: - Err(_) => mb_string.to_string_lossy().into_owned(), - }; - - log::logger().log( - &log::Record::builder() - .level(level) - .target(&target) - .args(format_args!("{}", managed_msg)) - .build(), - ); - } -} - -/// Convert `mb_string`, with the given character encoding `codepage`, to a UTF-16 string. -fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Error> { - if mb_string.is_empty() { - return Ok(vec![]); - } - - // SAFETY: `mb_string` is null-terminated and valid. - let wc_size = unsafe { - MultiByteToWideChar( - codepage, - 0, - mb_string.as_ptr() as *const u8, - -1, - ptr::null_mut(), - 0, - ) - }; - - if wc_size == 0 { - return Err(io::Error::last_os_error()); - } - - let mut wc_buffer = vec![0u16; usize::try_from(wc_size).unwrap()]; - - // SAFETY: `wc_buffer` can contain up to `wc_size` characters, including a null - // terminator. - let chars_written = unsafe { - MultiByteToWideChar( - codepage, - 0, - mb_string.as_ptr() as *const u8, - -1, - wc_buffer.as_mut_ptr(), - wc_size, - ) - }; - - if chars_written == 0 { - return Err(io::Error::last_os_error()); - } - - wc_buffer.truncate(usize::try_from(chars_written - 1).unwrap()); - - Ok(wc_buffer) -} - -#[cfg(test)] -mod test { - use super::multibyte_to_wide; - use windows_sys::Win32::Globalization::CP_UTF8; - - #[test] - fn test_multibyte_to_wide() { - // € = 0x20AC in UTF-16 - let converted = multibyte_to_wide(c"€€", CP_UTF8); - const EXPECTED: &[u16] = &[0x20AC, 0x20AC]; - assert!( - matches!(converted.as_deref(), Ok(EXPECTED)), - "expected Ok({EXPECTED:?}), got {converted:?}", - ); - - // boundary case - let converted = multibyte_to_wide(c"", CP_UTF8); - assert!( - matches!(converted.as_deref(), Ok([])), - "unexpected result {converted:?}" - ); - } -} - // Convert `result` into an option and log the error, if any. fn consume_and_log_hyperv_err<T>( action: &'static str, @@ -553,226 +280,3 @@ fn with_wmi_if_enabled(f: impl FnOnce(&wmi::WMIConnection)) { } }) } - -#[allow(non_snake_case)] -mod winfw { - use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString}; - use std::ffi::{c_char, c_void}; - use talpid_types::net::TransportProtocol; - - type LogSink = extern "system" fn(level: log::Level, msg: *const c_char, context: *mut c_void); - - pub struct WinFwAllowedEndpointContainer { - _clients: Box<[WideCString]>, - clients_ptrs: Box<[*const u16]>, - ip: WideCString, - port: u16, - protocol: WinFwProt, - } - - impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer { - fn from(endpoint: AllowedEndpoint) -> Self { - let clients = endpoint - .clients - .iter() - .map(WideCString::from_os_str_truncate) - .collect::<Box<_>>(); - let clients_ptrs = clients - .iter() - .map(|client| client.as_ptr()) - .collect::<Box<_>>(); - let ip = widestring_ip(endpoint.endpoint.address.ip()); - - WinFwAllowedEndpointContainer { - _clients: clients, - clients_ptrs, - ip, - port: endpoint.endpoint.address.port(), - protocol: WinFwProt::from(endpoint.endpoint.protocol), - } - } - } - - impl WinFwAllowedEndpointContainer { - pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> { - WinFwAllowedEndpoint { - num_clients: self.clients_ptrs.len() as u32, - clients: self.clients_ptrs.as_ptr(), - endpoint: WinFwEndpoint { - ip: self.ip.as_ptr(), - port: self.port, - protocol: self.protocol, - }, - - _phantom: std::marker::PhantomData, - } - } - } - - #[repr(C)] - pub struct WinFwAllowedEndpoint<'a> { - num_clients: u32, - clients: *const *const libc::wchar_t, - endpoint: WinFwEndpoint, - - _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>, - } - - #[repr(C)] - pub struct WinFwAllowedTunnelTraffic { - pub type_: WinFwAllowedTunnelTrafficType, - pub endpoint1: *const WinFwEndpoint, - pub endpoint2: *const WinFwEndpoint, - } - - #[repr(u8)] - #[derive(Clone, Copy)] - pub enum WinFwAllowedTunnelTrafficType { - None, - All, - One, - Two, - } - - impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType { - fn from(traffic: &AllowedTunnelTraffic) -> Self { - match traffic { - AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None, - AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All, - AllowedTunnelTraffic::One(..) => WinFwAllowedTunnelTrafficType::One, - AllowedTunnelTraffic::Two(..) => WinFwAllowedTunnelTrafficType::Two, - } - } - } - - #[repr(C)] - pub struct WinFwEndpoint { - pub ip: *const libc::wchar_t, - pub port: u16, - pub protocol: WinFwProt, - } - - #[repr(u8)] - #[derive(Clone, Copy)] - pub enum WinFwProt { - Tcp = 0u8, - Udp = 1u8, - } - - impl From<TransportProtocol> for WinFwProt { - fn from(prot: TransportProtocol) -> WinFwProt { - match prot { - TransportProtocol::Tcp => WinFwProt::Tcp, - TransportProtocol::Udp => WinFwProt::Udp, - } - } - } - - #[repr(C)] - pub struct WinFwSettings { - permitDhcp: bool, - permitLan: bool, - } - - impl WinFwSettings { - pub fn new(permit_lan: bool) -> WinFwSettings { - WinFwSettings { - permitDhcp: true, - permitLan: permit_lan, - } - } - } - - #[allow(dead_code)] - #[repr(u32)] - #[derive(Clone, Copy)] - pub enum WinFwCleanupPolicy { - ContinueBlocking = 0, - ResetFirewall = 1, - } - - ffi_error!(InitializationResult, Error::Initialization); - ffi_error!(DeinitializationResult, Error::Deinitialization); - - #[derive(Debug)] - #[allow(dead_code)] - #[repr(u32)] - pub enum WinFwPolicyStatus { - Success = 0, - GeneralFailure = 1, - LockTimeout = 2, - } - - impl WinFwPolicyStatus { - pub fn into_result(self) -> Result<(), super::FirewallPolicyError> { - match self { - WinFwPolicyStatus::Success => Ok(()), - WinFwPolicyStatus::GeneralFailure => Err(super::FirewallPolicyError::Generic), - WinFwPolicyStatus::LockTimeout => { - // TODO: Obtain application name and string from WinFw - Err(super::FirewallPolicyError::Locked(None)) - } - } - } - } - - impl From<WinFwPolicyStatus> for Result<(), super::FirewallPolicyError> { - fn from(val: WinFwPolicyStatus) -> Self { - val.into_result() - } - } - - unsafe extern "system" { - #[link_name = "WinFw_Initialize"] - pub fn WinFw_Initialize( - timeout: libc::c_uint, - sink: Option<LogSink>, - sink_context: *const u8, - ) -> InitializationResult; - - #[link_name = "WinFw_InitializeBlocked"] - pub fn WinFw_InitializeBlocked( - timeout: libc::c_uint, - settings: &WinFwSettings, - allowed_endpoint: *const WinFwAllowedEndpoint<'_>, - sink: Option<LogSink>, - sink_context: *const u8, - ) -> InitializationResult; - - #[link_name = "WinFw_Deinitialize"] - pub fn WinFw_Deinitialize(cleanupPolicy: WinFwCleanupPolicy) -> DeinitializationResult; - - #[link_name = "WinFw_ApplyPolicyConnecting"] - pub fn WinFw_ApplyPolicyConnecting( - settings: &WinFwSettings, - relay: &WinFwEndpoint, - relayClient: *const *const libc::wchar_t, - relayClientLen: usize, - tunnelIfaceAlias: *const libc::wchar_t, - allowedEndpoint: *const WinFwAllowedEndpoint<'_>, - allowedTunnelTraffic: &WinFwAllowedTunnelTraffic, - ) -> WinFwPolicyStatus; - - #[link_name = "WinFw_ApplyPolicyConnected"] - pub fn WinFw_ApplyPolicyConnected( - settings: &WinFwSettings, - relay: &WinFwEndpoint, - relayClient: *const *const libc::wchar_t, - relayClientLen: usize, - tunnelIfaceAlias: *const libc::wchar_t, - tunnelDnsServers: *const *const libc::wchar_t, - numTunnelDnsServers: usize, - nonTunnelDnsServers: *const *const libc::wchar_t, - numNonTunnelDnsServers: usize, - ) -> WinFwPolicyStatus; - - #[link_name = "WinFw_ApplyPolicyBlocked"] - pub fn WinFw_ApplyPolicyBlocked( - settings: &WinFwSettings, - allowed_endpoint: *const WinFwAllowedEndpoint<'_>, - ) -> WinFwPolicyStatus; - - #[link_name = "WinFw_Reset"] - pub fn WinFw_Reset() -> WinFwPolicyStatus; - } -} diff --git a/talpid-core/src/firewall/windows/winfw/mod.rs b/talpid-core/src/firewall/windows/winfw/mod.rs new file mode 100644 index 0000000000..a13ed7c7da --- /dev/null +++ b/talpid-core/src/firewall/windows/winfw/mod.rs @@ -0,0 +1,328 @@ +//! Safe bindings for the WinFW library. + +use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString}; +use std::ptr; +use talpid_types::{net::TransportProtocol, tunnel::FirewallPolicyError}; + +mod sys; +use sys::*; +pub use sys::{WinFwAllowedEndpointContainer, WinFwCleanupPolicy, WinFwSettings}; + +/// Timeout for acquiring the WFP transaction lock +const WINFW_TIMEOUT_SECONDS: u32 = 5; + +/// Initialize WinFw module. Returns an initialization error if called multiple times without +/// interleaving [Self::deinit]. +pub(super) fn initialize() -> Result<(), Error> { + // SAFETY: This function is always safe to call. + let init = unsafe { + WinFw_Initialize( + WINFW_TIMEOUT_SECONDS, + Some(log_sink), + LOGGING_CONTEXT.as_ptr(), + ) + }; + + init.into_result() +} + +/// Initialize WinFw module and apply blocking rules. Returns an initialization error if called +/// multiple times without interleaving [Self::deinit]. +pub(super) fn initialize_blocked( + allowed_endpoint: AllowedEndpoint, + allow_lan: bool, +) -> Result<(), Error> { + let cfg = WinFwSettings::new(allow_lan); + let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint); + // SAFETY: This function is always safe to call. + let init = unsafe { + WinFw_InitializeBlocked( + WINFW_TIMEOUT_SECONDS, + &cfg, + &allowed_endpoint.as_endpoint(), + Some(log_sink), + LOGGING_CONTEXT.as_ptr(), + ) + }; + init.into_result() +} + +/// Deinitialize WinFw module. Trying to use WinFw after calling deinit will result in an +/// error before [Self::initialize] is called. +pub(super) fn deinit(cleanup_policy: WinFwCleanupPolicy) -> Result<(), Error> { + // SAFETY: WinFw_Deinitialize is always safe to call. + // Will simply return false if WinFw already has been deinitialized. + let deinit = unsafe { WinFw_Deinitialize(cleanup_policy) }; + deinit.into_result() +} + +/// Reset all firewall policies applied by [winfw]. +/// +/// Sets the underlying active policy to None. +pub(super) fn reset() -> Result<(), FirewallPolicyError> { + // SAFETY: WinFw_Reset is always safe to call, even before WinFW has been + // initialized and after WinFW has been deinitialized. + let reset = unsafe { WinFw_Reset() }; + reset.into_result() +} + +/// Apply blocking firewall rules Sets the underlying active policy to Blocked. Exceptions +/// permitted through the firewall is defined by `winfw_settings` and `allowed_endpoint`. See +/// the BlockAll class for more information. +/// +/// Returns an error if [winfw] is not initialized. +pub(super) fn apply_policy_blocked( + winfw_settings: &WinFwSettings, + allowed_endpoint: Option<WinFwAllowedEndpointContainer>, +) -> Result<(), FirewallPolicyError> { + let allowed_endpoint = allowed_endpoint + .as_ref() + .map(WinFwAllowedEndpointContainer::as_endpoint) + .as_ref() + .map(ptr::from_ref) + .unwrap_or(ptr::null()); + // SAFETY: This function is always safe to call + let application = unsafe { WinFw_ApplyPolicyBlocked(winfw_settings, allowed_endpoint) }; + application.into_result() +} + +pub(super) fn apply_policy_connecting( + peer_endpoint: &AllowedEndpoint, + winfw_settings: &WinFwSettings, + tunnel_interface: Option<&str>, + allowed_endpoint: AllowedEndpoint, + allowed_tunnel_traffic: &AllowedTunnelTraffic, +) -> Result<(), FirewallPolicyError> { + let ip_str = widestring_ip(peer_endpoint.endpoint.address.ip()); + let winfw_relay = WinFwEndpoint { + ip: ip_str.as_ptr(), + port: peer_endpoint.endpoint.address.port(), + protocol: WinFwProt::from(peer_endpoint.endpoint.protocol), + }; + + // SAFETY: `endpoint1_ip`, `endpoint2_ip`, `endpoint1`, `endpoint2`, `relay_client_wstrs` + // must not be dropped until `WinFw_ApplyPolicyConnecting` has returned. + + let relay_client_wstrs: Vec<_> = peer_endpoint + .clients + .iter() + .map(WideCString::from_os_str_truncate) + .collect(); + let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs + .iter() + .map(|wstr| wstr.as_ptr()) + .collect(); + let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); + + let interface_wstr = tunnel_interface + .as_ref() + .map(WideCString::from_str_truncate); + let interface_wstr_ptr = if let Some(ref wstr) = interface_wstr { + wstr.as_ptr() + } else { + ptr::null() + }; + + let mut endpoint1_ip = WideCString::new(); + let mut endpoint2_ip = WideCString::new(); + let (endpoint1, endpoint2) = match allowed_tunnel_traffic { + AllowedTunnelTraffic::One(endpoint) => { + endpoint1_ip = widestring_ip(endpoint.address.ip()); + ( + Some(WinFwEndpoint { + ip: endpoint1_ip.as_ptr(), + port: endpoint.address.port(), + protocol: WinFwProt::from(endpoint.protocol), + }), + None, + ) + } + AllowedTunnelTraffic::Two(endpoint1, endpoint2) => { + endpoint1_ip = widestring_ip(endpoint1.address.ip()); + let endpoint1 = Some(WinFwEndpoint { + ip: endpoint1_ip.as_ptr(), + port: endpoint1.address.port(), + protocol: WinFwProt::from(endpoint1.protocol), + }); + endpoint2_ip = widestring_ip(endpoint2.address.ip()); + let endpoint2 = Some(WinFwEndpoint { + ip: endpoint2_ip.as_ptr(), + port: endpoint2.address.port(), + protocol: WinFwProt::from(endpoint2.protocol), + }); + (endpoint1, endpoint2) + } + AllowedTunnelTraffic::None | AllowedTunnelTraffic::All => (None, None), + }; + + let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint); + let allowed_endpoint = allowed_endpoint.as_endpoint(); + + let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic { + type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic), + endpoint1: endpoint1 + .as_ref() + .map(|ep| ep as *const _) + .unwrap_or(ptr::null()), + endpoint2: endpoint2 + .as_ref() + .map(|ep| ep as *const _) + .unwrap_or(ptr::null()), + }; + + #[allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. + let res = unsafe { + WinFw_ApplyPolicyConnecting( + winfw_settings, + &winfw_relay, + relay_client_wstr_ptrs.as_ptr(), + relay_client_wstr_ptrs_len, + interface_wstr_ptr, + &allowed_endpoint, + &allowed_tunnel_traffic, + ) + }; + // SAFETY: All of these hold stack allocated memory which is pointed to by + // `allowed_tunnel_traffic` and must remain allocated until `WinFw_ApplyPolicyConnecting` + // has returned. + drop(endpoint1_ip); + drop(endpoint2_ip); + #[allow(clippy::drop_non_drop)] + drop(endpoint1); + #[allow(clippy::drop_non_drop)] + drop(endpoint2); + drop(relay_client_wstrs); + res.into_result() +} + +pub(super) fn apply_policy_connected( + endpoint: &AllowedEndpoint, + winfw_settings: &WinFwSettings, + tunnel_interface: &str, + dns_config: &crate::dns::ResolvedDnsConfig, +) -> Result<(), FirewallPolicyError> { + let ip_str = widestring_ip(endpoint.endpoint.address.ip()); + + let tunnel_alias = WideCString::from_str_truncate(tunnel_interface); + + // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay + let winfw_relay = WinFwEndpoint { + ip: ip_str.as_ptr(), + port: endpoint.endpoint.address.port(), + protocol: WinFwProt::from(endpoint.endpoint.protocol), + }; + + // SAFETY: `relay_client_wstrs` must not be dropped until `WinFw_ApplyPolicyConnected` has + // returned. + let relay_client_wstrs: Vec<_> = endpoint + .clients + .iter() + .map(WideCString::from_os_str_truncate) + .collect(); + let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs + .iter() + .map(|wstr| wstr.as_ptr()) + .collect(); + let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); + + let tunnel_dns_servers: Vec<WideCString> = dns_config + .tunnel_config() + .iter() + .cloned() + .map(widestring_ip) + .collect(); + let tunnel_dns_servers: Vec<*const u16> = + tunnel_dns_servers.iter().map(|ip| ip.as_ptr()).collect(); + let non_tunnel_dns_servers: Vec<WideCString> = dns_config + .non_tunnel_config() + .iter() + .cloned() + .map(widestring_ip) + .collect(); + let non_tunnel_dns_servers: Vec<*const u16> = non_tunnel_dns_servers + .iter() + .map(|ip| ip.as_ptr()) + .collect(); + + #[allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare. + let result = unsafe { + WinFw_ApplyPolicyConnected( + winfw_settings, + &winfw_relay, + relay_client_wstr_ptrs.as_ptr(), + relay_client_wstr_ptrs_len, + tunnel_alias.as_ptr(), + tunnel_dns_servers.as_ptr(), + tunnel_dns_servers.len(), + non_tunnel_dns_servers.as_ptr(), + non_tunnel_dns_servers.len(), + ) + }; + + // SAFETY: `relay_client_wstrs` holds memory pointed to by pointers used in C++ and must + // not be dropped until after `WinFw_ApplyPolicyConnected` has returned. + drop(relay_client_wstrs); + result.into_result() +} + +impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer { + fn from(endpoint: AllowedEndpoint) -> Self { + let clients = endpoint + .clients + .iter() + .map(WideCString::from_os_str_truncate) + .collect::<Box<_>>(); + let clients_ptrs = clients + .iter() + .map(|client| client.as_ptr()) + .collect::<Box<_>>(); + let ip = widestring_ip(endpoint.endpoint.address.ip()); + + WinFwAllowedEndpointContainer { + _clients: clients, + clients_ptrs, + ip, + port: endpoint.endpoint.address.port(), + protocol: WinFwProt::from(endpoint.endpoint.protocol), + } + } +} + +impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType { + fn from(traffic: &AllowedTunnelTraffic) -> Self { + match traffic { + AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None, + AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All, + AllowedTunnelTraffic::One(..) => WinFwAllowedTunnelTrafficType::One, + AllowedTunnelTraffic::Two(..) => WinFwAllowedTunnelTrafficType::Two, + } + } +} + +impl From<TransportProtocol> for WinFwProt { + fn from(prot: TransportProtocol) -> WinFwProt { + match prot { + TransportProtocol::Tcp => WinFwProt::Tcp, + TransportProtocol::Udp => WinFwProt::Udp, + } + } +} + +impl WinFwPolicyStatus { + pub fn into_result(self) -> Result<(), super::FirewallPolicyError> { + match self { + WinFwPolicyStatus::Success => Ok(()), + WinFwPolicyStatus::GeneralFailure => Err(super::FirewallPolicyError::Generic), + WinFwPolicyStatus::LockTimeout => { + // TODO: Obtain application name and string from WinFw + Err(super::FirewallPolicyError::Locked(None)) + } + } + } +} + +impl From<WinFwPolicyStatus> for Result<(), super::FirewallPolicyError> { + fn from(val: WinFwPolicyStatus) -> Self { + val.into_result() + } +} diff --git a/talpid-core/src/firewall/windows/winfw/sys.rs b/talpid-core/src/firewall/windows/winfw/sys.rs new file mode 100644 index 0000000000..3349205cf3 --- /dev/null +++ b/talpid-core/src/firewall/windows/winfw/sys.rs @@ -0,0 +1,277 @@ +//! Data types and thin wrappers around WinFW C FFI. + +use std::ffi::{c_char, c_void, CStr}; +use std::io; +use std::ptr; + +use windows_sys::Win32::Globalization::{MultiByteToWideChar, CP_ACP}; + +use super::{Error, WideCString}; + +pub const LOGGING_CONTEXT: &CStr = c"WinFw"; + +#[repr(C)] +#[allow(non_snake_case)] +pub struct WinFwSettings { + permitDhcp: bool, + permitLan: bool, +} + +impl WinFwSettings { + pub fn new(permit_lan: bool) -> WinFwSettings { + WinFwSettings { + permitDhcp: true, + permitLan: permit_lan, + } + } +} + +#[allow(dead_code)] +#[repr(u32)] +#[derive(Clone, Copy)] +pub enum WinFwCleanupPolicy { + ContinueBlocking = 0, + ResetFirewall = 1, + BlockingUntilReboot = 2, +} + +#[derive(Debug)] +#[allow(dead_code)] +#[repr(u32)] +pub enum WinFwPolicyStatus { + Success = 0, + GeneralFailure = 1, + LockTimeout = 2, +} + +#[repr(C)] +pub struct WinFwEndpoint { + pub ip: *const libc::wchar_t, + pub port: u16, + pub protocol: WinFwProt, +} + +#[repr(u8)] +#[derive(Clone, Copy)] +pub enum WinFwProt { + Tcp = 0u8, + Udp = 1u8, +} + +#[repr(C)] +pub struct WinFwAllowedEndpoint<'a> { + num_clients: u32, + clients: *const *const libc::wchar_t, + endpoint: WinFwEndpoint, + + _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>, +} + +pub struct WinFwAllowedEndpointContainer { + pub _clients: Box<[WideCString]>, + pub clients_ptrs: Box<[*const u16]>, + pub ip: WideCString, + pub port: u16, + pub protocol: WinFwProt, +} + +impl WinFwAllowedEndpointContainer { + pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> { + WinFwAllowedEndpoint { + num_clients: self.clients_ptrs.len() as u32, + clients: self.clients_ptrs.as_ptr(), + endpoint: WinFwEndpoint { + ip: self.ip.as_ptr(), + port: self.port, + protocol: self.protocol, + }, + + _phantom: std::marker::PhantomData, + } + } +} + +#[repr(C)] +pub struct WinFwAllowedTunnelTraffic { + pub type_: WinFwAllowedTunnelTrafficType, + pub endpoint1: *const WinFwEndpoint, + pub endpoint2: *const WinFwEndpoint, +} + +#[repr(u8)] +#[derive(Clone, Copy)] +pub enum WinFwAllowedTunnelTrafficType { + None, + All, + One, + Two, +} + +ffi_error!(InitializationResult, Error::Initialization); +ffi_error!(DeinitializationResult, Error::Deinitialization); + +unsafe extern "system" { + #[link_name = "WinFw_Initialize"] + pub fn WinFw_Initialize( + timeout: libc::c_uint, + sink: Option<LogSink>, + sink_context: *const c_char, + ) -> InitializationResult; + + #[link_name = "WinFw_InitializeBlocked"] + pub fn WinFw_InitializeBlocked( + timeout: libc::c_uint, + settings: &WinFwSettings, + allowed_endpoint: *const WinFwAllowedEndpoint<'_>, + sink: Option<LogSink>, + sink_context: *const c_char, + ) -> InitializationResult; + + #[link_name = "WinFw_Deinitialize"] + pub fn WinFw_Deinitialize(cleanupPolicy: WinFwCleanupPolicy) -> DeinitializationResult; + + #[link_name = "WinFw_ApplyPolicyConnecting"] + pub fn WinFw_ApplyPolicyConnecting( + settings: &WinFwSettings, + relay: &WinFwEndpoint, + relayClient: *const *const libc::wchar_t, + relayClientLen: usize, + tunnelIfaceAlias: *const libc::wchar_t, + allowedEndpoint: *const WinFwAllowedEndpoint<'_>, + allowedTunnelTraffic: &WinFwAllowedTunnelTraffic, + ) -> WinFwPolicyStatus; + + #[link_name = "WinFw_ApplyPolicyConnected"] + pub fn WinFw_ApplyPolicyConnected( + settings: &WinFwSettings, + relay: &WinFwEndpoint, + relayClient: *const *const libc::wchar_t, + relayClientLen: usize, + tunnelIfaceAlias: *const libc::wchar_t, + tunnelDnsServers: *const *const libc::wchar_t, + numTunnelDnsServers: usize, + nonTunnelDnsServers: *const *const libc::wchar_t, + numNonTunnelDnsServers: usize, + ) -> WinFwPolicyStatus; + + #[link_name = "WinFw_ApplyPolicyBlocked"] + pub fn WinFw_ApplyPolicyBlocked( + settings: &WinFwSettings, + allowed_endpoint: *const WinFwAllowedEndpoint<'_>, + ) -> WinFwPolicyStatus; + + #[link_name = "WinFw_Reset"] + pub fn WinFw_Reset() -> WinFwPolicyStatus; +} + +pub type LogSink = extern "system" fn(level: log::Level, msg: *const c_char, context: *mut c_void); + +/// Logging callback implementation. +/// +/// SAFETY: +/// - `msg` must point to a valid C string or be null. +/// - `context` must point to a valid C string or be null. +pub extern "system" fn log_sink( + level: log::Level, + msg: *const std::ffi::c_char, + context: *mut std::ffi::c_void, +) { + if msg.is_null() { + log::error!("Log message from FFI boundary is NULL"); + return; + } + + let target = if context.is_null() { + "UNKNOWN".into() + } else { + // SAFETY: context is not null & caller promise that context is a valid C string. + unsafe { CStr::from_ptr(context as *const _).to_string_lossy() } + }; + + // SAFETY: msg is not null & caller promise that msg is a valid C string. + let mb_string = unsafe { CStr::from_ptr(msg) }; + + let managed_msg = match multibyte_to_wide(mb_string, CP_ACP) { + Ok(wide_str) => String::from_utf16_lossy(&wide_str), + // Best effort: + Err(_) => mb_string.to_string_lossy().into_owned(), + }; + + log::logger().log( + &log::Record::builder() + .level(level) + .target(&target) + .args(format_args!("{}", managed_msg)) + .build(), + ); +} + +/// Convert `mb_string`, with the given character encoding `codepage`, to a UTF-16 string. +pub fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Error> { + if mb_string.is_empty() { + return Ok(vec![]); + } + + // SAFETY: `mb_string` is null-terminated and valid. + let wc_size = unsafe { + MultiByteToWideChar( + codepage, + 0, + mb_string.as_ptr() as *const u8, + -1, + ptr::null_mut(), + 0, + ) + }; + + if wc_size == 0 { + return Err(io::Error::last_os_error()); + } + + let mut wc_buffer = vec![0u16; usize::try_from(wc_size).unwrap()]; + + // SAFETY: `wc_buffer` can contain up to `wc_size` characters, including a null + // terminator. + let chars_written = unsafe { + MultiByteToWideChar( + codepage, + 0, + mb_string.as_ptr() as *const u8, + -1, + wc_buffer.as_mut_ptr(), + wc_size, + ) + }; + + if chars_written == 0 { + return Err(io::Error::last_os_error()); + } + + wc_buffer.truncate(usize::try_from(chars_written - 1).unwrap()); + + Ok(wc_buffer) +} + +#[cfg(test)] +mod test { + use super::*; + use windows_sys::Win32::Globalization::CP_UTF8; + + #[test] + fn test_multibyte_to_wide() { + // € = 0x20AC in UTF-16 + let converted = multibyte_to_wide(c"€€", CP_UTF8); + const EXPECTED: &[u16] = &[0x20AC, 0x20AC]; + assert!( + matches!(converted.as_deref(), Ok(EXPECTED)), + "expected Ok({EXPECTED:?}), got {converted:?}", + ); + + // boundary case + let converted = multibyte_to_wide(c"", CP_UTF8); + assert!( + matches!(converted.as_deref(), Ok([])), + "unexpected result {converted:?}" + ); + } +} diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 8f96ff7b90..427a6e5cb3 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -30,7 +30,7 @@ impl DisconnectedState { ); } #[cfg(target_os = "macos")] - if shared_values.block_when_disconnected { + if shared_values.block_when_disconnected.bool() { if let Err(err) = Self::setup_local_dns_config(shared_values) { log::error!( "{}", @@ -64,7 +64,7 @@ impl DisconnectedState { // Being disconnected and having lockdown mode enabled implies that your internet // access is locked down #[cfg(not(target_os = "android"))] - locked_down: shared_values.block_when_disconnected, + locked_down: shared_values.block_when_disconnected.bool(), }, ) } @@ -74,7 +74,15 @@ impl DisconnectedState { shared_values: &mut SharedTunnelStateValues, should_reset_firewall: bool, ) { - let result = if shared_values.block_when_disconnected { + let result = if shared_values.block_when_disconnected.bool() { + #[cfg(target_os = "windows")] + { + // Respect the persist flag of BlockWhenDisconnected. + shared_values + .firewall + .persist(shared_values.block_when_disconnected.should_persist()); + } + let policy = FirewallPolicy::Blocked { allow_lan: shared_values.allow_lan, allowed_endpoint: Some(shared_values.allowed_endpoint.clone()), @@ -110,7 +118,7 @@ impl DisconnectedState { shared_values: &mut SharedTunnelStateValues, should_reset_firewall: bool, ) { - if should_reset_firewall && !shared_values.block_when_disconnected { + if should_reset_firewall && !shared_values.block_when_disconnected.bool() { if let Err(error) = shared_values.split_tunnel.clear_tunnel_addresses() { log::error!( "{}", @@ -198,7 +206,7 @@ impl TunnelState for DisconnectedState { #[cfg(windows)] Self::register_split_tunnel_addresses(shared_values, true); #[cfg(target_os = "macos")] - if block_when_disconnected { + if block_when_disconnected.bool() { if let Err(err) = Self::setup_local_dns_config(shared_values) { log::error!( "{}", diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 059edd417d..2deb7c8f7e 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -94,7 +94,7 @@ pub struct InitialTunnelState { pub allow_lan: bool, /// Block traffic unless connected to the VPN. #[cfg(not(target_os = "android"))] - pub block_when_disconnected: bool, + pub block_when_disconnected: BlockWhenDisconnected, /// DNS configuration to use pub dns_config: DnsConfig, /// A single endpoint that is allowed to communicate outside the tunnel, i.e. @@ -200,7 +200,7 @@ pub enum TunnelCommand { Dns(crate::dns::DnsConfig, oneshot::Sender<()>), /// Enable or disable the block_when_disconnected feature. #[cfg(not(target_os = "android"))] - BlockWhenDisconnected(bool, oneshot::Sender<()>), + BlockWhenDisconnected(BlockWhenDisconnected, oneshot::Sender<()>), /// Notify the state machine of the connectivity of the device. Connectivity(Connectivity), /// Open tunnel connection. @@ -234,6 +234,82 @@ enum EventResult { Close(Result<Option<ErrorStateCause>, oneshot::Canceled>), } +/// If firewall should apply blocking rules in the disconnected state. +/// Argument of TunnelCommand::BlockWhenDisconnected message. +/// +/// Semantically equivalent to a boolean value, but is grouped togetether with the persist +/// parameter on Windows for cohesiveness. +#[derive(Clone, Copy, Debug)] +pub enum BlockWhenDisconnected { + /// Firewall should *not* apply blocking rules. + Disabled, + /// Firewall should apply blocking rules. + Enabled { + /// If blocked state should be persisted across a reboot (restart of BFE) + persist: bool, + }, +} + +impl BlockWhenDisconnected { + /// `true`. Apply blocking firewall rules in the disconnected state. + pub const fn yes() -> Self { + BlockWhenDisconnected::Enabled { persist: true } + } + + /// `false`. Do *not* apply blocking firewall rules in the disconnected state. + pub const fn no() -> Self { + BlockWhenDisconnected::Disabled + } + + /// [self] as a boolean value. + pub const fn bool(&self) -> bool { + matches!(self, BlockWhenDisconnected::Enabled { .. }) + } + + /// If [BlockWhenDisconnected] should persist across reboots. + /// + /// Semantically meaningless on non-Windows platforms, will always return true. + pub const fn should_persist(&self) -> bool { + if cfg!(target_os = "windows") { + matches!(&self, BlockWhenDisconnected::Enabled { persist: true }) + } else { + true + } + } + + /// Semantically meaningless on non-Windows platforms + #[cfg(not(target_os = "windows"))] + pub fn persist(self, _persist: bool) -> Self { + self + } + + /// Semantically meaningless on non-Windows platforms + #[cfg(target_os = "windows")] + pub fn persist(self, persist: bool) -> Self { + match self { + BlockWhenDisconnected::Disabled => BlockWhenDisconnected::Disabled, + // Forget previous value of persist + BlockWhenDisconnected::Enabled { .. } => BlockWhenDisconnected::Enabled { persist }, + } + } +} + +impl From<bool> for BlockWhenDisconnected { + fn from(block: bool) -> Self { + if block { + BlockWhenDisconnected::yes() + } else { + BlockWhenDisconnected::no() + } + } +} + +impl PartialEq for BlockWhenDisconnected { + fn eq(&self, other: &Self) -> bool { + self.bool() == other.bool() + } +} + /// Asynchronous handling of the tunnel state machine. /// /// This type implements `Stream`, and attempts to advance the state machine based on the events @@ -295,7 +371,8 @@ impl TunnelStateMachine { let fw_args = FirewallArguments { #[cfg(not(target_os = "android"))] - initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall + initial_state: if args.settings.block_when_disconnected.bool() + || !args.settings.reset_firewall { InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone()) } else { @@ -474,7 +551,7 @@ struct SharedTunnelStateValues { allow_lan: bool, /// Should network access be allowed when in the disconnected state. #[cfg(not(target_os = "android"))] - block_when_disconnected: bool, + block_when_disconnected: BlockWhenDisconnected, /// True when the computer is known to be offline. connectivity: Connectivity, /// DNS configuration to use. diff --git a/talpid-openvpn/Cargo.toml b/talpid-openvpn/Cargo.toml index 34df09becf..031d6f996f 100644 --- a/talpid-openvpn/Cargo.toml +++ b/talpid-openvpn/Cargo.toml @@ -34,6 +34,8 @@ widestring = "1.0" winreg = { version = "0.51", features = ["transactions"] } talpid-windows = { path = "../talpid-windows" } once_cell = { workspace = true } +# Only needed because parity-tokio-ipc has forgotten to enable the winerror feature of winapi .. +winapi = { version = "0.3", features = ["winerror"] } [target.'cfg(windows)'.dependencies.windows-sys] workspace = true diff --git a/windows/winfw/src/winfw/objectpurger.cpp b/windows/winfw/src/winfw/objectpurger.cpp index dce36c99c8..52adaac187 100644 --- a/windows/winfw/src/winfw/objectpurger.cpp +++ b/windows/winfw/src/winfw/objectpurger.cpp @@ -71,6 +71,22 @@ ObjectPurger::RemovalFunctor ObjectPurger::GetRemoveNonPersistentFunctor() } //static +ObjectPurger::RemovalFunctor ObjectPurger::GetRemovePersistentFunctor() +{ + return [](wfp::FilterEngine &engine) + { + const auto registry = MullvadGuids::DetailedRegistry(MullvadGuids::IdentityQualifier::IncludePersistent); + + // Resolve correct overload. + void(*deleter)(wfp::FilterEngine &, const GUID &) = wfp::ObjectDeleter::DeleteFilter; + + RemoveRange(engine, deleter, registry.equal_range(WfpObjectType::Filter)); + RemoveRange(engine, wfp::ObjectDeleter::DeleteSublayer, registry.equal_range(WfpObjectType::Sublayer)); + RemoveRange(engine, wfp::ObjectDeleter::DeleteProvider, registry.equal_range(WfpObjectType::Provider)); + }; +} + +//static bool ObjectPurger::Execute(RemovalFunctor f) { auto engine = wfp::FilterEngine::StandardSession(); diff --git a/windows/winfw/src/winfw/objectpurger.h b/windows/winfw/src/winfw/objectpurger.h index 7728aac694..9d3ca0146e 100644 --- a/windows/winfw/src/winfw/objectpurger.h +++ b/windows/winfw/src/winfw/objectpurger.h @@ -16,6 +16,7 @@ public: static RemovalFunctor GetRemoveFiltersFunctor(); static RemovalFunctor GetRemoveAllFunctor(); static RemovalFunctor GetRemoveNonPersistentFunctor(); + static RemovalFunctor GetRemovePersistentFunctor(); static bool Execute(RemovalFunctor f); }; diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index c862de4b5a..cd90befead 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -4,6 +4,7 @@ #include "objectpurger.h" #include "mullvadobjects.h" #include "rules/persistent/blockall.h" +#include "rules/baseline/blockall.h" #include "libwfp/ipnetwork.h" #include <windows.h> #include <libcommon/error.h> @@ -167,11 +168,14 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy) delete g_fwContext; g_fwContext = nullptr; + std::stringstream ss; + ss << "Deinitializing WinFw"; + g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext); + // - // Continue blocking if this is what the caller requested + // Continue blocking with persistent rules if this is what the caller requested // and if the current policy is "(net) blocked". // - if (WINFW_CLEANUP_POLICY_CONTINUE_BLOCKING == cleanupPolicy && FwContext::Policy::Blocked == activePolicy) { @@ -182,6 +186,10 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy) rules::persistent::BlockAll blockAll; + std::stringstream ss; + ss << "Adding persistent block rules"; + g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext); + return sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine) { ObjectPurger::GetRemoveNonPersistentFunctor()(engine); @@ -205,6 +213,50 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy) } } + // + // Continue blocking with non-persistent rules if this is what the caller requested + // and if the current policy is "(net) blocked". + // + if (WINFW_CLEANUP_POLICY_BLOCK_UNTIL_REBOOT == cleanupPolicy + && FwContext::Policy::Blocked == activePolicy) + { + try + { + auto engine = wfp::FilterEngine::StandardSession(DEINITIALIZE_TIMEOUT); + auto sessionController = std::make_unique<SessionController>(std::move(engine)); + + rules::baseline::BlockAll blockAll; + + std::stringstream ss; + ss << "Adding ephemeral block rules"; + g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext); + + return sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine) + { + // Keep non-persistent filters intact, the intent is just to *not* + // persist any filters across a BFE restart, not muck around with + // any other filters. We will apply blocking filters anyway. + ObjectPurger::GetRemovePersistentFunctor()(engine); + + return controller.addProvider(*MullvadObjects::Provider()) + && controller.addSublayer(*MullvadObjects::SublayerBaseline()) + && blockAll.apply(controller); + }); + } + catch (std::exception & err) + { + if (nullptr != g_logSink) + { + g_logSink(MULLVAD_LOG_LEVEL_ERROR, err.what(), g_logSinkContext); + } + return false; + } + catch (...) + { + return false; + } + } + return WINFW_POLICY_STATUS_SUCCESS == WinFw_Reset(); } diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index ab2a136ceb..f40d835a95 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -127,6 +127,10 @@ enum WINFW_CLEANUP_POLICY : uint32_t // Remove all objects that have been registered with WFP. WINFW_CLEANUP_POLICY_RESET_FIREWALL = 1, + + // Continue blocking if this is the active policy. + // Adds ephemeral blocking filters that are active until WinFw is shut down (??) + WINFW_CLEANUP_POLICY_BLOCK_UNTIL_REBOOT = 2, }; // |
