summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/split_tunnel/windows/driver.rs355
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs389
-rw-r--r--talpid-core/src/split_tunnel/windows/windows.rs99
3 files changed, 456 insertions, 387 deletions
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs
index 70fe90f45d..6b6b2b2ea7 100644
--- a/talpid-core/src/split_tunnel/windows/driver.rs
+++ b/talpid-core/src/split_tunnel/windows/driver.rs
@@ -1,6 +1,6 @@
use super::windows::{
- get_device_path, get_process_creation_time, get_process_device_path, open_process,
- ProcessAccess, ProcessSnapshot,
+ get_device_path, get_process_creation_time, get_process_device_path, open_process, Event,
+ Overlapped, ProcessAccess, ProcessSnapshot,
};
use crate::windows::as_uninit_byte_slice;
use bitflags::bitflags;
@@ -33,12 +33,14 @@ use winapi::{
},
},
um::{
- handleapi::CloseHandle,
ioapiset::{DeviceIoControl, GetOverlappedResult},
minwinbase::OVERLAPPED,
- synchapi::{CreateEventW, WaitForSingleObject},
+ synchapi::{WaitForMultipleObjects, WaitForSingleObject},
tlhelp32::TH32CS_SNAPPROCESS,
- winbase::{FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_FAILED, WAIT_OBJECT_0},
+ winbase::{
+ FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED,
+ WAIT_OBJECT_0,
+ },
winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER},
},
};
@@ -254,24 +256,17 @@ impl DeviceHandle {
}
fn initialize(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::Initialize as u32,
- None,
- 0,
- None,
- )?;
+ device_io_control(self, DriverIoctlCode::Initialize as u32, None, 0)?;
Ok(())
}
fn register_processes(&self) -> io::Result<()> {
let process_tree_buffer = serialize_process_tree(build_process_tree()?)?;
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::RegisterProcesses as u32,
Some(&process_tree_buffer),
0,
- None,
)?;
Ok(())
}
@@ -327,11 +322,10 @@ impl DeviceHandle {
let buffer = as_uninit_byte_slice(&addresses);
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::RegisterIpAddresses as u32,
Some(buffer),
0,
- None,
)?;
Ok(())
@@ -339,11 +333,10 @@ impl DeviceHandle {
pub fn get_driver_state(&self) -> io::Result<DriverState> {
let buffer = device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::GetState as u32,
None,
size_of::<u64>() as u32,
- None,
)?
.unwrap();
@@ -381,36 +374,22 @@ impl DeviceHandle {
let config = make_process_config(&device_paths);
device_io_control(
- self.handle.as_raw_handle(),
+ self,
DriverIoctlCode::SetConfiguration as u32,
Some(&config),
0,
- None,
)?;
Ok(())
}
pub fn clear_config(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::ClearConfiguration as u32,
- None,
- 0,
- None,
- )?;
-
+ device_io_control(self, DriverIoctlCode::ClearConfiguration as u32, None, 0)?;
Ok(())
}
fn reset(&self) -> io::Result<()> {
- device_io_control(
- self.handle.as_raw_handle(),
- DriverIoctlCode::Reset as u32,
- None,
- 0,
- None,
- )?;
+ device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?;
Ok(())
}
}
@@ -680,18 +659,10 @@ struct ErrorMessageEventHeader {
/// # Panics
///
/// This may panic if `buffer` contains invalid data.
-pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
+pub fn parse_event_buffer(buffer: &[u8]) -> Result<(EventId, EventBody), UnknownEventId> {
// SAFETY: This panics if `buffer` is too small.
let raw_event_id: u32 = unsafe { deserialize_buffer(&buffer[0..mem::size_of::<u32>()]) };
- let _event_id = EventId::try_from(raw_event_id)
- .map_err(|error| {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to parse ST event buffer")
- );
- error
- })
- .ok()?;
+ let _event_id = EventId::try_from(raw_event_id)?;
// SAFETY: The event id is known to be valid.
let event_header: EventHeader =
@@ -711,7 +682,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
[string_byte_offset..(string_byte_offset + event.image_name_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::SplittingEvent {
process_id: event.process_id,
@@ -736,7 +707,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
[string_byte_offset..(string_byte_offset + event.image_name_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::SplittingError {
process_id: event.process_id,
@@ -757,7 +728,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
..(string_byte_offset + event.error_message_length as usize)],
);
- Some((
+ Ok((
event_header.event_id,
EventBody::ErrorMessage {
status: event.status,
@@ -768,170 +739,116 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> {
}
}
-/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0.
+/// Send an IOCTL code to the given device handle, and wait for the result.
+///
+/// `input` specifies an optional buffer for sending data.
+///
+/// Upon success, a buffer containing at most `output_size` bytes is returned,
+/// or `None` if no bytes were read.
pub fn device_io_control(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
input: Option<&[MaybeUninit<u8>]>,
output_size: u32,
- timeout: Option<Duration>,
) -> Result<Option<Vec<u8>>, io::Error> {
- struct HandleOwner {
- handle: RawHandle,
- }
- impl Drop for HandleOwner {
- fn drop(&mut self) {
- unsafe { CloseHandle(self.handle) };
- }
- }
-
- let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() };
- overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+ let mut overlapped = Overlapped::new(Some(Event::new(true, false)?))?;
- if overlapped.hEvent == ptr::null_mut() {
- return Err(io::Error::last_os_error());
- }
-
- let _handle_owner = HandleOwner {
- handle: overlapped.hEvent,
- };
-
- let mut out_buffer = if output_size > 0 {
- Some(Vec::with_capacity(output_size as usize))
+ let mut buffer = vec![];
+ let out_buffer = if output_size > 0 {
+ buffer.resize(
+ usize::try_from(output_size).expect("u32 must be no larger than usize"),
+ 0u8,
+ );
+ Some(&mut buffer[..])
} else {
None
};
- device_io_control_buffer(
- device,
- ioctl_code,
- input,
- out_buffer.as_mut(),
- &overlapped,
- timeout,
- )
- .map(|()| out_buffer)
+ let bytes_read =
+ device_io_control_buffer(device, ioctl_code, input, out_buffer, &mut overlapped)?;
+ if bytes_read > 0 {
+ buffer.truncate(usize::try_from(bytes_read).expect("u32 must be no larger than usize"));
+ return Ok(Some(buffer));
+ }
+ Ok(None)
}
-/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// Upon success, `output` buffer will contain at most `output.capacity()` bytes of data.
+/// Send an IOCTL code to the given device handle, and wait for the result.
+///
+/// `input` specifies an optional buffer for sending data.
+///
+/// Upon success, `output` buffer will contain at most `output.len()` bytes of data,
+/// and the function returns the number of bytes read.
+///
+/// # Panics
+///
+/// This function will panic if `overlapped` does not contain an event.
pub fn device_io_control_buffer(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
input: Option<&[MaybeUninit<u8>]>,
- mut output: Option<&mut Vec<u8>>,
- overlapped: &OVERLAPPED,
- timeout: Option<Duration>,
-) -> Result<(), io::Error> {
- let input_ptr = match input {
- Some(input) => input as *const _ as *mut _,
- None => ptr::null_mut(),
- };
- let input_len = input.map(|input| input.len()).unwrap_or(0);
-
+ output: Option<&mut [u8]>,
+ overlapped: &mut Overlapped,
+) -> Result<u32, io::Error> {
+ let output_len = output.as_ref().map(|output| output.len()).unwrap_or(0);
+ let output_len = u32::try_from(output_len).map_err(|_error| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "the output buffer is too large",
+ )
+ })?;
let out_ptr = match output {
- Some(ref mut output) => output.as_mut_ptr() as *mut _,
+ Some(output) => output as *mut _ as *mut _,
None => ptr::null_mut(),
};
- let output_size = if let Some(ref output) = output {
- output.capacity()
- } else {
- 0
- };
-
- let event = overlapped.hEvent;
-
- let mut returned_bytes = 0u32;
- let overlapped = overlapped as *const _ as *mut _;
-
- let result = unsafe {
- DeviceIoControl(
- device as *mut _,
+ // SAFETY: `out_ptr` will be valid until the result has been obtained.
+ unsafe {
+ device_io_control_buffer_async(
+ device,
ioctl_code,
- input_ptr,
- input_len as u32,
+ input,
out_ptr,
- output_size as u32,
- &mut returned_bytes,
- overlapped,
- )
- };
-
- if result != 0 {
- return Err(io::Error::new(
- io::ErrorKind::Other,
- "Expected pending operation",
- ));
- }
-
- let last_error = io::Error::last_os_error();
- if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) {
- return Err(last_error);
- }
-
- let timeout = timeout
- .map(|timeout| timeout.as_millis() as u32)
- .unwrap_or(INFINITE);
- let result = unsafe { WaitForSingleObject(event, timeout) };
- match result {
- WAIT_FAILED => return Err(io::Error::last_os_error()),
- WAIT_ABANDONED => return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")),
- WAIT_OBJECT_0 => (),
- error => return Err(io::Error::from_raw_os_error(error as i32)),
- }
-
- let result =
- unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, FALSE) };
-
- if result == 0 {
- return Err(io::Error::last_os_error());
- }
-
- if let Some(ref mut output) = output {
- unsafe { output.set_len(returned_bytes as usize) };
+ output_len,
+ overlapped.as_mut_ptr(),
+ )?;
}
-
- Ok(())
+ get_overlapped_result(device, overlapped)
}
/// Send an IOCTL code to the given device handle.
-/// `input` specifies an optional buffer to send.
-/// The result must be obtained using `GetOverlappedResult[Ex]`.
+///
+/// `input` specifies an optional buffer for sending data.
+/// `output_ptr` specifies an optional buffer for receiving data.
+///
+/// Obtain the result using [get_overlapped_result].
+///
+/// # Safety
+///
+/// * `output_ptr` must either be null or a valid buffer of `output_len` bytes. It must remain valid
+/// until the overlapped operation has completed.
pub unsafe fn device_io_control_buffer_async(
- device: RawHandle,
+ device: &DeviceHandle,
ioctl_code: u32,
- mut output: Option<&mut Vec<u8>>,
- input: Option<&[u8]>,
- overlapped: &OVERLAPPED,
+ input: Option<&[MaybeUninit<u8>]>,
+ output_ptr: *mut u8,
+ output_len: u32,
+ overlapped: *mut OVERLAPPED,
) -> Result<(), io::Error> {
let input_ptr = match input {
- Some(input) => input as *const _ as *mut _,
+ Some(input) => input.as_ptr() as *mut _,
None => ptr::null_mut(),
};
let input_len = input.map(|input| input.len()).unwrap_or(0);
- let out_ptr = match output {
- Some(ref mut output) => output.as_mut_ptr() as *mut _,
- None => ptr::null_mut(),
- };
- let output_size = if let Some(ref output) = output {
- output.capacity()
- } else {
- 0
- };
-
- let overlapped = overlapped as *const _ as *mut _;
-
let result = DeviceIoControl(
- device as *mut _,
+ device.as_raw_handle(),
ioctl_code,
input_ptr,
- input_len as u32,
- out_ptr,
- output_size as u32,
+ u32::try_from(input_len).map_err(|_error| {
+ io::Error::new(io::ErrorKind::InvalidInput, "the input buffer is too large")
+ })?,
+ output_ptr as *mut _,
+ output_len,
ptr::null_mut(),
overlapped,
);
@@ -951,6 +868,90 @@ pub unsafe fn device_io_control_buffer_async(
Ok(())
}
+/// Retrieves the result of an overlapped operation. On success, this returns
+/// the number of bytes transferred. For device I/O, this is the number of bytes
+/// written to the output buffer.
+///
+/// # Panics
+///
+/// This function will panic if `overlapped` does not contain an event.
+pub fn get_overlapped_result(
+ device: &DeviceHandle,
+ overlapped: &mut Overlapped,
+) -> io::Result<u32> {
+ let event = overlapped.get_event().unwrap();
+
+ // SAFETY: This is a valid event object.
+ unsafe { wait_for_single_object(event.as_raw_handle(), None) }?;
+
+ // SAFETY: The handle and overlapped object are valid.
+ let mut returned_bytes = 0u32;
+ let result = unsafe {
+ GetOverlappedResult(
+ device.as_raw_handle(),
+ overlapped.as_mut_ptr(),
+ &mut returned_bytes,
+ FALSE,
+ )
+ };
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(returned_bytes)
+}
+
+/// Waits for an object to be signaled, or until a timeout interval has elapsed.
+///
+/// # Safety
+///
+/// * `object` must be a valid object that can be signaled, such as an event object.
+pub unsafe fn wait_for_single_object(
+ object: RawHandle,
+ timeout: Option<Duration>,
+) -> io::Result<()> {
+ let timeout = match timeout {
+ Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| {
+ io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long")
+ })?,
+ None => INFINITE,
+ };
+ let result = WaitForSingleObject(object, timeout);
+ match result {
+ WAIT_OBJECT_0 => Ok(()),
+ WAIT_FAILED => Err(io::Error::last_os_error()),
+ WAIT_ABANDONED => Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")),
+ error => Err(io::Error::from_raw_os_error(error as i32)),
+ }
+}
+
+/// Waits for one or several objects to be signaled. On success, this returns a pointer to an
+/// object in `objects` that was signaled.
+///
+/// # Safety
+///
+/// * `objects` must be a slice of valid objects that can be signaled, such as event objects.
+pub unsafe fn wait_for_multiple_objects(
+ objects: &[RawHandle],
+ wait_all: bool,
+) -> io::Result<RawHandle> {
+ let objects_len = u32::try_from(objects.len())
+ .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?;
+ let result = WaitForMultipleObjects(
+ objects_len,
+ objects.as_ptr(),
+ if wait_all { TRUE } else { FALSE },
+ INFINITE,
+ );
+ let signaled_index = if result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + objects_len {
+ result - WAIT_OBJECT_0
+ } else if result >= WAIT_ABANDONED_0 && result < WAIT_ABANDONED_0 + objects_len {
+ return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex"));
+ } else {
+ return Err(io::Error::last_os_error());
+ };
+ Ok(objects[usize::try_from(signaled_index).expect("usize must be larger than u32")])
+}
+
/// Reads the value from `buffer`, zeroing any remaining bytes.
///
/// # Safety
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index fc75e80331..63a56be473 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -15,11 +15,10 @@ use std::{
collections::HashMap,
convert::TryFrom,
ffi::{OsStr, OsString},
- io, mem,
+ io,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
- os::windows::io::{AsRawHandle, RawHandle},
+ os::windows::io::AsRawHandle,
path::{Path, PathBuf},
- ptr,
sync::{
atomic::{AtomicBool, Ordering},
mpsc as sync_mpsc, Arc, Mutex, RwLock, Weak,
@@ -27,16 +26,6 @@ use std::{
time::Duration,
};
use talpid_types::{tunnel::ErrorStateCause, ErrorExt};
-use winapi::{
- shared::minwindef::{FALSE, TRUE},
- um::{
- handleapi::CloseHandle,
- ioapiset::GetOverlappedResult,
- minwinbase::OVERLAPPED,
- synchapi::{CreateEventW, SetEvent, WaitForMultipleObjects, WaitForSingleObject},
- winbase::{INFINITE, WAIT_ABANDONED_0, WAIT_OBJECT_0},
- },
-};
const DRIVER_EVENT_BUFFER_SIZE: usize = 2048;
const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123);
@@ -103,37 +92,13 @@ pub struct SplitTunnel {
runtime: tokio::runtime::Handle,
request_tx: RequestTx,
event_thread: Option<std::thread::JoinHandle<()>>,
- quit_event: Arc<QuitEvent>,
+ quit_event: Arc<windows::Event>,
excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
_route_change_callback: Option<WinNetCallbackHandle>,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
async_path_update_in_progress: Arc<AtomicBool>,
}
-struct QuitEvent(RawHandle);
-
-unsafe impl Send for QuitEvent {}
-unsafe impl Sync for QuitEvent {}
-
-impl QuitEvent {
- fn new() -> Self {
- Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) })
- }
-
- fn set_event(&self) -> io::Result<()> {
- if unsafe { SetEvent(self.0) } == 0 {
- return Err(io::Error::last_os_error());
- }
- Ok(())
- }
-}
-
-impl Drop for QuitEvent {
- fn drop(&mut self) {
- unsafe { CloseHandle(self.0) };
- }
-}
-
enum Request {
SetPaths(Vec<OsString>),
RegisterIps(InterfaceAddresses),
@@ -182,13 +147,12 @@ impl SplitTunnelHandle {
}
}
-struct EventThreadContext {
- handle: Arc<driver::DeviceHandle>,
- event_overlapped: OVERLAPPED,
- quit_event: Arc<QuitEvent>,
- excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
+enum EventResult {
+ /// Result containing the next event.
+ Event(driver::EventId, driver::EventBody),
+ /// Quit event was signaled.
+ Quit,
}
-unsafe impl Send for EventThreadContext {}
impl SplitTunnel {
/// Initialize the split tunnel device.
@@ -199,200 +163,207 @@ impl SplitTunnel {
) -> Result<Self, Error> {
let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?;
- let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() };
- event_overlapped.hEvent =
- unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
- if event_overlapped.hEvent == ptr::null_mut() {
- return Err(Error::EventThreadError(io::Error::last_os_error()));
- }
-
- let quit_event = Arc::new(QuitEvent::new());
let excluded_processes = Arc::new(RwLock::new(HashMap::new()));
+ let (event_thread, quit_event) =
+ Self::spawn_event_listener(handle, excluded_processes.clone())?;
- let event_context = EventThreadContext {
- handle: handle.clone(),
- event_overlapped,
- quit_event: quit_event.clone(),
- excluded_processes: excluded_processes.clone(),
- };
-
- let event_thread = std::thread::spawn(move || {
- use driver::{EventBody, EventId};
+ Ok(SplitTunnel {
+ runtime,
+ request_tx,
+ event_thread: Some(event_thread),
+ quit_event,
+ _route_change_callback: None,
+ daemon_tx,
+ async_path_update_in_progress: Arc::new(AtomicBool::new(false)),
+ excluded_processes,
+ })
+ }
- // Take ownership of the entire struct (Rust 2021 edition change)
- let _ = &event_context;
+ /// Spawns an event loop thread that processes events from the driver service.
+ fn spawn_event_listener(
+ handle: Arc<driver::DeviceHandle>,
+ excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
+ ) -> Result<(std::thread::JoinHandle<()>, Arc<windows::Event>), Error> {
+ let mut event_overlapped = windows::Overlapped::new(Some(
+ windows::Event::new(true, false).map_err(Error::EventThreadError)?,
+ ))
+ .map_err(Error::EventThreadError)?;
- let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE);
- let mut returned_bytes = 0u32;
+ let quit_event =
+ Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?);
+ let quit_event_copy = quit_event.clone();
- let event_objects = [
- event_context.event_overlapped.hEvent,
- event_context.quit_event.0,
- ];
+ let event_thread = std::thread::spawn(move || {
+ let mut data_buffer = vec![];
loop {
- if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } {
- // Quit event was signaled
- break;
- }
+ // Wait until either the next event is received or the quit event is signaled.
+ let (event_id, event_body) = match Self::fetch_next_event(
+ &handle,
+ &quit_event,
+ &mut event_overlapped,
+ &mut data_buffer,
+ ) {
+ Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body),
+ Ok(EventResult::Quit) => break,
+ Err(_error) => continue,
+ };
- if let Err(error) = unsafe {
- driver::device_io_control_buffer_async(
- event_context.handle.as_raw_handle(),
- driver::DriverIoctlCode::DequeEvent as u32,
- Some(&mut data_buffer),
- None,
- &event_context.event_overlapped,
- )
- } {
- log::error!(
- "{}",
- error.display_chain_with_msg("device_io_control failed")
- );
- continue;
- }
+ Self::handle_event(event_id, event_body, &excluded_processes);
+ }
- let result = unsafe {
- WaitForMultipleObjects(
- event_objects.len() as u32,
- &event_objects[0],
- FALSE,
- INFINITE,
- )
- };
+ log::debug!("Stopping split tunnel event thread");
+ });
- let signaled_index = if result >= WAIT_OBJECT_0
- && result < WAIT_OBJECT_0 + event_objects.len() as u32
- {
- result - WAIT_OBJECT_0
- } else if result >= WAIT_ABANDONED_0
- && result < WAIT_ABANDONED_0 + event_objects.len() as u32
- {
- result - WAIT_ABANDONED_0
- } else {
- let error = io::Error::last_os_error();
- log::error!(
- "{}",
- error.display_chain_with_msg("WaitForMultipleObjects failed")
- );
+ Ok((event_thread, quit_event_copy))
+ }
- continue;
- };
+ fn fetch_next_event(
+ device: &Arc<driver::DeviceHandle>,
+ quit_event: &windows::Event,
+ overlapped: &mut windows::Overlapped,
+ data_buffer: &mut Vec<u8>,
+ ) -> io::Result<EventResult> {
+ if unsafe {
+ driver::wait_for_single_object(quit_event.as_raw_handle(), Some(Duration::ZERO))
+ }
+ .is_ok()
+ {
+ return Ok(EventResult::Quit);
+ }
- if event_context.quit_event.0 == event_objects[signaled_index as usize] {
- // Quit event was signaled
- break;
- }
+ data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8);
- let result = unsafe {
- GetOverlappedResult(
- event_context.handle.as_raw_handle(),
- &event_context.event_overlapped as *const _ as *mut _,
- &mut returned_bytes,
- TRUE,
- )
- };
+ unsafe {
+ driver::device_io_control_buffer_async(
+ device,
+ driver::DriverIoctlCode::DequeEvent as u32,
+ None,
+ data_buffer.as_mut_ptr(),
+ u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"),
+ overlapped.as_mut_ptr(),
+ )
+ }
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("DeviceIoControl failed to deque event")
+ );
+ error
+ })?;
- if result == 0 {
- let error = io::Error::last_os_error();
+ let event_objects = [
+ overlapped.get_event().unwrap().as_raw_handle(),
+ quit_event.as_raw_handle(),
+ ];
+
+ let signaled_object =
+ unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) }.map_err(
+ |error| {
log::error!(
"{}",
- error.display_chain_with_msg("GetOverlappedResult failed")
+ error.display_chain_with_msg("wait_for_multiple_objects failed")
);
+ error
+ },
+ )?;
- continue;
- }
+ if signaled_object == quit_event.as_raw_handle() {
+ // Quit event was signaled
+ return Ok(EventResult::Quit);
+ }
- unsafe { data_buffer.set_len(returned_bytes as usize) };
+ let returned_bytes =
+ driver::get_overlapped_result(device, overlapped).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("get_overlapped_result failed for dequed event"),
+ );
+ error
+ })?;
- let event = driver::parse_event_buffer(&data_buffer);
+ data_buffer
+ .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32"));
- let (event_id, event_body) = match event {
- Some((event_id, event_body)) => (event_id, event_body),
- None => continue,
- };
+ driver::parse_event_buffer(&data_buffer)
+ .map(|(id, body)| EventResult::Event(id, body))
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to parse ST event buffer")
+ );
+ io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer")
+ })
+ }
- let event_str = match &event_id {
- EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => {
- "Start splitting process"
- }
- EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => {
- "Stop splitting process"
- }
- EventId::ErrorMessage => "ErrorMessage",
- };
+ fn handle_event(
+ event_id: driver::EventId,
+ event_body: driver::EventBody,
+ excluded_processes: &Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
+ ) {
+ use driver::{EventBody, EventId};
- match event_body {
- EventBody::SplittingEvent {
- process_id,
- reason,
- image,
- } => {
- let mut pids = event_context.excluded_processes.write().unwrap();
- match event_id {
- EventId::StartSplittingProcess => {
- if let Some(prev_entry) = pids.get(&process_id) {
- log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry);
- }
- pids.insert(
- process_id,
- ExcludedProcess {
- pid: u32::try_from(process_id)
- .expect("PID should be containable in a DWORD"),
- image: Path::new(&image).to_path_buf(),
- inherited: reason.contains(
- driver::SplittingChangeReason::BY_INHERITANCE,
- ),
- },
- );
- }
- EventId::StopSplittingProcess => {
- if pids.remove(&process_id).is_none() {
- log::error!(
- "Inconsistent process tree: {process_id} was not found"
- );
- }
- }
- _ => (),
- }
+ let event_str = match &event_id {
+ EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => {
+ "Start splitting process"
+ }
+ EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => {
+ "Stop splitting process"
+ }
+ EventId::ErrorMessage => "ErrorMessage",
+ };
- log::trace!(
- "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}",
- event_str,
- process_id,
- reason,
- image,
- );
- }
- EventBody::SplittingError { process_id, image } => {
- log::error!(
- "FAILED: {}:\n\tpid: {}\n\timage: {:?}",
- event_str,
+ match event_body {
+ EventBody::SplittingEvent {
+ process_id,
+ reason,
+ image,
+ } => {
+ let mut pids = excluded_processes.write().unwrap();
+ match event_id {
+ EventId::StartSplittingProcess => {
+ if let Some(prev_entry) = pids.get(&process_id) {
+ log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry);
+ }
+ pids.insert(
process_id,
- image,
+ ExcludedProcess {
+ pid: u32::try_from(process_id)
+ .expect("PID should be containable in a DWORD"),
+ image: Path::new(&image).to_path_buf(),
+ inherited: reason
+ .contains(driver::SplittingChangeReason::BY_INHERITANCE),
+ },
);
}
- EventBody::ErrorMessage { status, message } => {
- log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy())
+ EventId::StopSplittingProcess => {
+ if pids.remove(&process_id).is_none() {
+ log::error!("Inconsistent process tree: {process_id} was not found");
+ }
}
+ _ => (),
}
- }
-
- log::debug!("Stopping split tunnel event thread");
-
- unsafe { CloseHandle(event_context.event_overlapped.hEvent) };
- });
- Ok(SplitTunnel {
- runtime,
- request_tx,
- event_thread: Some(event_thread),
- quit_event,
- _route_change_callback: None,
- daemon_tx,
- async_path_update_in_progress: Arc::new(AtomicBool::new(false)),
- excluded_processes,
- })
+ log::trace!(
+ "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}",
+ event_str,
+ process_id,
+ reason,
+ image,
+ );
+ }
+ EventBody::SplittingError { process_id, image } => {
+ log::error!(
+ "FAILED: {}:\n\tpid: {}\n\timage: {:?}",
+ event_str,
+ process_id,
+ image,
+ );
+ }
+ EventBody::ErrorMessage { status, message } => {
+ log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy())
+ }
+ }
}
fn spawn_request_thread(
@@ -640,7 +611,7 @@ impl SplitTunnel {
impl Drop for SplitTunnel {
fn drop(&mut self) {
if let Some(_event_thread) = self.event_thread.take() {
- if let Err(error) = self.quit_event.set_event() {
+ if let Err(error) = self.quit_event.set() {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close ST event thread")
diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs
index 7000211bba..1cfcce15a5 100644
--- a/talpid-core/src/split_tunnel/windows/windows.rs
+++ b/talpid-core/src/split_tunnel/windows/windows.rs
@@ -13,15 +13,17 @@ use std::{
};
use winapi::{
shared::{
- minwindef::{DWORD, FALSE, FILETIME, TRUE},
+ minwindef::{BOOL, DWORD, FALSE, FILETIME, TRUE},
ntdef::ULARGE_INTEGER,
winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES},
},
um::{
fileapi::{GetFinalPathNameByHandleW, QueryDosDeviceW},
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ minwinbase::OVERLAPPED,
processthreadsapi::{GetProcessTimes, OpenProcess},
psapi::K32GetProcessImageFileNameW,
+ synchapi::{CreateEventW, SetEvent},
tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W},
winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION},
},
@@ -313,3 +315,98 @@ fn get_process_device_path_inner(
Ok(OsStringExt::from_wide(&buffer))
}
+
+/// Abstraction over `OVERLAPPED`, which is used for async I/O.
+pub struct Overlapped {
+ overlapped: OVERLAPPED,
+ event: Option<Event>,
+}
+
+unsafe impl Send for Overlapped {}
+unsafe impl Sync for Overlapped {}
+
+impl Overlapped {
+ /// Creates an `OVERLAPPED` object with `hEvent` set.
+ pub fn new(event: Option<Event>) -> io::Result<Self> {
+ let mut overlapped = Overlapped {
+ overlapped: unsafe { mem::zeroed() },
+ event: None,
+ };
+ overlapped.set_event(event);
+ Ok(overlapped)
+ }
+
+ /// Borrows the underlying `OVERLAPPED` object.
+ pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED {
+ &mut self.overlapped
+ }
+
+ /// Returns a reference to the associated event.
+ pub fn get_event(&self) -> Option<&Event> {
+ self.event.as_ref()
+ }
+
+ /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`)
+ fn set_event(&mut self, event: Option<Event>) {
+ match event {
+ Some(event) => {
+ let raw_event = event.0;
+ self.overlapped.hEvent = raw_event;
+ self.event = Some(event);
+ }
+ None => {
+ self.overlapped.hEvent = ptr::null_mut();
+ self.event = None;
+ }
+ }
+ }
+}
+
+/// Abstraction over a Windows event object.
+pub struct Event(RawHandle);
+
+unsafe impl Send for Event {}
+unsafe impl Sync for Event {}
+
+impl Event {
+ pub fn new(manual_reset: bool, initial_state: bool) -> io::Result<Self> {
+ let event = unsafe {
+ CreateEventW(
+ ptr::null_mut(),
+ bool_to_winbool(manual_reset),
+ bool_to_winbool(initial_state),
+ ptr::null(),
+ )
+ };
+ if event == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(Self(event))
+ }
+
+ pub fn set(&self) -> io::Result<()> {
+ if unsafe { SetEvent(self.0) } == FALSE {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+}
+
+impl AsRawHandle for Event {
+ fn as_raw_handle(&self) -> RawHandle {
+ self.0
+ }
+}
+
+impl Drop for Event {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.0) };
+ }
+}
+
+const fn bool_to_winbool(val: bool) -> BOOL {
+ match val {
+ true => TRUE,
+ false => FALSE,
+ }
+}