summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-06-14 14:03:02 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-06-14 14:03:02 +0200
commit05bc925914dc31fa74fb4b85b5d74061333f3b37 (patch)
treee2e800efbbf6d94dd63b25c38097bbac8c825b40
parent463042e83845efd09f61cd99feb65db58240bb32 (diff)
parent0ce130b6faec1272cb7708e4bf5f315e3e287237 (diff)
downloadmullvadvpn-05bc925914dc31fa74fb4b85b5d74061333f3b37.tar.xz
mullvadvpn-05bc925914dc31fa74fb4b85b5d74061333f3b37.zip
Merge branch 'add-win-st-pid-list'
-rw-r--r--CHANGELOG.md3
-rw-r--r--mullvad-cli/src/cmds/split_tunnel/windows.rs41
-rw-r--r--mullvad-daemon/src/lib.rs43
-rw-r--r--mullvad-daemon/src/management_interface.rs36
-rw-r--r--mullvad-management-interface/proto/management_interface.proto11
-rw-r--r--talpid-core/src/split_tunnel/windows/driver.rs355
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs414
-rw-r--r--talpid-core/src/split_tunnel/windows/windows.rs99
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs34
9 files changed, 656 insertions, 380 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2310942811..68b9644f75 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -30,6 +30,9 @@ Line wrap the file at 100 chars. Th
#### Linux
- Automatically attempt to detect and set the correct MTU for Wireguard tunnels.
+#### Windows
+- Add CLI command for listing excluded processes.
+
### Changed
- Display consistent colors regardless of monitor color profile.
diff --git a/mullvad-cli/src/cmds/split_tunnel/windows.rs b/mullvad-cli/src/cmds/split_tunnel/windows.rs
index 0bc8d468bc..9766607ce1 100644
--- a/mullvad-cli/src/cmds/split_tunnel/windows.rs
+++ b/mullvad-cli/src/cmds/split_tunnel/windows.rs
@@ -1,3 +1,5 @@
+use std::{ffi::OsStr, path::Path};
+
use crate::{new_rpc_client, Command, Result};
pub struct SplitTunnel;
@@ -23,11 +25,13 @@ impl Command for SplitTunnel {
),
)
.subcommand(clap::App::new("get").about("Display the split tunnel status"))
+ .subcommand(create_pid_subcommand())
}
async fn run(&self, matches: &clap::ArgMatches) -> Result<()> {
match matches.subcommand() {
Some(("app", matches)) => Self::handle_app_subcommand(matches).await,
+ Some(("pid", matches)) => Self::handle_pid_subcommand(matches).await,
Some(("get", _)) => self.get().await,
Some(("set", matches)) => {
let enabled = matches.value_of("policy").expect("missing policy");
@@ -50,6 +54,16 @@ fn create_app_subcommand() -> clap::App<'static> {
.subcommand(clap::App::new("clear"))
}
+fn create_pid_subcommand() -> clap::App<'static> {
+ clap::App::new("pid")
+ .about("Manages processes (PIDs) excluded from the tunnel")
+ .setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::App::new("list")
+ .about("List processes that are currently being excluded, i.e. their PIDs, as well as whether \
+ they are excluded because of their executable paths or because they're subprocesses of \
+ such processes"))
+}
+
impl SplitTunnel {
async fn handle_app_subcommand(matches: &clap::ArgMatches) -> Result<()> {
match matches.subcommand() {
@@ -91,6 +105,33 @@ impl SplitTunnel {
}
}
+ async fn handle_pid_subcommand(matches: &clap::ArgMatches) -> Result<()> {
+ match matches.subcommand() {
+ Some(("list", _)) => {
+ let processes = new_rpc_client()
+ .await?
+ .get_excluded_processes(())
+ .await?
+ .into_inner();
+
+ for process in &processes.processes {
+ let subproc = if process.inherited { "subprocess" } else { "" };
+ println!(
+ "{:<7}{subproc:<12}{}",
+ process.pid,
+ Path::new(&process.image)
+ .file_name()
+ .unwrap_or(OsStr::new("unknown"))
+ .to_string_lossy()
+ );
+ }
+
+ Ok(())
+ }
+ _ => unreachable!("unhandled subcommand"),
+ }
+ }
+
async fn set(&self, enabled: bool) -> Result<()> {
let mut rpc = new_rpc_client().await?;
rpc.set_split_tunnel_state(enabled).await?;
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index c013361925..95039cd662 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -66,7 +66,7 @@ use std::{
use talpid_core::split_tunnel;
use talpid_core::{
mpsc::Sender,
- tunnel_state_machine::{self, TunnelCommand},
+ tunnel_state_machine::{self, TunnelCommand, TunnelStateMachineHandle},
};
#[cfg(target_os = "android")]
use talpid_types::android::AndroidContext;
@@ -265,9 +265,12 @@ pub enum DaemonCommand {
/// Clear list of apps to exclude from the tunnel
#[cfg(windows)]
ClearSplitTunnelApps(ResponseTx<(), Error>),
- /// Disable split tunnel
+ /// Enable or disable split tunneling
#[cfg(windows)]
SetSplitTunnelState(ResponseTx<(), Error>, bool),
+ /// Returns all processes currently being excluded from the tunnel
+ #[cfg(windows)]
+ GetSplitTunnelProcesses(ResponseTx<Vec<split_tunnel::ExcludedProcess>, split_tunnel::Error>),
/// Toggle wireguard-nt on or off
#[cfg(target_os = "windows")]
UseWireGuardNt(ResponseTx<(), Error>, bool),
@@ -503,7 +506,6 @@ pub trait EventListener {
}
pub struct Daemon<L: EventListener> {
- tunnel_command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>,
tunnel_state: TunnelState,
target_state: PersistentTargetState,
state: DaemonExecutionState,
@@ -526,7 +528,7 @@ pub struct Daemon<L: EventListener> {
parameters_generator: tunnel::ParametersGenerator,
app_version_info: Option<AppVersionInfo>,
shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
- tunnel_state_machine_handle: tunnel_state_machine::JoinHandle,
+ tunnel_state_machine_handle: TunnelStateMachineHandle,
#[cfg(target_os = "windows")]
volume_update_tx: mpsc::UnboundedSender<()>,
}
@@ -647,7 +649,7 @@ where
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
- let (tunnel_command_tx, tunnel_state_machine_handle) = tunnel_state_machine::spawn(
+ let tunnel_state_machine_handle = tunnel_state_machine::spawn(
tunnel_state_machine::InitialTunnelState {
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
@@ -672,7 +674,8 @@ where
.await
.map_err(Error::TunnelError)?;
- endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx));
+ endpoint_updater
+ .set_tunnel_command_tx(Arc::downgrade(tunnel_state_machine_handle.command_tx()));
api::forward_offline_state(api_availability.clone(), offline_state_rx);
@@ -703,7 +706,6 @@ where
relay_list_updater.update().await;
let daemon = Daemon {
- tunnel_command_tx,
tunnel_state: TunnelState::Disconnected,
target_state,
state: DaemonExecutionState::Running,
@@ -789,7 +791,7 @@ where
L,
Vec<Pin<Box<dyn Future<Output = ()>>>>,
mullvad_api::Runtime,
- tunnel_state_machine::JoinHandle,
+ TunnelStateMachineHandle,
) {
let Daemon {
event_listener,
@@ -915,12 +917,12 @@ where
fn schedule_reconnect(&mut self, delay: Duration) {
self.unschedule_reconnect();
- let tunnel_command_tx = self.tx.to_specialized_sender();
+ let daemon_command_tx = self.tx.to_specialized_sender();
let (future, abort_handle) = abortable(Box::pin(async move {
tokio::time::sleep(delay).await;
log::debug!("Attempting to reconnect");
let (tx, rx) = oneshot::channel();
- let _ = tunnel_command_tx.send(DaemonCommand::Reconnect(tx));
+ let _ = daemon_command_tx.send(DaemonCommand::Reconnect(tx));
// suppress "unable to send" warning:
let _ = rx.await;
}));
@@ -1013,6 +1015,8 @@ where
ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await,
#[cfg(windows)]
SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await,
+ #[cfg(windows)]
+ GetSplitTunnelProcesses(tx) => self.on_get_split_tunnel_processes(tx),
#[cfg(target_os = "windows")]
UseWireGuardNt(tx, state) => self.on_use_wireguard_nt(tx, state).await,
#[cfg(target_os = "windows")]
@@ -1674,6 +1678,20 @@ where
}
#[cfg(windows)]
+ fn on_get_split_tunnel_processes(
+ &self,
+ tx: ResponseTx<Vec<split_tunnel::ExcludedProcess>, split_tunnel::Error>,
+ ) {
+ Self::oneshot_send(
+ tx,
+ self.tunnel_state_machine_handle
+ .split_tunnel()
+ .get_processes(),
+ "get_split_tunnel_processes response",
+ );
+ }
+
+ #[cfg(windows)]
async fn on_use_wireguard_nt(&mut self, tx: ResponseTx<(), Error>, state: bool) {
let save_result = self
.settings
@@ -2198,8 +2216,9 @@ where
}
}
- fn send_tunnel_command(&mut self, command: TunnelCommand) {
- self.tunnel_command_tx
+ fn send_tunnel_command(&self, command: TunnelCommand) {
+ self.tunnel_state_machine_handle
+ .command_tx()
.unbounded_send(command)
.expect("Tunnel state machine has stopped");
}
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index a7478eb048..7999fd9c55 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -768,6 +768,42 @@ impl ManagementService for ManagementServiceImpl {
}
#[cfg(windows)]
+ async fn get_excluded_processes(
+ &self,
+ _: Request<()>,
+ ) -> ServiceResult<types::ExcludedProcessList> {
+ log::debug!("get_excluded_processes");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::GetSplitTunnelProcesses(tx))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_split_tunnel_error)
+ .map(|processes| {
+ Response::new(types::ExcludedProcessList {
+ processes: processes
+ .into_iter()
+ .map(|process| types::ExcludedProcess {
+ // FIXME: This is necessarily 32 bits or less
+ pid: u32::try_from(process.pid).unwrap(),
+ image: process.image.into_os_string().to_string_lossy().to_string(),
+ inherited: process.inherited,
+ })
+ .collect(),
+ })
+ })
+ }
+
+ #[cfg(not(windows))]
+ async fn get_excluded_processes(
+ &self,
+ _: Request<()>,
+ ) -> ServiceResult<types::ExcludedProcessList> {
+ Ok(Response::new(types::ExcludedProcessList {
+ processes: vec![],
+ }))
+ }
+
+ #[cfg(windows)]
async fn set_use_wireguard_nt(&self, request: Request<bool>) -> ServiceResult<()> {
log::debug!("set_use_wireguard_nt");
let state = request.into_inner();
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 9148a10150..2f39f496d6 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -79,6 +79,7 @@ service ManagementService {
rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
+ rpc GetExcludedProcesses(google.protobuf.Empty) returns (ExcludedProcessList) {}
rpc SetUseWireguardNt(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
@@ -476,6 +477,16 @@ message PublicKey {
google.protobuf.Timestamp created = 2;
}
+message ExcludedProcess {
+ uint32 pid = 1;
+ string image = 2;
+ bool inherited = 3;
+}
+
+message ExcludedProcessList {
+ repeated ExcludedProcess processes = 1;
+}
+
message AppVersionInfo {
bool supported = 1;
string latest_stable = 2;
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs
index 70fe90f45d..6b6b2b2ea7 100644
--- a/talpid-core/src/split_tunnel/windows/driver.rs
+++ b/talpid-core/src/split_tunnel/windows/driver.rs
@@ -1,6 +1,6 @@
use super::windows::{
- get_device_path, get_process_creation_time, get_process_device_path, open_process,
- ProcessAccess, ProcessSnapshot,
+ get_device_path, get_process_creation_time, get_process_device_path, open_process, Event,
+ Overlapped, ProcessAccess, ProcessSnapshot,
};
use crate::windows::as_uninit_byte_slice;
use bitflags::bitflags;
@@ -33,12 +33,14 @@ use winapi::{
},
},
um::{
- handleapi::CloseHandle,
ioapiset::{DeviceIoControl, GetOverlappedResult},
minwinbase::OVERLAPPED,
- synchapi::{CreateEventW, WaitForSingleObject},
+ synchapi::{WaitForMultipleObjects, WaitForSingleObject},
tlhelp32::TH32CS_SNAPPROCESS,
- winbase::{FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_FAILED, WAIT_OBJECT_0},
+ winbase::{
+ FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED,
+ WAIT_OBJECT_0,
+ },
winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER},
},
};
@@ -254,24 +256,17 @@ impl DeviceHandle {
}
fn initialize(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::Initialize as u32,
- None,
- 0,
- None,
- )?;
+ device_io_control(self, DriverIoctlCode::Initialize as u32, None, 0)?;
Ok(())
}
fn register_processes(&self) -> io::Result<()> {
let process_tree_buffer = serialize_process_tree(build_process_tree()?)?;
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::RegisterProcesses as u32,
Some(&process_tree_buffer),
0,
- None,
)?;
Ok(())
}
@@ -327,11 +322,10 @@ impl DeviceHandle {
let buffer = as_uninit_byte_slice(&addresses);
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::RegisterIpAddresses as u32,
Some(buffer),
0,
- None,
)?;
Ok(())
@@ -339,11 +333,10 @@ impl DeviceHandle {
pub fn get_driver_state(&self) -> io::Result<DriverState> {
let buffer = device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::GetState as u32,
None,
size_of::<u64>() as u32,
- None,
)?
.unwrap();
@@ -381,36 +374,22 @@ impl DeviceHandle {
let config = make_process_config(&device_paths);
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::SetConfiguration as u32,
Some(&config),
0,
- None,
)?;
Ok(())
}
pub fn clear_config(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::ClearConfiguration as u32,
- None,
- 0,
- None,
- )?;
-
+ device_io_control(self, DriverIoctlCode::ClearConfiguration as u32, None, 0)?;
Ok(())
}
fn reset(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::Reset as u32,
- None,
- 0,
- None,
- )?;
+ device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?;
Ok(())
}
}
@@ -680,18 +659,10 @@ struct ErrorMessageEventHeader {
/// # Panics
///
/// This may panic if `buffer` contains invalid data.
-pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
+pub fn parse_event_buffer(buffer: &[u8]) -> Result<(EventId, EventBody), UnknownEventId> {
// SAFETY: This panics if `buffer` is too small.
let raw_event_id: u32 = unsafe { deserialize_buffer(&buffer[0..mem::size_of::<u32>()]) };
- let _event_id = EventId::try_from(raw_event_id)
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to parse ST event buffer")
- );
- error
- })
- .ok()?;
+ let _event_id = EventId::try_from(raw_event_id)?;
// SAFETY: The event id is known to be valid.
let event_header: EventHeader =
@@ -711,7 +682,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
[string_byte_offset..(string_byte_offset + event.image_name_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::SplittingEvent {
process_id: event.process_id,
@@ -736,7 +707,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
[string_byte_offset..(string_byte_offset + event.image_name_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::SplittingError {
process_id: event.process_id,
@@ -757,7 +728,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
..(string_byte_offset + event.error_message_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::ErrorMessage {
status: event.status,
@@ -768,170 +739,116 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
}
}
-/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0.
+/// Send an IOCTL code to the given device handle, and wait for the result.
+///
+/// `input` specifies an optional buffer for sending data.
+///
+/// Upon success, a buffer containing at most `output_size` bytes is returned,
+/// or `None` if no bytes were read.
pub fn device_io_control(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
input: Option<&[MaybeUninit<u8>]>,
output_size: u32,
- timeout: Option<Duration>,
) -> Result<Option<Vec<u8>>, io::Error> {
- struct HandleOwner {
- handle: RawHandle,
- }
- impl Drop for HandleOwner {
- fn drop(&mut self) {
- unsafe { CloseHandle(self.handle) };
- }
- }
-
- let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() };
- overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+ let mut overlapped = Overlapped::new(Some(Event::new(true, false)?))?;
- if overlapped.hEvent == ptr::null_mut() {
- return Err(io::Error::last_os_error());
- }
-
- let _handle_owner = HandleOwner {
- handle: overlapped.hEvent,
- };
-
- let mut out_buffer = if output_size > 0 {
- Some(Vec::with_capacity(output_size as usize))
+ let mut buffer = vec![];
+ let out_buffer = if output_size > 0 {
+ buffer.resize(
+ usize::try_from(output_size).expect("u32 must be no larger than usize"),
+ 0u8,
+ );
+ Some(&mut buffer[..])
} else {
None
};
- device_io_control_buffer(
- device,
- ioctl_code,
- input,
- out_buffer.as_mut(),
- &overlapped,
- timeout,
- )
- .map(|()| out_buffer)
+ let bytes_read =
+ device_io_control_buffer(device, ioctl_code, input, out_buffer, &mut overlapped)?;
+ if bytes_read > 0 {
+ buffer.truncate(usize::try_from(bytes_read).expect("u32 must be no larger than usize"));
+ return Ok(Some(buffer));
+ }
+ Ok(None)
}
-/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// Upon success, `output` buffer will contain at most `output.capacity()` bytes of data.
+/// Send an IOCTL code to the given device handle, and wait for the result.
+///
+/// `input` specifies an optional buffer for sending data.
+///
+/// Upon success, `output` buffer will contain at most `output.len()` bytes of data,
+/// and the function returns the number of bytes read.
+///
+/// # Panics
+///
+/// This function will panic if `overlapped` does not contain an event.
pub fn device_io_control_buffer(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
input: Option<&[MaybeUninit<u8>]>,
- mut output: Option<&mut Vec<u8>>,
- overlapped: &OVERLAPPED,
- timeout: Option<Duration>,
-) -> Result<(), io::Error> {
- let input_ptr = match input {
- Some(input) => input as *const _ as *mut _,
- None => ptr::null_mut(),
- };
- let input_len = input.map(|input| input.len()).unwrap_or(0);
-
+ output: Option<&mut [u8]>,
+ overlapped: &mut Overlapped,
+) -> Result<u32, io::Error> {
+ let output_len = output.as_ref().map(|output| output.len()).unwrap_or(0);
+ let output_len = u32::try_from(output_len).map_err(|_error| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "the output buffer is too large",
+ )
+ })?;
let out_ptr = match output {
- Some(ref mut output) => output.as_mut_ptr() as *mut _,
+ Some(output) => output as *mut _ as *mut _,
None => ptr::null_mut(),
};
- let output_size = if let Some(ref output) = output {
- output.capacity()
- } else {
- 0
- };
-
- let event = overlapped.hEvent;
-
- let mut returned_bytes = 0u32;
- let overlapped = overlapped as *const _ as *mut _;
-
- let result = unsafe {
- DeviceIoControl(
- device as *mut _,
+ // SAFETY: `out_ptr` will be valid until the result has been obtained.
+ unsafe {
+ device_io_control_buffer_async(
+ device,
ioctl_code,
- input_ptr,
- input_len as u32,
+ input,
out_ptr,
- output_size as u32,
- &mut returned_bytes,
- overlapped,
- )
- };
-
- if result != 0 {
- return Err(io::Error::new(
- io::ErrorKind::Other,
- "Expected pending operation",
- ));
- }
-
- let last_error = io::Error::last_os_error();
- if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) {
- return Err(last_error);
- }
-
- let timeout = timeout
- .map(|timeout| timeout.as_millis() as u32)
- .unwrap_or(INFINITE);
- let result = unsafe { WaitForSingleObject(event, timeout) };
- match result {
- WAIT_FAILED => return Err(io::Error::last_os_error()),
- WAIT_ABANDONED => return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")),
- WAIT_OBJECT_0 => (),
- error => return Err(io::Error::from_raw_os_error(error as i32)),
- }
-
- let result =
- unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, FALSE) };
-
- if result == 0 {
- return Err(io::Error::last_os_error());
- }
-
- if let Some(ref mut output) = output {
- unsafe { output.set_len(returned_bytes as usize) };
+ output_len,
+ overlapped.as_mut_ptr(),
+ )?;
}
-
- Ok(())
+ get_overlapped_result(device, overlapped)
}
/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// The result must be obtained using `GetOverlappedResult[Ex]`.
+///
+/// `input` specifies an optional buffer for sending data.
+/// `output_ptr` specifies an optional buffer for receiving data.
+///
+/// Obtain the result using [get_overlapped_result].
+///
+/// # Safety
+///
+/// * `output_ptr` must either be null or a valid buffer of `output_len` bytes. It must remain valid
+/// until the overlapped operation has completed.
pub unsafe fn device_io_control_buffer_async(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
- mut output: Option<&mut Vec<u8>>,
- input: Option<&[u8]>,
- overlapped: &OVERLAPPED,
+ input: Option<&[MaybeUninit<u8>]>,
+ output_ptr: *mut u8,
+ output_len: u32,
+ overlapped: *mut OVERLAPPED,
) -> Result<(), io::Error> {
let input_ptr = match input {
- Some(input) => input as *const _ as *mut _,
+ Some(input) => input.as_ptr() as *mut _,
None => ptr::null_mut(),
};
let input_len = input.map(|input| input.len()).unwrap_or(0);
- let out_ptr = match output {
- Some(ref mut output) => output.as_mut_ptr() as *mut _,
- None => ptr::null_mut(),
- };
- let output_size = if let Some(ref output) = output {
- output.capacity()
- } else {
- 0
- };
-
- let overlapped = overlapped as *const _ as *mut _;
-
let result = DeviceIoControl(
- device as *mut _,
+ device.as_raw_handle(),
ioctl_code,
input_ptr,
- input_len as u32,
- out_ptr,
- output_size as u32,
+ u32::try_from(input_len).map_err(|_error| {
+ io::Error::new(io::ErrorKind::InvalidInput, "the input buffer is too large")
+ })?,
+ output_ptr as *mut _,
+ output_len,
ptr::null_mut(),
overlapped,
);
@@ -951,6 +868,90 @@ pub unsafe fn device_io_control_buffer_async(
Ok(())
}
+/// Retrieves the result of an overlapped operation. On success, this returns
+/// the number of bytes transferred. For device I/O, this is the number of bytes
+/// written to the output buffer.
+///
+/// # Panics
+///
+/// This function will panic if `overlapped` does not contain an event.
+pub fn get_overlapped_result(
+ device: &DeviceHandle,
+ overlapped: &mut Overlapped,
+) -> io::Result<u32> {
+ let event = overlapped.get_event().unwrap();
+
+ // SAFETY: This is a valid event object.
+ unsafe { wait_for_single_object(event.as_raw_handle(), None) }?;
+
+ // SAFETY: The handle and overlapped object are valid.
+ let mut returned_bytes = 0u32;
+ let result = unsafe {
+ GetOverlappedResult(
+ device.as_raw_handle(),
+ overlapped.as_mut_ptr(),
+ &mut returned_bytes,
+ FALSE,
+ )
+ };
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(returned_bytes)
+}
+
+/// Waits for an object to be signaled, or until a timeout interval has elapsed.
+///
+/// # Safety
+///
+/// * `object` must be a valid object that can be signaled, such as an event object.
+pub unsafe fn wait_for_single_object(
+ object: RawHandle,
+ timeout: Option<Duration>,
+) -> io::Result<()> {
+ let timeout = match timeout {
+ Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| {
+ io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long")
+ })?,
+ None => INFINITE,
+ };
+ let result = WaitForSingleObject(object, timeout);
+ match result {
+ WAIT_OBJECT_0 => Ok(()),
+ WAIT_FAILED => Err(io::Error::last_os_error()),
+ WAIT_ABANDONED => Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")),
+ error => Err(io::Error::from_raw_os_error(error as i32)),
+ }
+}
+
+/// Waits for one or several objects to be signaled. On success, this returns a pointer to an
+/// object in `objects` that was signaled.
+///
+/// # Safety
+///
+/// * `objects` must be a slice of valid objects that can be signaled, such as event objects.
+pub unsafe fn wait_for_multiple_objects(
+ objects: &[RawHandle],
+ wait_all: bool,
+) -> io::Result<RawHandle> {
+ let objects_len = u32::try_from(objects.len())
+ .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?;
+ let result = WaitForMultipleObjects(
+ objects_len,
+ objects.as_ptr(),
+ if wait_all { TRUE } else { FALSE },
+ INFINITE,
+ );
+ let signaled_index = if result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + objects_len {
+ result - WAIT_OBJECT_0
+ } else if result >= WAIT_ABANDONED_0 && result < WAIT_ABANDONED_0 + objects_len {
+ return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex"));
+ } else {
+ return Err(io::Error::last_os_error());
+ };
+ Ok(objects[usize::try_from(signaled_index).expect("usize must be larger than u32")])
+}
+
/// Reads the value from `buffer`, zeroing any remaining bytes.
///
/// # Safety
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index 5b5b29e0cc..63a56be473 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -12,29 +12,20 @@ use crate::{
};
use futures::channel::{mpsc, oneshot};
use std::{
+ collections::HashMap,
convert::TryFrom,
ffi::{OsStr, OsString},
- io, mem,
+ io,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
- os::windows::io::{AsRawHandle, RawHandle},
- ptr,
+ os::windows::io::AsRawHandle,
+ path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
- mpsc as sync_mpsc, Arc, Mutex, Weak,
+ mpsc as sync_mpsc, Arc, Mutex, RwLock, Weak,
},
time::Duration,
};
use talpid_types::{tunnel::ErrorStateCause, ErrorExt};
-use winapi::{
- shared::minwindef::{FALSE, TRUE},
- um::{
- handleapi::CloseHandle,
- ioapiset::GetOverlappedResult,
- minwinbase::OVERLAPPED,
- synchapi::{CreateEventW, SetEvent, WaitForMultipleObjects, WaitForSingleObject},
- winbase::{INFINITE, WAIT_ABANDONED_0, WAIT_OBJECT_0},
- },
-};
const DRIVER_EVENT_BUFFER_SIZE: usize = 2048;
const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123);
@@ -84,8 +75,8 @@ pub enum Error {
RequestThreadStuck,
/// The request handling thread is down
- #[error(display = "The ST request thread is down")]
- RequestThreadDown,
+ #[error(display = "The split tunnel monitor is down")]
+ SplitTunnelDown,
/// Failed to start the NTFS reparse point monitor
#[error(display = "Failed to start path monitor")]
@@ -101,36 +92,13 @@ pub struct SplitTunnel {
runtime: tokio::runtime::Handle,
request_tx: RequestTx,
event_thread: Option<std::thread::JoinHandle<()>>,
- quit_event: Arc<QuitEvent>,
+ quit_event: Arc<windows::Event>,
+ excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
_route_change_callback: Option<WinNetCallbackHandle>,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
async_path_update_in_progress: Arc<AtomicBool>,
}
-struct QuitEvent(RawHandle);
-
-unsafe impl Send for QuitEvent {}
-unsafe impl Sync for QuitEvent {}
-
-impl QuitEvent {
- fn new() -> Self {
- Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) })
- }
-
- fn set_event(&self) -> io::Result<()> {
- if unsafe { SetEvent(self.0) } == 0 {
- return Err(io::Error::last_os_error());
- }
- Ok(())
- }
-}
-
-impl Drop for QuitEvent {
- fn drop(&mut self) {
- unsafe { CloseHandle(self.0) };
- }
-}
-
enum Request {
SetPaths(Vec<OsString>),
RegisterIps(InterfaceAddresses),
@@ -148,12 +116,43 @@ struct InterfaceAddresses {
internet_ipv6: Option<Ipv6Addr>,
}
-struct EventThreadContext {
- handle: Arc<driver::DeviceHandle>,
- event_overlapped: OVERLAPPED,
- quit_event: Arc<QuitEvent>,
+/// Represents a process that is being excluded from the tunnel.
+#[derive(Debug, Clone)]
+pub struct ExcludedProcess {
+ /// Process identifier.
+ pub pid: u32,
+ /// Path to the image that this process is an instance of.
+ pub image: PathBuf,
+ /// If true, then the process is split because its parent was split,
+ /// not due to its path being in the config.
+ pub inherited: bool,
+}
+
+/// Cloneable handle for interacting with the split tunnel module.
+#[derive(Debug, Clone)]
+pub struct SplitTunnelHandle {
+ excluded_processes: Weak<RwLock<HashMap<usize, ExcludedProcess>>>,
+}
+
+impl SplitTunnelHandle {
+ /// Return processes that are currently being excluded, including
+ /// their pids, paths, and reason for being excluded.
+ pub fn get_processes(&self) -> Result<Vec<ExcludedProcess>, Error> {
+ let processes = self
+ .excluded_processes
+ .upgrade()
+ .ok_or(Error::SplitTunnelDown)?;
+ let processes = processes.read().unwrap();
+ Ok(processes.values().cloned().collect())
+ }
+}
+
+enum EventResult {
+ /// Result containing the next event.
+ Event(driver::EventId, driver::EventBody),
+ /// Quit event was signaled.
+ Quit,
}
-unsafe impl Send for EventThreadContext {}
impl SplitTunnel {
/// Initialize the split tunnel device.
@@ -164,169 +163,207 @@ impl SplitTunnel {
) -> Result<Self, Error> {
let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?;
- let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() };
- event_overlapped.hEvent =
- unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
- if event_overlapped.hEvent == ptr::null_mut() {
- return Err(Error::EventThreadError(io::Error::last_os_error()));
- }
-
- let quit_event = Arc::new(QuitEvent::new());
-
- let event_context = EventThreadContext {
- handle: handle.clone(),
- event_overlapped,
- quit_event: quit_event.clone(),
- };
+ let excluded_processes = Arc::new(RwLock::new(HashMap::new()));
+ let (event_thread, quit_event) =
+ Self::spawn_event_listener(handle, excluded_processes.clone())?;
- let event_thread = std::thread::spawn(move || {
- use driver::{EventBody, EventId};
+ Ok(SplitTunnel {
+ runtime,
+ request_tx,
+ event_thread: Some(event_thread),
+ quit_event,
+ _route_change_callback: None,
+ daemon_tx,
+ async_path_update_in_progress: Arc::new(AtomicBool::new(false)),
+ excluded_processes,
+ })
+ }
- // Take ownership of the entire struct (Rust 2021 edition change)
- let _ = &event_context;
+ /// Spawns an event loop thread that processes events from the driver service.
+ fn spawn_event_listener(
+ handle: Arc<driver::DeviceHandle>,
+ excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
+ ) -> Result<(std::thread::JoinHandle<()>, Arc<windows::Event>), Error> {
+ let mut event_overlapped = windows::Overlapped::new(Some(
+ windows::Event::new(true, false).map_err(Error::EventThreadError)?,
+ ))
+ .map_err(Error::EventThreadError)?;
- let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE);
- let mut returned_bytes = 0u32;
+ let quit_event =
+ Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?);
+ let quit_event_copy = quit_event.clone();
- let event_objects = [
- event_context.event_overlapped.hEvent,
- event_context.quit_event.0,
- ];
+ let event_thread = std::thread::spawn(move || {
+ let mut data_buffer = vec![];
loop {
- if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } {
- // Quit event was signaled
- break;
- }
+ // Wait until either the next event is received or the quit event is signaled.
+ let (event_id, event_body) = match Self::fetch_next_event(
+ &handle,
+ &quit_event,
+ &mut event_overlapped,
+ &mut data_buffer,
+ ) {
+ Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body),
+ Ok(EventResult::Quit) => break,
+ Err(_error) => continue,
+ };
- if let Err(error) = unsafe {
- driver::device_io_control_buffer_async(
- event_context.handle.as_raw_handle(),
- driver::DriverIoctlCode::DequeEvent as u32,
- Some(&mut data_buffer),
- None,
- &event_context.event_overlapped,
- )
- } {
- log::error!(
- "{}",
- error.display_chain_with_msg("device_io_control failed")
- );
- continue;
- }
+ Self::handle_event(event_id, event_body, &excluded_processes);
+ }
- let result = unsafe {
- WaitForMultipleObjects(
- event_objects.len() as u32,
- &event_objects[0],
- FALSE,
- INFINITE,
- )
- };
+ log::debug!("Stopping split tunnel event thread");
+ });
- let signaled_index = if result >= WAIT_OBJECT_0
- && result < WAIT_OBJECT_0 + event_objects.len() as u32
- {
- result - WAIT_OBJECT_0
- } else if result >= WAIT_ABANDONED_0
- && result < WAIT_ABANDONED_0 + event_objects.len() as u32
- {
- result - WAIT_ABANDONED_0
- } else {
- let error = io::Error::last_os_error();
- log::error!(
- "{}",
- error.display_chain_with_msg("WaitForMultipleObjects failed")
- );
+ Ok((event_thread, quit_event_copy))
+ }
- continue;
- };
+ fn fetch_next_event(
+ device: &Arc<driver::DeviceHandle>,
+ quit_event: &windows::Event,
+ overlapped: &mut windows::Overlapped,
+ data_buffer: &mut Vec<u8>,
+ ) -> io::Result<EventResult> {
+ if unsafe {
+ driver::wait_for_single_object(quit_event.as_raw_handle(), Some(Duration::ZERO))
+ }
+ .is_ok()
+ {
+ return Ok(EventResult::Quit);
+ }
- if event_context.quit_event.0 == event_objects[signaled_index as usize] {
- // Quit event was signaled
- break;
- }
+ data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8);
- let result = unsafe {
- GetOverlappedResult(
- event_context.handle.as_raw_handle(),
- &event_context.event_overlapped as *const _ as *mut _,
- &mut returned_bytes,
- TRUE,
- )
- };
+ unsafe {
+ driver::device_io_control_buffer_async(
+ device,
+ driver::DriverIoctlCode::DequeEvent as u32,
+ None,
+ data_buffer.as_mut_ptr(),
+ u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"),
+ overlapped.as_mut_ptr(),
+ )
+ }
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("DeviceIoControl failed to deque event")
+ );
+ error
+ })?;
+
+ let event_objects = [
+ overlapped.get_event().unwrap().as_raw_handle(),
+ quit_event.as_raw_handle(),
+ ];
- if result == 0 {
- let error = io::Error::last_os_error();
+ let signaled_object =
+ unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) }.map_err(
+ |error| {
log::error!(
"{}",
- error.display_chain_with_msg("GetOverlappedResult failed")
+ error.display_chain_with_msg("wait_for_multiple_objects failed")
);
+ error
+ },
+ )?;
- continue;
- }
+ if signaled_object == quit_event.as_raw_handle() {
+ // Quit event was signaled
+ return Ok(EventResult::Quit);
+ }
- unsafe { data_buffer.set_len(returned_bytes as usize) };
+ let returned_bytes =
+ driver::get_overlapped_result(device, overlapped).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("get_overlapped_result failed for dequed event"),
+ );
+ error
+ })?;
- let event = driver::parse_event_buffer(&data_buffer);
+ data_buffer
+ .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32"));
- let (event_id, event_body) = match event {
- Some((event_id, event_body)) => (event_id, event_body),
- None => continue,
- };
+ driver::parse_event_buffer(&data_buffer)
+ .map(|(id, body)| EventResult::Event(id, body))
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to parse ST event buffer")
+ );
+ io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer")
+ })
+ }
- let event_str = match &event_id {
- EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => {
- "Start splitting process"
- }
- EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => {
- "Stop splitting process"
- }
- EventId::ErrorMessage => "ErrorMessage",
- };
+ fn handle_event(
+ event_id: driver::EventId,
+ event_body: driver::EventBody,
+ excluded_processes: &Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
+ ) {
+ use driver::{EventBody, EventId};
- match event_body {
- EventBody::SplittingEvent {
- process_id,
- reason,
- image,
- } => {
- log::trace!(
- "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}",
- event_str,
- process_id,
- reason,
- image,
- );
- }
- EventBody::SplittingError { process_id, image } => {
- log::error!(
- "FAILED: {}:\n\tpid: {}\n\timage: {:?}",
- event_str,
+ let event_str = match &event_id {
+ EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => {
+ "Start splitting process"
+ }
+ EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => {
+ "Stop splitting process"
+ }
+ EventId::ErrorMessage => "ErrorMessage",
+ };
+
+ match event_body {
+ EventBody::SplittingEvent {
+ process_id,
+ reason,
+ image,
+ } => {
+ let mut pids = excluded_processes.write().unwrap();
+ match event_id {
+ EventId::StartSplittingProcess => {
+ if let Some(prev_entry) = pids.get(&process_id) {
+ log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry);
+ }
+ pids.insert(
process_id,
- image,
+ ExcludedProcess {
+ pid: u32::try_from(process_id)
+ .expect("PID should be containable in a DWORD"),
+ image: Path::new(&image).to_path_buf(),
+ inherited: reason
+ .contains(driver::SplittingChangeReason::BY_INHERITANCE),
+ },
);
}
- EventBody::ErrorMessage { status, message } => {
- log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy())
+ EventId::StopSplittingProcess => {
+ if pids.remove(&process_id).is_none() {
+ log::error!("Inconsistent process tree: {process_id} was not found");
+ }
}
+ _ => (),
}
- }
-
- log::debug!("Stopping split tunnel event thread");
- unsafe { CloseHandle(event_context.event_overlapped.hEvent) };
- });
-
- Ok(SplitTunnel {
- runtime,
- request_tx,
- event_thread: Some(event_thread),
- quit_event,
- _route_change_callback: None,
- daemon_tx,
- async_path_update_in_progress: Arc::new(AtomicBool::new(false)),
- })
+ log::trace!(
+ "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}",
+ event_str,
+ process_id,
+ reason,
+ image,
+ );
+ }
+ EventBody::SplittingError { process_id, image } => {
+ log::error!(
+ "FAILED: {}:\n\tpid: {}\n\timage: {:?}",
+ event_str,
+ process_id,
+ image,
+ );
+ }
+ EventBody::ErrorMessage { status, message } => {
+ log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy())
+ }
+ }
}
fn spawn_request_thread(
@@ -464,7 +501,7 @@ impl SplitTunnel {
request_tx
.send((request, response_tx))
- .map_err(|_| Error::RequestThreadDown)?;
+ .map_err(|_| Error::SplitTunnelDown)?;
response_rx
.recv_timeout(REQUEST_TIMEOUT)
@@ -506,8 +543,8 @@ impl SplitTunnel {
let wait_task = move || {
request_tx
.send((request, response_tx))
- .map_err(|_| Error::RequestThreadDown)?;
- response_rx.recv().map_err(|_| Error::RequestThreadDown)?
+ .map_err(|_| Error::SplitTunnelDown)?;
+ response_rx.recv().map_err(|_| Error::SplitTunnelDown)?
};
let in_progress = self.async_path_update_in_progress.clone();
self.runtime.spawn_blocking(move || {
@@ -562,12 +599,19 @@ impl SplitTunnel {
self._route_change_callback = None;
self.send_request(Request::RegisterIps(InterfaceAddresses::default()))
}
+
+ /// Returns a handle used for interacting with the split tunnel module.
+ pub fn handle(&self) -> SplitTunnelHandle {
+ SplitTunnelHandle {
+ excluded_processes: Arc::downgrade(&self.excluded_processes),
+ }
+ }
}
impl Drop for SplitTunnel {
fn drop(&mut self) {
if let Some(_event_thread) = self.event_thread.take() {
- if let Err(error) = self.quit_event.set_event() {
+ if let Err(error) = self.quit_event.set() {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close ST event thread")
diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs
index 7000211bba..1cfcce15a5 100644
--- a/talpid-core/src/split_tunnel/windows/windows.rs
+++ b/talpid-core/src/split_tunnel/windows/windows.rs
@@ -13,15 +13,17 @@ use std::{
};
use winapi::{
shared::{
- minwindef::{DWORD, FALSE, FILETIME, TRUE},
+ minwindef::{BOOL, DWORD, FALSE, FILETIME, TRUE},
ntdef::ULARGE_INTEGER,
winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES},
},
um::{
fileapi::{GetFinalPathNameByHandleW, QueryDosDeviceW},
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ minwinbase::OVERLAPPED,
processthreadsapi::{GetProcessTimes, OpenProcess},
psapi::K32GetProcessImageFileNameW,
+ synchapi::{CreateEventW, SetEvent},
tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W},
winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION},
},
@@ -313,3 +315,98 @@ fn get_process_device_path_inner(
Ok(OsStringExt::from_wide(&buffer))
}
+
+/// Abstraction over `OVERLAPPED`, which is used for async I/O.
+pub struct Overlapped {
+ overlapped: OVERLAPPED,
+ event: Option<Event>,
+}
+
+unsafe impl Send for Overlapped {}
+unsafe impl Sync for Overlapped {}
+
+impl Overlapped {
+ /// Creates an `OVERLAPPED` object with `hEvent` set.
+ pub fn new(event: Option<Event>) -> io::Result<Self> {
+ let mut overlapped = Overlapped {
+ overlapped: unsafe { mem::zeroed() },
+ event: None,
+ };
+ overlapped.set_event(event);
+ Ok(overlapped)
+ }
+
+ /// Borrows the underlying `OVERLAPPED` object.
+ pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED {
+ &mut self.overlapped
+ }
+
+ /// Returns a reference to the associated event.
+ pub fn get_event(&self) -> Option<&Event> {
+ self.event.as_ref()
+ }
+
+ /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`)
+ fn set_event(&mut self, event: Option<Event>) {
+ match event {
+ Some(event) => {
+ let raw_event = event.0;
+ self.overlapped.hEvent = raw_event;
+ self.event = Some(event);
+ }
+ None => {
+ self.overlapped.hEvent = ptr::null_mut();
+ self.event = None;
+ }
+ }
+ }
+}
+
+/// Abstraction over a Windows event object.
+pub struct Event(RawHandle);
+
+unsafe impl Send for Event {}
+unsafe impl Sync for Event {}
+
+impl Event {
+ pub fn new(manual_reset: bool, initial_state: bool) -> io::Result<Self> {
+ let event = unsafe {
+ CreateEventW(
+ ptr::null_mut(),
+ bool_to_winbool(manual_reset),
+ bool_to_winbool(initial_state),
+ ptr::null(),
+ )
+ };
+ if event == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(Self(event))
+ }
+
+ pub fn set(&self) -> io::Result<()> {
+ if unsafe { SetEvent(self.0) } == FALSE {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+}
+
+impl AsRawHandle for Event {
+ fn as_raw_handle(&self) -> RawHandle {
+ self.0
+ }
+}
+
+impl Drop for Event {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.0) };
+ }
+}
+
+const fn bool_to_winbool(val: bool) -> BOOL {
+ match val {
+ true => TRUE,
+ false => FALSE,
+ }
+}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 6c388a4680..e55cc9c0d5 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -116,7 +116,7 @@ pub async fn spawn(
#[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>,
#[cfg(target_os = "macos")] exclusion_gid: u32,
#[cfg(target_os = "android")] android_context: AndroidContext,
-) -> Result<(Arc<mpsc::UnboundedSender<TunnelCommand>>, JoinHandle), Error> {
+) -> Result<TunnelStateMachineHandle, Error> {
let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
@@ -150,6 +150,9 @@ pub async fn spawn(
)
.await?;
+ #[cfg(windows)]
+ let split_tunnel = state_machine.shared_values.split_tunnel.handle();
+
tokio::task::spawn_blocking(move || {
state_machine.run(state_change_listener);
if shutdown_tx.send(()).is_err() {
@@ -157,7 +160,12 @@ pub async fn spawn(
}
});
- Ok((command_tx, JoinHandle { shutdown_rx }))
+ Ok(TunnelStateMachineHandle {
+ command_tx,
+ shutdown_rx,
+ #[cfg(windows)]
+ split_tunnel,
+ })
}
/// Representation of external commands for the tunnel state machine.
@@ -590,18 +598,34 @@ state_wrapper! {
}
}
-/// Handle used to wait for the tunnel state machine to shut down.
-pub struct JoinHandle {
+/// Handle used to control the tunnel state machine.
+pub struct TunnelStateMachineHandle {
+ command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>,
shutdown_rx: oneshot::Receiver<()>,
+ #[cfg(windows)]
+ split_tunnel: split_tunnel::SplitTunnelHandle,
}
-impl JoinHandle {
+impl TunnelStateMachineHandle {
/// Waits for the tunnel state machine to shut down.
/// This may fail after a timeout of `TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT`.
pub async fn try_join(self) {
+ drop(self.command_tx);
+
match tokio::time::timeout(TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, self.shutdown_rx).await {
Ok(_) => log::info!("Tunnel state machine shut down"),
Err(_) => log::error!("Tunnel state machine did not shut down gracefully"),
}
}
+
+ /// Returns tunnel command sender.
+ pub fn command_tx(&self) -> &Arc<mpsc::UnboundedSender<TunnelCommand>> {
+ &self.command_tx
+ }
+
+ /// Returns split tunnel object handle.
+ #[cfg(windows)]
+ pub fn split_tunnel(&self) -> &split_tunnel::SplitTunnelHandle {
+ &self.split_tunnel
+ }
}