diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-10-15 13:31:16 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-07-02 09:54:19 +0200 |
| commit | ad0a2907cff1a8ffe0cc6588554fa7672bf671b4 (patch) | |
| tree | 061c8519c9a10767ffa840c21720c2bd5fe817b4 | |
| parent | 9cc3585e99c9ba798e8b2dab983b3aad96450180 (diff) | |
| download | mullvadvpn-ad0a2907cff1a8ffe0cc6588554fa7672bf671b4.tar.xz mullvadvpn-ad0a2907cff1a8ffe0cc6588554fa7672bf671b4.zip | |
Use overlapped I/O for split tunnel driver
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 84 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 44 |
2 files changed, 96 insertions, 32 deletions
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index 093012224b..b43e566680 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -18,10 +18,19 @@ use std::{ ptr, }; use winapi::{ - shared::{in6addr::IN6_ADDR, inaddr::IN_ADDR}, + shared::{ + in6addr::IN6_ADDR, + inaddr::IN_ADDR, + minwindef::{FALSE, TRUE}, + winerror::ERROR_IO_PENDING, + }, um::{ - ioapiset::DeviceIoControl, + handleapi::CloseHandle, + ioapiset::{DeviceIoControl, GetOverlappedResult}, + minwinbase::OVERLAPPED, + synchapi::CreateEventW, tlhelp32::TH32CS_SNAPPROCESS, + winbase::FILE_FLAG_OVERLAPPED, winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, }, }; @@ -80,11 +89,6 @@ pub enum EventId { ErrorStopSplittingProcess, } -pub struct Event { - event_id: EventId, - body: EventBody, -} - pub enum EventBody { SplittingEvent { process_id: u32, @@ -107,21 +111,29 @@ pub enum SplittingChangeReason { pub struct DeviceHandle { handle: fs::File, + overlapped: OVERLAPPED, } impl DeviceHandle { pub fn new() -> io::Result<Self> { + let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() }; + overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; + + if overlapped.hEvent == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + // Connect to the driver log::trace!("Connecting to the driver"); let handle = OpenOptions::new() .read(true) .write(true) .share_mode(0) - .custom_flags(0) + .custom_flags(FILE_FLAG_OVERLAPPED) .attributes(0) .open(DRIVER_SYMBOLIC_NAME)?; - let device = Self { handle }; + let device = Self { handle, overlapped }; // Initialize the driver let state = device.get_driver_state()?; @@ -146,6 +158,7 @@ impl DeviceHandle { DriverIoctlCode::Initialize as u32, None, 0, + &self.overlapped, )?; Ok(()) } @@ -157,6 +170,7 @@ impl DeviceHandle { DriverIoctlCode::RegisterProcesses as u32, Some(&process_tree_buffer), 0, + &self.overlapped, )?; Ok(()) } @@ -213,6 +227,7 @@ impl DeviceHandle { DriverIoctlCode::RegisterIpAddresses as u32, Some(buffer), 0, + &self.overlapped, )?; Ok(()) @@ -224,6 +239,7 @@ impl DeviceHandle { DriverIoctlCode::GetState as u32, None, size_of::<u64>() as u32, + &self.overlapped, )? .unwrap(); @@ -248,6 +264,7 @@ impl DeviceHandle { DriverIoctlCode::SetConfiguration as u32, Some(&config), 0, + &self.overlapped, )?; Ok(()) @@ -259,13 +276,16 @@ impl DeviceHandle { DriverIoctlCode::ClearConfiguration as u32, None, 0, + &self.overlapped, )?; Ok(()) } +} - pub fn deque_event(&self, buffer: &mut Vec<u8>) -> io::Result<(EventId, EventBody)> { - deque_event(self.handle.as_raw_handle(), buffer) +impl Drop for DeviceHandle { + fn drop(&mut self) { + unsafe { CloseHandle(self.overlapped.hEvent) }; } } @@ -275,12 +295,17 @@ impl AsRawHandle for DeviceHandle { } } -pub fn deque_event(handle: RawHandle, buffer: &mut Vec<u8>) -> io::Result<(EventId, EventBody)> { +pub fn deque_event( + handle: RawHandle, + buffer: &mut Vec<u8>, + overlapped: &mut OVERLAPPED, +) -> io::Result<(EventId, EventBody)> { device_io_control_buffer( handle, DriverIoctlCode::DequeEvent as u32, None, Some(buffer), + overlapped, )?; let mut event_header: EventHeader = unsafe { mem::zeroed() }; @@ -604,6 +629,7 @@ pub fn device_io_control( ioctl_code: u32, input: Option<&[u8]>, output_size: u32, + overlapped: &OVERLAPPED, ) -> Result<Option<Vec<u8>>, io::Error> { let mut out_buffer = if output_size > 0 { Some(Vec::with_capacity(output_size as usize)) @@ -611,7 +637,8 @@ pub fn device_io_control( None }; - device_io_control_buffer(device, ioctl_code, input, out_buffer.as_mut()).map(|()| out_buffer) + device_io_control_buffer(device, ioctl_code, input, out_buffer.as_mut(), overlapped) + .map(|()| out_buffer) } /// Send an IOCTL code to the given device handle. @@ -622,6 +649,7 @@ pub fn device_io_control_buffer( ioctl_code: u32, input: Option<&[u8]>, mut output: Option<&mut Vec<u8>>, + overlapped: &OVERLAPPED, ) -> Result<(), io::Error> { let input_ptr = match input { Some(input) => input as *const _ as *mut _, @@ -640,6 +668,7 @@ pub fn device_io_control_buffer( }; let mut returned_bytes = 0u32; + let overlapped = overlapped as *const _ as *mut _; let result = unsafe { DeviceIoControl( @@ -649,20 +678,35 @@ pub fn device_io_control_buffer( input_len as u32, out_ptr, output_size as u32, - &mut returned_bytes as *mut _, - ptr::null_mut(), // TODO + &mut returned_bytes, + overlapped, ) }; + if result != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "Expected pending operation", + )); + } + + let last_error = io::Error::last_os_error(); + if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) { + return Err(last_error); + } + + let result = + unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, TRUE) }; + + if result == 0 { + return Err(io::Error::last_os_error()); + } + if let Some(ref mut output) = output { unsafe { output.set_len(returned_bytes as usize) }; } - if result != 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) - } + Ok(()) } /// Creates a new instance of an arbitrary type from a byte buffer. diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 857a20435d..516b44de72 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -3,15 +3,16 @@ mod windows; use std::{ ffi::OsStr, - io, + io, mem, net::{Ipv4Addr, Ipv6Addr}, - os::windows::{ - io::{AsRawHandle, IntoRawHandle, RawHandle}, - thread, - }, + os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle}, + ptr, }; use talpid_types::ErrorExt; -use winapi::um::processthreadsapi::TerminateThread; +use winapi::{ + shared::minwindef::{FALSE, TRUE}, + um::{minwinbase::OVERLAPPED, processthreadsapi::TerminateThread, synchapi::CreateEventW}, +}; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; @@ -30,6 +31,10 @@ pub enum Error { /// Failed to register interface IP addresses #[error(display = "Failed to register IP addresses for exclusions")] RegisterIps(#[error(source)] io::Error), + + /// Failed to set up the driver event loop + #[error(display = "Failed to set up the driver event loop")] + EventThreadError(#[error(source)] io::Error), } /// Manages applications whose traffic to exclude from the tunnel. @@ -38,22 +43,30 @@ pub struct SplitTunnel { event_thread: Option<std::thread::JoinHandle<()>>, } -struct HandleContainer { +struct EventThreadContext { handle: RawHandle, + event_overlapped: OVERLAPPED, } -// FIXME: ! This is not safe. The handle will be invalidated when SplitTunnel is dropped -unsafe impl Send for HandleContainer {} +// FIXME: ! This is not safe. The driver handle will be invalidated when SplitTunnel is dropped +unsafe impl Send for EventThreadContext {} impl SplitTunnel { /// Initialize the driver. pub fn new() -> Result<Self, Error> { - // TODO: spawn event monitor let handle = driver::DeviceHandle::new().map_err(Error::InitializationFailed)?; // FIXME: Want to use same pointer, but must be certain that the thread dies after this dies - let raw_handle = HandleContainer { + let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() }; + event_overlapped.hEvent = + unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) }; + if event_overlapped.hEvent == ptr::null_mut() { + return Err(Error::EventThreadError(io::Error::last_os_error())); + } + + let mut event_context = EventThreadContext { handle: handle.as_raw_handle(), + event_overlapped, }; let event_thread = std::thread::spawn(move || { @@ -62,7 +75,11 @@ impl SplitTunnel { let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE); loop { - match driver::deque_event(raw_handle.handle, &mut data_buffer) { + match driver::deque_event( + event_context.handle, + &mut data_buffer, + &mut event_context.event_overlapped, + ) { Ok((event_id, event_body)) => { let event_str = match &event_id { EventId::StartSplittingProcess @@ -91,6 +108,9 @@ impl SplitTunnel { // TODO: Quit when signaled. Overlapping + WaitForMultipleObjects? } + + // FIXME: The event object will not be destroyed since we use TerminateThread + // unsafe { CloseHandle(event_overlapped.hEvent) }; }); Ok(SplitTunnel { |
