diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-06-08 15:35:34 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 13:59:51 +0200 |
| commit | 0ce130b6faec1272cb7708e4bf5f315e3e287237 (patch) | |
| tree | e2e800efbbf6d94dd63b25c38097bbac8c825b40 /talpid-core/src | |
| parent | dc80e6b604d33b0ff1ea8139ba5728996ddafbf2 (diff) | |
| download | mullvadvpn-0ce130b6faec1272cb7708e4bf5f315e3e287237.tar.xz mullvadvpn-0ce130b6faec1272cb7708e4bf5f315e3e287237.zip | |
Refactor split tunnel module event thread and synchronization
Diffstat (limited to 'talpid-core/src')
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 355 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 389 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/windows.rs | 99 |
3 files changed, 456 insertions, 387 deletions
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 70fe90f45d..6b6b2b2ea7 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -1,6 +1,6 @@ use super::windows::{ - get_device_path, get_process_creation_time, get_process_device_path, open_process, - ProcessAccess, ProcessSnapshot, + get_device_path, get_process_creation_time, get_process_device_path, open_process, Event, + Overlapped, ProcessAccess, ProcessSnapshot, }; use crate::windows::as_uninit_byte_slice; use bitflags::bitflags; @@ -33,12 +33,14 @@ use winapi::{ }, }, um::{ - handleapi::CloseHandle, ioapiset::{DeviceIoControl, GetOverlappedResult}, minwinbase::OVERLAPPED, - synchapi::{CreateEventW, WaitForSingleObject}, + synchapi::{WaitForMultipleObjects, WaitForSingleObject}, tlhelp32::TH32CS_SNAPPROCESS, - winbase::{FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_FAILED, WAIT_OBJECT_0}, + winbase::{ + FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_FAILED, + WAIT_OBJECT_0, + }, winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, }, }; @@ -254,24 +256,17 @@ impl DeviceHandle { } fn initialize(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::Initialize as u32, - None, - 0, - None, - )?; + device_io_control(self, DriverIoctlCode::Initialize as u32, None, 0)?; Ok(()) } fn register_processes(&self) -> io::Result<()> { let process_tree_buffer = serialize_process_tree(build_process_tree()?)?; device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::RegisterProcesses as u32, Some(&process_tree_buffer), 0, - None, )?; Ok(()) } @@ -327,11 +322,10 @@ impl DeviceHandle { let buffer = as_uninit_byte_slice(&addresses); device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::RegisterIpAddresses as u32, Some(buffer), 0, - None, )?; Ok(()) @@ -339,11 +333,10 @@ impl DeviceHandle { pub fn get_driver_state(&self) -> io::Result<DriverState> { let buffer = device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::GetState as u32, None, size_of::<u64>() as u32, - None, )? .unwrap(); @@ -381,36 +374,22 @@ impl DeviceHandle { let config = make_process_config(&device_paths); device_io_control( - self.handle.as_raw_handle(), + self, DriverIoctlCode::SetConfiguration as u32, Some(&config), 0, - None, )?; Ok(()) } pub fn clear_config(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::ClearConfiguration as u32, - None, - 0, - None, - )?; - + device_io_control(self, DriverIoctlCode::ClearConfiguration as u32, None, 0)?; Ok(()) } fn reset(&self) -> io::Result<()> { - device_io_control( - self.handle.as_raw_handle(), - DriverIoctlCode::Reset as u32, - None, - 0, - None, - )?; + device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?; Ok(()) } } @@ -680,18 +659,10 @@ struct ErrorMessageEventHeader { /// # Panics /// /// This may panic if `buffer` contains invalid data. -pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { +pub fn parse_event_buffer(buffer: &[u8]) -> Result<(EventId, EventBody), UnknownEventId> { // SAFETY: This panics if `buffer` is too small. let raw_event_id: u32 = unsafe { deserialize_buffer(&buffer[0..mem::size_of::<u32>()]) }; - let _event_id = EventId::try_from(raw_event_id) - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to parse ST event buffer") - ); - error - }) - .ok()?; + let _event_id = EventId::try_from(raw_event_id)?; // SAFETY: The event id is known to be valid. let event_header: EventHeader = @@ -711,7 +682,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { [string_byte_offset..(string_byte_offset + event.image_name_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::SplittingEvent { process_id: event.process_id, @@ -736,7 +707,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { [string_byte_offset..(string_byte_offset + event.image_name_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::SplittingError { process_id: event.process_id, @@ -757,7 +728,7 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { ..(string_byte_offset + event.error_message_length as usize)], ); - Some(( + Ok(( event_header.event_id, EventBody::ErrorMessage { status: event.status, @@ -768,170 +739,116 @@ pub fn parse_event_buffer(buffer: &[u8]) -> Option<(EventId, EventBody)> { } } -/// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0. +/// Send an IOCTL code to the given device handle, and wait for the result. +/// +/// `input` specifies an optional buffer for sending data. +/// +/// Upon success, a buffer containing at most `output_size` bytes is returned, +/// or `None` if no bytes were read. pub fn device_io_control( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, input: Option<&[MaybeUninit<u8>]>, output_size: u32, - timeout: Option<Duration>, ) -> Result<Option<Vec<u8>>, io::Error> { - struct HandleOwner { - handle: RawHandle, - } - impl Drop for HandleOwner { - fn drop(&mut self) { - unsafe { CloseHandle(self.handle) }; - } - } - - let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() }; - overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; + let mut overlapped = Overlapped::new(Some(Event::new(true, false)?))?; - if overlapped.hEvent == ptr::null_mut() { - return Err(io::Error::last_os_error()); - } - - let _handle_owner = HandleOwner { - handle: overlapped.hEvent, - }; - - let mut out_buffer = if output_size > 0 { - Some(Vec::with_capacity(output_size as usize)) + let mut buffer = vec![]; + let out_buffer = if output_size > 0 { + buffer.resize( + usize::try_from(output_size).expect("u32 must be no larger than usize"), + 0u8, + ); + Some(&mut buffer[..]) } else { None }; - device_io_control_buffer( - device, - ioctl_code, - input, - out_buffer.as_mut(), - &overlapped, - timeout, - ) - .map(|()| out_buffer) + let bytes_read = + device_io_control_buffer(device, ioctl_code, input, out_buffer, &mut overlapped)?; + if bytes_read > 0 { + buffer.truncate(usize::try_from(bytes_read).expect("u32 must be no larger than usize")); + return Ok(Some(buffer)); + } + Ok(None) } -/// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// Upon success, `output` buffer will contain at most `output.capacity()` bytes of data. +/// Send an IOCTL code to the given device handle, and wait for the result. +/// +/// `input` specifies an optional buffer for sending data. +/// +/// Upon success, `output` buffer will contain at most `output.len()` bytes of data, +/// and the function returns the number of bytes read. +/// +/// # Panics +/// +/// This function will panic if `overlapped` does not contain an event. pub fn device_io_control_buffer( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, input: Option<&[MaybeUninit<u8>]>, - mut output: Option<&mut Vec<u8>>, - overlapped: &OVERLAPPED, - timeout: Option<Duration>, -) -> Result<(), io::Error> { - let input_ptr = match input { - Some(input) => input as *const _ as *mut _, - None => ptr::null_mut(), - }; - let input_len = input.map(|input| input.len()).unwrap_or(0); - + output: Option<&mut [u8]>, + overlapped: &mut Overlapped, +) -> Result<u32, io::Error> { + let output_len = output.as_ref().map(|output| output.len()).unwrap_or(0); + let output_len = u32::try_from(output_len).map_err(|_error| { + io::Error::new( + io::ErrorKind::InvalidInput, + "the output buffer is too large", + ) + })?; let out_ptr = match output { - Some(ref mut output) => output.as_mut_ptr() as *mut _, + Some(output) => output as *mut _ as *mut _, None => ptr::null_mut(), }; - let output_size = if let Some(ref output) = output { - output.capacity() - } else { - 0 - }; - - let event = overlapped.hEvent; - - let mut returned_bytes = 0u32; - let overlapped = overlapped as *const _ as *mut _; - - let result = unsafe { - DeviceIoControl( - device as *mut _, + // SAFETY: `out_ptr` will be valid until the result has been obtained. + unsafe { + device_io_control_buffer_async( + device, ioctl_code, - input_ptr, - input_len as u32, + input, out_ptr, - output_size as u32, - &mut returned_bytes, - overlapped, - ) - }; - - if result != 0 { - return Err(io::Error::new( - io::ErrorKind::Other, - "Expected pending operation", - )); - } - - let last_error = io::Error::last_os_error(); - if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) { - return Err(last_error); - } - - let timeout = timeout - .map(|timeout| timeout.as_millis() as u32) - .unwrap_or(INFINITE); - let result = unsafe { WaitForSingleObject(event, timeout) }; - match result { - WAIT_FAILED => return Err(io::Error::last_os_error()), - WAIT_ABANDONED => return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")), - WAIT_OBJECT_0 => (), - error => return Err(io::Error::from_raw_os_error(error as i32)), - } - - let result = - unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, FALSE) }; - - if result == 0 { - return Err(io::Error::last_os_error()); - } - - if let Some(ref mut output) = output { - unsafe { output.set_len(returned_bytes as usize) }; + output_len, + overlapped.as_mut_ptr(), + )?; } - - Ok(()) + get_overlapped_result(device, overlapped) } /// Send an IOCTL code to the given device handle. -/// `input` specifies an optional buffer to send. -/// The result must be obtained using `GetOverlappedResult[Ex]`. +/// +/// `input` specifies an optional buffer for sending data. +/// `output_ptr` specifies an optional buffer for receiving data. +/// +/// Obtain the result using [get_overlapped_result]. +/// +/// # Safety +/// +/// * `output_ptr` must either be null or a valid buffer of `output_len` bytes. It must remain valid +/// until the overlapped operation has completed. pub unsafe fn device_io_control_buffer_async( - device: RawHandle, + device: &DeviceHandle, ioctl_code: u32, - mut output: Option<&mut Vec<u8>>, - input: Option<&[u8]>, - overlapped: &OVERLAPPED, + input: Option<&[MaybeUninit<u8>]>, + output_ptr: *mut u8, + output_len: u32, + overlapped: *mut OVERLAPPED, ) -> Result<(), io::Error> { let input_ptr = match input { - Some(input) => input as *const _ as *mut _, + Some(input) => input.as_ptr() as *mut _, None => ptr::null_mut(), }; let input_len = input.map(|input| input.len()).unwrap_or(0); - let out_ptr = match output { - Some(ref mut output) => output.as_mut_ptr() as *mut _, - None => ptr::null_mut(), - }; - let output_size = if let Some(ref output) = output { - output.capacity() - } else { - 0 - }; - - let overlapped = overlapped as *const _ as *mut _; - let result = DeviceIoControl( - device as *mut _, + device.as_raw_handle(), ioctl_code, input_ptr, - input_len as u32, - out_ptr, - output_size as u32, + u32::try_from(input_len).map_err(|_error| { + io::Error::new(io::ErrorKind::InvalidInput, "the input buffer is too large") + })?, + output_ptr as *mut _, + output_len, ptr::null_mut(), overlapped, ); @@ -951,6 +868,90 @@ pub unsafe fn device_io_control_buffer_async( Ok(()) } +/// Retrieves the result of an overlapped operation. On success, this returns +/// the number of bytes transferred. For device I/O, this is the number of bytes +/// written to the output buffer. +/// +/// # Panics +/// +/// This function will panic if `overlapped` does not contain an event. +pub fn get_overlapped_result( + device: &DeviceHandle, + overlapped: &mut Overlapped, +) -> io::Result<u32> { + let event = overlapped.get_event().unwrap(); + + // SAFETY: This is a valid event object. + unsafe { wait_for_single_object(event.as_raw_handle(), None) }?; + + // SAFETY: The handle and overlapped object are valid. + let mut returned_bytes = 0u32; + let result = unsafe { + GetOverlappedResult( + device.as_raw_handle(), + overlapped.as_mut_ptr(), + &mut returned_bytes, + FALSE, + ) + }; + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(returned_bytes) +} + +/// Waits for an object to be signaled, or until a timeout interval has elapsed. +/// +/// # Safety +/// +/// * `object` must be a valid object that can be signaled, such as an event object. +pub unsafe fn wait_for_single_object( + object: RawHandle, + timeout: Option<Duration>, +) -> io::Result<()> { + let timeout = match timeout { + Some(timeout) => u32::try_from(timeout.as_millis()).map_err(|_error| { + io::Error::new(io::ErrorKind::InvalidInput, "the duration is too long") + })?, + None => INFINITE, + }; + let result = WaitForSingleObject(object, timeout); + match result { + WAIT_OBJECT_0 => Ok(()), + WAIT_FAILED => Err(io::Error::last_os_error()), + WAIT_ABANDONED => Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")), + error => Err(io::Error::from_raw_os_error(error as i32)), + } +} + +/// Waits for one or several objects to be signaled. On success, this returns a pointer to an +/// object in `objects` that was signaled. +/// +/// # Safety +/// +/// * `objects` must be a slice of valid objects that can be signaled, such as event objects. +pub unsafe fn wait_for_multiple_objects( + objects: &[RawHandle], + wait_all: bool, +) -> io::Result<RawHandle> { + let objects_len = u32::try_from(objects.len()) + .map_err(|_error| io::Error::new(io::ErrorKind::InvalidInput, "too many objects"))?; + let result = WaitForMultipleObjects( + objects_len, + objects.as_ptr(), + if wait_all { TRUE } else { FALSE }, + INFINITE, + ); + let signaled_index = if result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + objects_len { + result - WAIT_OBJECT_0 + } else if result >= WAIT_ABANDONED_0 && result < WAIT_ABANDONED_0 + objects_len { + return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")); + } else { + return Err(io::Error::last_os_error()); + }; + Ok(objects[usize::try_from(signaled_index).expect("usize must be larger than u32")]) +} + /// Reads the value from `buffer`, zeroing any remaining bytes. /// /// # Safety diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index fc75e80331..63a56be473 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -15,11 +15,10 @@ use std::{ collections::HashMap, convert::TryFrom, ffi::{OsStr, OsString}, - io, mem, + io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, - os::windows::io::{AsRawHandle, RawHandle}, + os::windows::io::AsRawHandle, path::{Path, PathBuf}, - ptr, sync::{ atomic::{AtomicBool, Ordering}, mpsc as sync_mpsc, Arc, Mutex, RwLock, Weak, @@ -27,16 +26,6 @@ use std::{ time::Duration, }; use talpid_types::{tunnel::ErrorStateCause, ErrorExt}; -use winapi::{ - shared::minwindef::{FALSE, TRUE}, - um::{ - handleapi::CloseHandle, - ioapiset::GetOverlappedResult, - minwinbase::OVERLAPPED, - synchapi::{CreateEventW, SetEvent, WaitForMultipleObjects, WaitForSingleObject}, - winbase::{INFINITE, WAIT_ABANDONED_0, WAIT_OBJECT_0}, - }, -}; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); @@ -103,37 +92,13 @@ pub struct SplitTunnel { runtime: tokio::runtime::Handle, request_tx: RequestTx, event_thread: Option<std::thread::JoinHandle<()>>, - quit_event: Arc<QuitEvent>, + quit_event: Arc<windows::Event>, excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, _route_change_callback: Option<WinNetCallbackHandle>, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, async_path_update_in_progress: Arc<AtomicBool>, } -struct QuitEvent(RawHandle); - -unsafe impl Send for QuitEvent {} -unsafe impl Sync for QuitEvent {} - -impl QuitEvent { - fn new() -> Self { - Self(unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }) - } - - fn set_event(&self) -> io::Result<()> { - if unsafe { SetEvent(self.0) } == 0 { - return Err(io::Error::last_os_error()); - } - Ok(()) - } -} - -impl Drop for QuitEvent { - fn drop(&mut self) { - unsafe { CloseHandle(self.0) }; - } -} - enum Request { SetPaths(Vec<OsString>), RegisterIps(InterfaceAddresses), @@ -182,13 +147,12 @@ impl SplitTunnelHandle { } } -struct EventThreadContext { - handle: Arc<driver::DeviceHandle>, - event_overlapped: OVERLAPPED, - quit_event: Arc<QuitEvent>, - excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, +enum EventResult { + /// Result containing the next event. + Event(driver::EventId, driver::EventBody), + /// Quit event was signaled. + Quit, } -unsafe impl Send for EventThreadContext {} impl SplitTunnel { /// Initialize the split tunnel device. @@ -199,200 +163,207 @@ impl SplitTunnel { ) -> Result<Self, Error> { let (request_tx, handle) = Self::spawn_request_thread(volume_update_rx)?; - let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() }; - event_overlapped.hEvent = - unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; - if event_overlapped.hEvent == ptr::null_mut() { - return Err(Error::EventThreadError(io::Error::last_os_error())); - } - - let quit_event = Arc::new(QuitEvent::new()); let excluded_processes = Arc::new(RwLock::new(HashMap::new())); + let (event_thread, quit_event) = + Self::spawn_event_listener(handle, excluded_processes.clone())?; - let event_context = EventThreadContext { - handle: handle.clone(), - event_overlapped, - quit_event: quit_event.clone(), - excluded_processes: excluded_processes.clone(), - }; - - let event_thread = std::thread::spawn(move || { - use driver::{EventBody, EventId}; + Ok(SplitTunnel { + runtime, + request_tx, + event_thread: Some(event_thread), + quit_event, + _route_change_callback: None, + daemon_tx, + async_path_update_in_progress: Arc::new(AtomicBool::new(false)), + excluded_processes, + }) + } - // Take ownership of the entire struct (Rust 2021 edition change) - let _ = &event_context; + /// Spawns an event loop thread that processes events from the driver service. + fn spawn_event_listener( + handle: Arc<driver::DeviceHandle>, + excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, + ) -> Result<(std::thread::JoinHandle<()>, Arc<windows::Event>), Error> { + let mut event_overlapped = windows::Overlapped::new(Some( + windows::Event::new(true, false).map_err(Error::EventThreadError)?, + )) + .map_err(Error::EventThreadError)?; - let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE); - let mut returned_bytes = 0u32; + let quit_event = + Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?); + let quit_event_copy = quit_event.clone(); - let event_objects = [ - event_context.event_overlapped.hEvent, - event_context.quit_event.0, - ]; + let event_thread = std::thread::spawn(move || { + let mut data_buffer = vec![]; loop { - if unsafe { WaitForSingleObject(event_context.quit_event.0, 0) == WAIT_OBJECT_0 } { - // Quit event was signaled - break; - } + // Wait until either the next event is received or the quit event is signaled. + let (event_id, event_body) = match Self::fetch_next_event( + &handle, + &quit_event, + &mut event_overlapped, + &mut data_buffer, + ) { + Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body), + Ok(EventResult::Quit) => break, + Err(_error) => continue, + }; - if let Err(error) = unsafe { - driver::device_io_control_buffer_async( - event_context.handle.as_raw_handle(), - driver::DriverIoctlCode::DequeEvent as u32, - Some(&mut data_buffer), - None, - &event_context.event_overlapped, - ) - } { - log::error!( - "{}", - error.display_chain_with_msg("device_io_control failed") - ); - continue; - } + Self::handle_event(event_id, event_body, &excluded_processes); + } - let result = unsafe { - WaitForMultipleObjects( - event_objects.len() as u32, - &event_objects[0], - FALSE, - INFINITE, - ) - }; + log::debug!("Stopping split tunnel event thread"); + }); - let signaled_index = if result >= WAIT_OBJECT_0 - && result < WAIT_OBJECT_0 + event_objects.len() as u32 - { - result - WAIT_OBJECT_0 - } else if result >= WAIT_ABANDONED_0 - && result < WAIT_ABANDONED_0 + event_objects.len() as u32 - { - result - WAIT_ABANDONED_0 - } else { - let error = io::Error::last_os_error(); - log::error!( - "{}", - error.display_chain_with_msg("WaitForMultipleObjects failed") - ); + Ok((event_thread, quit_event_copy)) + } - continue; - }; + fn fetch_next_event( + device: &Arc<driver::DeviceHandle>, + quit_event: &windows::Event, + overlapped: &mut windows::Overlapped, + data_buffer: &mut Vec<u8>, + ) -> io::Result<EventResult> { + if unsafe { + driver::wait_for_single_object(quit_event.as_raw_handle(), Some(Duration::ZERO)) + } + .is_ok() + { + return Ok(EventResult::Quit); + } - if event_context.quit_event.0 == event_objects[signaled_index as usize] { - // Quit event was signaled - break; - } + data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8); - let result = unsafe { - GetOverlappedResult( - event_context.handle.as_raw_handle(), - &event_context.event_overlapped as *const _ as *mut _, - &mut returned_bytes, - TRUE, - ) - }; + unsafe { + driver::device_io_control_buffer_async( + device, + driver::DriverIoctlCode::DequeEvent as u32, + None, + data_buffer.as_mut_ptr(), + u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"), + overlapped.as_mut_ptr(), + ) + } + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("DeviceIoControl failed to deque event") + ); + error + })?; - if result == 0 { - let error = io::Error::last_os_error(); + let event_objects = [ + overlapped.get_event().unwrap().as_raw_handle(), + quit_event.as_raw_handle(), + ]; + + let signaled_object = + unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) }.map_err( + |error| { log::error!( "{}", - error.display_chain_with_msg("GetOverlappedResult failed") + error.display_chain_with_msg("wait_for_multiple_objects failed") ); + error + }, + )?; - continue; - } + if signaled_object == quit_event.as_raw_handle() { + // Quit event was signaled + return Ok(EventResult::Quit); + } - unsafe { data_buffer.set_len(returned_bytes as usize) }; + let returned_bytes = + driver::get_overlapped_result(device, overlapped).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("get_overlapped_result failed for dequed event"), + ); + error + })?; - let event = driver::parse_event_buffer(&data_buffer); + data_buffer + .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32")); - let (event_id, event_body) = match event { - Some((event_id, event_body)) => (event_id, event_body), - None => continue, - }; + driver::parse_event_buffer(&data_buffer) + .map(|(id, body)| EventResult::Event(id, body)) + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to parse ST event buffer") + ); + io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer") + }) + } - let event_str = match &event_id { - EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { - "Start splitting process" - } - EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { - "Stop splitting process" - } - EventId::ErrorMessage => "ErrorMessage", - }; + fn handle_event( + event_id: driver::EventId, + event_body: driver::EventBody, + excluded_processes: &Arc<RwLock<HashMap<usize, ExcludedProcess>>>, + ) { + use driver::{EventBody, EventId}; - match event_body { - EventBody::SplittingEvent { - process_id, - reason, - image, - } => { - let mut pids = event_context.excluded_processes.write().unwrap(); - match event_id { - EventId::StartSplittingProcess => { - if let Some(prev_entry) = pids.get(&process_id) { - log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry); - } - pids.insert( - process_id, - ExcludedProcess { - pid: u32::try_from(process_id) - .expect("PID should be containable in a DWORD"), - image: Path::new(&image).to_path_buf(), - inherited: reason.contains( - driver::SplittingChangeReason::BY_INHERITANCE, - ), - }, - ); - } - EventId::StopSplittingProcess => { - if pids.remove(&process_id).is_none() { - log::error!( - "Inconsistent process tree: {process_id} was not found" - ); - } - } - _ => (), - } + let event_str = match &event_id { + EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { + "Start splitting process" + } + EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { + "Stop splitting process" + } + EventId::ErrorMessage => "ErrorMessage", + }; - log::trace!( - "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", - event_str, - process_id, - reason, - image, - ); - } - EventBody::SplittingError { process_id, image } => { - log::error!( - "FAILED: {}:\n\tpid: {}\n\timage: {:?}", - event_str, + match event_body { + EventBody::SplittingEvent { + process_id, + reason, + image, + } => { + let mut pids = excluded_processes.write().unwrap(); + match event_id { + EventId::StartSplittingProcess => { + if let Some(prev_entry) = pids.get(&process_id) { + log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry); + } + pids.insert( process_id, - image, + ExcludedProcess { + pid: u32::try_from(process_id) + .expect("PID should be containable in a DWORD"), + image: Path::new(&image).to_path_buf(), + inherited: reason + .contains(driver::SplittingChangeReason::BY_INHERITANCE), + }, ); } - EventBody::ErrorMessage { status, message } => { - log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) + EventId::StopSplittingProcess => { + if pids.remove(&process_id).is_none() { + log::error!("Inconsistent process tree: {process_id} was not found"); + } } + _ => (), } - } - - log::debug!("Stopping split tunnel event thread"); - - unsafe { CloseHandle(event_context.event_overlapped.hEvent) }; - }); - Ok(SplitTunnel { - runtime, - request_tx, - event_thread: Some(event_thread), - quit_event, - _route_change_callback: None, - daemon_tx, - async_path_update_in_progress: Arc::new(AtomicBool::new(false)), - excluded_processes, - }) + log::trace!( + "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", + event_str, + process_id, + reason, + image, + ); + } + EventBody::SplittingError { process_id, image } => { + log::error!( + "FAILED: {}:\n\tpid: {}\n\timage: {:?}", + event_str, + process_id, + image, + ); + } + EventBody::ErrorMessage { status, message } => { + log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) + } + } } fn spawn_request_thread( @@ -640,7 +611,7 @@ impl SplitTunnel { impl Drop for SplitTunnel { fn drop(&mut self) { if let Some(_event_thread) = self.event_thread.take() { - if let Err(error) = self.quit_event.set_event() { + if let Err(error) = self.quit_event.set() { log::error!( "{}", error.display_chain_with_msg("Failed to close ST event thread") diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs index 7000211bba..1cfcce15a5 100644 --- a/talpid-core/src/split_tunnel/windows/windows.rs +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -13,15 +13,17 @@ use std::{ }; use winapi::{ shared::{ - minwindef::{DWORD, FALSE, FILETIME, TRUE}, + minwindef::{BOOL, DWORD, FALSE, FILETIME, TRUE}, ntdef::ULARGE_INTEGER, winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES}, }, um::{ fileapi::{GetFinalPathNameByHandleW, QueryDosDeviceW}, handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + minwinbase::OVERLAPPED, processthreadsapi::{GetProcessTimes, OpenProcess}, psapi::K32GetProcessImageFileNameW, + synchapi::{CreateEventW, SetEvent}, tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W}, winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION}, }, @@ -313,3 +315,98 @@ fn get_process_device_path_inner( Ok(OsStringExt::from_wide(&buffer)) } + +/// Abstraction over `OVERLAPPED`, which is used for async I/O. +pub struct Overlapped { + overlapped: OVERLAPPED, + event: Option<Event>, +} + +unsafe impl Send for Overlapped {} +unsafe impl Sync for Overlapped {} + +impl Overlapped { + /// Creates an `OVERLAPPED` object with `hEvent` set. + pub fn new(event: Option<Event>) -> io::Result<Self> { + let mut overlapped = Overlapped { + overlapped: unsafe { mem::zeroed() }, + event: None, + }; + overlapped.set_event(event); + Ok(overlapped) + } + + /// Borrows the underlying `OVERLAPPED` object. + pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED { + &mut self.overlapped + } + + /// Returns a reference to the associated event. + pub fn get_event(&self) -> Option<&Event> { + self.event.as_ref() + } + + /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`) + fn set_event(&mut self, event: Option<Event>) { + match event { + Some(event) => { + let raw_event = event.0; + self.overlapped.hEvent = raw_event; + self.event = Some(event); + } + None => { + self.overlapped.hEvent = ptr::null_mut(); + self.event = None; + } + } + } +} + +/// Abstraction over a Windows event object. +pub struct Event(RawHandle); + +unsafe impl Send for Event {} +unsafe impl Sync for Event {} + +impl Event { + pub fn new(manual_reset: bool, initial_state: bool) -> io::Result<Self> { + let event = unsafe { + CreateEventW( + ptr::null_mut(), + bool_to_winbool(manual_reset), + bool_to_winbool(initial_state), + ptr::null(), + ) + }; + if event == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(Self(event)) + } + + pub fn set(&self) -> io::Result<()> { + if unsafe { SetEvent(self.0) } == FALSE { + return Err(io::Error::last_os_error()); + } + Ok(()) + } +} + +impl AsRawHandle for Event { + fn as_raw_handle(&self) -> RawHandle { + self.0 + } +} + +impl Drop for Event { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + +const fn bool_to_winbool(val: bool) -> BOOL { + match val { + true => TRUE, + false => FALSE, + } +} |
