summaryrefslogtreecommitdiffhomepage
path: root/talpid-core
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-02-01 11:57:52 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-02-08 17:43:12 +0100
commite2f7cf1ba90fa59ea04d26dae28ae576db92bb07 (patch)
treec20956d1bb25f8703186171d97059fd23d6e1cec /talpid-core
parent022874449ba862aba6788bb430099923ba8a1a6c (diff)
downloadmullvadvpn-e2f7cf1ba90fa59ea04d26dae28ae576db92bb07.tar.xz
mullvadvpn-e2f7cf1ba90fa59ea04d26dae28ae576db92bb07.zip
Reapply excluded paths when the frontend receives messages for device
arrivals or removals
Diffstat (limited to 'talpid-core')
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs12
-rw-r--r--talpid-core/src/split_tunnel/windows/volume_monitor.rs204
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs9
3 files changed, 171 insertions, 54 deletions
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index 34c9a67864..efdd75ecf9 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -157,8 +157,9 @@ impl SplitTunnel {
pub fn new(
runtime: tokio::runtime::Handle,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
) -> Result<Self, Error> {
- let (request_tx, handle) = Self::spawn_request_thread()?;
+ let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?;
let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() };
event_overlapped.hEvent =
@@ -326,7 +327,9 @@ impl SplitTunnel {
})
}
- fn spawn_request_thread() -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
+ fn spawn_request_thread(
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ ) -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
let (tx, rx): (RequestTx, _) = sync_mpsc::channel();
let (init_tx, init_rx) = sync_mpsc::channel();
@@ -337,10 +340,11 @@ impl SplitTunnel {
let path_monitor = path_monitor::PathMonitor::spawn(monitor_tx.clone())
.map_err(Error::StartPathMonitor)?;
- let mut volume_monitor = volume_monitor::VolumeMonitor::spawn(
+ let volume_monitor = volume_monitor::VolumeMonitor::spawn(
path_monitor.clone(),
monitor_tx,
monitored_paths.clone(),
+ volume_update_rx,
);
std::thread::spawn(move || {
@@ -403,7 +407,7 @@ impl SplitTunnel {
}
}
- volume_monitor.close();
+ drop(volume_monitor);
if let Err(error) = path_monitor.shutdown() {
log::error!(
"{}",
diff --git a/talpid-core/src/split_tunnel/windows/volume_monitor.rs b/talpid-core/src/split_tunnel/windows/volume_monitor.rs
index e9060e5528..1993cc809c 100644
--- a/talpid-core/src/split_tunnel/windows/volume_monitor.rs
+++ b/talpid-core/src/split_tunnel/windows/volume_monitor.rs
@@ -2,11 +2,14 @@
//! tunnel config if any of the excluded paths are affected by them.
use super::path_monitor::PathMonitorHandle;
use crate::windows::window::{create_hidden_window, WindowCloseHandle};
+use futures::{channel::mpsc, StreamExt};
use std::{
ffi::OsString,
+ io,
path::{self, Path},
- sync::{mpsc as sync_mpsc, Arc, Mutex},
+ sync::{mpsc as sync_mpsc, Arc, Mutex, MutexGuard},
};
+use talpid_types::ErrorExt;
use winapi::{
shared::minwindef::TRUE,
um::{
@@ -14,83 +17,188 @@ use winapi::{
DBTF_NET, DBT_DEVICEARRIVAL, DBT_DEVICEREMOVECOMPLETE, DBT_DEVTYP_VOLUME,
DEV_BROADCAST_HDR, DEV_BROADCAST_VOLUME, WM_DEVICECHANGE,
},
+ fileapi::GetLogicalDrives,
winuser::DefWindowProcW,
},
};
pub(super) struct VolumeMonitor(());
+pub(super) struct VolumeMonitorHandle {
+ window_handle: WindowCloseHandle,
+ internal_monitor_task: tokio::task::JoinHandle<()>,
+}
+
+impl Drop for VolumeMonitorHandle {
+ fn drop(&mut self) {
+ self.window_handle.close();
+ self.internal_monitor_task.abort();
+ }
+}
+
impl VolumeMonitor {
pub fn spawn(
path_monitor: PathMonitorHandle,
update_tx: sync_mpsc::Sender<()>,
paths: Arc<Mutex<Vec<OsString>>>,
- ) -> WindowCloseHandle {
- create_hidden_window(move |window, message, w_param, l_param| {
- if message != WM_DEVICECHANGE
- || (w_param != DBT_DEVICEARRIVAL && w_param != DBT_DEVICEREMOVECOMPLETE)
- {
- return unsafe { DefWindowProcW(window, message, w_param, l_param) };
- }
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ ) -> VolumeMonitorHandle {
+ // A bitmask containing all (known) mounted drives.
+ let known_state = Arc::new(Mutex::new(0u32));
- let paths_guard = paths.lock().unwrap();
- let mut label_found = false;
+ // Lock before registering event handler
+ let mut known_state_guard = known_state.lock().unwrap();
- let volumes = unsafe { parse_broadcast(&*(l_param as *const _)) };
- for volume in volumes {
- for path in &*paths_guard {
- let path = (path as &dyn AsRef<Path>).as_ref();
- if let Some(path::Component::Prefix(prefix)) = path.components().next() {
- match prefix.kind() {
- path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => {
- if disk == volume {
- label_found = true;
- break;
- }
- }
- _ => (),
- }
- }
- }
- if label_found {
- break;
- }
+ let internal_monitor_task = tokio::spawn(frontend_monitor(
+ known_state.clone(),
+ path_monitor.clone(),
+ update_tx.clone(),
+ paths.clone(),
+ volume_update_rx,
+ ));
+
+ let window_handle =
+ start_internal_monitor(known_state.clone(), path_monitor, update_tx, paths);
+
+ *known_state_guard = get_logical_drives().unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to initialize state of mounted volumes")
+ );
+ 0
+ });
+
+ VolumeMonitorHandle {
+ window_handle,
+ internal_monitor_task,
+ }
+ }
+}
+
+/// Monitors update requests from frontends. This checks if the known state of mounted volumes
+/// has change, and, if so, reapplies the ST config.
+async fn frontend_monitor(
+ known_state: Arc<Mutex<u32>>,
+ path_monitor: PathMonitorHandle,
+ update_tx: sync_mpsc::Sender<()>,
+ paths: Arc<Mutex<Vec<OsString>>>,
+ mut volume_update_rx: mpsc::UnboundedReceiver<()>,
+) {
+ while let Some(()) = volume_update_rx.next().await {
+ let mut known_state_guard = known_state.lock().unwrap();
+ let new_state = get_logical_drives().unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain new state of mounted volumes")
+ );
+ *known_state_guard
+ });
+
+ // Was there a change?
+ let state_diff = *known_state_guard ^ new_state;
+ if state_diff != 0 {
+ *known_state_guard = new_state;
+ let paths_guard = paths.lock().unwrap();
+ if matches_volume(state_diff, &paths_guard) {
+ // Reapply config
+ let _ = update_tx.send(());
+ let _ = path_monitor.refresh();
}
+ }
+ }
+}
+
+/// Monitors window events received by session 0.
+fn start_internal_monitor(
+ known_state: Arc<Mutex<u32>>,
+ path_monitor: PathMonitorHandle,
+ update_tx: sync_mpsc::Sender<()>,
+ paths: Arc<Mutex<Vec<OsString>>>,
+) -> WindowCloseHandle {
+ create_hidden_window(move |window, message, w_param, l_param| {
+ if !is_device_arrival_or_removal(message, w_param) {
+ return unsafe { DefWindowProcW(window, message, w_param, l_param) };
+ }
+ let paths_guard = paths.lock().unwrap();
+ let mut known_state_guard = known_state.lock().unwrap();
+
+ let volumes = unsafe { parse_device_volume_broadcast(&*(l_param as *const _)) };
- if label_found {
+ let prev_state = *known_state_guard;
+ let is_arrival = w_param == DBT_DEVICEARRIVAL;
+ if is_arrival {
+ *known_state_guard |= volumes;
+ } else {
+ *known_state_guard &= !volumes;
+ }
+
+ // Compare against known state to ignore duplicate notifications
+ // from frontends
+ let state_diff = *known_state_guard ^ prev_state;
+ if state_diff != 0 {
+ if matches_volume(volumes, &paths_guard) {
// Reapply config
let _ = update_tx.send(());
let _ = path_monitor.refresh();
}
+ }
+
+ // Always grant the request
+ TRUE as isize
+ })
+}
+
+/// Return a bitmask representing all currently available disk drives.
+/// Each bit refers to a volume letter. The bit 0 refers to 'A', bit 1
+/// refers to 'B', bit 2 to 'C', etc.
+fn get_logical_drives() -> io::Result<u32> {
+ let result = unsafe { GetLogicalDrives() };
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(result)
+}
- // Always grant the request
- TRUE as isize
- })
+/// Return whether any of the paths in `paths_guard` reside on any volume in `volumes` (a mask).
+fn matches_volume(volumes: u32, paths_guard: &MutexGuard<'_, Vec<OsString>>) -> bool {
+ for path in &**paths_guard {
+ let path = (path as &dyn AsRef<Path>).as_ref();
+ if let Some(path::Component::Prefix(prefix)) = path.components().next() {
+ match prefix.kind() {
+ path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => {
+ if disk < 'A' as u8 || disk > 'Z' as u8 {
+ log::warn!("Ignoring invalid volume \"{}\"", disk as char);
+ continue;
+ }
+ let disk = disk - 'A' as u8;
+ if volumes & (1 << disk) != 0 {
+ return true;
+ }
+ }
+ _ => (),
+ }
+ }
}
+ false
}
-/// Return volume labels (ASCII-encoded) affected by the device arrival or removal message, if any.
-unsafe fn parse_broadcast(broadcast: &DEV_BROADCAST_HDR) -> Vec<u8> {
- let mut labels = vec![];
+fn is_device_arrival_or_removal(message: u32, w_param: usize) -> bool {
+ message == WM_DEVICECHANGE
+ && (w_param == DBT_DEVICEARRIVAL || w_param == DBT_DEVICEREMOVECOMPLETE)
+}
+/// Return volumes affected by the device arrival or removal message as a mask.
+/// This has the same format as `get_logical_drives()`.
+unsafe fn parse_device_volume_broadcast(broadcast: &DEV_BROADCAST_HDR) -> u32 {
if broadcast.dbch_devicetype != DBT_DEVTYP_VOLUME {
- return labels;
+ return 0;
}
let volume_broadcast = &*(broadcast as *const _ as *const DEV_BROADCAST_VOLUME);
if volume_broadcast.dbcv_flags & DBTF_NET != 0 {
// Ignore net event
- return labels;
- }
-
- // 26 = 1 + 'Z' - 'A'
- let num_drives = 1 + 'Z' as u8 - 'A' as u8;
- for i in 0..num_drives {
- let is_affected = ((volume_broadcast.dbcv_unitmask >> i) & 1) != 0;
- if is_affected {
- labels.push('A' as u8 + i);
- }
+ return 0;
}
- labels
+ volume_broadcast.dbcv_unitmask
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 298f8fc7a6..e22975a827 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -103,6 +103,7 @@ pub async fn spawn(
state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static,
offline_state_listener: mpsc::UnboundedSender<bool>,
shutdown_tx: oneshot::Sender<()>,
+ #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
@@ -130,6 +131,8 @@ pub async fn spawn(
log_dir,
resource_dir,
command_rx,
+ #[cfg(target_os = "windows")]
+ volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
@@ -207,6 +210,7 @@ impl TunnelStateMachine {
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
+ #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Self, Error> {
@@ -216,8 +220,9 @@ impl TunnelStateMachine {
let filtering_resolver = crate::resolver::start_resolver().await?;
#[cfg(windows)]
- let split_tunnel = split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone())
- .map_err(Error::InitSplitTunneling)?;
+ let split_tunnel =
+ split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
+ .map_err(Error::InitSplitTunneling)?;
let args = FirewallArguments {
initial_state: if settings.block_when_disconnected || !settings.reset_firewall {