diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-05-29 09:45:17 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-07-02 09:54:19 +0200 |
| commit | e5baa0e08816d535a031b3d8575701b8d43fb0c2 (patch) | |
| tree | c4bf2ec1956977676bc25c2630bd38789f43dade /talpid-core/src | |
| parent | 207ab239223686ff72c43a8a5d615565ab81b5ab (diff) | |
| download | mullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.tar.xz mullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.zip | |
Support Windows split tunneling driver
Diffstat (limited to 'talpid-core/src')
| -rw-r--r-- | talpid-core/src/split_tunnel/mod.rs | 7 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 514 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 69 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/windows.rs | 259 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 231 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnected_state.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/disconnecting_state.rs | 31 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/error_state.rs | 23 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 34 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 27 |
11 files changed, 1230 insertions, 7 deletions
diff --git a/talpid-core/src/split_tunnel/mod.rs b/talpid-core/src/split_tunnel/mod.rs index c7c366d6ea..3c3f6af294 100644 --- a/talpid-core/src/split_tunnel/mod.rs +++ b/talpid-core/src/split_tunnel/mod.rs @@ -4,3 +4,10 @@ mod imp; #[cfg(target_os = "linux")] pub use imp::*; + +#[cfg(windows)] +#[path = "windows/mod.rs"] +mod imp; + +#[cfg(windows)] +pub use imp::*; diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs new file mode 100644 index 0000000000..26495a5877 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -0,0 +1,514 @@ +use super::windows::{ + get_final_path_name, get_process_creation_time, get_process_device_path, open_process, + ProcessAccess, ProcessSnapshot, +}; +use std::{ + cell::RefCell, + collections::HashMap, + ffi::{OsStr, OsString}, + fs::{self, OpenOptions}, + io, + mem::{self, size_of}, + net::{Ipv4Addr, Ipv6Addr}, + os::windows::{ + ffi::OsStrExt, + fs::OpenOptionsExt, + io::{AsRawHandle, RawHandle}, + }, + ptr, +}; +use winapi::{ + shared::{in6addr::IN6_ADDR, inaddr::IN_ADDR}, + um::{ + ioapiset::DeviceIoControl, + tlhelp32::TH32CS_SNAPPROCESS, + winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, + }, +}; + +const DRIVER_SYMBOLIC_NAME: &str = "\\\\.\\MULLVADSPLITTUNNEL"; +const ST_DEVICE_TYPE: u32 = 0x8000; + +const fn ctl_code(device_type: u32, function: u32, method: u32, access: u32) -> u32 { + device_type << 16 | access << 14 | function << 2 | method +} + +#[repr(u32)] +#[allow(dead_code)] +enum DriverIoctlCode { + Initialize = ctl_code(ST_DEVICE_TYPE, 1, METHOD_NEITHER, FILE_ANY_ACCESS), + DequeEvent = ctl_code(ST_DEVICE_TYPE, 2, METHOD_BUFFERED, FILE_ANY_ACCESS), + RegisterProcesses = ctl_code(ST_DEVICE_TYPE, 3, METHOD_BUFFERED, FILE_ANY_ACCESS), + RegisterIpAddresses = ctl_code(ST_DEVICE_TYPE, 4, METHOD_BUFFERED, FILE_ANY_ACCESS), + GetIpAddresses = ctl_code(ST_DEVICE_TYPE, 5, METHOD_BUFFERED, FILE_ANY_ACCESS), + SetConfiguration = ctl_code(ST_DEVICE_TYPE, 6, METHOD_BUFFERED, FILE_ANY_ACCESS), + GetConfiguration = ctl_code(ST_DEVICE_TYPE, 7, METHOD_BUFFERED, FILE_ANY_ACCESS), + ClearConfiguration = ctl_code(ST_DEVICE_TYPE, 8, METHOD_NEITHER, FILE_ANY_ACCESS), + GetState = ctl_code(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS), + QueryProcess = ctl_code(ST_DEVICE_TYPE, 10, METHOD_BUFFERED, FILE_ANY_ACCESS), +} + +#[derive(Debug, PartialEq)] +#[repr(u32)] +#[allow(dead_code)] +pub enum DriverState { + // Default state after being loaded. + None = 0, + // DriverEntry has completed successfully. + // Basically only driver and device objects are created at this point. + Started = 1, + // All subsystems are initialized. + Initialized = 2, + // User mode has registered all processes in the system. + Ready = 3, + // IP addresses are registered. + // A valid configuration is registered. + Engaged = 4, + // Driver is unloading. + Terminating = 5, +} + +pub struct DeviceHandle { + handle: fs::File, +} + +impl DeviceHandle { + pub fn new() -> io::Result<Self> { + // Connect to the driver + log::trace!("Connecting to the driver"); + let handle = OpenOptions::new() + .read(true) + .write(true) + .share_mode(0) + .custom_flags(0) + .attributes(0) + .open(DRIVER_SYMBOLIC_NAME)?; + + let device = Self { handle }; + + // Initialize the driver + let state = device.get_driver_state()?; + if state == DriverState::Started { + log::trace!("Initializing driver"); + device.initialize()?; + } + + // Initialize process tree + let state = device.get_driver_state()?; + if state == DriverState::Initialized { + log::trace!("Registering processes"); + device.register_processes()?; + } + + Ok(device) + } + + fn initialize(&self) -> io::Result<()> { + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::Initialize as u32, + None, + 0, + )?; + Ok(()) + } + + fn register_processes(&self) -> io::Result<()> { + let process_tree_buffer = serialize_process_tree(build_process_tree()?)?; + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::RegisterProcesses as u32, + Some(&process_tree_buffer), + 0, + )?; + Ok(()) + } + + pub fn register_ips( + &self, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> io::Result<()> { + let mut addresses: SplitTunnelAddresses = unsafe { mem::zeroed() }; + + unsafe { + let tunnel_ipv4 = tunnel_ipv4.octets(); + ptr::copy_nonoverlapping( + &tunnel_ipv4[0] as *const u8, + &mut addresses.tunnel_ipv4 as *mut _ as *mut u8, + tunnel_ipv4.len(), + ); + + if let Some(tunnel_ipv6) = tunnel_ipv6 { + let tunnel_ipv6 = tunnel_ipv6.octets(); + ptr::copy_nonoverlapping( + &tunnel_ipv6[0] as *const u8, + &mut addresses.tunnel_ipv6 as *mut _ as *mut u8, + tunnel_ipv6.len(), + ); + } + + let internet_ipv4 = internet_ipv4.octets(); + ptr::copy_nonoverlapping( + &internet_ipv4[0] as *const u8, + &mut addresses.internet_ipv4 as *mut _ as *mut u8, + internet_ipv4.len(), + ); + + if let Some(internet_ipv6) = internet_ipv6 { + let internet_ipv6 = internet_ipv6.octets(); + ptr::copy_nonoverlapping( + &internet_ipv6[0] as *const u8, + &mut addresses.internet_ipv6 as *mut _ as *mut u8, + internet_ipv6.len(), + ); + } + } + + let buffer = &addresses as *const _ as *const u8; + let buffer = + unsafe { std::slice::from_raw_parts(buffer, size_of::<SplitTunnelAddresses>()) }; + + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::RegisterIpAddresses as u32, + Some(buffer), + 0, + )?; + + Ok(()) + } + + pub fn get_driver_state(&self) -> io::Result<DriverState> { + let buffer = device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::GetState as u32, + None, + size_of::<u64>() as u32, + )? + .unwrap(); + + Ok(unsafe { deserialize_buffer(&buffer) }) + } + + pub fn set_config<T: AsRef<OsStr>>(&self, apps: &[T]) -> io::Result<()> { + let mut device_paths = Vec::with_capacity(apps.len()); + for app in apps.as_ref() { + device_paths.push(get_final_path_name(app)?); + } + + log::debug!("Excluded device paths:"); + for path in &device_paths { + log::debug!(" {:?}", path); + } + + let config = make_process_config(&device_paths); + + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::SetConfiguration as u32, + Some(&config), + 0, + )?; + + Ok(()) + } + + pub fn clear_config(&self) -> io::Result<()> { + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::ClearConfiguration as u32, + None, + 0, + )?; + + Ok(()) + } +} + +#[repr(C)] +struct SplitTunnelAddresses { + tunnel_ipv4: IN_ADDR, + internet_ipv4: IN_ADDR, + tunnel_ipv6: IN6_ADDR, + internet_ipv6: IN6_ADDR, +} + +#[repr(C)] +struct ConfigurationHeader { + // Number of entries immediately following the header. + num_entries: usize, + // Total byte length: header + entries + string buffer. + total_length: usize, +} + +#[repr(C)] +struct ConfigurationEntry { + // Offset into buffer region that follows all entries. + // The image name uses the physical path. + name_offset: usize, + // Byte length for non-null terminated wide char string. + name_length: u16, +} + +/// Create a buffer containing a `ConfigurationHeader` and number of `ConfigurationEntry`s, +/// followed by the same number of paths to those entries. +fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { + let apps: Vec<Vec<u16>> = apps + .iter() + .map(|app| app.as_ref().encode_wide().collect()) + .collect(); + + let total_string_size: usize = apps.iter().map(|app| size_of::<u16>() * app.len()).sum(); + + let total_buffer_size = size_of::<ConfigurationHeader>() + + size_of::<ConfigurationEntry>() * apps.len() + + total_string_size; + + let mut buffer = Vec::<u8>::new(); + buffer.resize(total_buffer_size, 0); + + let (header, tail) = buffer.split_at_mut(size_of::<ConfigurationHeader>()); + + // Serialize configuration header + let header_struct = ConfigurationHeader { + num_entries: apps.len(), + total_length: total_buffer_size, + }; + header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + + // Serialize configuration entries and strings + let (entries, string_data) = tail.split_at_mut(apps.len() * size_of::<ConfigurationEntry>()); + let mut string_offset = 0; + + for (i, app) in apps.iter().enumerate() { + write_string_to_buffer(string_data, string_offset, &app); + + let app_bytelen = size_of::<u16>() * app.len(); + let entry = ConfigurationEntry { + name_offset: string_offset, + name_length: app_bytelen as u16, + }; + let entry_offset = size_of::<ConfigurationEntry>() * i; + entries[entry_offset..entry_offset + size_of::<ConfigurationEntry>()] + .copy_from_slice(unsafe { as_u8_slice(&entry) }); + + string_offset += app_bytelen; + } + + buffer +} + +#[derive(Debug)] +struct ProcessInfo { + pid: u32, + parent_pid: u32, + creation_time: u64, + device_path: Vec<u16>, +} + +/// List process identifiers, their parents, and their device paths. +fn build_process_tree() -> io::Result<Vec<ProcessInfo>> { + let mut process_info = HashMap::new(); + + let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?; + for entry in snap.entries() { + let entry = entry?; + + let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) { + Ok(handle) => Ok(handle), + Err(error) => { + // Skip process objects that cannot be opened + match error.kind() { + // System process + io::ErrorKind::PermissionDenied => continue, + // System idle or csrss process + io::ErrorKind::InvalidInput => continue, + _ => Err(error), + } + } + }?; + + // TODO: Skip objects whose paths or timestamps cannot be obtained? + + process_info.insert( + entry.pid, + RefCell::new(ProcessInfo { + pid: entry.pid, + parent_pid: entry.parent_pid, + creation_time: get_process_creation_time(process.get_raw()).unwrap_or(0), + device_path: get_process_device_path(process.get_raw()) + .unwrap_or(OsString::from("")) + .encode_wide() + .collect(), + }), + ); + } + + // Handle pid recycling + // If the "parent" is younger than the process itself, it is not our parent. + for info in process_info.values() { + let mut info = info.borrow_mut(); + let parent_pid = info.parent_pid; + if parent_pid == 0 { + continue; + } + if let Some(parent_info) = process_info.get(&parent_pid) { + if parent_info.borrow_mut().creation_time > info.creation_time { + info.parent_pid = 0; + } + } + } + + Ok(process_info + .into_iter() + .map(|(_, info)| info.into_inner()) + .collect()) +} + +#[repr(C)] +struct ProcessRegistryHeader { + // Number of entries immediately following the header. + num_entries: usize, + // Total byte length: header + entries + string buffer. + total_length: usize, +} + +#[repr(C)] +struct ProcessRegistryEntry { + pid: RawHandle, + parent_pid: RawHandle, + // Image name offset (following the last entry). + image_name_offset: usize, + // Image name length. + image_name_size: u16, +} + +fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Error> { + // Construct a buffer: + // ProcessRegistryHeader + // ProcessRegistryEntry.. + // Image names.. + + let total_string_size: usize = processes + .iter() + .map(|info| size_of::<u16>() * info.device_path.len()) + .sum(); + let total_buffer_size = size_of::<ProcessRegistryHeader>() + + size_of::<ProcessRegistryEntry>() * processes.len() + + total_string_size; + + let mut buffer = Vec::<u8>::new(); + buffer.resize(total_buffer_size, 0); + + let (header, tail) = buffer.split_at_mut(size_of::<ProcessRegistryHeader>()); + let header_struct = ProcessRegistryHeader { + num_entries: processes.len(), + total_length: total_buffer_size, + }; + header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + + let (entries, string_data) = + tail.split_at_mut(size_of::<ProcessRegistryEntry>() * processes.len()); + + let mut string_offset = 0; + + for (i, entry) in processes.into_iter().enumerate() { + let mut out_entry = ProcessRegistryEntry { + pid: entry.pid as usize as RawHandle, + parent_pid: entry.parent_pid as usize as RawHandle, + image_name_size: 0, + image_name_offset: 0, + }; + + if !entry.device_path.is_empty() { + write_string_to_buffer(string_data, string_offset, &entry.device_path); + + out_entry.image_name_size = (entry.device_path.len() * size_of::<u16>()) as u16; + out_entry.image_name_offset = string_offset; + + string_offset += size_of::<u16>() * entry.device_path.len(); + } + + let entry_offset = size_of::<ProcessRegistryEntry>() * i; + entries[entry_offset..entry_offset + size_of::<ProcessRegistryEntry>()] + .copy_from_slice(unsafe { as_u8_slice(&out_entry) }); + } + + Ok(buffer) +} + +/// Send an IOCTL code to the given device handle. +/// `input` specifies an optional buffer to send. +/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0. +pub fn device_io_control( + device: RawHandle, + ioctl_code: u32, + input: Option<&[u8]>, + output_size: u32, +) -> Result<Option<Vec<u8>>, io::Error> { + let input_ptr = match input { + Some(input) => input as *const _ as *mut _, + None => ptr::null_mut(), + }; + let input_len = input.map(|input| input.len()).unwrap_or(0); + + let mut out_buffer = if output_size > 0 { + Some(Vec::with_capacity(output_size as usize)) + } else { + None + }; + + let out_ptr = match out_buffer { + Some(ref mut out_buffer) => out_buffer.as_mut_ptr() as *mut _, + None => ptr::null_mut(), + }; + + let mut returned_bytes = 0u32; + + let result = unsafe { + DeviceIoControl( + device as *mut _, + ioctl_code, + input_ptr, + input_len as u32, + out_ptr, + output_size, + &mut returned_bytes as *mut _, + ptr::null_mut(), // TODO + ) + }; + + if let Some(ref mut out_buffer) = out_buffer { + unsafe { out_buffer.set_len(returned_bytes as usize) }; + } + + if result != 0 { + Ok(out_buffer) + } else { + Err(io::Error::last_os_error()) + } +} + +/// Creates a new instance of an arbitrary type from a byte buffer. +pub unsafe fn deserialize_buffer<T: Sized>(buffer: &Vec<u8>) -> T { + let mut instance: T = mem::zeroed(); + ptr::copy_nonoverlapping(buffer.as_ptr() as *const T, &mut instance as *mut _, 1); + instance +} + +fn write_string_to_buffer(buffer: &mut [u8], byte_offset: usize, string: &[u16]) { + for (i, byte) in string + .iter() + .flat_map(|word| std::array::IntoIter::new(word.to_ne_bytes())) + .enumerate() + { + buffer[byte_offset + i] = byte; + } +} + +unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] { + std::slice::from_raw_parts(object as *const _ as *const _, size_of::<T>()) +} diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs new file mode 100644 index 0000000000..c6b8fae332 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -0,0 +1,69 @@ +mod driver; +mod windows; + +use std::{ + ffi::OsStr, + io, + net::{Ipv4Addr, Ipv6Addr}, +}; +use talpid_types::ErrorExt; + +/// Errors that may occur in [`SplitTunnel`]. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to identify or initialize the driver + #[error(display = "Failed to find or initialize driver")] + InitializationFailed(#[error(source)] io::Error), + + /// Failed to set paths to excluded applications + #[error(display = "Failed to set list of excluded applications")] + SetConfiguration(#[error(source)] io::Error), + + /// Failed to register interface IP addresses + #[error(display = "Failed to register IP addresses for exclusions")] + RegisterIps(#[error(source)] io::Error), +} + +/// Manages applications whose traffic to exclude from the tunnel. +pub struct SplitTunnel(driver::DeviceHandle); + +impl SplitTunnel { + /// Initialize the driver. + pub fn new() -> Result<Self, Error> { + Ok(SplitTunnel( + driver::DeviceHandle::new().map_err(Error::InitializationFailed)?, + )) + } + + /// Set a list of applications to exclude from the tunnel. + pub fn set_paths<T: AsRef<OsStr>>(&self, paths: &[T]) -> Result<(), Error> { + if paths.len() > 0 { + self.0.set_config(paths).map_err(Error::SetConfiguration) + } else { + self.0.clear_config().map_err(Error::SetConfiguration) + } + } + + /// Configures IP addresses used for socket rebinding. + pub fn register_ips( + &self, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> Result<(), Error> { + self.0 + .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6) + .map_err(Error::RegisterIps) + } +} + +impl Drop for SplitTunnel { + fn drop(&mut self) { + let paths: [&OsStr; 0] = []; + if let Err(error) = self.set_paths(&paths) { + log::error!("{}", error.display_chain()); + } + } +} diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs new file mode 100644 index 0000000000..be8631d53c --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -0,0 +1,259 @@ +// TODO: The snapshot code could be combined with the mostly-identical code in +// the windows_exception_logging module. + +use std::{ + ffi::{OsStr, OsString}, + fs::OpenOptions, + io, mem, + os::windows::{ + ffi::OsStringExt, + io::{AsRawHandle, RawHandle}, + }, + ptr, +}; +use winapi::{ + shared::{ + minwindef::{DWORD, FALSE, FILETIME, TRUE}, + ntdef::ULARGE_INTEGER, + winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES}, + }, + um::{ + fileapi::GetFinalPathNameByHandleW, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + processthreadsapi::{GetProcessTimes, OpenProcess}, + psapi::K32GetProcessImageFileNameW, + tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W}, + winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION}, + }, +}; + +/// Return path with the volume device path. +const VOLUME_NAME_NT: u32 = 0x02; + +pub struct ProcessSnapshot { + handle: HANDLE, +} + +impl ProcessSnapshot { + pub fn new(flags: DWORD, process_id: DWORD) -> io::Result<ProcessSnapshot> { + let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; + + if snap == INVALID_HANDLE_VALUE { + Err(io::Error::last_os_error()) + } else { + Ok(ProcessSnapshot { handle: snap }) + } + } + + pub fn handle(&self) -> HANDLE { + self.handle + } + + pub fn entries(&self) -> ProcessSnapshotEntries<'_> { + let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() }; + entry.dwSize = mem::size_of::<PROCESSENTRY32W>() as u32; + + ProcessSnapshotEntries { + snapshot: self, + iter_started: false, + temp_entry: entry, + } + } +} + +impl Drop for ProcessSnapshot { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +pub struct ProcessEntry { + pub pid: u32, + pub parent_pid: u32, +} + +pub struct ProcessSnapshotEntries<'a> { + snapshot: &'a ProcessSnapshot, + iter_started: bool, + temp_entry: PROCESSENTRY32W, +} + +impl Iterator for ProcessSnapshotEntries<'_> { + type Item = io::Result<ProcessEntry>; + + fn next(&mut self) -> Option<io::Result<ProcessEntry>> { + if self.iter_started { + if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + let last_error = io::Error::last_os_error(); + + return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { + None + } else { + Some(Err(last_error)) + }; + } + } else { + if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + return Some(Err(io::Error::last_os_error())); + } + self.iter_started = true; + } + + Some(Ok(ProcessEntry { + pid: self.temp_entry.th32ProcessID, + parent_pid: self.temp_entry.th32ParentProcessID, + })) + } +} + +pub fn get_final_path_name<T: AsRef<OsStr>>(path: T) -> Result<OsString, io::Error> { + // TODO: verify that all flags, including security flags, are ok + // TODO: verify that the file is a PE executable? + // TODO: verify that the executable is on a physical drive? + let file = OpenOptions::new().read(true).open(path.as_ref())?; + get_final_path_name_by_handle(file.as_raw_handle()) +} + +pub fn get_final_path_name_by_handle(raw_handle: RawHandle) -> Result<OsString, io::Error> { + let buffer_size = unsafe { + GetFinalPathNameByHandleW(raw_handle as *mut _, ptr::null_mut(), 0u32, VOLUME_NAME_NT) + } as usize; + + if buffer_size == 0 { + return Err(io::Error::last_os_error()); + } + + let mut buffer = Vec::new(); + buffer.reserve_exact(buffer_size); + + let status = unsafe { + GetFinalPathNameByHandleW( + raw_handle as *mut _, + buffer.as_mut_ptr(), + buffer_size as u32, + VOLUME_NAME_NT, + ) + } as usize; + + if status == 0 { + return Err(io::Error::last_os_error()); + } + + unsafe { buffer.set_len(buffer_size - 1) }; + + // TODO: can this be done by stealing 'buffer' instead of copying it? + Ok(OsStringExt::from_wide(&buffer)) +} + +/// Object that frees its handle when dropped. +pub struct WinHandle(RawHandle); + +impl WinHandle { + pub fn get_raw(&self) -> RawHandle { + self.0 + } +} + +impl Drop for WinHandle { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + +#[repr(u32)] +pub enum ProcessAccess { + QueryLimitedInformation = PROCESS_QUERY_LIMITED_INFORMATION, + // TODO: could be extended +} + +/// Open an existing process object. +pub fn open_process( + access: ProcessAccess, + inherit_handle: bool, + pid: u32, +) -> Result<WinHandle, io::Error> { + let handle = unsafe { + OpenProcess( + access as u32, + if inherit_handle { TRUE } else { FALSE }, + pid, + ) + }; + + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(WinHandle(handle)) +} + +/// Returns the age of a running process. +pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> { + // TODO: FileTimeToSystemTime -> chrono::NaiveDateTime + let mut creation_time: FILETIME = unsafe { mem::zeroed() }; + let mut dummy: FILETIME = unsafe { mem::zeroed() }; + if unsafe { + GetProcessTimes( + handle, + &mut creation_time as *mut _, + &mut dummy as *mut _, + &mut dummy as *mut _, + &mut dummy as *mut _, + ) + } == 0 + { + return Err(io::Error::last_os_error()); + } + + let mut uli_time: ULARGE_INTEGER = unsafe { mem::zeroed() }; + unsafe { + uli_time.s_mut().LowPart = creation_time.dwLowDateTime; + uli_time.s_mut().HighPart = creation_time.dwHighDateTime; + } + + Ok(*unsafe { uli_time.QuadPart() }) +} + +/// Returns the device path for a running process. +pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> { + let mut initial_capacity = 512; + loop { + let result = get_process_device_path_inner(handle, initial_capacity); + match result { + Ok(path) => return Ok(path), + Err(error) => { + if ERROR_INSUFFICIENT_BUFFER == error.raw_os_error().unwrap() as u32 { + // Try again with a larger buffer capacity. + initial_capacity *= 2; + continue; + } + return Err(error); + } + } + } +} + +fn get_process_device_path_inner( + handle: RawHandle, + buffer_capacity: usize, +) -> Result<OsString, io::Error> { + let mut buffer = Vec::<u16>::new(); + buffer.reserve_exact(buffer_capacity); + + let written = unsafe { + K32GetProcessImageFileNameW( + handle, + buffer.as_mut_ptr() as *mut _, + buffer.capacity() as u32, + ) + }; + if written == 0 { + return Err(io::Error::last_os_error()); + } + + // `written` does not include a null terminator + unsafe { buffer.set_len(written as usize) }; + + Ok(OsStringExt::from_wide(&buffer)) +} diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 72af2b6f38..d7bf24b49e 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -7,9 +7,20 @@ use crate::{ firewall::FirewallPolicy, tunnel::{CloseHandle, TunnelEvent, TunnelMetadata}, }; +#[cfg(windows)] +use crate::{ + split_tunnel::{self, SplitTunnel}, + winnet::{self, get_best_default_route, interface_luid_to_ip, WinNetAddrFamily}, +}; use cfg_if::cfg_if; use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::net::IpAddr; +#[cfg(windows)] +use std::{ + ffi::OsStr, + net::{Ipv4Addr, Ipv6Addr}, + sync::{Arc, Mutex}, +}; use talpid_types::{ net::TunnelParameters, tunnel::{ErrorStateCause, FirewallPolicyError}, @@ -116,6 +127,137 @@ impl ConnectedState { } } + #[cfg(target_os = "windows")] + pub unsafe extern "system" fn split_tunnel_default_route_change_handler( + event_type: winnet::WinNetDefaultRouteChangeEventType, + address_family: WinNetAddrFamily, + default_route: winnet::WinNetDefaultRoute, + ctx: *mut libc::c_void, + ) { + // Update the "internet interface" IP when best default route changes + let ctx = &mut *(ctx as *mut SplitTunnelDefaultRouteChangeHandlerContext); + + let result = match event_type { + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => { + let ip = interface_luid_to_ip(address_family.clone(), default_route.interface_luid); + + // TODO: Should we block here? + let ip = match ip { + Ok(Some(ip)) => ip, + Ok(None) => { + log::error!("Failed to obtain new default route address: none found",); + // Early return + return; + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to obtain new default route address" + ) + ); + // Early return + return; + } + }; + + match address_family { + WinNetAddrFamily::IPV4 => { + let ip = Ipv4Addr::from(ip); + ctx.internet_ipv4 = ip; + } + WinNetAddrFamily::IPV6 => { + let ip = Ipv6Addr::from(ip); + ctx.internet_ipv6 = Some(ip); + } + } + + ctx.register_ips() + } + // no default route + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => { + match address_family { + WinNetAddrFamily::IPV4 => { + ctx.internet_ipv4 = Ipv4Addr::new(0, 0, 0, 0); + } + WinNetAddrFamily::IPV6 => { + ctx.internet_ipv6 = None; + } + } + ctx.register_ips() + } + }; + + if let Err(error) = result { + // TODO: Should we block here? + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to register new addresses in split tunnel driver" + ) + ); + } + } + + #[cfg(windows)] + fn update_split_tunnel_addresses( + &self, + shared_values: &mut SharedTunnelStateValues, + ) -> Result<(), BoxedError> { + // Identify tunnel IP addresses + // TODO: Multiple IP addresses? + let mut tunnel_ipv4 = None; + let mut tunnel_ipv6 = None; + + for ip in &self.metadata.ips { + match ip { + IpAddr::V4(address) => tunnel_ipv4 = Some(address.clone()), + IpAddr::V6(address) => tunnel_ipv6 = Some(address.clone()), + } + } + + // Identify IP address that gives us Internet access + let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4) + .map_err(BoxedError::new)? + .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV4, route.interface_luid)) + .transpose() + .map_err(BoxedError::new)? + .flatten(); + let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6) + .map_err(BoxedError::new)? + .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV6, route.interface_luid)) + .transpose() + .map_err(BoxedError::new)? + .flatten(); + + let tunnel_ipv4 = tunnel_ipv4.unwrap_or(Ipv4Addr::new(0, 0, 0, 0)); + let internet_ipv4 = Ipv4Addr::from(internet_ipv4.unwrap_or_default()); + let internet_ipv6 = internet_ipv6.map(|addr| Ipv6Addr::from(addr)); + + let context = SplitTunnelDefaultRouteChangeHandlerContext::new( + shared_values.split_tunnel.clone(), + tunnel_ipv4, + tunnel_ipv6, + internet_ipv4, + internet_ipv6, + ); + + shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex") + .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6) + .map_err(BoxedError::new)?; + + #[cfg(target_os = "windows")] + shared_values.route_manager.add_default_route_callback( + Some(Self::split_tunnel_default_route_change_handler), + context, + ); + + Ok(()) + } + fn set_dns(&self, shared_values: &mut SharedTunnelStateValues) -> Result<(), BoxedError> { let dns_ips = self.get_dns_servers(shared_values); shared_values @@ -150,6 +292,18 @@ impl ConnectedState { } } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn disconnect( self, shared_values: &mut SharedTunnelStateValues, @@ -158,6 +312,24 @@ impl ConnectedState { Self::reset_dns(shared_values); Self::reset_routes(shared_values); + #[cfg(windows)] + if let Err(error) = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex") + .register_ips( + Ipv4Addr::new(0, 0, 0, 0), + None, + Ipv4Addr::new(0, 0, 0, 0), + None, + ) + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to unregister IP addresses") + ); + } + EventConsequence::NewState(DisconnectingState::enter( shared_values, (self.close_handle, self.tunnel_close_event, after_disconnect), @@ -257,6 +429,11 @@ impl ConnectedState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } @@ -326,6 +503,19 @@ impl TunnelState for ConnectedState { ), ) } else { + #[cfg(windows)] + if let Err(error) = connected_state.update_split_tunnel_addresses(shared_values) { + log::error!("{}", error.display_chain()); + return DisconnectingState::enter( + shared_values, + ( + connected_state.close_handle, + connected_state.tunnel_close_event, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ), + ); + } + ( TunnelStateWrapper::from(connected_state), TunnelStateTransition::Connected(tunnel_endpoint), @@ -360,3 +550,44 @@ impl TunnelState for ConnectedState { } } } + +#[cfg(target_os = "windows")] +struct SplitTunnelDefaultRouteChangeHandlerContext { + split_tunnel: Arc<Mutex<SplitTunnel>>, + pub tunnel_ipv4: Ipv4Addr, + pub tunnel_ipv6: Option<Ipv6Addr>, + pub internet_ipv4: Ipv4Addr, + pub internet_ipv6: Option<Ipv6Addr>, +} + +#[cfg(target_os = "windows")] +impl SplitTunnelDefaultRouteChangeHandlerContext { + pub fn new( + split_tunnel: Arc<Mutex<SplitTunnel>>, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> Self { + SplitTunnelDefaultRouteChangeHandlerContext { + split_tunnel, + tunnel_ipv4, + tunnel_ipv6, + internet_ipv4, + internet_ipv6, + } + } + + pub fn register_ips(&self) -> Result<(), split_tunnel::Error> { + let split_tunnel = self + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.register_ips( + self.tunnel_ipv4, + self.tunnel_ipv6, + self.internet_ipv4, + self.internet_ipv6, + ) + } +} diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index b0c87acdb4..34e9eeb2be 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -3,6 +3,8 @@ use super::{ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::{ firewall::FirewallPolicy, routing::RouteManager, @@ -17,6 +19,8 @@ use futures::{ FutureExt, StreamExt, }; use log::{debug, error, info, trace, warn}; +#[cfg(windows)] +use std::ffi::OsStr; use std::{ path::{Path, PathBuf}, thread, @@ -89,6 +93,18 @@ impl ConnectingState { }) } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn start_tunnel( runtime: tokio::runtime::Handle, parameters: TunnelParameters, @@ -314,6 +330,11 @@ impl ConnectingState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 8d2c9bc0fa..cfa7794af7 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -3,7 +3,11 @@ use super::{ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; +#[cfg(windows)] +use crate::split_tunnel; use futures::StreamExt; +#[cfg(windows)] +use std::ffi::OsStr; use talpid_types::ErrorExt; /// No tunnel is running. @@ -36,6 +40,18 @@ impl DisconnectedState { log::error!("{}", error_chain); } } + + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } } impl TunnelState for DisconnectedState { @@ -115,6 +131,11 @@ impl TunnelState for DisconnectedState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } Some(_) => SameState(self.into()), None => Finished, } diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 7d308d5971..e488896042 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -3,8 +3,12 @@ use super::{ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::tunnel::CloseHandle; use futures::{future::FusedFuture, StreamExt}; +#[cfg(windows)] +use std::ffi::OsStr; use std::thread; use talpid_types::{ tunnel::{ActionAfterDisconnect, ErrorStateCause}, @@ -59,6 +63,11 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Nothing } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Nothing + } }, AfterDisconnect::Block(reason) => match command { Some(TunnelCommand::AllowLan(allow_lan)) => { @@ -96,6 +105,11 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Block(reason) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Block(reason) + } None => AfterDisconnect::Block(reason), }, AfterDisconnect::Reconnect(retry_attempt) => match command { @@ -134,12 +148,29 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Reconnect(retry_attempt) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Reconnect(retry_attempt) + } }, }; EventConsequence::SameState(self.into()) } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn after_disconnect( self, block_reason: Option<ErrorStateCause>, diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 5e647c8201..6d772a62b8 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -3,7 +3,11 @@ use super::{ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; +#[cfg(windows)] +use crate::split_tunnel; use futures::StreamExt; +#[cfg(windows)] +use std::ffi::OsStr; use talpid_types::{ tunnel::{self as talpid_tunnel, ErrorStateCause, FirewallPolicyError}, ErrorExt, @@ -61,6 +65,18 @@ impl ErrorState { } } } + + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } } impl TunnelState for ErrorState { @@ -151,12 +167,17 @@ impl TunnelState for ErrorState { Some(TunnelCommand::Block(reason)) => { NewState(ErrorState::enter(shared_values, reason)) } - #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + // TODO: Do nothing here? + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 38496865a4..969c2116a1 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -11,6 +11,8 @@ use self::{ disconnecting_state::{AfterDisconnect, DisconnectingState}, error_state::ErrorState, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::{ dns::DnsMonitor, firewall::{Firewall, FirewallArguments}, @@ -19,6 +21,11 @@ use crate::{ routing::RouteManager, tunnel::{tun_provider::TunProvider, TunnelEvent}, }; +#[cfg(windows)] +use std::ffi::OsString; +#[cfg(windows)] +use std::sync::Mutex; + use futures::{ channel::{mpsc, oneshot}, stream, StreamExt, @@ -47,9 +54,9 @@ pub enum Error { OfflineMonitorError(#[error(source)] crate::offline::Error), /// Unable to set up split tunneling - #[cfg(target_os = "linux")] + #[cfg(target_os = "windows")] #[error(display = "Failed to initialize split tunneling")] - InitSplitTunneling(#[error(source)] crate::split_tunnel::Error), + InitSplitTunneling(#[error(source)] split_tunnel::Error), /// Failed to initialize the system firewall integration. #[error(display = "Failed to initialize the system firewall integration")] @@ -86,6 +93,7 @@ pub async fn spawn( shutdown_tx: oneshot::Sender<()>, reset_firewall: bool, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(windows)] exclude_paths: Vec<OsString>, ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); @@ -122,6 +130,8 @@ pub async fn spawn( reset_firewall, #[cfg(target_os = "android")] android_context, + #[cfg(windows)] + exclude_paths, )); let state_machine = match state_machine { Ok(state_machine) => { @@ -169,6 +179,12 @@ pub enum TunnelCommand { /// Bypass a socket, allowing traffic to flow through outside the tunnel. #[cfg(target_os = "android")] BypassSocket(RawFd, oneshot::Sender<()>), + /// Set applications that are allowed to send and receive traffic outside of the tunnel. + #[cfg(windows)] + SetExcludedApps( + oneshot::Sender<Result<(), split_tunnel::Error>>, + Vec<OsString>, + ), } type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>; @@ -207,6 +223,7 @@ impl TunnelStateMachine { commands: mpsc::UnboundedReceiver<TunnelCommand>, reset_firewall: bool, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(windows)] exclude_paths: Vec<OsString>, ) -> Result<Self, Error> { let args = FirewallArguments { initialize_blocked: block_when_disconnected || !reset_firewall, @@ -239,6 +256,14 @@ impl TunnelStateMachine { .await .map_err(Error::OfflineMonitorError)?; let is_offline = offline_monitor.is_offline().await; + + #[cfg(windows)] + let split_tunnel = split_tunnel::SplitTunnel::new().map_err(Error::InitSplitTunneling)?; + #[cfg(windows)] + split_tunnel + .set_paths(&exclude_paths) + .map_err(Error::InitSplitTunneling)?; + let mut shared_values = SharedTunnelStateValues { runtime, firewall, @@ -256,6 +281,8 @@ impl TunnelStateMachine { resource_dir, #[cfg(target_os = "linux")] connectivity_check_was_enabled: None, + #[cfg(windows)] + split_tunnel: Arc::new(Mutex::new(split_tunnel)), }; let (initial_state, _) = DisconnectedState::enter(&mut shared_values, reset_firewall); @@ -337,6 +364,9 @@ struct SharedTunnelStateValues { /// NetworkManager's connecitivity check state. #[cfg(target_os = "linux")] connectivity_check_was_enabled: Option<bool>, + /// Management of excluded apps. + #[cfg(windows)] + split_tunnel: Arc<Mutex<split_tunnel::SplitTunnel>>, } impl SharedTunnelStateValues { diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 79008f8cbc..8b3d503462 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -85,6 +85,7 @@ pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, E } } +#[derive(Debug, Clone)] #[allow(dead_code)] #[repr(u32)] pub enum WinNetAddrFamily { @@ -121,15 +122,33 @@ pub struct WinNetDefaultRoute { pub gateway: WinNetIp, } -impl From<WinNetIp> for IpAddr { - fn from(addr: WinNetIp) -> IpAddr { +impl From<WinNetIp> for Ipv4Addr { + fn from(addr: WinNetIp) -> Ipv4Addr { match addr.addr_family { WinNetAddrFamily::IPV4 => { let mut bytes: [u8; 4] = Default::default(); bytes.clone_from_slice(&addr.ip_bytes[..4]); - IpAddr::V4(Ipv4Addr::from(bytes)) + Ipv4Addr::from(bytes) } - WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)), + WinNetAddrFamily::IPV6 => panic!("address family mismatch"), + } + } +} + +impl From<WinNetIp> for Ipv6Addr { + fn from(addr: WinNetIp) -> Ipv6Addr { + match addr.addr_family { + WinNetAddrFamily::IPV4 => panic!("address family mismatch"), + WinNetAddrFamily::IPV6 => Ipv6Addr::from(addr.ip_bytes), + } + } +} + +impl From<WinNetIp> for IpAddr { + fn from(addr: WinNetIp) -> IpAddr { + match addr.addr_family { + WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::from(addr)), + WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr)), } } } |
