summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-02-01 11:57:52 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-02-08 17:43:12 +0100
commite2f7cf1ba90fa59ea04d26dae28ae576db92bb07 (patch)
treec20956d1bb25f8703186171d97059fd23d6e1cec
parent022874449ba862aba6788bb430099923ba8a1a6c (diff)
downloadmullvadvpn-e2f7cf1ba90fa59ea04d26dae28ae576db92bb07.tar.xz
mullvadvpn-e2f7cf1ba90fa59ea04d26dae28ae576db92bb07.zip
Reapply excluded paths when the frontend receives messages for device
arrivals or removals
-rw-r--r--gui/src/main/daemon-rpc.ts4
-rw-r--r--gui/src/main/index.ts14
-rw-r--r--mullvad-daemon/src/lib.rs20
-rw-r--r--mullvad-daemon/src/management_interface.rs16
-rw-r--r--mullvad-management-interface/proto/management_interface.proto3
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs12
-rw-r--r--talpid-core/src/split_tunnel/windows/volume_monitor.rs204
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs9
8 files changed, 228 insertions, 54 deletions
diff --git a/gui/src/main/daemon-rpc.ts b/gui/src/main/daemon-rpc.ts
index ac677f4ded..c00508dc57 100644
--- a/gui/src/main/daemon-rpc.ts
+++ b/gui/src/main/daemon-rpc.ts
@@ -501,6 +501,10 @@ export class DaemonRpc {
await this.callBool(this.client.setSplitTunnelState, enabled);
}
+ public async checkVolumes(): Promise<void> {
+ await this.callEmpty(this.client.checkVolumes);
+ }
+
private subscriptionId(): number {
const current = this.nextSubscriptionId;
this.nextSubscriptionId += 1;
diff --git a/gui/src/main/index.ts b/gui/src/main/index.ts
index a9daa98367..5b7504e711 100644
--- a/gui/src/main/index.ts
+++ b/gui/src/main/index.ts
@@ -1857,6 +1857,20 @@ class ApplicationMain {
// https://github.com/electron/electron/blob/main/docs/faq.md#the-font-looks-blurry-what-is-this-and-what-can-i-do
backgroundColor: '#fff',
});
+ const WM_DEVICECHANGE = 0x0219;
+ const DBT_DEVICEARRIVAL = 0x8000;
+ const DBT_DEVICEREMOVECOMPLETE = 0x8004;
+ appWindow.hookWindowMessage(WM_DEVICECHANGE, (wParam) => {
+ const wParamL = wParam.readBigInt64LE(0);
+ if (wParamL != DBT_DEVICEARRIVAL && wParamL != DBT_DEVICEREMOVECOMPLETE) {
+ return;
+ }
+ this.daemonRpc
+ .checkVolumes()
+ .catch((error) =>
+ log.error(`Unable to notify daemon of device event: ${error.message}`),
+ );
+ });
appWindow.removeMenu();
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index b54eae5908..2ef7bccc39 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -293,6 +293,9 @@ pub enum DaemonCommand {
/// Toggle wireguard-nt on or off
#[cfg(target_os = "windows")]
UseWireGuardNt(ResponseTx<(), Error>, bool),
+ /// Notify the split tunnel monitor that a volume was mounted or dismounted
+ #[cfg(target_os = "windows")]
+ CheckVolumes(ResponseTx<(), Error>),
/// Makes the daemon exit the main loop and quit.
Shutdown,
/// Saves the target tunnel state and enters a blocking state. The state is restored
@@ -546,6 +549,8 @@ pub struct Daemon<L: EventListener> {
shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
/// oneshot channel that completes once the tunnel state machine has been shut down
tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>,
+ #[cfg(target_os = "windows")]
+ volume_update_tx: mpsc::UnboundedSender<()>,
}
impl<L> Daemon<L>
@@ -627,6 +632,8 @@ where
Self::get_allowed_endpoint(rpc_runtime.address_cache.peek_address());
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
+ #[cfg(target_os = "windows")]
+ let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
let tunnel_command_tx = tunnel_state_machine::spawn(
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
@@ -643,6 +650,8 @@ where
internal_event_tx.to_specialized_sender(),
offline_state_tx,
tunnel_state_machine_shutdown_tx,
+ #[cfg(target_os = "windows")]
+ volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
@@ -742,6 +751,8 @@ where
app_version_info,
shutdown_tasks: vec![],
tunnel_state_machine_shutdown_signal,
+ #[cfg(target_os = "windows")]
+ volume_update_tx,
};
daemon.ensure_wireguard_keys_for_current_account().await;
@@ -1241,6 +1252,8 @@ where
SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await,
#[cfg(target_os = "windows")]
UseWireGuardNt(tx, state) => self.on_use_wireguard_nt(tx, state).await,
+ #[cfg(target_os = "windows")]
+ CheckVolumes(tx) => self.on_check_volumes(tx).await,
Shutdown => self.trigger_shutdown_event(),
PrepareRestart => self.on_prepare_restart(),
#[cfg(target_os = "android")]
@@ -1980,6 +1993,13 @@ where
}
}
+ #[cfg(windows)]
+ async fn on_check_volumes(&mut self, tx: ResponseTx<(), Error>) {
+ if self.volume_update_tx.unbounded_send(()).is_ok() {
+ let _ = tx.send(Ok(()));
+ }
+ }
+
async fn on_update_relay_settings(
&mut self,
tx: ResponseTx<(), settings::Error>,
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 2136312541..ba828ed903 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -708,6 +708,22 @@ impl ManagementService for ManagementServiceImpl {
async fn set_use_wireguard_nt(&self, _: Request<bool>) -> ServiceResult<()> {
Ok(Response::new(()))
}
+
+ #[cfg(windows)]
+ async fn check_volumes(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("check_volumes");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::CheckVolumes(tx))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+
+ #[cfg(not(windows))]
+ async fn check_volumes(&self, _: Request<()>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
}
impl ManagementServiceImpl {
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index c4d5575ae3..e690557aae 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -71,6 +71,9 @@ service ManagementService {
rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
rpc SetUseWireguardNt(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
+
+ // Notify the split tunnel monitor that a volume was mounted or dismounted (Windows).
+ rpc CheckVolumes(google.protobuf.Empty) returns (google.protobuf.Empty) {}
}
message RelaySettingsUpdate {
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index 34c9a67864..efdd75ecf9 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -157,8 +157,9 @@ impl SplitTunnel {
pub fn new(
runtime: tokio::runtime::Handle,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
) -> Result<Self, Error> {
- let (request_tx, handle) = Self::spawn_request_thread()?;
+ let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?;
let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() };
event_overlapped.hEvent =
@@ -326,7 +327,9 @@ impl SplitTunnel {
})
}
- fn spawn_request_thread() -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
+ fn spawn_request_thread(
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ ) -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
let (tx, rx): (RequestTx, _) = sync_mpsc::channel();
let (init_tx, init_rx) = sync_mpsc::channel();
@@ -337,10 +340,11 @@ impl SplitTunnel {
let path_monitor = path_monitor::PathMonitor::spawn(monitor_tx.clone())
.map_err(Error::StartPathMonitor)?;
- let mut volume_monitor = volume_monitor::VolumeMonitor::spawn(
+ let volume_monitor = volume_monitor::VolumeMonitor::spawn(
path_monitor.clone(),
monitor_tx,
monitored_paths.clone(),
+ volume_update_rx,
);
std::thread::spawn(move || {
@@ -403,7 +407,7 @@ impl SplitTunnel {
}
}
- volume_monitor.close();
+ drop(volume_monitor);
if let Err(error) = path_monitor.shutdown() {
log::error!(
"{}",
diff --git a/talpid-core/src/split_tunnel/windows/volume_monitor.rs b/talpid-core/src/split_tunnel/windows/volume_monitor.rs
index e9060e5528..1993cc809c 100644
--- a/talpid-core/src/split_tunnel/windows/volume_monitor.rs
+++ b/talpid-core/src/split_tunnel/windows/volume_monitor.rs
@@ -2,11 +2,14 @@
//! tunnel config if any of the excluded paths are affected by them.
use super::path_monitor::PathMonitorHandle;
use crate::windows::window::{create_hidden_window, WindowCloseHandle};
+use futures::{channel::mpsc, StreamExt};
use std::{
ffi::OsString,
+ io,
path::{self, Path},
- sync::{mpsc as sync_mpsc, Arc, Mutex},
+ sync::{mpsc as sync_mpsc, Arc, Mutex, MutexGuard},
};
+use talpid_types::ErrorExt;
use winapi::{
shared::minwindef::TRUE,
um::{
@@ -14,83 +17,188 @@ use winapi::{
DBTF_NET, DBT_DEVICEARRIVAL, DBT_DEVICEREMOVECOMPLETE, DBT_DEVTYP_VOLUME,
DEV_BROADCAST_HDR, DEV_BROADCAST_VOLUME, WM_DEVICECHANGE,
},
+ fileapi::GetLogicalDrives,
winuser::DefWindowProcW,
},
};
pub(super) struct VolumeMonitor(());
+pub(super) struct VolumeMonitorHandle {
+ window_handle: WindowCloseHandle,
+ internal_monitor_task: tokio::task::JoinHandle<()>,
+}
+
+impl Drop for VolumeMonitorHandle {
+ fn drop(&mut self) {
+ self.window_handle.close();
+ self.internal_monitor_task.abort();
+ }
+}
+
impl VolumeMonitor {
pub fn spawn(
path_monitor: PathMonitorHandle,
update_tx: sync_mpsc::Sender<()>,
paths: Arc<Mutex<Vec<OsString>>>,
- ) -> WindowCloseHandle {
- create_hidden_window(move |window, message, w_param, l_param| {
- if message != WM_DEVICECHANGE
- || (w_param != DBT_DEVICEARRIVAL && w_param != DBT_DEVICEREMOVECOMPLETE)
- {
- return unsafe { DefWindowProcW(window, message, w_param, l_param) };
- }
+ volume_update_rx: mpsc::UnboundedReceiver<()>,
+ ) -> VolumeMonitorHandle {
+ // A bitmask containing all (known) mounted drives.
+ let known_state = Arc::new(Mutex::new(0u32));
- let paths_guard = paths.lock().unwrap();
- let mut label_found = false;
+ // Lock before registering event handler
+ let mut known_state_guard = known_state.lock().unwrap();
- let volumes = unsafe { parse_broadcast(&*(l_param as *const _)) };
- for volume in volumes {
- for path in &*paths_guard {
- let path = (path as &dyn AsRef<Path>).as_ref();
- if let Some(path::Component::Prefix(prefix)) = path.components().next() {
- match prefix.kind() {
- path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => {
- if disk == volume {
- label_found = true;
- break;
- }
- }
- _ => (),
- }
- }
- }
- if label_found {
- break;
- }
+ let internal_monitor_task = tokio::spawn(frontend_monitor(
+ known_state.clone(),
+ path_monitor.clone(),
+ update_tx.clone(),
+ paths.clone(),
+ volume_update_rx,
+ ));
+
+ let window_handle =
+ start_internal_monitor(known_state.clone(), path_monitor, update_tx, paths);
+
+ *known_state_guard = get_logical_drives().unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to initialize state of mounted volumes")
+ );
+ 0
+ });
+
+ VolumeMonitorHandle {
+ window_handle,
+ internal_monitor_task,
+ }
+ }
+}
+
+/// Monitors update requests from frontends. This checks if the known state of mounted volumes
+/// has change, and, if so, reapplies the ST config.
+async fn frontend_monitor(
+ known_state: Arc<Mutex<u32>>,
+ path_monitor: PathMonitorHandle,
+ update_tx: sync_mpsc::Sender<()>,
+ paths: Arc<Mutex<Vec<OsString>>>,
+ mut volume_update_rx: mpsc::UnboundedReceiver<()>,
+) {
+ while let Some(()) = volume_update_rx.next().await {
+ let mut known_state_guard = known_state.lock().unwrap();
+ let new_state = get_logical_drives().unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain new state of mounted volumes")
+ );
+ *known_state_guard
+ });
+
+ // Was there a change?
+ let state_diff = *known_state_guard ^ new_state;
+ if state_diff != 0 {
+ *known_state_guard = new_state;
+ let paths_guard = paths.lock().unwrap();
+ if matches_volume(state_diff, &paths_guard) {
+ // Reapply config
+ let _ = update_tx.send(());
+ let _ = path_monitor.refresh();
}
+ }
+ }
+}
+
+/// Monitors window events received by session 0.
+fn start_internal_monitor(
+ known_state: Arc<Mutex<u32>>,
+ path_monitor: PathMonitorHandle,
+ update_tx: sync_mpsc::Sender<()>,
+ paths: Arc<Mutex<Vec<OsString>>>,
+) -> WindowCloseHandle {
+ create_hidden_window(move |window, message, w_param, l_param| {
+ if !is_device_arrival_or_removal(message, w_param) {
+ return unsafe { DefWindowProcW(window, message, w_param, l_param) };
+ }
+ let paths_guard = paths.lock().unwrap();
+ let mut known_state_guard = known_state.lock().unwrap();
+
+ let volumes = unsafe { parse_device_volume_broadcast(&*(l_param as *const _)) };
- if label_found {
+ let prev_state = *known_state_guard;
+ let is_arrival = w_param == DBT_DEVICEARRIVAL;
+ if is_arrival {
+ *known_state_guard |= volumes;
+ } else {
+ *known_state_guard &= !volumes;
+ }
+
+ // Compare against known state to ignore duplicate notifications
+ // from frontends
+ let state_diff = *known_state_guard ^ prev_state;
+ if state_diff != 0 {
+ if matches_volume(volumes, &paths_guard) {
// Reapply config
let _ = update_tx.send(());
let _ = path_monitor.refresh();
}
+ }
+
+ // Always grant the request
+ TRUE as isize
+ })
+}
+
+/// Return a bitmask representing all currently available disk drives.
+/// Each bit refers to a volume letter. The bit 0 refers to 'A', bit 1
+/// refers to 'B', bit 2 to 'C', etc.
+fn get_logical_drives() -> io::Result<u32> {
+ let result = unsafe { GetLogicalDrives() };
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(result)
+}
- // Always grant the request
- TRUE as isize
- })
+/// Return whether any of the paths in `paths_guard` reside on any volume in `volumes` (a mask).
+fn matches_volume(volumes: u32, paths_guard: &MutexGuard<'_, Vec<OsString>>) -> bool {
+ for path in &**paths_guard {
+ let path = (path as &dyn AsRef<Path>).as_ref();
+ if let Some(path::Component::Prefix(prefix)) = path.components().next() {
+ match prefix.kind() {
+ path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => {
+ if disk < 'A' as u8 || disk > 'Z' as u8 {
+ log::warn!("Ignoring invalid volume \"{}\"", disk as char);
+ continue;
+ }
+ let disk = disk - 'A' as u8;
+ if volumes & (1 << disk) != 0 {
+ return true;
+ }
+ }
+ _ => (),
+ }
+ }
}
+ false
}
-/// Return volume labels (ASCII-encoded) affected by the device arrival or removal message, if any.
-unsafe fn parse_broadcast(broadcast: &DEV_BROADCAST_HDR) -> Vec<u8> {
- let mut labels = vec![];
+fn is_device_arrival_or_removal(message: u32, w_param: usize) -> bool {
+ message == WM_DEVICECHANGE
+ && (w_param == DBT_DEVICEARRIVAL || w_param == DBT_DEVICEREMOVECOMPLETE)
+}
+/// Return volumes affected by the device arrival or removal message as a mask.
+/// This has the same format as `get_logical_drives()`.
+unsafe fn parse_device_volume_broadcast(broadcast: &DEV_BROADCAST_HDR) -> u32 {
if broadcast.dbch_devicetype != DBT_DEVTYP_VOLUME {
- return labels;
+ return 0;
}
let volume_broadcast = &*(broadcast as *const _ as *const DEV_BROADCAST_VOLUME);
if volume_broadcast.dbcv_flags & DBTF_NET != 0 {
// Ignore net event
- return labels;
- }
-
- // 26 = 1 + 'Z' - 'A'
- let num_drives = 1 + 'Z' as u8 - 'A' as u8;
- for i in 0..num_drives {
- let is_affected = ((volume_broadcast.dbcv_unitmask >> i) & 1) != 0;
- if is_affected {
- labels.push('A' as u8 + i);
- }
+ return 0;
}
- labels
+ volume_broadcast.dbcv_unitmask
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 298f8fc7a6..e22975a827 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -103,6 +103,7 @@ pub async fn spawn(
state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static,
offline_state_listener: mpsc::UnboundedSender<bool>,
shutdown_tx: oneshot::Sender<()>,
+ #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
@@ -130,6 +131,8 @@ pub async fn spawn(
log_dir,
resource_dir,
command_rx,
+ #[cfg(target_os = "windows")]
+ volume_update_rx,
#[cfg(target_os = "macos")]
exclusion_gid,
#[cfg(target_os = "android")]
@@ -207,6 +210,7 @@ impl TunnelStateMachine {
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
+ #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Self, Error> {
@@ -216,8 +220,9 @@ impl TunnelStateMachine {
let filtering_resolver = crate::resolver::start_resolver().await?;
#[cfg(windows)]
- let split_tunnel = split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone())
- .map_err(Error::InitSplitTunneling)?;
+ let split_tunnel =
+ split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx)
+ .map_err(Error::InitSplitTunneling)?;
let args = FirewallArguments {
initial_state: if settings.block_when_disconnected || !settings.reset_firewall {