diff options
Diffstat (limited to 'talpid-core')
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/volume_monitor.rs | 204 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 9 |
3 files changed, 171 insertions, 54 deletions
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 34c9a67864..efdd75ecf9 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -157,8 +157,9 @@ impl SplitTunnel { pub fn new( runtime: tokio::runtime::Handle, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, + volume_update_rx: mpsc::UnboundedReceiver<()>, ) -> Result<Self, Error> { - let (request_tx, handle) = Self::spawn_request_thread()?; + let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?; let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() }; event_overlapped.hEvent = @@ -326,7 +327,9 @@ impl SplitTunnel { }) } - fn spawn_request_thread() -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> { + fn spawn_request_thread( + volume_update_rx: mpsc::UnboundedReceiver<()>, + ) -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> { let (tx, rx): (RequestTx, _) = sync_mpsc::channel(); let (init_tx, init_rx) = sync_mpsc::channel(); @@ -337,10 +340,11 @@ impl SplitTunnel { let path_monitor = path_monitor::PathMonitor::spawn(monitor_tx.clone()) .map_err(Error::StartPathMonitor)?; - let mut volume_monitor = volume_monitor::VolumeMonitor::spawn( + let volume_monitor = volume_monitor::VolumeMonitor::spawn( path_monitor.clone(), monitor_tx, monitored_paths.clone(), + volume_update_rx, ); std::thread::spawn(move || { @@ -403,7 +407,7 @@ impl SplitTunnel { } } - volume_monitor.close(); + drop(volume_monitor); if let Err(error) = path_monitor.shutdown() { log::error!( "{}", diff --git a/talpid-core/src/split_tunnel/windows/volume_monitor.rs b/talpid-core/src/split_tunnel/windows/volume_monitor.rs index e9060e5528..1993cc809c 100644 --- a/talpid-core/src/split_tunnel/windows/volume_monitor.rs +++ b/talpid-core/src/split_tunnel/windows/volume_monitor.rs @@ -2,11 +2,14 @@ //! tunnel config if any of the excluded paths are affected by them. use super::path_monitor::PathMonitorHandle; use crate::windows::window::{create_hidden_window, WindowCloseHandle}; +use futures::{channel::mpsc, StreamExt}; use std::{ ffi::OsString, + io, path::{self, Path}, - sync::{mpsc as sync_mpsc, Arc, Mutex}, + sync::{mpsc as sync_mpsc, Arc, Mutex, MutexGuard}, }; +use talpid_types::ErrorExt; use winapi::{ shared::minwindef::TRUE, um::{ @@ -14,83 +17,188 @@ use winapi::{ DBTF_NET, DBT_DEVICEARRIVAL, DBT_DEVICEREMOVECOMPLETE, DBT_DEVTYP_VOLUME, DEV_BROADCAST_HDR, DEV_BROADCAST_VOLUME, WM_DEVICECHANGE, }, + fileapi::GetLogicalDrives, winuser::DefWindowProcW, }, }; pub(super) struct VolumeMonitor(()); +pub(super) struct VolumeMonitorHandle { + window_handle: WindowCloseHandle, + internal_monitor_task: tokio::task::JoinHandle<()>, +} + +impl Drop for VolumeMonitorHandle { + fn drop(&mut self) { + self.window_handle.close(); + self.internal_monitor_task.abort(); + } +} + impl VolumeMonitor { pub fn spawn( path_monitor: PathMonitorHandle, update_tx: sync_mpsc::Sender<()>, paths: Arc<Mutex<Vec<OsString>>>, - ) -> WindowCloseHandle { - create_hidden_window(move |window, message, w_param, l_param| { - if message != WM_DEVICECHANGE - || (w_param != DBT_DEVICEARRIVAL && w_param != DBT_DEVICEREMOVECOMPLETE) - { - return unsafe { DefWindowProcW(window, message, w_param, l_param) }; - } + volume_update_rx: mpsc::UnboundedReceiver<()>, + ) -> VolumeMonitorHandle { + // A bitmask containing all (known) mounted drives. + let known_state = Arc::new(Mutex::new(0u32)); - let paths_guard = paths.lock().unwrap(); - let mut label_found = false; + // Lock before registering event handler + let mut known_state_guard = known_state.lock().unwrap(); - let volumes = unsafe { parse_broadcast(&*(l_param as *const _)) }; - for volume in volumes { - for path in &*paths_guard { - let path = (path as &dyn AsRef<Path>).as_ref(); - if let Some(path::Component::Prefix(prefix)) = path.components().next() { - match prefix.kind() { - path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => { - if disk == volume { - label_found = true; - break; - } - } - _ => (), - } - } - } - if label_found { - break; - } + let internal_monitor_task = tokio::spawn(frontend_monitor( + known_state.clone(), + path_monitor.clone(), + update_tx.clone(), + paths.clone(), + volume_update_rx, + )); + + let window_handle = + start_internal_monitor(known_state.clone(), path_monitor, update_tx, paths); + + *known_state_guard = get_logical_drives().unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to initialize state of mounted volumes") + ); + 0 + }); + + VolumeMonitorHandle { + window_handle, + internal_monitor_task, + } + } +} + +/// Monitors update requests from frontends. This checks if the known state of mounted volumes +/// has change, and, if so, reapplies the ST config. +async fn frontend_monitor( + known_state: Arc<Mutex<u32>>, + path_monitor: PathMonitorHandle, + update_tx: sync_mpsc::Sender<()>, + paths: Arc<Mutex<Vec<OsString>>>, + mut volume_update_rx: mpsc::UnboundedReceiver<()>, +) { + while let Some(()) = volume_update_rx.next().await { + let mut known_state_guard = known_state.lock().unwrap(); + let new_state = get_logical_drives().unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain new state of mounted volumes") + ); + *known_state_guard + }); + + // Was there a change? + let state_diff = *known_state_guard ^ new_state; + if state_diff != 0 { + *known_state_guard = new_state; + let paths_guard = paths.lock().unwrap(); + if matches_volume(state_diff, &paths_guard) { + // Reapply config + let _ = update_tx.send(()); + let _ = path_monitor.refresh(); } + } + } +} + +/// Monitors window events received by session 0. +fn start_internal_monitor( + known_state: Arc<Mutex<u32>>, + path_monitor: PathMonitorHandle, + update_tx: sync_mpsc::Sender<()>, + paths: Arc<Mutex<Vec<OsString>>>, +) -> WindowCloseHandle { + create_hidden_window(move |window, message, w_param, l_param| { + if !is_device_arrival_or_removal(message, w_param) { + return unsafe { DefWindowProcW(window, message, w_param, l_param) }; + } + let paths_guard = paths.lock().unwrap(); + let mut known_state_guard = known_state.lock().unwrap(); + + let volumes = unsafe { parse_device_volume_broadcast(&*(l_param as *const _)) }; - if label_found { + let prev_state = *known_state_guard; + let is_arrival = w_param == DBT_DEVICEARRIVAL; + if is_arrival { + *known_state_guard |= volumes; + } else { + *known_state_guard &= !volumes; + } + + // Compare against known state to ignore duplicate notifications + // from frontends + let state_diff = *known_state_guard ^ prev_state; + if state_diff != 0 { + if matches_volume(volumes, &paths_guard) { // Reapply config let _ = update_tx.send(()); let _ = path_monitor.refresh(); } + } + + // Always grant the request + TRUE as isize + }) +} + +/// Return a bitmask representing all currently available disk drives. +/// Each bit refers to a volume letter. The bit 0 refers to 'A', bit 1 +/// refers to 'B', bit 2 to 'C', etc. +fn get_logical_drives() -> io::Result<u32> { + let result = unsafe { GetLogicalDrives() }; + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(result) +} - // Always grant the request - TRUE as isize - }) +/// Return whether any of the paths in `paths_guard` reside on any volume in `volumes` (a mask). +fn matches_volume(volumes: u32, paths_guard: &MutexGuard<'_, Vec<OsString>>) -> bool { + for path in &**paths_guard { + let path = (path as &dyn AsRef<Path>).as_ref(); + if let Some(path::Component::Prefix(prefix)) = path.components().next() { + match prefix.kind() { + path::Prefix::VerbatimDisk(disk) | path::Prefix::Disk(disk) => { + if disk < 'A' as u8 || disk > 'Z' as u8 { + log::warn!("Ignoring invalid volume \"{}\"", disk as char); + continue; + } + let disk = disk - 'A' as u8; + if volumes & (1 << disk) != 0 { + return true; + } + } + _ => (), + } + } } + false } -/// Return volume labels (ASCII-encoded) affected by the device arrival or removal message, if any. -unsafe fn parse_broadcast(broadcast: &DEV_BROADCAST_HDR) -> Vec<u8> { - let mut labels = vec![]; +fn is_device_arrival_or_removal(message: u32, w_param: usize) -> bool { + message == WM_DEVICECHANGE + && (w_param == DBT_DEVICEARRIVAL || w_param == DBT_DEVICEREMOVECOMPLETE) +} +/// Return volumes affected by the device arrival or removal message as a mask. +/// This has the same format as `get_logical_drives()`. +unsafe fn parse_device_volume_broadcast(broadcast: &DEV_BROADCAST_HDR) -> u32 { if broadcast.dbch_devicetype != DBT_DEVTYP_VOLUME { - return labels; + return 0; } let volume_broadcast = &*(broadcast as *const _ as *const DEV_BROADCAST_VOLUME); if volume_broadcast.dbcv_flags & DBTF_NET != 0 { // Ignore net event - return labels; - } - - // 26 = 1 + 'Z' - 'A' - let num_drives = 1 + 'Z' as u8 - 'A' as u8; - for i in 0..num_drives { - let is_affected = ((volume_broadcast.dbcv_unitmask >> i) & 1) != 0; - if is_affected { - labels.push('A' as u8 + i); - } + return 0; } - labels + volume_broadcast.dbcv_unitmask } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 298f8fc7a6..e22975a827 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -103,6 +103,7 @@ pub async fn spawn( state_change_listener: impl Sender<TunnelStateTransition> + Send + 'static, offline_state_listener: mpsc::UnboundedSender<bool>, shutdown_tx: oneshot::Sender<()>, + #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "macos")] exclusion_gid: u32, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { @@ -130,6 +131,8 @@ pub async fn spawn( log_dir, resource_dir, command_rx, + #[cfg(target_os = "windows")] + volume_update_rx, #[cfg(target_os = "macos")] exclusion_gid, #[cfg(target_os = "android")] @@ -207,6 +210,7 @@ impl TunnelStateMachine { log_dir: Option<PathBuf>, resource_dir: PathBuf, commands_rx: mpsc::UnboundedReceiver<TunnelCommand>, + #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "macos")] exclusion_gid: u32, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Self, Error> { @@ -216,8 +220,9 @@ impl TunnelStateMachine { let filtering_resolver = crate::resolver::start_resolver().await?; #[cfg(windows)] - let split_tunnel = split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone()) - .map_err(Error::InitSplitTunneling)?; + let split_tunnel = + split_tunnel::SplitTunnel::new(runtime.clone(), command_tx.clone(), volume_update_rx) + .map_err(Error::InitSplitTunneling)?; let args = FirewallArguments { initial_state: if settings.block_when_disconnected || !settings.reset_firewall { |
