summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--mullvad-daemon/src/lib.rs30
-rw-r--r--talpid-core/src/firewall/mod.rs6
-rw-r--r--talpid-core/src/firewall/windows/mod.rs604
-rw-r--r--talpid-core/src/firewall/windows/winfw/mod.rs328
-rw-r--r--talpid-core/src/firewall/windows/winfw/sys.rs277
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs18
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs85
-rw-r--r--talpid-openvpn/Cargo.toml2
-rw-r--r--windows/winfw/src/winfw/objectpurger.cpp16
-rw-r--r--windows/winfw/src/winfw/objectpurger.h1
-rw-r--r--windows/winfw/src/winfw/winfw.cpp56
-rw-r--r--windows/winfw/src/winfw/winfw.h4
13 files changed, 862 insertions, 566 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 5aa2d3b30d..482657c481 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5445,6 +5445,7 @@ dependencies = [
"triggered",
"uuid",
"widestring",
+ "winapi",
"windows-sys 0.52.0",
"winreg 0.51.0",
]
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index efdbea5cc6..c33ecf542b 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -83,6 +83,8 @@ use std::{
sync::{Arc, Weak},
time::Duration,
};
+#[cfg(not(target_os = "android"))]
+use talpid_core::tunnel_state_machine::BlockWhenDisconnected;
use talpid_core::{
mpsc::Sender,
split_tunnel,
@@ -875,7 +877,9 @@ impl Daemon {
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
#[cfg(not(target_os = "android"))]
- block_when_disconnected: settings.block_when_disconnected,
+ block_when_disconnected: BlockWhenDisconnected::from(
+ settings.block_when_disconnected,
+ ),
dns_config: dns::addresses_from_options(&settings.tunnel_options.dns_options),
allowed_endpoint: access_mode_handler
.get_current()
@@ -2385,7 +2389,7 @@ impl Daemon {
Ok(settings_changed) => {
if settings_changed {
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
- block_when_disconnected,
+ BlockWhenDisconnected::from(block_when_disconnected),
oneshot_map(tx, |tx, ()| {
Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
}),
@@ -3087,7 +3091,7 @@ impl Daemon {
{
let (tx, _rx) = oneshot::channel();
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
- self.settings.block_when_disconnected,
+ BlockWhenDisconnected::from(self.settings.block_when_disconnected),
tx,
));
}
@@ -3149,7 +3153,10 @@ impl Daemon {
{
log::debug!("Blocking firewall during shutdown");
let (tx, _rx) = oneshot::channel();
- self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
+ self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
+ BlockWhenDisconnected::yes(),
+ tx,
+ ));
}
self.disconnect_tunnel();
@@ -3164,8 +3171,21 @@ impl Daemon {
// without causing the service to be restarted.
#[cfg(not(target_os = "android"))]
if *self.target_state == TargetState::Secured {
+ let persist = if cfg!(target_os = "windows") {
+ // During app upgrades, as a safety measure, we make the firewall filters
+ // non-persistent. If the installation of the new version fails and
+ // the user is left in blocked state with no app, they can reboot
+ // to regain internet access.
+ self.settings.settings().block_when_disconnected
+ || self.settings.settings().auto_connect
+ } else {
+ true
+ };
let (tx, _rx) = oneshot::channel();
- self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
+ self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
+ BlockWhenDisconnected::yes().persist(persist),
+ tx,
+ ));
}
self.target_state.lock();
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index 053317f2a5..4255854624 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -326,4 +326,10 @@ impl Firewall {
log::info!("Resetting firewall policy");
self.inner.reset_policy()
}
+
+ /// Sets whether the firewall should persist the blocking rules across a reboot.
+ #[cfg(target_os = "windows")]
+ pub fn persist(&mut self, persist: bool) {
+ self.inner.persist(persist);
+ }
}
diff --git a/talpid-core/src/firewall/windows/mod.rs b/talpid-core/src/firewall/windows/mod.rs
index 3ef49938dd..13d3f5db90 100644
--- a/talpid-core/src/firewall/windows/mod.rs
+++ b/talpid-core/src/firewall/windows/mod.rs
@@ -1,20 +1,18 @@
-#![allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare.
+use std::{net::IpAddr, sync::LazyLock};
-use crate::{dns::ResolvedDnsConfig, tunnel::TunnelMetadata};
-
-use std::{ffi::CStr, io, net::IpAddr, ptr, sync::LazyLock};
-
-use self::winfw::*;
-use super::{FirewallArguments, FirewallPolicy, InitialFirewallState};
use talpid_types::{
net::{AllowedEndpoint, AllowedTunnelTraffic},
tunnel::FirewallPolicyError,
ErrorExt,
};
use widestring::WideCString;
-use windows_sys::Win32::Globalization::{MultiByteToWideChar, CP_ACP};
+
+use self::winfw::*;
+use super::{FirewallArguments, FirewallPolicy, InitialFirewallState};
+use crate::{dns::ResolvedDnsConfig, tunnel::TunnelMetadata};
mod hyperv;
+mod winfw;
const HYPERV_LEAK_WARNING_MSG: &str = "Hyper-V (e.g. WSL machines) may leak in blocked states.";
@@ -74,13 +72,19 @@ pub enum Error {
ResettingPolicy(#[source] FirewallPolicyError),
}
-/// Timeout for acquiring the WFP transaction lock
-const WINFW_TIMEOUT_SECONDS: u32 = 5;
-
-const LOGGING_CONTEXT: &[u8] = b"WinFw\0";
-
/// The Windows implementation for the firewall.
-pub struct Firewall(());
+pub struct Firewall {
+ /// If firewall rules should even if firewall module is shut down or dies.
+ ///
+ /// This should only very cautiously be turned off.
+ persist: bool,
+}
+
+impl Default for Firewall {
+ fn default() -> Self {
+ Self { persist: true }
+ }
+}
impl Firewall {
pub fn from_args(args: FirewallArguments) -> Result<Self, Error> {
@@ -92,35 +96,16 @@ impl Firewall {
}
pub fn new() -> Result<Self, Error> {
- unsafe {
- WinFw_Initialize(
- WINFW_TIMEOUT_SECONDS,
- Some(log_sink),
- LOGGING_CONTEXT.as_ptr(),
- )
- .into_result()?
- };
-
+ winfw::initialize()?;
log::trace!("Successfully initialized windows firewall module");
- Ok(Firewall(()))
+ Ok(Firewall::default())
}
fn initialize_blocked(
allowed_endpoint: AllowedEndpoint,
allow_lan: bool,
) -> Result<Self, Error> {
- let cfg = &WinFwSettings::new(allow_lan);
- let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint);
- unsafe {
- WinFw_InitializeBlocked(
- WINFW_TIMEOUT_SECONDS,
- cfg,
- &allowed_endpoint.as_endpoint(),
- Some(log_sink),
- LOGGING_CONTEXT.as_ptr(),
- )
- .into_result()?
- };
+ winfw::initialize_blocked(allowed_endpoint, allow_lan)?;
log::trace!("Successfully initialized windows firewall module to a blocking state");
with_wmi_if_enabled(|wmi| {
@@ -128,7 +113,7 @@ impl Firewall {
consume_and_log_hyperv_err("Add block-all Hyper-V filter", result);
});
- Ok(Firewall(()))
+ Ok(Firewall::default())
}
pub fn apply_policy(&mut self, policy: FirewallPolicy) -> Result<(), Error> {
@@ -150,8 +135,8 @@ impl Firewall {
self.set_connecting_state(
&peer_endpoint,
cfg,
- &tunnel,
- &WinFwAllowedEndpointContainer::from(allowed_endpoint).as_endpoint(),
+ tunnel.as_ref(),
+ allowed_endpoint,
&allowed_tunnel_traffic,
)
}
@@ -190,7 +175,7 @@ impl Firewall {
}
pub fn reset_policy(&mut self) -> Result<(), Error> {
- unsafe { WinFw_Reset().into_result().map_err(Error::ResettingPolicy) }?;
+ winfw::reset().map_err(Error::ResettingPolicy)?;
with_wmi_if_enabled(|wmi| {
let result = hyperv::remove_blocking_hyperv_firewall_rules(wmi);
@@ -200,113 +185,28 @@ impl Firewall {
Ok(())
}
+ pub fn persist(&mut self, persist: bool) {
+ self.persist = persist;
+ }
+
fn set_connecting_state(
&mut self,
- endpoint: &AllowedEndpoint,
+ peer_endpoint: &AllowedEndpoint,
winfw_settings: &WinFwSettings,
- tunnel_metadata: &Option<TunnelMetadata>,
- allowed_endpoint: &WinFwAllowedEndpoint<'_>,
+ tunnel_metadata: Option<&TunnelMetadata>,
+ allowed_endpoint: AllowedEndpoint,
allowed_tunnel_traffic: &AllowedTunnelTraffic,
) -> Result<(), Error> {
log::trace!("Applying 'connecting' firewall policy");
- let ip_str = widestring_ip(endpoint.endpoint.address.ip());
- let winfw_relay = WinFwEndpoint {
- ip: ip_str.as_ptr(),
- port: endpoint.endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.endpoint.protocol),
- };
-
- // SAFETY: `endpoint1_ip`, `endpoint2_ip`, `endpoint1`, `endpoint2`, `relay_client_wstrs`
- // must not be dropped until `WinFw_ApplyPolicyConnecting` has returned.
-
- let relay_client_wstrs: Vec<_> = endpoint
- .clients
- .iter()
- .map(WideCString::from_os_str_truncate)
- .collect();
- let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs
- .iter()
- .map(|wstr| wstr.as_ptr())
- .collect();
- let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len();
-
- let interface_wstr = tunnel_metadata
- .as_ref()
- .map(|metadata| WideCString::from_str_truncate(&metadata.interface));
- let interface_wstr_ptr = if let Some(ref wstr) = interface_wstr {
- wstr.as_ptr()
- } else {
- ptr::null()
- };
-
- let mut endpoint1_ip = WideCString::new();
- let mut endpoint2_ip = WideCString::new();
- let (endpoint1, endpoint2) = match allowed_tunnel_traffic {
- AllowedTunnelTraffic::One(endpoint) => {
- endpoint1_ip = widestring_ip(endpoint.address.ip());
- (
- Some(WinFwEndpoint {
- ip: endpoint1_ip.as_ptr(),
- port: endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.protocol),
- }),
- None,
- )
- }
- AllowedTunnelTraffic::Two(endpoint1, endpoint2) => {
- endpoint1_ip = widestring_ip(endpoint1.address.ip());
- let endpoint1 = Some(WinFwEndpoint {
- ip: endpoint1_ip.as_ptr(),
- port: endpoint1.address.port(),
- protocol: WinFwProt::from(endpoint1.protocol),
- });
- endpoint2_ip = widestring_ip(endpoint2.address.ip());
- let endpoint2 = Some(WinFwEndpoint {
- ip: endpoint2_ip.as_ptr(),
- port: endpoint2.address.port(),
- protocol: WinFwProt::from(endpoint2.protocol),
- });
- (endpoint1, endpoint2)
- }
- AllowedTunnelTraffic::None | AllowedTunnelTraffic::All => (None, None),
- };
-
- let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic {
- type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic),
- endpoint1: endpoint1
- .as_ref()
- .map(|ep| ep as *const _)
- .unwrap_or(ptr::null()),
- endpoint2: endpoint2
- .as_ref()
- .map(|ep| ep as *const _)
- .unwrap_or(ptr::null()),
- };
-
- let res = unsafe {
- WinFw_ApplyPolicyConnecting(
- winfw_settings,
- &winfw_relay,
- relay_client_wstr_ptrs.as_ptr(),
- relay_client_wstr_ptrs_len,
- interface_wstr_ptr,
- allowed_endpoint,
- &allowed_tunnel_traffic,
- )
- .into_result()
- .map_err(Error::ApplyingConnectingPolicy)
- };
- // SAFETY: All of these hold stack allocated memory which is pointed to by
- // `allowed_tunnel_traffic` and must remain allocated until `WinFw_ApplyPolicyConnecting`
- // has returned.
- drop(endpoint1_ip);
- drop(endpoint2_ip);
- #[allow(clippy::drop_non_drop)]
- drop(endpoint1);
- #[allow(clippy::drop_non_drop)]
- drop(endpoint2);
- drop(relay_client_wstrs);
- res
+ let tunnel_interface = tunnel_metadata.map(|metadata| metadata.interface.as_ref());
+ winfw::apply_policy_connecting(
+ peer_endpoint,
+ winfw_settings,
+ tunnel_interface,
+ allowed_endpoint,
+ allowed_tunnel_traffic,
+ )
+ .map_err(Error::ApplyingConnectingPolicy)
}
fn set_connected_state(
@@ -317,69 +217,9 @@ impl Firewall {
dns_config: &ResolvedDnsConfig,
) -> Result<(), Error> {
log::trace!("Applying 'connected' firewall policy");
- let ip_str = widestring_ip(endpoint.endpoint.address.ip());
-
- let tunnel_alias = WideCString::from_str_truncate(&tunnel_metadata.interface);
-
- // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay
- let winfw_relay = WinFwEndpoint {
- ip: ip_str.as_ptr(),
- port: endpoint.endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.endpoint.protocol),
- };
-
- // SAFETY: `relay_client_wstrs` must not be dropped until `WinFw_ApplyPolicyConnected` has
- // returned.
- let relay_client_wstrs: Vec<_> = endpoint
- .clients
- .iter()
- .map(WideCString::from_os_str_truncate)
- .collect();
- let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs
- .iter()
- .map(|wstr| wstr.as_ptr())
- .collect();
- let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len();
-
- let tunnel_dns_servers: Vec<WideCString> = dns_config
- .tunnel_config()
- .iter()
- .cloned()
- .map(widestring_ip)
- .collect();
- let tunnel_dns_servers: Vec<*const u16> =
- tunnel_dns_servers.iter().map(|ip| ip.as_ptr()).collect();
- let non_tunnel_dns_servers: Vec<WideCString> = dns_config
- .non_tunnel_config()
- .iter()
- .cloned()
- .map(widestring_ip)
- .collect();
- let non_tunnel_dns_servers: Vec<*const u16> = non_tunnel_dns_servers
- .iter()
- .map(|ip| ip.as_ptr())
- .collect();
-
- let result = unsafe {
- WinFw_ApplyPolicyConnected(
- winfw_settings,
- &winfw_relay,
- relay_client_wstr_ptrs.as_ptr(),
- relay_client_wstr_ptrs_len,
- tunnel_alias.as_ptr(),
- tunnel_dns_servers.as_ptr(),
- tunnel_dns_servers.len(),
- non_tunnel_dns_servers.as_ptr(),
- non_tunnel_dns_servers.len(),
- )
- .into_result()
+ let tunnel_interface = &tunnel_metadata.interface;
+ winfw::apply_policy_connected(endpoint, winfw_settings, tunnel_interface, dns_config)
.map_err(Error::ApplyingConnectedPolicy)
- };
-
- // SAFETY: `relay_client_wstrs` holds memory pointed to by pointers used in C++ and must
- // not be dropped until after `WinFw_ApplyPolicyConnected` has returned.
- drop(relay_client_wstrs);
- result
}
fn set_blocked_state(
@@ -388,35 +228,25 @@ impl Firewall {
allowed_endpoint: Option<WinFwAllowedEndpointContainer>,
) -> Result<(), Error> {
log::trace!("Applying 'blocked' firewall policy");
- let endpoint = allowed_endpoint
- .as_ref()
- .map(WinFwAllowedEndpointContainer::as_endpoint);
-
- unsafe {
- WinFw_ApplyPolicyBlocked(
- winfw_settings,
- endpoint
- .as_ref()
- .map(|container| container as *const _)
- .unwrap_or(ptr::null()),
- )
- .into_result()
+ winfw::apply_policy_blocked(winfw_settings, allowed_endpoint)
.map_err(Error::ApplyingBlockedPolicy)
- }
}
}
impl Drop for Firewall {
fn drop(&mut self) {
- if unsafe {
- WinFw_Deinitialize(WinFwCleanupPolicy::ContinueBlocking)
- .into_result()
- .is_ok()
- } {
- log::trace!("Successfully deinitialized windows firewall module");
+ // Deinitialize WinFW with or without persistent filters.
+ // All other filters should still remain intact.
+ let cleanup_policy = if self.persist {
+ WinFwCleanupPolicy::ContinueBlocking
} else {
- log::error!("Failed to deinitialize windows firewall module");
+ WinFwCleanupPolicy::BlockingUntilReboot
};
+
+ match winfw::deinit(cleanup_policy) {
+ Ok(()) => log::trace!("Successfully deinitialized windows firewall module"),
+ Err(_) => log::error!("Failed to deinitialize windows firewall module"),
+ }
}
}
@@ -424,109 +254,6 @@ fn widestring_ip(ip: IpAddr) -> WideCString {
WideCString::from_str_truncate(ip.to_string())
}
-/// Logging callback implementation.
-pub extern "system" fn log_sink(
- level: log::Level,
- msg: *const std::ffi::c_char,
- context: *mut std::ffi::c_void,
-) {
- if msg.is_null() {
- log::error!("Log message from FFI boundary is NULL");
- } else {
- let target = if context.is_null() {
- "UNKNOWN".into()
- } else {
- unsafe { CStr::from_ptr(context as *const _).to_string_lossy() }
- };
-
- let mb_string = unsafe { CStr::from_ptr(msg) };
-
- let managed_msg = match multibyte_to_wide(mb_string, CP_ACP) {
- Ok(wide_str) => String::from_utf16_lossy(&wide_str),
- // Best effort:
- Err(_) => mb_string.to_string_lossy().into_owned(),
- };
-
- log::logger().log(
- &log::Record::builder()
- .level(level)
- .target(&target)
- .args(format_args!("{}", managed_msg))
- .build(),
- );
- }
-}
-
-/// Convert `mb_string`, with the given character encoding `codepage`, to a UTF-16 string.
-fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Error> {
- if mb_string.is_empty() {
- return Ok(vec![]);
- }
-
- // SAFETY: `mb_string` is null-terminated and valid.
- let wc_size = unsafe {
- MultiByteToWideChar(
- codepage,
- 0,
- mb_string.as_ptr() as *const u8,
- -1,
- ptr::null_mut(),
- 0,
- )
- };
-
- if wc_size == 0 {
- return Err(io::Error::last_os_error());
- }
-
- let mut wc_buffer = vec![0u16; usize::try_from(wc_size).unwrap()];
-
- // SAFETY: `wc_buffer` can contain up to `wc_size` characters, including a null
- // terminator.
- let chars_written = unsafe {
- MultiByteToWideChar(
- codepage,
- 0,
- mb_string.as_ptr() as *const u8,
- -1,
- wc_buffer.as_mut_ptr(),
- wc_size,
- )
- };
-
- if chars_written == 0 {
- return Err(io::Error::last_os_error());
- }
-
- wc_buffer.truncate(usize::try_from(chars_written - 1).unwrap());
-
- Ok(wc_buffer)
-}
-
-#[cfg(test)]
-mod test {
- use super::multibyte_to_wide;
- use windows_sys::Win32::Globalization::CP_UTF8;
-
- #[test]
- fn test_multibyte_to_wide() {
- // € = 0x20AC in UTF-16
- let converted = multibyte_to_wide(c"€€", CP_UTF8);
- const EXPECTED: &[u16] = &[0x20AC, 0x20AC];
- assert!(
- matches!(converted.as_deref(), Ok(EXPECTED)),
- "expected Ok({EXPECTED:?}), got {converted:?}",
- );
-
- // boundary case
- let converted = multibyte_to_wide(c"", CP_UTF8);
- assert!(
- matches!(converted.as_deref(), Ok([])),
- "unexpected result {converted:?}"
- );
- }
-}
-
// Convert `result` into an option and log the error, if any.
fn consume_and_log_hyperv_err<T>(
action: &'static str,
@@ -553,226 +280,3 @@ fn with_wmi_if_enabled(f: impl FnOnce(&wmi::WMIConnection)) {
}
})
}
-
-#[allow(non_snake_case)]
-mod winfw {
- use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString};
- use std::ffi::{c_char, c_void};
- use talpid_types::net::TransportProtocol;
-
- type LogSink = extern "system" fn(level: log::Level, msg: *const c_char, context: *mut c_void);
-
- pub struct WinFwAllowedEndpointContainer {
- _clients: Box<[WideCString]>,
- clients_ptrs: Box<[*const u16]>,
- ip: WideCString,
- port: u16,
- protocol: WinFwProt,
- }
-
- impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer {
- fn from(endpoint: AllowedEndpoint) -> Self {
- let clients = endpoint
- .clients
- .iter()
- .map(WideCString::from_os_str_truncate)
- .collect::<Box<_>>();
- let clients_ptrs = clients
- .iter()
- .map(|client| client.as_ptr())
- .collect::<Box<_>>();
- let ip = widestring_ip(endpoint.endpoint.address.ip());
-
- WinFwAllowedEndpointContainer {
- _clients: clients,
- clients_ptrs,
- ip,
- port: endpoint.endpoint.address.port(),
- protocol: WinFwProt::from(endpoint.endpoint.protocol),
- }
- }
- }
-
- impl WinFwAllowedEndpointContainer {
- pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> {
- WinFwAllowedEndpoint {
- num_clients: self.clients_ptrs.len() as u32,
- clients: self.clients_ptrs.as_ptr(),
- endpoint: WinFwEndpoint {
- ip: self.ip.as_ptr(),
- port: self.port,
- protocol: self.protocol,
- },
-
- _phantom: std::marker::PhantomData,
- }
- }
- }
-
- #[repr(C)]
- pub struct WinFwAllowedEndpoint<'a> {
- num_clients: u32,
- clients: *const *const libc::wchar_t,
- endpoint: WinFwEndpoint,
-
- _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>,
- }
-
- #[repr(C)]
- pub struct WinFwAllowedTunnelTraffic {
- pub type_: WinFwAllowedTunnelTrafficType,
- pub endpoint1: *const WinFwEndpoint,
- pub endpoint2: *const WinFwEndpoint,
- }
-
- #[repr(u8)]
- #[derive(Clone, Copy)]
- pub enum WinFwAllowedTunnelTrafficType {
- None,
- All,
- One,
- Two,
- }
-
- impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType {
- fn from(traffic: &AllowedTunnelTraffic) -> Self {
- match traffic {
- AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None,
- AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All,
- AllowedTunnelTraffic::One(..) => WinFwAllowedTunnelTrafficType::One,
- AllowedTunnelTraffic::Two(..) => WinFwAllowedTunnelTrafficType::Two,
- }
- }
- }
-
- #[repr(C)]
- pub struct WinFwEndpoint {
- pub ip: *const libc::wchar_t,
- pub port: u16,
- pub protocol: WinFwProt,
- }
-
- #[repr(u8)]
- #[derive(Clone, Copy)]
- pub enum WinFwProt {
- Tcp = 0u8,
- Udp = 1u8,
- }
-
- impl From<TransportProtocol> for WinFwProt {
- fn from(prot: TransportProtocol) -> WinFwProt {
- match prot {
- TransportProtocol::Tcp => WinFwProt::Tcp,
- TransportProtocol::Udp => WinFwProt::Udp,
- }
- }
- }
-
- #[repr(C)]
- pub struct WinFwSettings {
- permitDhcp: bool,
- permitLan: bool,
- }
-
- impl WinFwSettings {
- pub fn new(permit_lan: bool) -> WinFwSettings {
- WinFwSettings {
- permitDhcp: true,
- permitLan: permit_lan,
- }
- }
- }
-
- #[allow(dead_code)]
- #[repr(u32)]
- #[derive(Clone, Copy)]
- pub enum WinFwCleanupPolicy {
- ContinueBlocking = 0,
- ResetFirewall = 1,
- }
-
- ffi_error!(InitializationResult, Error::Initialization);
- ffi_error!(DeinitializationResult, Error::Deinitialization);
-
- #[derive(Debug)]
- #[allow(dead_code)]
- #[repr(u32)]
- pub enum WinFwPolicyStatus {
- Success = 0,
- GeneralFailure = 1,
- LockTimeout = 2,
- }
-
- impl WinFwPolicyStatus {
- pub fn into_result(self) -> Result<(), super::FirewallPolicyError> {
- match self {
- WinFwPolicyStatus::Success => Ok(()),
- WinFwPolicyStatus::GeneralFailure => Err(super::FirewallPolicyError::Generic),
- WinFwPolicyStatus::LockTimeout => {
- // TODO: Obtain application name and string from WinFw
- Err(super::FirewallPolicyError::Locked(None))
- }
- }
- }
- }
-
- impl From<WinFwPolicyStatus> for Result<(), super::FirewallPolicyError> {
- fn from(val: WinFwPolicyStatus) -> Self {
- val.into_result()
- }
- }
-
- unsafe extern "system" {
- #[link_name = "WinFw_Initialize"]
- pub fn WinFw_Initialize(
- timeout: libc::c_uint,
- sink: Option<LogSink>,
- sink_context: *const u8,
- ) -> InitializationResult;
-
- #[link_name = "WinFw_InitializeBlocked"]
- pub fn WinFw_InitializeBlocked(
- timeout: libc::c_uint,
- settings: &WinFwSettings,
- allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
- sink: Option<LogSink>,
- sink_context: *const u8,
- ) -> InitializationResult;
-
- #[link_name = "WinFw_Deinitialize"]
- pub fn WinFw_Deinitialize(cleanupPolicy: WinFwCleanupPolicy) -> DeinitializationResult;
-
- #[link_name = "WinFw_ApplyPolicyConnecting"]
- pub fn WinFw_ApplyPolicyConnecting(
- settings: &WinFwSettings,
- relay: &WinFwEndpoint,
- relayClient: *const *const libc::wchar_t,
- relayClientLen: usize,
- tunnelIfaceAlias: *const libc::wchar_t,
- allowedEndpoint: *const WinFwAllowedEndpoint<'_>,
- allowedTunnelTraffic: &WinFwAllowedTunnelTraffic,
- ) -> WinFwPolicyStatus;
-
- #[link_name = "WinFw_ApplyPolicyConnected"]
- pub fn WinFw_ApplyPolicyConnected(
- settings: &WinFwSettings,
- relay: &WinFwEndpoint,
- relayClient: *const *const libc::wchar_t,
- relayClientLen: usize,
- tunnelIfaceAlias: *const libc::wchar_t,
- tunnelDnsServers: *const *const libc::wchar_t,
- numTunnelDnsServers: usize,
- nonTunnelDnsServers: *const *const libc::wchar_t,
- numNonTunnelDnsServers: usize,
- ) -> WinFwPolicyStatus;
-
- #[link_name = "WinFw_ApplyPolicyBlocked"]
- pub fn WinFw_ApplyPolicyBlocked(
- settings: &WinFwSettings,
- allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
- ) -> WinFwPolicyStatus;
-
- #[link_name = "WinFw_Reset"]
- pub fn WinFw_Reset() -> WinFwPolicyStatus;
- }
-}
diff --git a/talpid-core/src/firewall/windows/winfw/mod.rs b/talpid-core/src/firewall/windows/winfw/mod.rs
new file mode 100644
index 0000000000..a13ed7c7da
--- /dev/null
+++ b/talpid-core/src/firewall/windows/winfw/mod.rs
@@ -0,0 +1,328 @@
+//! Safe bindings for the WinFW library.
+
+use super::{widestring_ip, AllowedEndpoint, AllowedTunnelTraffic, Error, WideCString};
+use std::ptr;
+use talpid_types::{net::TransportProtocol, tunnel::FirewallPolicyError};
+
+mod sys;
+use sys::*;
+pub use sys::{WinFwAllowedEndpointContainer, WinFwCleanupPolicy, WinFwSettings};
+
+/// Timeout for acquiring the WFP transaction lock
+const WINFW_TIMEOUT_SECONDS: u32 = 5;
+
+/// Initialize WinFw module. Returns an initialization error if called multiple times without
+/// interleaving [Self::deinit].
+pub(super) fn initialize() -> Result<(), Error> {
+ // SAFETY: This function is always safe to call.
+ let init = unsafe {
+ WinFw_Initialize(
+ WINFW_TIMEOUT_SECONDS,
+ Some(log_sink),
+ LOGGING_CONTEXT.as_ptr(),
+ )
+ };
+
+ init.into_result()
+}
+
+/// Initialize WinFw module and apply blocking rules. Returns an initialization error if called
+/// multiple times without interleaving [Self::deinit].
+pub(super) fn initialize_blocked(
+ allowed_endpoint: AllowedEndpoint,
+ allow_lan: bool,
+) -> Result<(), Error> {
+ let cfg = WinFwSettings::new(allow_lan);
+ let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint);
+ // SAFETY: This function is always safe to call.
+ let init = unsafe {
+ WinFw_InitializeBlocked(
+ WINFW_TIMEOUT_SECONDS,
+ &cfg,
+ &allowed_endpoint.as_endpoint(),
+ Some(log_sink),
+ LOGGING_CONTEXT.as_ptr(),
+ )
+ };
+ init.into_result()
+}
+
+/// Deinitialize WinFw module. Trying to use WinFw after calling deinit will result in an
+/// error before [Self::initialize] is called.
+pub(super) fn deinit(cleanup_policy: WinFwCleanupPolicy) -> Result<(), Error> {
+ // SAFETY: WinFw_Deinitialize is always safe to call.
+ // Will simply return false if WinFw already has been deinitialized.
+ let deinit = unsafe { WinFw_Deinitialize(cleanup_policy) };
+ deinit.into_result()
+}
+
+/// Reset all firewall policies applied by [winfw].
+///
+/// Sets the underlying active policy to None.
+pub(super) fn reset() -> Result<(), FirewallPolicyError> {
+ // SAFETY: WinFw_Reset is always safe to call, even before WinFW has been
+ // initialized and after WinFW has been deinitialized.
+ let reset = unsafe { WinFw_Reset() };
+ reset.into_result()
+}
+
+/// Apply blocking firewall rules Sets the underlying active policy to Blocked. Exceptions
+/// permitted through the firewall is defined by `winfw_settings` and `allowed_endpoint`. See
+/// the BlockAll class for more information.
+///
+/// Returns an error if [winfw] is not initialized.
+pub(super) fn apply_policy_blocked(
+ winfw_settings: &WinFwSettings,
+ allowed_endpoint: Option<WinFwAllowedEndpointContainer>,
+) -> Result<(), FirewallPolicyError> {
+ let allowed_endpoint = allowed_endpoint
+ .as_ref()
+ .map(WinFwAllowedEndpointContainer::as_endpoint)
+ .as_ref()
+ .map(ptr::from_ref)
+ .unwrap_or(ptr::null());
+ // SAFETY: This function is always safe to call
+ let application = unsafe { WinFw_ApplyPolicyBlocked(winfw_settings, allowed_endpoint) };
+ application.into_result()
+}
+
+pub(super) fn apply_policy_connecting(
+ peer_endpoint: &AllowedEndpoint,
+ winfw_settings: &WinFwSettings,
+ tunnel_interface: Option<&str>,
+ allowed_endpoint: AllowedEndpoint,
+ allowed_tunnel_traffic: &AllowedTunnelTraffic,
+) -> Result<(), FirewallPolicyError> {
+ let ip_str = widestring_ip(peer_endpoint.endpoint.address.ip());
+ let winfw_relay = WinFwEndpoint {
+ ip: ip_str.as_ptr(),
+ port: peer_endpoint.endpoint.address.port(),
+ protocol: WinFwProt::from(peer_endpoint.endpoint.protocol),
+ };
+
+ // SAFETY: `endpoint1_ip`, `endpoint2_ip`, `endpoint1`, `endpoint2`, `relay_client_wstrs`
+ // must not be dropped until `WinFw_ApplyPolicyConnecting` has returned.
+
+ let relay_client_wstrs: Vec<_> = peer_endpoint
+ .clients
+ .iter()
+ .map(WideCString::from_os_str_truncate)
+ .collect();
+ let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs
+ .iter()
+ .map(|wstr| wstr.as_ptr())
+ .collect();
+ let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len();
+
+ let interface_wstr = tunnel_interface
+ .as_ref()
+ .map(WideCString::from_str_truncate);
+ let interface_wstr_ptr = if let Some(ref wstr) = interface_wstr {
+ wstr.as_ptr()
+ } else {
+ ptr::null()
+ };
+
+ let mut endpoint1_ip = WideCString::new();
+ let mut endpoint2_ip = WideCString::new();
+ let (endpoint1, endpoint2) = match allowed_tunnel_traffic {
+ AllowedTunnelTraffic::One(endpoint) => {
+ endpoint1_ip = widestring_ip(endpoint.address.ip());
+ (
+ Some(WinFwEndpoint {
+ ip: endpoint1_ip.as_ptr(),
+ port: endpoint.address.port(),
+ protocol: WinFwProt::from(endpoint.protocol),
+ }),
+ None,
+ )
+ }
+ AllowedTunnelTraffic::Two(endpoint1, endpoint2) => {
+ endpoint1_ip = widestring_ip(endpoint1.address.ip());
+ let endpoint1 = Some(WinFwEndpoint {
+ ip: endpoint1_ip.as_ptr(),
+ port: endpoint1.address.port(),
+ protocol: WinFwProt::from(endpoint1.protocol),
+ });
+ endpoint2_ip = widestring_ip(endpoint2.address.ip());
+ let endpoint2 = Some(WinFwEndpoint {
+ ip: endpoint2_ip.as_ptr(),
+ port: endpoint2.address.port(),
+ protocol: WinFwProt::from(endpoint2.protocol),
+ });
+ (endpoint1, endpoint2)
+ }
+ AllowedTunnelTraffic::None | AllowedTunnelTraffic::All => (None, None),
+ };
+
+ let allowed_endpoint = WinFwAllowedEndpointContainer::from(allowed_endpoint);
+ let allowed_endpoint = allowed_endpoint.as_endpoint();
+
+ let allowed_tunnel_traffic = WinFwAllowedTunnelTraffic {
+ type_: WinFwAllowedTunnelTrafficType::from(allowed_tunnel_traffic),
+ endpoint1: endpoint1
+ .as_ref()
+ .map(|ep| ep as *const _)
+ .unwrap_or(ptr::null()),
+ endpoint2: endpoint2
+ .as_ref()
+ .map(|ep| ep as *const _)
+ .unwrap_or(ptr::null()),
+ };
+
+ #[allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare.
+ let res = unsafe {
+ WinFw_ApplyPolicyConnecting(
+ winfw_settings,
+ &winfw_relay,
+ relay_client_wstr_ptrs.as_ptr(),
+ relay_client_wstr_ptrs_len,
+ interface_wstr_ptr,
+ &allowed_endpoint,
+ &allowed_tunnel_traffic,
+ )
+ };
+ // SAFETY: All of these hold stack allocated memory which is pointed to by
+ // `allowed_tunnel_traffic` and must remain allocated until `WinFw_ApplyPolicyConnecting`
+ // has returned.
+ drop(endpoint1_ip);
+ drop(endpoint2_ip);
+ #[allow(clippy::drop_non_drop)]
+ drop(endpoint1);
+ #[allow(clippy::drop_non_drop)]
+ drop(endpoint2);
+ drop(relay_client_wstrs);
+ res.into_result()
+}
+
+pub(super) fn apply_policy_connected(
+ endpoint: &AllowedEndpoint,
+ winfw_settings: &WinFwSettings,
+ tunnel_interface: &str,
+ dns_config: &crate::dns::ResolvedDnsConfig,
+) -> Result<(), FirewallPolicyError> {
+ let ip_str = widestring_ip(endpoint.endpoint.address.ip());
+
+ let tunnel_alias = WideCString::from_str_truncate(tunnel_interface);
+
+ // ip_str, gateway_str and tunnel_alias have to outlive winfw_relay
+ let winfw_relay = WinFwEndpoint {
+ ip: ip_str.as_ptr(),
+ port: endpoint.endpoint.address.port(),
+ protocol: WinFwProt::from(endpoint.endpoint.protocol),
+ };
+
+ // SAFETY: `relay_client_wstrs` must not be dropped until `WinFw_ApplyPolicyConnected` has
+ // returned.
+ let relay_client_wstrs: Vec<_> = endpoint
+ .clients
+ .iter()
+ .map(WideCString::from_os_str_truncate)
+ .collect();
+ let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs
+ .iter()
+ .map(|wstr| wstr.as_ptr())
+ .collect();
+ let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len();
+
+ let tunnel_dns_servers: Vec<WideCString> = dns_config
+ .tunnel_config()
+ .iter()
+ .cloned()
+ .map(widestring_ip)
+ .collect();
+ let tunnel_dns_servers: Vec<*const u16> =
+ tunnel_dns_servers.iter().map(|ip| ip.as_ptr()).collect();
+ let non_tunnel_dns_servers: Vec<WideCString> = dns_config
+ .non_tunnel_config()
+ .iter()
+ .cloned()
+ .map(widestring_ip)
+ .collect();
+ let non_tunnel_dns_servers: Vec<*const u16> = non_tunnel_dns_servers
+ .iter()
+ .map(|ip| ip.as_ptr())
+ .collect();
+
+ #[allow(clippy::undocumented_unsafe_blocks)] // Remove me if you dare.
+ let result = unsafe {
+ WinFw_ApplyPolicyConnected(
+ winfw_settings,
+ &winfw_relay,
+ relay_client_wstr_ptrs.as_ptr(),
+ relay_client_wstr_ptrs_len,
+ tunnel_alias.as_ptr(),
+ tunnel_dns_servers.as_ptr(),
+ tunnel_dns_servers.len(),
+ non_tunnel_dns_servers.as_ptr(),
+ non_tunnel_dns_servers.len(),
+ )
+ };
+
+ // SAFETY: `relay_client_wstrs` holds memory pointed to by pointers used in C++ and must
+ // not be dropped until after `WinFw_ApplyPolicyConnected` has returned.
+ drop(relay_client_wstrs);
+ result.into_result()
+}
+
+impl From<AllowedEndpoint> for WinFwAllowedEndpointContainer {
+ fn from(endpoint: AllowedEndpoint) -> Self {
+ let clients = endpoint
+ .clients
+ .iter()
+ .map(WideCString::from_os_str_truncate)
+ .collect::<Box<_>>();
+ let clients_ptrs = clients
+ .iter()
+ .map(|client| client.as_ptr())
+ .collect::<Box<_>>();
+ let ip = widestring_ip(endpoint.endpoint.address.ip());
+
+ WinFwAllowedEndpointContainer {
+ _clients: clients,
+ clients_ptrs,
+ ip,
+ port: endpoint.endpoint.address.port(),
+ protocol: WinFwProt::from(endpoint.endpoint.protocol),
+ }
+ }
+}
+
+impl From<&AllowedTunnelTraffic> for WinFwAllowedTunnelTrafficType {
+ fn from(traffic: &AllowedTunnelTraffic) -> Self {
+ match traffic {
+ AllowedTunnelTraffic::None => WinFwAllowedTunnelTrafficType::None,
+ AllowedTunnelTraffic::All => WinFwAllowedTunnelTrafficType::All,
+ AllowedTunnelTraffic::One(..) => WinFwAllowedTunnelTrafficType::One,
+ AllowedTunnelTraffic::Two(..) => WinFwAllowedTunnelTrafficType::Two,
+ }
+ }
+}
+
+impl From<TransportProtocol> for WinFwProt {
+ fn from(prot: TransportProtocol) -> WinFwProt {
+ match prot {
+ TransportProtocol::Tcp => WinFwProt::Tcp,
+ TransportProtocol::Udp => WinFwProt::Udp,
+ }
+ }
+}
+
+impl WinFwPolicyStatus {
+ pub fn into_result(self) -> Result<(), super::FirewallPolicyError> {
+ match self {
+ WinFwPolicyStatus::Success => Ok(()),
+ WinFwPolicyStatus::GeneralFailure => Err(super::FirewallPolicyError::Generic),
+ WinFwPolicyStatus::LockTimeout => {
+ // TODO: Obtain application name and string from WinFw
+ Err(super::FirewallPolicyError::Locked(None))
+ }
+ }
+ }
+}
+
+impl From<WinFwPolicyStatus> for Result<(), super::FirewallPolicyError> {
+ fn from(val: WinFwPolicyStatus) -> Self {
+ val.into_result()
+ }
+}
diff --git a/talpid-core/src/firewall/windows/winfw/sys.rs b/talpid-core/src/firewall/windows/winfw/sys.rs
new file mode 100644
index 0000000000..3349205cf3
--- /dev/null
+++ b/talpid-core/src/firewall/windows/winfw/sys.rs
@@ -0,0 +1,277 @@
+//! Data types and thin wrappers around WinFW C FFI.
+
+use std::ffi::{c_char, c_void, CStr};
+use std::io;
+use std::ptr;
+
+use windows_sys::Win32::Globalization::{MultiByteToWideChar, CP_ACP};
+
+use super::{Error, WideCString};
+
+pub const LOGGING_CONTEXT: &CStr = c"WinFw";
+
+#[repr(C)]
+#[allow(non_snake_case)]
+pub struct WinFwSettings {
+ permitDhcp: bool,
+ permitLan: bool,
+}
+
+impl WinFwSettings {
+ pub fn new(permit_lan: bool) -> WinFwSettings {
+ WinFwSettings {
+ permitDhcp: true,
+ permitLan: permit_lan,
+ }
+ }
+}
+
+#[allow(dead_code)]
+#[repr(u32)]
+#[derive(Clone, Copy)]
+pub enum WinFwCleanupPolicy {
+ ContinueBlocking = 0,
+ ResetFirewall = 1,
+ BlockingUntilReboot = 2,
+}
+
+#[derive(Debug)]
+#[allow(dead_code)]
+#[repr(u32)]
+pub enum WinFwPolicyStatus {
+ Success = 0,
+ GeneralFailure = 1,
+ LockTimeout = 2,
+}
+
+#[repr(C)]
+pub struct WinFwEndpoint {
+ pub ip: *const libc::wchar_t,
+ pub port: u16,
+ pub protocol: WinFwProt,
+}
+
+#[repr(u8)]
+#[derive(Clone, Copy)]
+pub enum WinFwProt {
+ Tcp = 0u8,
+ Udp = 1u8,
+}
+
+#[repr(C)]
+pub struct WinFwAllowedEndpoint<'a> {
+ num_clients: u32,
+ clients: *const *const libc::wchar_t,
+ endpoint: WinFwEndpoint,
+
+ _phantom: std::marker::PhantomData<&'a WinFwAllowedEndpointContainer>,
+}
+
+pub struct WinFwAllowedEndpointContainer {
+ pub _clients: Box<[WideCString]>,
+ pub clients_ptrs: Box<[*const u16]>,
+ pub ip: WideCString,
+ pub port: u16,
+ pub protocol: WinFwProt,
+}
+
+impl WinFwAllowedEndpointContainer {
+ pub fn as_endpoint(&self) -> WinFwAllowedEndpoint<'_> {
+ WinFwAllowedEndpoint {
+ num_clients: self.clients_ptrs.len() as u32,
+ clients: self.clients_ptrs.as_ptr(),
+ endpoint: WinFwEndpoint {
+ ip: self.ip.as_ptr(),
+ port: self.port,
+ protocol: self.protocol,
+ },
+
+ _phantom: std::marker::PhantomData,
+ }
+ }
+}
+
+#[repr(C)]
+pub struct WinFwAllowedTunnelTraffic {
+ pub type_: WinFwAllowedTunnelTrafficType,
+ pub endpoint1: *const WinFwEndpoint,
+ pub endpoint2: *const WinFwEndpoint,
+}
+
+#[repr(u8)]
+#[derive(Clone, Copy)]
+pub enum WinFwAllowedTunnelTrafficType {
+ None,
+ All,
+ One,
+ Two,
+}
+
+ffi_error!(InitializationResult, Error::Initialization);
+ffi_error!(DeinitializationResult, Error::Deinitialization);
+
+unsafe extern "system" {
+ #[link_name = "WinFw_Initialize"]
+ pub fn WinFw_Initialize(
+ timeout: libc::c_uint,
+ sink: Option<LogSink>,
+ sink_context: *const c_char,
+ ) -> InitializationResult;
+
+ #[link_name = "WinFw_InitializeBlocked"]
+ pub fn WinFw_InitializeBlocked(
+ timeout: libc::c_uint,
+ settings: &WinFwSettings,
+ allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
+ sink: Option<LogSink>,
+ sink_context: *const c_char,
+ ) -> InitializationResult;
+
+ #[link_name = "WinFw_Deinitialize"]
+ pub fn WinFw_Deinitialize(cleanupPolicy: WinFwCleanupPolicy) -> DeinitializationResult;
+
+ #[link_name = "WinFw_ApplyPolicyConnecting"]
+ pub fn WinFw_ApplyPolicyConnecting(
+ settings: &WinFwSettings,
+ relay: &WinFwEndpoint,
+ relayClient: *const *const libc::wchar_t,
+ relayClientLen: usize,
+ tunnelIfaceAlias: *const libc::wchar_t,
+ allowedEndpoint: *const WinFwAllowedEndpoint<'_>,
+ allowedTunnelTraffic: &WinFwAllowedTunnelTraffic,
+ ) -> WinFwPolicyStatus;
+
+ #[link_name = "WinFw_ApplyPolicyConnected"]
+ pub fn WinFw_ApplyPolicyConnected(
+ settings: &WinFwSettings,
+ relay: &WinFwEndpoint,
+ relayClient: *const *const libc::wchar_t,
+ relayClientLen: usize,
+ tunnelIfaceAlias: *const libc::wchar_t,
+ tunnelDnsServers: *const *const libc::wchar_t,
+ numTunnelDnsServers: usize,
+ nonTunnelDnsServers: *const *const libc::wchar_t,
+ numNonTunnelDnsServers: usize,
+ ) -> WinFwPolicyStatus;
+
+ #[link_name = "WinFw_ApplyPolicyBlocked"]
+ pub fn WinFw_ApplyPolicyBlocked(
+ settings: &WinFwSettings,
+ allowed_endpoint: *const WinFwAllowedEndpoint<'_>,
+ ) -> WinFwPolicyStatus;
+
+ #[link_name = "WinFw_Reset"]
+ pub fn WinFw_Reset() -> WinFwPolicyStatus;
+}
+
+pub type LogSink = extern "system" fn(level: log::Level, msg: *const c_char, context: *mut c_void);
+
+/// Logging callback implementation.
+///
+/// SAFETY:
+/// - `msg` must point to a valid C string or be null.
+/// - `context` must point to a valid C string or be null.
+pub extern "system" fn log_sink(
+ level: log::Level,
+ msg: *const std::ffi::c_char,
+ context: *mut std::ffi::c_void,
+) {
+ if msg.is_null() {
+ log::error!("Log message from FFI boundary is NULL");
+ return;
+ }
+
+ let target = if context.is_null() {
+ "UNKNOWN".into()
+ } else {
+ // SAFETY: context is not null & caller promise that context is a valid C string.
+ unsafe { CStr::from_ptr(context as *const _).to_string_lossy() }
+ };
+
+ // SAFETY: msg is not null & caller promise that msg is a valid C string.
+ let mb_string = unsafe { CStr::from_ptr(msg) };
+
+ let managed_msg = match multibyte_to_wide(mb_string, CP_ACP) {
+ Ok(wide_str) => String::from_utf16_lossy(&wide_str),
+ // Best effort:
+ Err(_) => mb_string.to_string_lossy().into_owned(),
+ };
+
+ log::logger().log(
+ &log::Record::builder()
+ .level(level)
+ .target(&target)
+ .args(format_args!("{}", managed_msg))
+ .build(),
+ );
+}
+
+/// Convert `mb_string`, with the given character encoding `codepage`, to a UTF-16 string.
+pub fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result<Vec<u16>, io::Error> {
+ if mb_string.is_empty() {
+ return Ok(vec![]);
+ }
+
+ // SAFETY: `mb_string` is null-terminated and valid.
+ let wc_size = unsafe {
+ MultiByteToWideChar(
+ codepage,
+ 0,
+ mb_string.as_ptr() as *const u8,
+ -1,
+ ptr::null_mut(),
+ 0,
+ )
+ };
+
+ if wc_size == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ let mut wc_buffer = vec![0u16; usize::try_from(wc_size).unwrap()];
+
+ // SAFETY: `wc_buffer` can contain up to `wc_size` characters, including a null
+ // terminator.
+ let chars_written = unsafe {
+ MultiByteToWideChar(
+ codepage,
+ 0,
+ mb_string.as_ptr() as *const u8,
+ -1,
+ wc_buffer.as_mut_ptr(),
+ wc_size,
+ )
+ };
+
+ if chars_written == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ wc_buffer.truncate(usize::try_from(chars_written - 1).unwrap());
+
+ Ok(wc_buffer)
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use windows_sys::Win32::Globalization::CP_UTF8;
+
+ #[test]
+ fn test_multibyte_to_wide() {
+ // € = 0x20AC in UTF-16
+ let converted = multibyte_to_wide(c"€€", CP_UTF8);
+ const EXPECTED: &[u16] = &[0x20AC, 0x20AC];
+ assert!(
+ matches!(converted.as_deref(), Ok(EXPECTED)),
+ "expected Ok({EXPECTED:?}), got {converted:?}",
+ );
+
+ // boundary case
+ let converted = multibyte_to_wide(c"", CP_UTF8);
+ assert!(
+ matches!(converted.as_deref(), Ok([])),
+ "unexpected result {converted:?}"
+ );
+ }
+}
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 8f96ff7b90..427a6e5cb3 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -30,7 +30,7 @@ impl DisconnectedState {
);
}
#[cfg(target_os = "macos")]
- if shared_values.block_when_disconnected {
+ if shared_values.block_when_disconnected.bool() {
if let Err(err) = Self::setup_local_dns_config(shared_values) {
log::error!(
"{}",
@@ -64,7 +64,7 @@ impl DisconnectedState {
// Being disconnected and having lockdown mode enabled implies that your internet
// access is locked down
#[cfg(not(target_os = "android"))]
- locked_down: shared_values.block_when_disconnected,
+ locked_down: shared_values.block_when_disconnected.bool(),
},
)
}
@@ -74,7 +74,15 @@ impl DisconnectedState {
shared_values: &mut SharedTunnelStateValues,
should_reset_firewall: bool,
) {
- let result = if shared_values.block_when_disconnected {
+ let result = if shared_values.block_when_disconnected.bool() {
+ #[cfg(target_os = "windows")]
+ {
+ // Respect the persist flag of BlockWhenDisconnected.
+ shared_values
+ .firewall
+ .persist(shared_values.block_when_disconnected.should_persist());
+ }
+
let policy = FirewallPolicy::Blocked {
allow_lan: shared_values.allow_lan,
allowed_endpoint: Some(shared_values.allowed_endpoint.clone()),
@@ -110,7 +118,7 @@ impl DisconnectedState {
shared_values: &mut SharedTunnelStateValues,
should_reset_firewall: bool,
) {
- if should_reset_firewall && !shared_values.block_when_disconnected {
+ if should_reset_firewall && !shared_values.block_when_disconnected.bool() {
if let Err(error) = shared_values.split_tunnel.clear_tunnel_addresses() {
log::error!(
"{}",
@@ -198,7 +206,7 @@ impl TunnelState for DisconnectedState {
#[cfg(windows)]
Self::register_split_tunnel_addresses(shared_values, true);
#[cfg(target_os = "macos")]
- if block_when_disconnected {
+ if block_when_disconnected.bool() {
if let Err(err) = Self::setup_local_dns_config(shared_values) {
log::error!(
"{}",
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 059edd417d..2deb7c8f7e 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -94,7 +94,7 @@ pub struct InitialTunnelState {
pub allow_lan: bool,
/// Block traffic unless connected to the VPN.
#[cfg(not(target_os = "android"))]
- pub block_when_disconnected: bool,
+ pub block_when_disconnected: BlockWhenDisconnected,
/// DNS configuration to use
pub dns_config: DnsConfig,
/// A single endpoint that is allowed to communicate outside the tunnel, i.e.
@@ -200,7 +200,7 @@ pub enum TunnelCommand {
Dns(crate::dns::DnsConfig, oneshot::Sender<()>),
/// Enable or disable the block_when_disconnected feature.
#[cfg(not(target_os = "android"))]
- BlockWhenDisconnected(bool, oneshot::Sender<()>),
+ BlockWhenDisconnected(BlockWhenDisconnected, oneshot::Sender<()>),
/// Notify the state machine of the connectivity of the device.
Connectivity(Connectivity),
/// Open tunnel connection.
@@ -234,6 +234,82 @@ enum EventResult {
Close(Result<Option<ErrorStateCause>, oneshot::Canceled>),
}
+/// If firewall should apply blocking rules in the disconnected state.
+/// Argument of TunnelCommand::BlockWhenDisconnected message.
+///
+/// Semantically equivalent to a boolean value, but is grouped togetether with the persist
+/// parameter on Windows for cohesiveness.
+#[derive(Clone, Copy, Debug)]
+pub enum BlockWhenDisconnected {
+ /// Firewall should *not* apply blocking rules.
+ Disabled,
+ /// Firewall should apply blocking rules.
+ Enabled {
+ /// If blocked state should be persisted across a reboot (restart of BFE)
+ persist: bool,
+ },
+}
+
+impl BlockWhenDisconnected {
+ /// `true`. Apply blocking firewall rules in the disconnected state.
+ pub const fn yes() -> Self {
+ BlockWhenDisconnected::Enabled { persist: true }
+ }
+
+ /// `false`. Do *not* apply blocking firewall rules in the disconnected state.
+ pub const fn no() -> Self {
+ BlockWhenDisconnected::Disabled
+ }
+
+ /// [self] as a boolean value.
+ pub const fn bool(&self) -> bool {
+ matches!(self, BlockWhenDisconnected::Enabled { .. })
+ }
+
+ /// If [BlockWhenDisconnected] should persist across reboots.
+ ///
+ /// Semantically meaningless on non-Windows platforms, will always return true.
+ pub const fn should_persist(&self) -> bool {
+ if cfg!(target_os = "windows") {
+ matches!(&self, BlockWhenDisconnected::Enabled { persist: true })
+ } else {
+ true
+ }
+ }
+
+ /// Semantically meaningless on non-Windows platforms
+ #[cfg(not(target_os = "windows"))]
+ pub fn persist(self, _persist: bool) -> Self {
+ self
+ }
+
+ /// Semantically meaningless on non-Windows platforms
+ #[cfg(target_os = "windows")]
+ pub fn persist(self, persist: bool) -> Self {
+ match self {
+ BlockWhenDisconnected::Disabled => BlockWhenDisconnected::Disabled,
+ // Forget previous value of persist
+ BlockWhenDisconnected::Enabled { .. } => BlockWhenDisconnected::Enabled { persist },
+ }
+ }
+}
+
+impl From<bool> for BlockWhenDisconnected {
+ fn from(block: bool) -> Self {
+ if block {
+ BlockWhenDisconnected::yes()
+ } else {
+ BlockWhenDisconnected::no()
+ }
+ }
+}
+
+impl PartialEq for BlockWhenDisconnected {
+ fn eq(&self, other: &Self) -> bool {
+ self.bool() == other.bool()
+ }
+}
+
/// Asynchronous handling of the tunnel state machine.
///
/// This type implements `Stream`, and attempts to advance the state machine based on the events
@@ -295,7 +371,8 @@ impl TunnelStateMachine {
let fw_args = FirewallArguments {
#[cfg(not(target_os = "android"))]
- initial_state: if args.settings.block_when_disconnected || !args.settings.reset_firewall
+ initial_state: if args.settings.block_when_disconnected.bool()
+ || !args.settings.reset_firewall
{
InitialFirewallState::Blocked(args.settings.allowed_endpoint.clone())
} else {
@@ -474,7 +551,7 @@ struct SharedTunnelStateValues {
allow_lan: bool,
/// Should network access be allowed when in the disconnected state.
#[cfg(not(target_os = "android"))]
- block_when_disconnected: bool,
+ block_when_disconnected: BlockWhenDisconnected,
/// True when the computer is known to be offline.
connectivity: Connectivity,
/// DNS configuration to use.
diff --git a/talpid-openvpn/Cargo.toml b/talpid-openvpn/Cargo.toml
index 34df09becf..031d6f996f 100644
--- a/talpid-openvpn/Cargo.toml
+++ b/talpid-openvpn/Cargo.toml
@@ -34,6 +34,8 @@ widestring = "1.0"
winreg = { version = "0.51", features = ["transactions"] }
talpid-windows = { path = "../talpid-windows" }
once_cell = { workspace = true }
+# Only needed because parity-tokio-ipc has forgotten to enable the winerror feature of winapi ..
+winapi = { version = "0.3", features = ["winerror"] }
[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
diff --git a/windows/winfw/src/winfw/objectpurger.cpp b/windows/winfw/src/winfw/objectpurger.cpp
index dce36c99c8..52adaac187 100644
--- a/windows/winfw/src/winfw/objectpurger.cpp
+++ b/windows/winfw/src/winfw/objectpurger.cpp
@@ -71,6 +71,22 @@ ObjectPurger::RemovalFunctor ObjectPurger::GetRemoveNonPersistentFunctor()
}
//static
+ObjectPurger::RemovalFunctor ObjectPurger::GetRemovePersistentFunctor()
+{
+ return [](wfp::FilterEngine &engine)
+ {
+ const auto registry = MullvadGuids::DetailedRegistry(MullvadGuids::IdentityQualifier::IncludePersistent);
+
+ // Resolve correct overload.
+ void(*deleter)(wfp::FilterEngine &, const GUID &) = wfp::ObjectDeleter::DeleteFilter;
+
+ RemoveRange(engine, deleter, registry.equal_range(WfpObjectType::Filter));
+ RemoveRange(engine, wfp::ObjectDeleter::DeleteSublayer, registry.equal_range(WfpObjectType::Sublayer));
+ RemoveRange(engine, wfp::ObjectDeleter::DeleteProvider, registry.equal_range(WfpObjectType::Provider));
+ };
+}
+
+//static
bool ObjectPurger::Execute(RemovalFunctor f)
{
auto engine = wfp::FilterEngine::StandardSession();
diff --git a/windows/winfw/src/winfw/objectpurger.h b/windows/winfw/src/winfw/objectpurger.h
index 7728aac694..9d3ca0146e 100644
--- a/windows/winfw/src/winfw/objectpurger.h
+++ b/windows/winfw/src/winfw/objectpurger.h
@@ -16,6 +16,7 @@ public:
static RemovalFunctor GetRemoveFiltersFunctor();
static RemovalFunctor GetRemoveAllFunctor();
static RemovalFunctor GetRemoveNonPersistentFunctor();
+ static RemovalFunctor GetRemovePersistentFunctor();
static bool Execute(RemovalFunctor f);
};
diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp
index c862de4b5a..cd90befead 100644
--- a/windows/winfw/src/winfw/winfw.cpp
+++ b/windows/winfw/src/winfw/winfw.cpp
@@ -4,6 +4,7 @@
#include "objectpurger.h"
#include "mullvadobjects.h"
#include "rules/persistent/blockall.h"
+#include "rules/baseline/blockall.h"
#include "libwfp/ipnetwork.h"
#include <windows.h>
#include <libcommon/error.h>
@@ -167,11 +168,14 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy)
delete g_fwContext;
g_fwContext = nullptr;
+ std::stringstream ss;
+ ss << "Deinitializing WinFw";
+ g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext);
+
//
- // Continue blocking if this is what the caller requested
+ // Continue blocking with persistent rules if this is what the caller requested
// and if the current policy is "(net) blocked".
//
-
if (WINFW_CLEANUP_POLICY_CONTINUE_BLOCKING == cleanupPolicy
&& FwContext::Policy::Blocked == activePolicy)
{
@@ -182,6 +186,10 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy)
rules::persistent::BlockAll blockAll;
+ std::stringstream ss;
+ ss << "Adding persistent block rules";
+ g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext);
+
return sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine)
{
ObjectPurger::GetRemoveNonPersistentFunctor()(engine);
@@ -205,6 +213,50 @@ WinFw_Deinitialize(WINFW_CLEANUP_POLICY cleanupPolicy)
}
}
+ //
+ // Continue blocking with non-persistent rules if this is what the caller requested
+ // and if the current policy is "(net) blocked".
+ //
+ if (WINFW_CLEANUP_POLICY_BLOCK_UNTIL_REBOOT == cleanupPolicy
+ && FwContext::Policy::Blocked == activePolicy)
+ {
+ try
+ {
+ auto engine = wfp::FilterEngine::StandardSession(DEINITIALIZE_TIMEOUT);
+ auto sessionController = std::make_unique<SessionController>(std::move(engine));
+
+ rules::baseline::BlockAll blockAll;
+
+ std::stringstream ss;
+ ss << "Adding ephemeral block rules";
+ g_logSink(MULLVAD_LOG_LEVEL_WARNING, ss.str().c_str(), g_logSinkContext);
+
+ return sessionController->executeTransaction([&](SessionController &controller, wfp::FilterEngine &engine)
+ {
+ // Keep non-persistent filters intact, the intent is just to *not*
+ // persist any filters across a BFE restart, not muck around with
+ // any other filters. We will apply blocking filters anyway.
+ ObjectPurger::GetRemovePersistentFunctor()(engine);
+
+ return controller.addProvider(*MullvadObjects::Provider())
+ && controller.addSublayer(*MullvadObjects::SublayerBaseline())
+ && blockAll.apply(controller);
+ });
+ }
+ catch (std::exception & err)
+ {
+ if (nullptr != g_logSink)
+ {
+ g_logSink(MULLVAD_LOG_LEVEL_ERROR, err.what(), g_logSinkContext);
+ }
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+ }
+
return WINFW_POLICY_STATUS_SUCCESS == WinFw_Reset();
}
diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h
index ab2a136ceb..f40d835a95 100644
--- a/windows/winfw/src/winfw/winfw.h
+++ b/windows/winfw/src/winfw/winfw.h
@@ -127,6 +127,10 @@ enum WINFW_CLEANUP_POLICY : uint32_t
// Remove all objects that have been registered with WFP.
WINFW_CLEANUP_POLICY_RESET_FIREWALL = 1,
+
+ // Continue blocking if this is the active policy.
+ // Adds ephemeral blocking filters that are active until WinFw is shut down (??)
+ WINFW_CLEANUP_POLICY_BLOCK_UNTIL_REBOOT = 2,
};
//