diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-06-14 14:03:02 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 14:03:02 +0200 |
| commit | 05bc925914dc31fa74fb4b85b5d74061333f3b37 (patch) | |
| tree | e2e800efbbf6d94dd63b25c38097bbac8c825b40 | |
| parent | 463042e83845efd09f61cd99feb65db58240bb32 (diff) | |
| parent | 0ce130b6faec1272cb7708e4bf5f315e3e287237 (diff) | |
| download | mullvadvpn-05bc925914dc31fa74fb4b85b5d74061333f3b37.tar.xz mullvadvpn-05bc925914dc31fa74fb4b85b5d74061333f3b37.zip | |
Merge branch 'add-win-st-pid-list'
| -rw-r--r-- | CHANGELOG.md | 3 | ||||
| -rw-r--r-- | mullvad-cli/src/cmds/split_tunnel/windows.rs | 41 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 43 | ||||
| -rw-r--r-- | mullvad-daemon/src/management_interface.rs | 36 | ||||
| -rw-r--r-- | mullvad-management-interface/proto/management_interface.proto | 11 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 355 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 414 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/windows.rs | 99 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 34 |
9 files changed, 656 insertions, 380 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 2310942811..68b9644f75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ Line wrap the file at 100 chars. Th #### Linux - Automatically attempt to detect and set the correct MTU for Wireguard tunnels. +#### Windows +- Add CLI command for listing excluded processes. + ### Changed - Display consistent colors regardless of monitor color profile. diff --git a/mullvad-cli/src/cmds/split_tunnel/windows.rs b/mullvad-cli/src/cmds/split_tunnel/windows.rs index 0bc8d468bc..9766607ce1 100644 --- a/mullvad-cli/src/cmds/split_tunnel/windows.rs +++ b/mullvad-cli/src/cmds/split_tunnel/windows.rs @@ -1,3 +1,5 @@ +use std::{ffi::OsStr, path::Path}; + use crate::{new_rpc_client, Command, Result}; pub struct SplitTunnel; @@ -23,11 +25,13 @@ impl Command for SplitTunnel { ), ) .subcommand(clap::App::new("get").about("Display the split tunnel status")) + .subcommand(create_pid_subcommand()) } async fn run(&self, matches: &clap::ArgMatches) -> Result<()> { match matches.subcommand() { Some(("app", matches)) => Self::handle_app_subcommand(matches).await, + Some(("pid", matches)) => Self::handle_pid_subcommand(matches).await, Some(("get", _)) => self.get().await, Some(("set", matches)) => { let enabled = matches.value_of("policy").expect("missing policy"); @@ -50,6 +54,16 @@ fn create_app_subcommand() -> clap::App<'static> { .subcommand(clap::App::new("clear")) } +fn create_pid_subcommand() -> clap::App<'static> { + clap::App::new("pid") + .about("Manages processes (PIDs) excluded from the tunnel") + .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::App::new("list") + .about("List processes that are currently being excluded, i.e. their PIDs, as well as whether \ + they are excluded because of their executable paths or because they're subprocesses of \ + such processes")) +} + impl SplitTunnel { async fn handle_app_subcommand(matches: &clap::ArgMatches) -> Result<()> { match matches.subcommand() { @@ -91,6 +105,33 @@ impl SplitTunnel { } } + async fn handle_pid_subcommand(matches: &clap::ArgMatches) -> Result<()> { + match matches.subcommand() { + Some(("list", _)) => { + let processes = new_rpc_client() + .await? + .get_excluded_processes(()) + .await? + .into_inner(); + + for process in &processes.processes { + let subproc = if process.inherited { "subprocess" } else { "" }; + println!( + "{:<7}{subproc:<12}{}", + process.pid, + Path::new(&process.image) + .file_name() + .unwrap_or(OsStr::new("unknown")) + .to_string_lossy() + ); + } + + Ok(()) + } + _ => unreachable!("unhandled subcommand"), + } + } + async fn set(&self, enabled: bool) -> Result<()> { let mut rpc = new_rpc_client().await?; rpc.set_split_tunnel_state(enabled).await?; diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index c013361925..95039cd662 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -66,7 +66,7 @@ use std::{ use talpid_core::split_tunnel; use talpid_core::{ mpsc::Sender, - tunnel_state_machine::{self, TunnelCommand}, + tunnel_state_machine::{self, TunnelCommand, TunnelStateMachineHandle}, }; #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; @@ -265,9 +265,12 @@ pub enum DaemonCommand { /// Clear list of apps to exclude from the tunnel #[cfg(windows)] ClearSplitTunnelApps(ResponseTx<(), Error>), - /// Disable split tunnel + /// Enable or disable split tunneling #[cfg(windows)] SetSplitTunnelState(ResponseTx<(), Error>, bool), + /// Returns all processes currently being excluded from the tunnel + #[cfg(windows)] + GetSplitTunnelProcesses(ResponseTx<Vec<split_tunnel::ExcludedProcess>, split_tunnel::Error>), /// Toggle wireguard-nt on or off #[cfg(target_os = "windows")] UseWireGuardNt(ResponseTx<(), Error>, bool), @@ -503,7 +506,6 @@ pub trait EventListener { } pub struct Daemon<L: EventListener> { - tunnel_command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>, tunnel_state: TunnelState, target_state: PersistentTargetState, state: DaemonExecutionState, @@ -526,7 +528,7 @@ pub struct Daemon<L: EventListener> { parameters_generator: tunnel::ParametersGenerator, app_version_info: Option<AppVersionInfo>, shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>, - tunnel_state_machine_handle: tunnel_state_machine::JoinHandle, + tunnel_state_machine_handle: TunnelStateMachineHandle, #[cfg(target_os = "windows")] volume_update_tx: mpsc::UnboundedSender<()>, } @@ -647,7 +649,7 @@ where let (offline_state_tx, offline_state_rx) = mpsc::unbounded(); #[cfg(target_os = "windows")] let (volume_update_tx, volume_update_rx) = mpsc::unbounded(); - let (tunnel_command_tx, tunnel_state_machine_handle) = tunnel_state_machine::spawn( + let tunnel_state_machine_handle = tunnel_state_machine::spawn( tunnel_state_machine::InitialTunnelState { allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, @@ -672,7 +674,8 @@ where .await .map_err(Error::TunnelError)?; - endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx)); + endpoint_updater + .set_tunnel_command_tx(Arc::downgrade(tunnel_state_machine_handle.command_tx())); api::forward_offline_state(api_availability.clone(), offline_state_rx); @@ -703,7 +706,6 @@ where relay_list_updater.update().await; let daemon = Daemon { - tunnel_command_tx, tunnel_state: TunnelState::Disconnected, target_state, state: DaemonExecutionState::Running, @@ -789,7 +791,7 @@ where L, Vec<Pin<Box<dyn Future<Output = ()>>>>, mullvad_api::Runtime, - tunnel_state_machine::JoinHandle, + TunnelStateMachineHandle, ) { let Daemon { event_listener, @@ -915,12 +917,12 @@ where fn schedule_reconnect(&mut self, delay: Duration) { self.unschedule_reconnect(); - let tunnel_command_tx = self.tx.to_specialized_sender(); + let daemon_command_tx = self.tx.to_specialized_sender(); let (future, abort_handle) = abortable(Box::pin(async move { tokio::time::sleep(delay).await; log::debug!("Attempting to reconnect"); let (tx, rx) = oneshot::channel(); - let _ = tunnel_command_tx.send(DaemonCommand::Reconnect(tx)); + let _ = daemon_command_tx.send(DaemonCommand::Reconnect(tx)); // suppress "unable to send" warning: let _ = rx.await; })); @@ -1013,6 +1015,8 @@ where ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await, #[cfg(windows)] SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await, + #[cfg(windows)] + GetSplitTunnelProcesses(tx) => self.on_get_split_tunnel_processes(tx), #[cfg(target_os = "windows")] UseWireGuardNt(tx, state) => self.on_use_wireguard_nt(tx, state).await, #[cfg(target_os = "windows")] @@ -1674,6 +1678,20 @@ where } #[cfg(windows)] + fn on_get_split_tunnel_processes( + &self, + tx: ResponseTx<Vec<split_tunnel::ExcludedProcess>, split_tunnel::Error>, + ) { + Self::oneshot_send( + tx, + self.tunnel_state_machine_handle + .split_tunnel() + .get_processes(), + "get_split_tunnel_processes response", + ); + } + + #[cfg(windows)] async fn on_use_wireguard_nt(&mut self, tx: ResponseTx<(), Error>, state: bool) { let save_result = self .settings @@ -2198,8 +2216,9 @@ where } } - fn send_tunnel_command(&mut self, command: TunnelCommand) { - self.tunnel_command_tx + fn send_tunnel_command(&self, command: TunnelCommand) { + self.tunnel_state_machine_handle + .command_tx() .unbounded_send(command) .expect("Tunnel state machine has stopped"); } diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index a7478eb048..7999fd9c55 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -768,6 +768,42 @@ impl ManagementService for ManagementServiceImpl { } #[cfg(windows)] + async fn get_excluded_processes( + &self, + _: Request<()>, + ) -> ServiceResult<types::ExcludedProcessList> { + log::debug!("get_excluded_processes"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::GetSplitTunnelProcesses(tx))?; + self.wait_for_result(rx) + .await? + .map_err(map_split_tunnel_error) + .map(|processes| { + Response::new(types::ExcludedProcessList { + processes: processes + .into_iter() + .map(|process| types::ExcludedProcess { + // FIXME: This is necessarily 32 bits or less + pid: u32::try_from(process.pid).unwrap(), + image: process.image.into_os_string().to_string_lossy().to_string(), + inherited: process.inherited, + }) + .collect(), + }) + }) + } + + #[cfg(not(windows))] + async fn get_excluded_processes( + &self, + _: Request<()>, + ) -> ServiceResult<types::ExcludedProcessList> { + Ok(Response::new(types::ExcludedProcessList { + processes: vec![], + })) + } + + #[cfg(windows)] async fn set_use_wireguard_nt(&self, request: Request<bool>) -> ServiceResult<()> { log::debug!("set_use_wireguard_nt"); let state = request.into_inner(); diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 9148a10150..2f39f496d6 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -79,6 +79,7 @@ service ManagementService { rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {} rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} + rpc GetExcludedProcesses(google.protobuf.Empty) returns (ExcludedProcessList) {} rpc SetUseWireguardNt(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} @@ -476,6 +477,16 @@ message PublicKey { google.protobuf.Timestamp created = 2; } +message ExcludedProcess { + uint32 pid = 1; + string image = 2; + bool inherited = 3; +} + +message ExcludedProcessList { + repeated ExcludedProcess processes = 1; +} + message AppVersionInfo { bool supported = 1; string latest_stable = 2; diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 70fe90f45d..6b6b2b2ea7 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -1,6 +1,6 @@ use super::windows::{ - get_device_path, get_process_creation_time, get_process_device_path, open_process, - ProcessAccess, ProcessSnapshot, + get_device_path, get_process_creation_time, get_process_device_path, open_process, Event, + Overlapped, ProcessAccess, ProcessSnapshot, }; use crate::windows::as_uninit_byte_slice; use bitflags::bitflags; @@ -33,12 +33,14 @@ use winapi::{ }, }, um::{ - handleapi::CloseHandle, ioapiset::{DeviceIoControl, GetOverlappedResult}, minwinbase::OVERLAPPED, - synchapi::{CreateEventW, WaitForSingleObject}, + synchapi::{WaitForMultipleObjects, WaitForSingleObject}, tlhelp32::TH32CS_SNAPPROCESS, - winbase::{FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_FAILED, WAIT_OBJECT_0}, + winbase::{ + FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED, + WAIT_OBJECT_0, + }, winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, }, }; @@ -254,24 +256,17 @@ impl DeviceHandle { } fn initialize(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::Initialize as u32, - None, - 0, - None, - )?; + device_io_control(self, DriverIoctlCode::Initialize as u32, None, 0)?; Ok(()) } fn register_processes(&self) -> io::Result<()> { let process_tree_buffer = serialize_process_tree(build_process_tree()?)?; device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::RegisterProcesses as u32, Some(&process_tree_buffer), 0, - None, )?; Ok(()) } @@ -327,11 +322,10 @@ impl DeviceHandle { let buffer = as_uninit_byte_slice(&addresses); device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::RegisterIpAddresses as u32, Some(buffer), 0, - None, )?; Ok(()) @@ -339,11 +333,10 @@ impl DeviceHandle { pub fn get_driver_state(&self) -> io::Result<DriverState> { let buffer = device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::GetState as u32, None, size_of::<u64>() as u32, - None, )? .unwrap(); @@ -381,36 +374,22 @@ impl DeviceHandle { let config = make_process_config(&device_paths); device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::SetConfiguration as u32, Some(&config), 0, - None, )?; Ok(()) } pub fn clear_config(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::ClearConfiguration as u32, - None, - 0, - None, - )?; - + device_io_control(self, DriverIoctlCode::ClearConfiguration as u32, None, 0)?; Ok(()) } fn reset(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::Reset as u32, - None, - 0, - None, - )?; + device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?; Ok(()) } } @@ -680,18 +659,10 @@ struct ErrorMessageEventHeader { /// # Panics /// /// This may panic if `buffer` contains invalid data. -pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { +pub fn parse_event_buffer(buffer: &[u8]) -> Result<(EventId, EventBody), UnknownEventId> { // SAFETY: This panics if `buffer` is too small. let raw_event_id: u32 = unsafe { deserialize_buffer(&buffer[0..mem::size_of::<u32>()]) }; - let _event_id = EventId::try_from(raw_event_id) - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to parse ST event buffer") - ); - error - }) - .ok()?; + let _event_id = EventId::try_from(raw_event_id)?; // SAFETY: The event id is known to be valid. let event_header: EventHeader = @@ -711,7 +682,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { [string_byte_offset..(string_byte_offset + event.image_name_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::SplittingEvent { process_id: event.process_id, @@ -736,7 +707,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { [string_byte_offset..(string_byte_offset + event.image_name_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::SplittingError { process_id: event.process_id, @@ -757,7 +728,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { ..(string_byte_offset + event.error_message_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::ErrorMessage { status: event.status, @@ -768,170 +739,116 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { } } -/// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0. +/// Send an IOCTL code to the given device handle, and wait for the result. +/// +/// `input` specifies an optional buffer for sending data. +/// +/// Upon success, a buffer containing at most `output_size` bytes is returned, +/// or `None` if no bytes were read. pub fn device_io_control( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, input: Option<&[MaybeUninit<u8>]>, output_size: u32, - timeout: Option<Duration>, ) -> Result<Option<Vec<u8>>, io::Error> { - struct HandleOwner { - handle: RawHandle, - } - impl Drop for HandleOwner { - fn drop(&mut self) { - unsafe { CloseHandle(self.handle) }; - } - } - - let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() }; - overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; + let mut overlapped = Overlapped::new(Some(Event::new(true, false)?))?; - if overlapped.hEvent == ptr::null_mut() { - return Err(io::Error::last_os_error()); - } - - let _handle_owner = HandleOwner { - handle: overlapped.hEvent, - }; - - let mut out_buffer = if output_size > 0 { - Some(Vec::with_capacity(output_size as usize)) + let mut buffer = vec![]; + let out_buffer = if output_size > 0 { + buffer.resize( + usize::try_from(output_size).expect("u32 must be no larger than usize"), + 0u8, + ); + Some(&mut buffer[..]) } else { None }; - device_io_control_buffer( - device, - ioctl_code, - input, - out_buffer.as_mut(), - &overlapped, - timeout, - ) - .map(|()| out_buffer) + let bytes_read = + device_io_control_buffer(device, ioctl_code, input, out_buffer, &mut overlapped)?; + if bytes_read > 0 { + buffer.truncate(usize::try_from(bytes_read).expect("u32 must be no larger than usize")); + return Ok(Some(buffer)); + } + Ok(None) } -/// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// Upon success, `output` buffer will contain at most `output.capacity()` bytes of data. +/// Send an IOCTL code to the given device handle, and wait for the result. +/// +/// `input` specifies an optional buffer for sending data. +/// +/// Upon success, `output` buffer will contain at most `output.len()` bytes of data, +/// and the function returns the number of bytes read. +/// +/// # Panics +/// +/// This function will panic if `overlapped` does not contain an event. pub fn device_io_control_buffer( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, input: Option<&[MaybeUninit<u8>]>, - mut output: Option<&mut Vec<u8>>, - overlapped: &OVERLAPPED, - timeout: Option<Duration>, -) -> Result<(), io::Error> { - let input_ptr = match input { - Some(input) => input as *const _ as *mut _, - None => ptr::null_mut(), - }; - let input_len = input.map(|input| input.len()).unwrap_or(0); - + output: Option<&mut [u8]>, + overlapped: &mut Overlapped, +) -> Result<u32, io::Error> { + let output_len = output.as_ref().map(|output| output.len()).unwrap_or(0); + let output_len = u32::try_from(output_len).map_err(|_error| { + io::Error::new( + io::ErrorKind::InvalidInput, + "the output buffer is too large", + ) + })?; let out_ptr = match output { - Some(ref mut output) => output.as_mut_ptr() as *mut _, + Some(output) => output as *mut _ as *mut _, None => ptr::null_mut(), }; - let output_size = if let Some(ref output) = output { - output.capacity() - } else { - 0 - }; - - let event = overlapped.hEvent; - - let mut returned_bytes = 0u32; - let overlapped = overlapped as *const _ as *mut _; - - let result = unsafe { - DeviceIoControl( - device as *mut _, + // SAFETY: `out_ptr` will be valid until the result has been obtained. + unsafe { + device_io_control_buffer_async( + device, ioctl_code, - input_ptr, - input_len as u32, + input, out_ptr, - output_size as u32, - &mut returned_bytes, - overlapped, - ) - }; - - if result != 0 { - return Err(io::Error::new( - io::ErrorKind::Other, - "Expected pending operation", - )); - } - - let last_error = io::Error::last_os_error(); - if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) { - return Err(last_error); - } - - let timeout = timeout - .map(|timeout| timeout.as_millis() as u32) - .unwrap_or(INFINITE); - let result = unsafe { WaitForSingleObject(event, timeout) }; - match result { - WAIT_FAILED => return Err(io::Error::last_os_error()), - WAIT_ABANDONED => return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")), - WAIT_OBJECT_0 => (), - error => return Err(io::Error::from_raw_os_error(error as i32)), - } - - let result = - unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, FALSE) }; - - if result == 0 { - return Err(io::Error::last_os_error()); - } - - if let Some(ref mut output) = output { - unsafe { output.set_len(returned_bytes as usize) }; + output_len, + overlapped.as_mut_ptr(), + )?; } - - Ok(()) + get_overlapped_result(device, overlapped) } /// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// The result must be obtained using `GetOverlappedResult[Ex]`. +/// +/// `input` specifies an optional buffer for sending data. +/// `output_ptr` specifies an optional buffer for receiving data. +/// +/// Obtain the result using [get_overlapped_result]. +/// +/// # Safety +/// +/// * `output_ptr` must either be null or a valid buffer of `output_len` bytes. It must remain valid +/// until the overlapped operation has completed. pub unsafe fn device_io_control_buffer_async( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, - mut output: Option<&mut Vec<u8>>, - input: Option<&[u8]>, - overlapped: &OVERLAPPED, + input: Option<&[MaybeUninit<u8>]>, + output_ptr: *mut u8, + output_len: u32, + overlapped: *mut OVERLAPPED, ) -> Result<(), io::Error> { let input_ptr = match input { - Some(input) => input as *const _ as *mut _, + Some(input) => input.as_ptr() as *mut _, None => ptr::null_mut(), }; let input_len = input.map(|input| input.len()).unwrap_or(0); - let out_ptr = match output { - Some(ref mut output) => output.as_mut_ptr() as *mut _, - None => ptr::null_mut(), - }; - let output_size = if let Some(ref output) = output { - output.capacity() - } else { - 0 - }; - - let overlapped = overlapped as *const _ as *mut _; - let result = DeviceIoControl( - device as *mut _, + device.as_raw_handle(), ioctl_code, input_ptr, - input_len as u32, - out_ptr, - output_size as u32, + u32::try_from(input_len).map_err(|_error| { + io::Error::new(io::ErrorKind::InvalidInput, "the input buffer is too large") + })?, + output_ptr as *mut _, + output_len, ptr::null_mut(), overlapped, ); @@ -951,6 +868,90 @@ pub unsafe fn device_io_control_buffer_async( Ok(()) } +/// Retrieves the result of an overlapped operation. On success, this returns +/// the number of bytes transferred. For device I/O, this is the number of bytes +/// written to the output buffer. +/// +/// # Panics +/// +/// This function will panic if `overlapped` does not contain an event. +pub fn get_overlapped_result( + device: &DeviceHandle, + overlapped: &mut Overlapped, +) -> io::Result<u32> { + let event = overlapped.get_event().unwrap(); + + // SAFETY: This is a valid event object. + unsafe { wait_for_single_object(event.as_raw_handle(), None) }?; + + // SAFETY: The handle and overlapped object are valid. + let mut returned_bytes = 0u32; + let result = unsafe { + GetOverlappedResult( + device.as_raw_handle(), + overlapped.as_mut_ptr(), + &mut returned_bytes, + FALSE, + ) + }; + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(returned_bytes) +} + +/// Waits for an object to be signaled, or until a timeout interval has elapsed. +/// +/// # Safety +/// +/// * `object` must be a valid object that can be signaled, such as an event object. +pub unsafe fn wait_for_single_object( + object: RawHandle, + timeout: Option<Duration>, +) -> io::Result<()> { + let timeout = match timeout { + Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| { + io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long") + })?, + None => INFINITE, + }; + let result = WaitForSingleObject(object, timeout); + match result { + WAIT_OBJECT_0 => Ok(()), + WAIT_FAILED => Err(io::Error::last_os_error()), + WAIT_ABANDONED => Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")), + error => Err(io::Error::from_raw_os_error(error as i32)), + } +} + +/// Waits for one or several objects to be signaled. On success, this returns a pointer to an +/// object in `objects` that was signaled. +/// +/// # Safety +/// +/// * `objects` must be a slice of valid objects that can be signaled, such as event objects. +pub unsafe fn wait_for_multiple_objects( + objects: &[RawHandle], + wait_all: bool, +) -> io::Result<RawHandle> { + let objects_len = u32::try_from(objects.len()) + .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?; + let result = WaitForMultipleObjects( + objects_len, + objects.as_ptr(), + if wait_all { TRUE } else { FALSE }, + INFINITE, + ); + let signaled_index = if result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + objects_len { + result - WAIT_OBJECT_0 + } else if result >= WAIT_ABANDONED_0 && result < WAIT_ABANDONED_0 + objects_len { + return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")); + } else { + return Err(io::Error::last_os_error()); + }; + Ok(objects[usize::try_from(signaled_index).expect("usize must be larger than u32")]) +} + /// Reads the value from `buffer`, zeroing any remaining bytes. /// /// # Safety diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 5b5b29e0cc..63a56be473 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -12,29 +12,20 @@ use crate::{ }; use futures::channel::{mpsc, oneshot}; use std::{ + collections::HashMap, convert::TryFrom, ffi::{OsStr, OsString}, - io, mem, + io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, - os::windows::io::{AsRawHandle, RawHandle}, - ptr, + os::windows::io::AsRawHandle, + path::{Path, PathBuf}, sync::{ atomic::{AtomicBool, Ordering}, - mpsc as sync_mpsc, Arc, Mutex, Weak, + mpsc as sync_mpsc, Arc, Mutex, RwLock, Weak, }, time::Duration, }; use talpid_types::{tunnel::ErrorStateCause, ErrorExt}; -use winapi::{ - shared::minwindef::{FALSE, TRUE}, - um::{ - handleapi::CloseHandle, - ioapiset::GetOverlappedResult, - minwinbase::OVERLAPPED, - synchapi::{CreateEventW, SetEvent, WaitForMultipleObjects, WaitForSingleObject}, - winbase::{INFINITE, WAIT_ABANDONED_0, WAIT_OBJECT_0}, - }, -}; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); @@ -84,8 +75,8 @@ pub enum Error { RequestThreadStuck, /// The request handling thread is down - #[error(display = "The ST request thread is down")] - RequestThreadDown, + #[error(display = "The split tunnel monitor is down")] + SplitTunnelDown, /// Failed to start the NTFS reparse point monitor #[error(display = "Failed to start path monitor")] @@ -101,36 +92,13 @@ pub struct SplitTunnel { runtime: tokio::runtime::Handle, request_tx: RequestTx, event_thread: Option<std::thread::JoinHandle<()>>, - quit_event: Arc<QuitEvent>, + quit_event: Arc<windows::Event>, + excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, _route_change_callback: Option<WinNetCallbackHandle>, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, async_path_update_in_progress: Arc<AtomicBool>, } -struct QuitEvent(RawHandle); - -unsafe impl Send for QuitEvent {} -unsafe impl Sync for QuitEvent {} - -impl QuitEvent { - fn new() -> Self { - Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }) - } - - fn set_event(&self) -> io::Result<()> { - if unsafe { SetEvent(self.0) } == 0 { - return Err(io::Error::last_os_error()); - } - Ok(()) - } -} - -impl Drop for QuitEvent { - fn drop(&mut self) { - unsafe { CloseHandle(self.0) }; - } -} - enum Request { SetPaths(Vec<OsString>), RegisterIps(InterfaceAddresses), @@ -148,12 +116,43 @@ struct InterfaceAddresses { internet_ipv6: Option<Ipv6Addr>, } -struct EventThreadContext { - handle: Arc<driver::DeviceHandle>, - event_overlapped: OVERLAPPED, - quit_event: Arc<QuitEvent>, +/// Represents a process that is being excluded from the tunnel. +#[derive(Debug, Clone)] +pub struct ExcludedProcess { + /// Process identifier. + pub pid: u32, + /// Path to the image that this process is an instance of. + pub image: PathBuf, + /// If true, then the process is split because its parent was split, + /// not due to its path being in the config. + pub inherited: bool, +} + +/// Cloneable handle for interacting with the split tunnel module. +#[derive(Debug, Clone)] +pub struct SplitTunnelHandle { + excluded_processes: Weak<RwLock<HashMap<usize, ExcludedProcess>>>, +} + +impl SplitTunnelHandle { + /// Return processes that are currently being excluded, including + /// their pids, paths, and reason for being excluded. + pub fn get_processes(&self) -> Result<Vec<ExcludedProcess>, Error> { + let processes = self + .excluded_processes + .upgrade() + .ok_or(Error::SplitTunnelDown)?; + let processes = processes.read().unwrap(); + Ok(processes.values().cloned().collect()) + } +} + +enum EventResult { + /// Result containing the next event. + Event(driver::EventId, driver::EventBody), + /// Quit event was signaled. + Quit, } -unsafe impl Send for EventThreadContext {} impl SplitTunnel { /// Initialize the split tunnel device. @@ -164,169 +163,207 @@ impl SplitTunnel { ) -> Result<Self, Error> { let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?; - let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() }; - event_overlapped.hEvent = - unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; - if event_overlapped.hEvent == ptr::null_mut() { - return Err(Error::EventThreadError(io::Error::last_os_error())); - } - - let quit_event = Arc::new(QuitEvent::new()); - - let event_context = EventThreadContext { - handle: handle.clone(), - event_overlapped, - quit_event: quit_event.clone(), - }; + let excluded_processes = Arc::new(RwLock::new(HashMap::new())); + let (event_thread, quit_event) = + Self::spawn_event_listener(handle, excluded_processes.clone())?; - let event_thread = std::thread::spawn(move || { - use driver::{EventBody, EventId}; + Ok(SplitTunnel { + runtime, + request_tx, + event_thread: Some(event_thread), + quit_event, + _route_change_callback: None, + daemon_tx, + async_path_update_in_progress: Arc::new(AtomicBool::new(false)), + excluded_processes, + }) + } - // Take ownership of the entire struct (Rust 2021 edition change) - let _ = &event_context; + /// Spawns an event loop thread that processes events from the driver service. + fn spawn_event_listener( + handle: Arc<driver::DeviceHandle>, + excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, + ) -> Result<(std::thread::JoinHandle<()>, Arc<windows::Event>), Error> { + let mut event_overlapped = windows::Overlapped::new(Some( + windows::Event::new(true, false).map_err(Error::EventThreadError)?, + )) + .map_err(Error::EventThreadError)?; - let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE); - let mut returned_bytes = 0u32; + let quit_event = + Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?); + let quit_event_copy = quit_event.clone(); - let event_objects = [ - event_context.event_overlapped.hEvent, - event_context.quit_event.0, - ]; + let event_thread = std::thread::spawn(move || { + let mut data_buffer = vec![]; loop { - if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } { - // Quit event was signaled - break; - } + // Wait until either the next event is received or the quit event is signaled. + let (event_id, event_body) = match Self::fetch_next_event( + &handle, + &quit_event, + &mut event_overlapped, + &mut data_buffer, + ) { + Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body), + Ok(EventResult::Quit) => break, + Err(_error) => continue, + }; - if let Err(error) = unsafe { - driver::device_io_control_buffer_async( - event_context.handle.as_raw_handle(), - driver::DriverIoctlCode::DequeEvent as u32, - Some(&mut data_buffer), - None, - &event_context.event_overlapped, - ) - } { - log::error!( - "{}", - error.display_chain_with_msg("device_io_control failed") - ); - continue; - } + Self::handle_event(event_id, event_body, &excluded_processes); + } - let result = unsafe { - WaitForMultipleObjects( - event_objects.len() as u32, - &event_objects[0], - FALSE, - INFINITE, - ) - }; + log::debug!("Stopping split tunnel event thread"); + }); - let signaled_index = if result >= WAIT_OBJECT_0 - && result < WAIT_OBJECT_0 + event_objects.len() as u32 - { - result - WAIT_OBJECT_0 - } else if result >= WAIT_ABANDONED_0 - && result < WAIT_ABANDONED_0 + event_objects.len() as u32 - { - result - WAIT_ABANDONED_0 - } else { - let error = io::Error::last_os_error(); - log::error!( - "{}", - error.display_chain_with_msg("WaitForMultipleObjects failed") - ); + Ok((event_thread, quit_event_copy)) + } - continue; - }; + fn fetch_next_event( + device: &Arc<driver::DeviceHandle>, + quit_event: &windows::Event, + overlapped: &mut windows::Overlapped, + data_buffer: &mut Vec<u8>, + ) -> io::Result<EventResult> { + if unsafe { + driver::wait_for_single_object(quit_event.as_raw_handle(), Some(Duration::ZERO)) + } + .is_ok() + { + return Ok(EventResult::Quit); + } - if event_context.quit_event.0 == event_objects[signaled_index as usize] { - // Quit event was signaled - break; - } + data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8); - let result = unsafe { - GetOverlappedResult( - event_context.handle.as_raw_handle(), - &event_context.event_overlapped as *const _ as *mut _, - &mut returned_bytes, - TRUE, - ) - }; + unsafe { + driver::device_io_control_buffer_async( + device, + driver::DriverIoctlCode::DequeEvent as u32, + None, + data_buffer.as_mut_ptr(), + u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"), + overlapped.as_mut_ptr(), + ) + } + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("DeviceIoControl failed to deque event") + ); + error + })?; + + let event_objects = [ + overlapped.get_event().unwrap().as_raw_handle(), + quit_event.as_raw_handle(), + ]; - if result == 0 { - let error = io::Error::last_os_error(); + let signaled_object = + unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) }.map_err( + |error| { log::error!( "{}", - error.display_chain_with_msg("GetOverlappedResult failed") + error.display_chain_with_msg("wait_for_multiple_objects failed") ); + error + }, + )?; - continue; - } + if signaled_object == quit_event.as_raw_handle() { + // Quit event was signaled + return Ok(EventResult::Quit); + } - unsafe { data_buffer.set_len(returned_bytes as usize) }; + let returned_bytes = + driver::get_overlapped_result(device, overlapped).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("get_overlapped_result failed for dequed event"), + ); + error + })?; - let event = driver::parse_event_buffer(&data_buffer); + data_buffer + .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32")); - let (event_id, event_body) = match event { - Some((event_id, event_body)) => (event_id, event_body), - None => continue, - }; + driver::parse_event_buffer(&data_buffer) + .map(|(id, body)| EventResult::Event(id, body)) + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to parse ST event buffer") + ); + io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer") + }) + } - let event_str = match &event_id { - EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { - "Start splitting process" - } - EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { - "Stop splitting process" - } - EventId::ErrorMessage => "ErrorMessage", - }; + fn handle_event( + event_id: driver::EventId, + event_body: driver::EventBody, + excluded_processes: &Arc<RwLock<HashMap<usize, ExcludedProcess>>>, + ) { + use driver::{EventBody, EventId}; - match event_body { - EventBody::SplittingEvent { - process_id, - reason, - image, - } => { - log::trace!( - "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", - event_str, - process_id, - reason, - image, - ); - } - EventBody::SplittingError { process_id, image } => { - log::error!( - "FAILED: {}:\n\tpid: {}\n\timage: {:?}", - event_str, + let event_str = match &event_id { + EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { + "Start splitting process" + } + EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { + "Stop splitting process" + } + EventId::ErrorMessage => "ErrorMessage", + }; + + match event_body { + EventBody::SplittingEvent { + process_id, + reason, + image, + } => { + let mut pids = excluded_processes.write().unwrap(); + match event_id { + EventId::StartSplittingProcess => { + if let Some(prev_entry) = pids.get(&process_id) { + log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry); + } + pids.insert( process_id, - image, + ExcludedProcess { + pid: u32::try_from(process_id) + .expect("PID should be containable in a DWORD"), + image: Path::new(&image).to_path_buf(), + inherited: reason + .contains(driver::SplittingChangeReason::BY_INHERITANCE), + }, ); } - EventBody::ErrorMessage { status, message } => { - log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) + EventId::StopSplittingProcess => { + if pids.remove(&process_id).is_none() { + log::error!("Inconsistent process tree: {process_id} was not found"); + } } + _ => (), } - } - - log::debug!("Stopping split tunnel event thread"); - unsafe { CloseHandle(event_context.event_overlapped.hEvent) }; - }); - - Ok(SplitTunnel { - runtime, - request_tx, - event_thread: Some(event_thread), - quit_event, - _route_change_callback: None, - daemon_tx, - async_path_update_in_progress: Arc::new(AtomicBool::new(false)), - }) + log::trace!( + "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", + event_str, + process_id, + reason, + image, + ); + } + EventBody::SplittingError { process_id, image } => { + log::error!( + "FAILED: {}:\n\tpid: {}\n\timage: {:?}", + event_str, + process_id, + image, + ); + } + EventBody::ErrorMessage { status, message } => { + log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) + } + } } fn spawn_request_thread( @@ -464,7 +501,7 @@ impl SplitTunnel { request_tx .send((request, response_tx)) - .map_err(|_| Error::RequestThreadDown)?; + .map_err(|_| Error::SplitTunnelDown)?; response_rx .recv_timeout(REQUEST_TIMEOUT) @@ -506,8 +543,8 @@ impl SplitTunnel { let wait_task = move || { request_tx .send((request, response_tx)) - .map_err(|_| Error::RequestThreadDown)?; - response_rx.recv().map_err(|_| Error::RequestThreadDown)? + .map_err(|_| Error::SplitTunnelDown)?; + response_rx.recv().map_err(|_| Error::SplitTunnelDown)? }; let in_progress = self.async_path_update_in_progress.clone(); self.runtime.spawn_blocking(move || { @@ -562,12 +599,19 @@ impl SplitTunnel { self._route_change_callback = None; self.send_request(Request::RegisterIps(InterfaceAddresses::default())) } + + /// Returns a handle used for interacting with the split tunnel module. + pub fn handle(&self) -> SplitTunnelHandle { + SplitTunnelHandle { + excluded_processes: Arc::downgrade(&self.excluded_processes), + } + } } impl Drop for SplitTunnel { fn drop(&mut self) { if let Some(_event_thread) = self.event_thread.take() { - if let Err(error) = self.quit_event.set_event() { + if let Err(error) = self.quit_event.set() { log::error!( "{}", error.display_chain_with_msg("Failed to close ST event thread") diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs index 7000211bba..1cfcce15a5 100644 --- a/talpid-core/src/split_tunnel/windows/windows.rs +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -13,15 +13,17 @@ use std::{ }; use winapi::{ shared::{ - minwindef::{DWORD, FALSE, FILETIME, TRUE}, + minwindef::{BOOL, DWORD, FALSE, FILETIME, TRUE}, ntdef::ULARGE_INTEGER, winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES}, }, um::{ fileapi::{GetFinalPathNameByHandleW, QueryDosDeviceW}, handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + minwinbase::OVERLAPPED, processthreadsapi::{GetProcessTimes, OpenProcess}, psapi::K32GetProcessImageFileNameW, + synchapi::{CreateEventW, SetEvent}, tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W}, winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION}, }, @@ -313,3 +315,98 @@ fn get_process_device_path_inner( Ok(OsStringExt::from_wide(&buffer)) } + +/// Abstraction over `OVERLAPPED`, which is used for async I/O. +pub struct Overlapped { + overlapped: OVERLAPPED, + event: Option<Event>, +} + +unsafe impl Send for Overlapped {} +unsafe impl Sync for Overlapped {} + +impl Overlapped { + /// Creates an `OVERLAPPED` object with `hEvent` set. + pub fn new(event: Option<Event>) -> io::Result<Self> { + let mut overlapped = Overlapped { + overlapped: unsafe { mem::zeroed() }, + event: None, + }; + overlapped.set_event(event); + Ok(overlapped) + } + + /// Borrows the underlying `OVERLAPPED` object. + pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED { + &mut self.overlapped + } + + /// Returns a reference to the associated event. + pub fn get_event(&self) -> Option<&Event> { + self.event.as_ref() + } + + /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`) + fn set_event(&mut self, event: Option<Event>) { + match event { + Some(event) => { + let raw_event = event.0; + self.overlapped.hEvent = raw_event; + self.event = Some(event); + } + None => { + self.overlapped.hEvent = ptr::null_mut(); + self.event = None; + } + } + } +} + +/// Abstraction over a Windows event object. +pub struct Event(RawHandle); + +unsafe impl Send for Event {} +unsafe impl Sync for Event {} + +impl Event { + pub fn new(manual_reset: bool, initial_state: bool) -> io::Result<Self> { + let event = unsafe { + CreateEventW( + ptr::null_mut(), + bool_to_winbool(manual_reset), + bool_to_winbool(initial_state), + ptr::null(), + ) + }; + if event == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(Self(event)) + } + + pub fn set(&self) -> io::Result<()> { + if unsafe { SetEvent(self.0) } == FALSE { + return Err(io::Error::last_os_error()); + } + Ok(()) + } +} + +impl AsRawHandle for Event { + fn as_raw_handle(&self) -> RawHandle { + self.0 + } +} + +impl Drop for Event { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + +const fn bool_to_winbool(val: bool) -> BOOL { + match val { + true => TRUE, + false => FALSE, + } +} diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 6c388a4680..e55cc9c0d5 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -116,7 +116,7 @@ pub async fn spawn( #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "macos")] exclusion_gid: u32, #[cfg(target_os = "android")] android_context: AndroidContext, -) -> Result<(Arc<mpsc::UnboundedSender<TunnelCommand>>, JoinHandle), Error> { +) -> Result<TunnelStateMachineHandle, Error> { let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); @@ -150,6 +150,9 @@ pub async fn spawn( ) .await?; + #[cfg(windows)] + let split_tunnel = state_machine.shared_values.split_tunnel.handle(); + tokio::task::spawn_blocking(move || { state_machine.run(state_change_listener); if shutdown_tx.send(()).is_err() { @@ -157,7 +160,12 @@ pub async fn spawn( } }); - Ok((command_tx, JoinHandle { shutdown_rx })) + Ok(TunnelStateMachineHandle { + command_tx, + shutdown_rx, + #[cfg(windows)] + split_tunnel, + }) } /// Representation of external commands for the tunnel state machine. @@ -590,18 +598,34 @@ state_wrapper! { } } -/// Handle used to wait for the tunnel state machine to shut down. -pub struct JoinHandle { +/// Handle used to control the tunnel state machine. +pub struct TunnelStateMachineHandle { + command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>, shutdown_rx: oneshot::Receiver<()>, + #[cfg(windows)] + split_tunnel: split_tunnel::SplitTunnelHandle, } -impl JoinHandle { +impl TunnelStateMachineHandle { /// Waits for the tunnel state machine to shut down. /// This may fail after a timeout of `TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT`. pub async fn try_join(self) { + drop(self.command_tx); + match tokio::time::timeout(TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, self.shutdown_rx).await { Ok(_) => log::info!("Tunnel state machine shut down"), Err(_) => log::error!("Tunnel state machine did not shut down gracefully"), } } + + /// Returns tunnel command sender. + pub fn command_tx(&self) -> &Arc<mpsc::UnboundedSender<TunnelCommand>> { + &self.command_tx + } + + /// Returns split tunnel object handle. + #[cfg(windows)] + pub fn split_tunnel(&self) -> &split_tunnel::SplitTunnelHandle { + &self.split_tunnel + } } |
