summaryrefslogtreecommitdiffhomepage
path: root/talpid-core/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-05-29 09:45:17 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-07-02 09:54:19 +0200
commite5baa0e08816d535a031b3d8575701b8d43fb0c2 (patch)
treec4bf2ec1956977676bc25c2630bd38789f43dade /talpid-core/src
parent207ab239223686ff72c43a8a5d615565ab81b5ab (diff)
downloadmullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.tar.xz
mullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.zip
Support Windows split tunneling driver
Diffstat (limited to 'talpid-core/src')
-rw-r--r--talpid-core/src/split_tunnel/mod.rs7
-rw-r--r--talpid-core/src/split_tunnel/windows/driver.rs514
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs69
-rw-r--r--talpid-core/src/split_tunnel/windows/windows.rs259
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs231
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs21
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs21
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs31
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs23
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs34
-rw-r--r--talpid-core/src/winnet.rs27
11 files changed, 1230 insertions, 7 deletions
diff --git a/talpid-core/src/split_tunnel/mod.rs b/talpid-core/src/split_tunnel/mod.rs
index c7c366d6ea..3c3f6af294 100644
--- a/talpid-core/src/split_tunnel/mod.rs
+++ b/talpid-core/src/split_tunnel/mod.rs
@@ -4,3 +4,10 @@ mod imp;
#[cfg(target_os = "linux")]
pub use imp::*;
+
+#[cfg(windows)]
+#[path = "windows/mod.rs"]
+mod imp;
+
+#[cfg(windows)]
+pub use imp::*;
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs
new file mode 100644
index 0000000000..26495a5877
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/driver.rs
@@ -0,0 +1,514 @@
+use super::windows::{
+ get_final_path_name, get_process_creation_time, get_process_device_path, open_process,
+ ProcessAccess, ProcessSnapshot,
+};
+use std::{
+ cell::RefCell,
+ collections::HashMap,
+ ffi::{OsStr, OsString},
+ fs::{self, OpenOptions},
+ io,
+ mem::{self, size_of},
+ net::{Ipv4Addr, Ipv6Addr},
+ os::windows::{
+ ffi::OsStrExt,
+ fs::OpenOptionsExt,
+ io::{AsRawHandle, RawHandle},
+ },
+ ptr,
+};
+use winapi::{
+ shared::{in6addr::IN6_ADDR, inaddr::IN_ADDR},
+ um::{
+ ioapiset::DeviceIoControl,
+ tlhelp32::TH32CS_SNAPPROCESS,
+ winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER},
+ },
+};
+
+const DRIVER_SYMBOLIC_NAME: &str = "\\\\.\\MULLVADSPLITTUNNEL";
+const ST_DEVICE_TYPE: u32 = 0x8000;
+
+const fn ctl_code(device_type: u32, function: u32, method: u32, access: u32) -> u32 {
+ device_type << 16 | access << 14 | function << 2 | method
+}
+
+#[repr(u32)]
+#[allow(dead_code)]
+enum DriverIoctlCode {
+ Initialize = ctl_code(ST_DEVICE_TYPE, 1, METHOD_NEITHER, FILE_ANY_ACCESS),
+ DequeEvent = ctl_code(ST_DEVICE_TYPE, 2, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ RegisterProcesses = ctl_code(ST_DEVICE_TYPE, 3, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ RegisterIpAddresses = ctl_code(ST_DEVICE_TYPE, 4, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ GetIpAddresses = ctl_code(ST_DEVICE_TYPE, 5, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ SetConfiguration = ctl_code(ST_DEVICE_TYPE, 6, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ GetConfiguration = ctl_code(ST_DEVICE_TYPE, 7, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ ClearConfiguration = ctl_code(ST_DEVICE_TYPE, 8, METHOD_NEITHER, FILE_ANY_ACCESS),
+ GetState = ctl_code(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ QueryProcess = ctl_code(ST_DEVICE_TYPE, 10, METHOD_BUFFERED, FILE_ANY_ACCESS),
+}
+
+#[derive(Debug, PartialEq)]
+#[repr(u32)]
+#[allow(dead_code)]
+pub enum DriverState {
+ // Default state after being loaded.
+ None = 0,
+ // DriverEntry has completed successfully.
+ // Basically only driver and device objects are created at this point.
+ Started = 1,
+ // All subsystems are initialized.
+ Initialized = 2,
+ // User mode has registered all processes in the system.
+ Ready = 3,
+ // IP addresses are registered.
+ // A valid configuration is registered.
+ Engaged = 4,
+ // Driver is unloading.
+ Terminating = 5,
+}
+
+pub struct DeviceHandle {
+ handle: fs::File,
+}
+
+impl DeviceHandle {
+ pub fn new() -> io::Result<Self> {
+ // Connect to the driver
+ log::trace!("Connecting to the driver");
+ let handle = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .share_mode(0)
+ .custom_flags(0)
+ .attributes(0)
+ .open(DRIVER_SYMBOLIC_NAME)?;
+
+ let device = Self { handle };
+
+ // Initialize the driver
+ let state = device.get_driver_state()?;
+ if state == DriverState::Started {
+ log::trace!("Initializing driver");
+ device.initialize()?;
+ }
+
+ // Initialize process tree
+ let state = device.get_driver_state()?;
+ if state == DriverState::Initialized {
+ log::trace!("Registering processes");
+ device.register_processes()?;
+ }
+
+ Ok(device)
+ }
+
+ fn initialize(&self) -> io::Result<()> {
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::Initialize as u32,
+ None,
+ 0,
+ )?;
+ Ok(())
+ }
+
+ fn register_processes(&self) -> io::Result<()> {
+ let process_tree_buffer = serialize_process_tree(build_process_tree()?)?;
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::RegisterProcesses as u32,
+ Some(&process_tree_buffer),
+ 0,
+ )?;
+ Ok(())
+ }
+
+ pub fn register_ips(
+ &self,
+ tunnel_ipv4: Ipv4Addr,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Ipv4Addr,
+ internet_ipv6: Option<Ipv6Addr>,
+ ) -> io::Result<()> {
+ let mut addresses: SplitTunnelAddresses = unsafe { mem::zeroed() };
+
+ unsafe {
+ let tunnel_ipv4 = tunnel_ipv4.octets();
+ ptr::copy_nonoverlapping(
+ &tunnel_ipv4[0] as *const u8,
+ &mut addresses.tunnel_ipv4 as *mut _ as *mut u8,
+ tunnel_ipv4.len(),
+ );
+
+ if let Some(tunnel_ipv6) = tunnel_ipv6 {
+ let tunnel_ipv6 = tunnel_ipv6.octets();
+ ptr::copy_nonoverlapping(
+ &tunnel_ipv6[0] as *const u8,
+ &mut addresses.tunnel_ipv6 as *mut _ as *mut u8,
+ tunnel_ipv6.len(),
+ );
+ }
+
+ let internet_ipv4 = internet_ipv4.octets();
+ ptr::copy_nonoverlapping(
+ &internet_ipv4[0] as *const u8,
+ &mut addresses.internet_ipv4 as *mut _ as *mut u8,
+ internet_ipv4.len(),
+ );
+
+ if let Some(internet_ipv6) = internet_ipv6 {
+ let internet_ipv6 = internet_ipv6.octets();
+ ptr::copy_nonoverlapping(
+ &internet_ipv6[0] as *const u8,
+ &mut addresses.internet_ipv6 as *mut _ as *mut u8,
+ internet_ipv6.len(),
+ );
+ }
+ }
+
+ let buffer = &addresses as *const _ as *const u8;
+ let buffer =
+ unsafe { std::slice::from_raw_parts(buffer, size_of::<SplitTunnelAddresses>()) };
+
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::RegisterIpAddresses as u32,
+ Some(buffer),
+ 0,
+ )?;
+
+ Ok(())
+ }
+
+ pub fn get_driver_state(&self) -> io::Result<DriverState> {
+ let buffer = device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::GetState as u32,
+ None,
+ size_of::<u64>() as u32,
+ )?
+ .unwrap();
+
+ Ok(unsafe { deserialize_buffer(&buffer) })
+ }
+
+ pub fn set_config<T: AsRef<OsStr>>(&self, apps: &[T]) -> io::Result<()> {
+ let mut device_paths = Vec::with_capacity(apps.len());
+ for app in apps.as_ref() {
+ device_paths.push(get_final_path_name(app)?);
+ }
+
+ log::debug!("Excluded device paths:");
+ for path in &device_paths {
+ log::debug!(" {:?}", path);
+ }
+
+ let config = make_process_config(&device_paths);
+
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::SetConfiguration as u32,
+ Some(&config),
+ 0,
+ )?;
+
+ Ok(())
+ }
+
+ pub fn clear_config(&self) -> io::Result<()> {
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::ClearConfiguration as u32,
+ None,
+ 0,
+ )?;
+
+ Ok(())
+ }
+}
+
+#[repr(C)]
+struct SplitTunnelAddresses {
+ tunnel_ipv4: IN_ADDR,
+ internet_ipv4: IN_ADDR,
+ tunnel_ipv6: IN6_ADDR,
+ internet_ipv6: IN6_ADDR,
+}
+
+#[repr(C)]
+struct ConfigurationHeader {
+ // Number of entries immediately following the header.
+ num_entries: usize,
+ // Total byte length: header + entries + string buffer.
+ total_length: usize,
+}
+
+#[repr(C)]
+struct ConfigurationEntry {
+ // Offset into buffer region that follows all entries.
+ // The image name uses the physical path.
+ name_offset: usize,
+ // Byte length for non-null terminated wide char string.
+ name_length: u16,
+}
+
+/// Create a buffer containing a `ConfigurationHeader` and number of `ConfigurationEntry`s,
+/// followed by the same number of paths to those entries.
+fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> {
+ let apps: Vec<Vec<u16>> = apps
+ .iter()
+ .map(|app| app.as_ref().encode_wide().collect())
+ .collect();
+
+ let total_string_size: usize = apps.iter().map(|app| size_of::<u16>() * app.len()).sum();
+
+ let total_buffer_size = size_of::<ConfigurationHeader>()
+ + size_of::<ConfigurationEntry>() * apps.len()
+ + total_string_size;
+
+ let mut buffer = Vec::<u8>::new();
+ buffer.resize(total_buffer_size, 0);
+
+ let (header, tail) = buffer.split_at_mut(size_of::<ConfigurationHeader>());
+
+ // Serialize configuration header
+ let header_struct = ConfigurationHeader {
+ num_entries: apps.len(),
+ total_length: total_buffer_size,
+ };
+ header.copy_from_slice(unsafe { as_u8_slice(&header_struct) });
+
+ // Serialize configuration entries and strings
+ let (entries, string_data) = tail.split_at_mut(apps.len() * size_of::<ConfigurationEntry>());
+ let mut string_offset = 0;
+
+ for (i, app) in apps.iter().enumerate() {
+ write_string_to_buffer(string_data, string_offset, &app);
+
+ let app_bytelen = size_of::<u16>() * app.len();
+ let entry = ConfigurationEntry {
+ name_offset: string_offset,
+ name_length: app_bytelen as u16,
+ };
+ let entry_offset = size_of::<ConfigurationEntry>() * i;
+ entries[entry_offset..entry_offset + size_of::<ConfigurationEntry>()]
+ .copy_from_slice(unsafe { as_u8_slice(&entry) });
+
+ string_offset += app_bytelen;
+ }
+
+ buffer
+}
+
+#[derive(Debug)]
+struct ProcessInfo {
+ pid: u32,
+ parent_pid: u32,
+ creation_time: u64,
+ device_path: Vec<u16>,
+}
+
+/// List process identifiers, their parents, and their device paths.
+fn build_process_tree() -> io::Result<Vec<ProcessInfo>> {
+ let mut process_info = HashMap::new();
+
+ let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?;
+ for entry in snap.entries() {
+ let entry = entry?;
+
+ let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) {
+ Ok(handle) => Ok(handle),
+ Err(error) => {
+ // Skip process objects that cannot be opened
+ match error.kind() {
+ // System process
+ io::ErrorKind::PermissionDenied => continue,
+ // System idle or csrss process
+ io::ErrorKind::InvalidInput => continue,
+ _ => Err(error),
+ }
+ }
+ }?;
+
+ // TODO: Skip objects whose paths or timestamps cannot be obtained?
+
+ process_info.insert(
+ entry.pid,
+ RefCell::new(ProcessInfo {
+ pid: entry.pid,
+ parent_pid: entry.parent_pid,
+ creation_time: get_process_creation_time(process.get_raw()).unwrap_or(0),
+ device_path: get_process_device_path(process.get_raw())
+ .unwrap_or(OsString::from(""))
+ .encode_wide()
+ .collect(),
+ }),
+ );
+ }
+
+ // Handle pid recycling
+ // If the "parent" is younger than the process itself, it is not our parent.
+ for info in process_info.values() {
+ let mut info = info.borrow_mut();
+ let parent_pid = info.parent_pid;
+ if parent_pid == 0 {
+ continue;
+ }
+ if let Some(parent_info) = process_info.get(&parent_pid) {
+ if parent_info.borrow_mut().creation_time > info.creation_time {
+ info.parent_pid = 0;
+ }
+ }
+ }
+
+ Ok(process_info
+ .into_iter()
+ .map(|(_, info)| info.into_inner())
+ .collect())
+}
+
+#[repr(C)]
+struct ProcessRegistryHeader {
+ // Number of entries immediately following the header.
+ num_entries: usize,
+ // Total byte length: header + entries + string buffer.
+ total_length: usize,
+}
+
+#[repr(C)]
+struct ProcessRegistryEntry {
+ pid: RawHandle,
+ parent_pid: RawHandle,
+ // Image name offset (following the last entry).
+ image_name_offset: usize,
+ // Image name length.
+ image_name_size: u16,
+}
+
+fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Error> {
+ // Construct a buffer:
+ // ProcessRegistryHeader
+ // ProcessRegistryEntry..
+ // Image names..
+
+ let total_string_size: usize = processes
+ .iter()
+ .map(|info| size_of::<u16>() * info.device_path.len())
+ .sum();
+ let total_buffer_size = size_of::<ProcessRegistryHeader>()
+ + size_of::<ProcessRegistryEntry>() * processes.len()
+ + total_string_size;
+
+ let mut buffer = Vec::<u8>::new();
+ buffer.resize(total_buffer_size, 0);
+
+ let (header, tail) = buffer.split_at_mut(size_of::<ProcessRegistryHeader>());
+ let header_struct = ProcessRegistryHeader {
+ num_entries: processes.len(),
+ total_length: total_buffer_size,
+ };
+ header.copy_from_slice(unsafe { as_u8_slice(&header_struct) });
+
+ let (entries, string_data) =
+ tail.split_at_mut(size_of::<ProcessRegistryEntry>() * processes.len());
+
+ let mut string_offset = 0;
+
+ for (i, entry) in processes.into_iter().enumerate() {
+ let mut out_entry = ProcessRegistryEntry {
+ pid: entry.pid as usize as RawHandle,
+ parent_pid: entry.parent_pid as usize as RawHandle,
+ image_name_size: 0,
+ image_name_offset: 0,
+ };
+
+ if !entry.device_path.is_empty() {
+ write_string_to_buffer(string_data, string_offset, &entry.device_path);
+
+ out_entry.image_name_size = (entry.device_path.len() * size_of::<u16>()) as u16;
+ out_entry.image_name_offset = string_offset;
+
+ string_offset += size_of::<u16>() * entry.device_path.len();
+ }
+
+ let entry_offset = size_of::<ProcessRegistryEntry>() * i;
+ entries[entry_offset..entry_offset + size_of::<ProcessRegistryEntry>()]
+ .copy_from_slice(unsafe { as_u8_slice(&out_entry) });
+ }
+
+ Ok(buffer)
+}
+
+/// Send an IOCTL code to the given device handle.
+/// `input` specifies an optional buffer to send.
+/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0.
+pub fn device_io_control(
+ device: RawHandle,
+ ioctl_code: u32,
+ input: Option<&[u8]>,
+ output_size: u32,
+) -> Result<Option<Vec<u8>>, io::Error> {
+ let input_ptr = match input {
+ Some(input) => input as *const _ as *mut _,
+ None => ptr::null_mut(),
+ };
+ let input_len = input.map(|input| input.len()).unwrap_or(0);
+
+ let mut out_buffer = if output_size > 0 {
+ Some(Vec::with_capacity(output_size as usize))
+ } else {
+ None
+ };
+
+ let out_ptr = match out_buffer {
+ Some(ref mut out_buffer) => out_buffer.as_mut_ptr() as *mut _,
+ None => ptr::null_mut(),
+ };
+
+ let mut returned_bytes = 0u32;
+
+ let result = unsafe {
+ DeviceIoControl(
+ device as *mut _,
+ ioctl_code,
+ input_ptr,
+ input_len as u32,
+ out_ptr,
+ output_size,
+ &mut returned_bytes as *mut _,
+ ptr::null_mut(), // TODO
+ )
+ };
+
+ if let Some(ref mut out_buffer) = out_buffer {
+ unsafe { out_buffer.set_len(returned_bytes as usize) };
+ }
+
+ if result != 0 {
+ Ok(out_buffer)
+ } else {
+ Err(io::Error::last_os_error())
+ }
+}
+
+/// Creates a new instance of an arbitrary type from a byte buffer.
+pub unsafe fn deserialize_buffer<T: Sized>(buffer: &Vec<u8>) -> T {
+ let mut instance: T = mem::zeroed();
+ ptr::copy_nonoverlapping(buffer.as_ptr() as *const T, &mut instance as *mut _, 1);
+ instance
+}
+
+fn write_string_to_buffer(buffer: &mut [u8], byte_offset: usize, string: &[u16]) {
+ for (i, byte) in string
+ .iter()
+ .flat_map(|word| std::array::IntoIter::new(word.to_ne_bytes()))
+ .enumerate()
+ {
+ buffer[byte_offset + i] = byte;
+ }
+}
+
+unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] {
+ std::slice::from_raw_parts(object as *const _ as *const _, size_of::<T>())
+}
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
new file mode 100644
index 0000000000..c6b8fae332
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -0,0 +1,69 @@
+mod driver;
+mod windows;
+
+use std::{
+ ffi::OsStr,
+ io,
+ net::{Ipv4Addr, Ipv6Addr},
+};
+use talpid_types::ErrorExt;
+
+/// Errors that may occur in [`SplitTunnel`].
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to identify or initialize the driver
+ #[error(display = "Failed to find or initialize driver")]
+ InitializationFailed(#[error(source)] io::Error),
+
+ /// Failed to set paths to excluded applications
+ #[error(display = "Failed to set list of excluded applications")]
+ SetConfiguration(#[error(source)] io::Error),
+
+ /// Failed to register interface IP addresses
+ #[error(display = "Failed to register IP addresses for exclusions")]
+ RegisterIps(#[error(source)] io::Error),
+}
+
+/// Manages applications whose traffic to exclude from the tunnel.
+pub struct SplitTunnel(driver::DeviceHandle);
+
+impl SplitTunnel {
+ /// Initialize the driver.
+ pub fn new() -> Result<Self, Error> {
+ Ok(SplitTunnel(
+ driver::DeviceHandle::new().map_err(Error::InitializationFailed)?,
+ ))
+ }
+
+ /// Set a list of applications to exclude from the tunnel.
+ pub fn set_paths<T: AsRef<OsStr>>(&self, paths: &[T]) -> Result<(), Error> {
+ if paths.len() > 0 {
+ self.0.set_config(paths).map_err(Error::SetConfiguration)
+ } else {
+ self.0.clear_config().map_err(Error::SetConfiguration)
+ }
+ }
+
+ /// Configures IP addresses used for socket rebinding.
+ pub fn register_ips(
+ &self,
+ tunnel_ipv4: Ipv4Addr,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Ipv4Addr,
+ internet_ipv6: Option<Ipv6Addr>,
+ ) -> Result<(), Error> {
+ self.0
+ .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6)
+ .map_err(Error::RegisterIps)
+ }
+}
+
+impl Drop for SplitTunnel {
+ fn drop(&mut self) {
+ let paths: [&OsStr; 0] = [];
+ if let Err(error) = self.set_paths(&paths) {
+ log::error!("{}", error.display_chain());
+ }
+ }
+}
diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs
new file mode 100644
index 0000000000..be8631d53c
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/windows.rs
@@ -0,0 +1,259 @@
+// TODO: The snapshot code could be combined with the mostly-identical code in
+// the windows_exception_logging module.
+
+use std::{
+ ffi::{OsStr, OsString},
+ fs::OpenOptions,
+ io, mem,
+ os::windows::{
+ ffi::OsStringExt,
+ io::{AsRawHandle, RawHandle},
+ },
+ ptr,
+};
+use winapi::{
+ shared::{
+ minwindef::{DWORD, FALSE, FILETIME, TRUE},
+ ntdef::ULARGE_INTEGER,
+ winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES},
+ },
+ um::{
+ fileapi::GetFinalPathNameByHandleW,
+ handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ processthreadsapi::{GetProcessTimes, OpenProcess},
+ psapi::K32GetProcessImageFileNameW,
+ tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W},
+ winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION},
+ },
+};
+
+/// Return path with the volume device path.
+const VOLUME_NAME_NT: u32 = 0x02;
+
+pub struct ProcessSnapshot {
+ handle: HANDLE,
+}
+
+impl ProcessSnapshot {
+ pub fn new(flags: DWORD, process_id: DWORD) -> io::Result<ProcessSnapshot> {
+ let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) };
+
+ if snap == INVALID_HANDLE_VALUE {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(ProcessSnapshot { handle: snap })
+ }
+ }
+
+ pub fn handle(&self) -> HANDLE {
+ self.handle
+ }
+
+ pub fn entries(&self) -> ProcessSnapshotEntries<'_> {
+ let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() };
+ entry.dwSize = mem::size_of::<PROCESSENTRY32W>() as u32;
+
+ ProcessSnapshotEntries {
+ snapshot: self,
+ iter_started: false,
+ temp_entry: entry,
+ }
+ }
+}
+
+impl Drop for ProcessSnapshot {
+ fn drop(&mut self) {
+ unsafe {
+ CloseHandle(self.handle);
+ }
+ }
+}
+
+pub struct ProcessEntry {
+ pub pid: u32,
+ pub parent_pid: u32,
+}
+
+pub struct ProcessSnapshotEntries<'a> {
+ snapshot: &'a ProcessSnapshot,
+ iter_started: bool,
+ temp_entry: PROCESSENTRY32W,
+}
+
+impl Iterator for ProcessSnapshotEntries<'_> {
+ type Item = io::Result<ProcessEntry>;
+
+ fn next(&mut self) -> Option<io::Result<ProcessEntry>> {
+ if self.iter_started {
+ if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE {
+ let last_error = io::Error::last_os_error();
+
+ return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES {
+ None
+ } else {
+ Some(Err(last_error))
+ };
+ }
+ } else {
+ if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE {
+ return Some(Err(io::Error::last_os_error()));
+ }
+ self.iter_started = true;
+ }
+
+ Some(Ok(ProcessEntry {
+ pid: self.temp_entry.th32ProcessID,
+ parent_pid: self.temp_entry.th32ParentProcessID,
+ }))
+ }
+}
+
+pub fn get_final_path_name<T: AsRef<OsStr>>(path: T) -> Result<OsString, io::Error> {
+ // TODO: verify that all flags, including security flags, are ok
+ // TODO: verify that the file is a PE executable?
+ // TODO: verify that the executable is on a physical drive?
+ let file = OpenOptions::new().read(true).open(path.as_ref())?;
+ get_final_path_name_by_handle(file.as_raw_handle())
+}
+
+pub fn get_final_path_name_by_handle(raw_handle: RawHandle) -> Result<OsString, io::Error> {
+ let buffer_size = unsafe {
+ GetFinalPathNameByHandleW(raw_handle as *mut _, ptr::null_mut(), 0u32, VOLUME_NAME_NT)
+ } as usize;
+
+ if buffer_size == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ let mut buffer = Vec::new();
+ buffer.reserve_exact(buffer_size);
+
+ let status = unsafe {
+ GetFinalPathNameByHandleW(
+ raw_handle as *mut _,
+ buffer.as_mut_ptr(),
+ buffer_size as u32,
+ VOLUME_NAME_NT,
+ )
+ } as usize;
+
+ if status == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ unsafe { buffer.set_len(buffer_size - 1) };
+
+ // TODO: can this be done by stealing 'buffer' instead of copying it?
+ Ok(OsStringExt::from_wide(&buffer))
+}
+
+/// Object that frees its handle when dropped.
+pub struct WinHandle(RawHandle);
+
+impl WinHandle {
+ pub fn get_raw(&self) -> RawHandle {
+ self.0
+ }
+}
+
+impl Drop for WinHandle {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.0) };
+ }
+}
+
+#[repr(u32)]
+pub enum ProcessAccess {
+ QueryLimitedInformation = PROCESS_QUERY_LIMITED_INFORMATION,
+ // TODO: could be extended
+}
+
+/// Open an existing process object.
+pub fn open_process(
+ access: ProcessAccess,
+ inherit_handle: bool,
+ pid: u32,
+) -> Result<WinHandle, io::Error> {
+ let handle = unsafe {
+ OpenProcess(
+ access as u32,
+ if inherit_handle { TRUE } else { FALSE },
+ pid,
+ )
+ };
+
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(WinHandle(handle))
+}
+
+/// Returns the age of a running process.
+pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> {
+ // TODO: FileTimeToSystemTime -> chrono::NaiveDateTime
+ let mut creation_time: FILETIME = unsafe { mem::zeroed() };
+ let mut dummy: FILETIME = unsafe { mem::zeroed() };
+ if unsafe {
+ GetProcessTimes(
+ handle,
+ &mut creation_time as *mut _,
+ &mut dummy as *mut _,
+ &mut dummy as *mut _,
+ &mut dummy as *mut _,
+ )
+ } == 0
+ {
+ return Err(io::Error::last_os_error());
+ }
+
+ let mut uli_time: ULARGE_INTEGER = unsafe { mem::zeroed() };
+ unsafe {
+ uli_time.s_mut().LowPart = creation_time.dwLowDateTime;
+ uli_time.s_mut().HighPart = creation_time.dwHighDateTime;
+ }
+
+ Ok(*unsafe { uli_time.QuadPart() })
+}
+
+/// Returns the device path for a running process.
+pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> {
+ let mut initial_capacity = 512;
+ loop {
+ let result = get_process_device_path_inner(handle, initial_capacity);
+ match result {
+ Ok(path) => return Ok(path),
+ Err(error) => {
+ if ERROR_INSUFFICIENT_BUFFER == error.raw_os_error().unwrap() as u32 {
+ // Try again with a larger buffer capacity.
+ initial_capacity *= 2;
+ continue;
+ }
+ return Err(error);
+ }
+ }
+ }
+}
+
+fn get_process_device_path_inner(
+ handle: RawHandle,
+ buffer_capacity: usize,
+) -> Result<OsString, io::Error> {
+ let mut buffer = Vec::<u16>::new();
+ buffer.reserve_exact(buffer_capacity);
+
+ let written = unsafe {
+ K32GetProcessImageFileNameW(
+ handle,
+ buffer.as_mut_ptr() as *mut _,
+ buffer.capacity() as u32,
+ )
+ };
+ if written == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ // `written` does not include a null terminator
+ unsafe { buffer.set_len(written as usize) };
+
+ Ok(OsStringExt::from_wide(&buffer))
+}
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 72af2b6f38..d7bf24b49e 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -7,9 +7,20 @@ use crate::{
firewall::FirewallPolicy,
tunnel::{CloseHandle, TunnelEvent, TunnelMetadata},
};
+#[cfg(windows)]
+use crate::{
+ split_tunnel::{self, SplitTunnel},
+ winnet::{self, get_best_default_route, interface_luid_to_ip, WinNetAddrFamily},
+};
use cfg_if::cfg_if;
use futures::{channel::mpsc, stream::Fuse, StreamExt};
use std::net::IpAddr;
+#[cfg(windows)]
+use std::{
+ ffi::OsStr,
+ net::{Ipv4Addr, Ipv6Addr},
+ sync::{Arc, Mutex},
+};
use talpid_types::{
net::TunnelParameters,
tunnel::{ErrorStateCause, FirewallPolicyError},
@@ -116,6 +127,137 @@ impl ConnectedState {
}
}
+ #[cfg(target_os = "windows")]
+ pub unsafe extern "system" fn split_tunnel_default_route_change_handler(
+ event_type: winnet::WinNetDefaultRouteChangeEventType,
+ address_family: WinNetAddrFamily,
+ default_route: winnet::WinNetDefaultRoute,
+ ctx: *mut libc::c_void,
+ ) {
+ // Update the "internet interface" IP when best default route changes
+ let ctx = &mut *(ctx as *mut SplitTunnelDefaultRouteChangeHandlerContext);
+
+ let result = match event_type {
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ let ip = interface_luid_to_ip(address_family.clone(), default_route.interface_luid);
+
+ // TODO: Should we block here?
+ let ip = match ip {
+ Ok(Some(ip)) => ip,
+ Ok(None) => {
+ log::error!("Failed to obtain new default route address: none found",);
+ // Early return
+ return;
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to obtain new default route address"
+ )
+ );
+ // Early return
+ return;
+ }
+ };
+
+ match address_family {
+ WinNetAddrFamily::IPV4 => {
+ let ip = Ipv4Addr::from(ip);
+ ctx.internet_ipv4 = ip;
+ }
+ WinNetAddrFamily::IPV6 => {
+ let ip = Ipv6Addr::from(ip);
+ ctx.internet_ipv6 = Some(ip);
+ }
+ }
+
+ ctx.register_ips()
+ }
+ // no default route
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => {
+ match address_family {
+ WinNetAddrFamily::IPV4 => {
+ ctx.internet_ipv4 = Ipv4Addr::new(0, 0, 0, 0);
+ }
+ WinNetAddrFamily::IPV6 => {
+ ctx.internet_ipv6 = None;
+ }
+ }
+ ctx.register_ips()
+ }
+ };
+
+ if let Err(error) = result {
+ // TODO: Should we block here?
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to register new addresses in split tunnel driver"
+ )
+ );
+ }
+ }
+
+ #[cfg(windows)]
+ fn update_split_tunnel_addresses(
+ &self,
+ shared_values: &mut SharedTunnelStateValues,
+ ) -> Result<(), BoxedError> {
+ // Identify tunnel IP addresses
+ // TODO: Multiple IP addresses?
+ let mut tunnel_ipv4 = None;
+ let mut tunnel_ipv6 = None;
+
+ for ip in &self.metadata.ips {
+ match ip {
+ IpAddr::V4(address) => tunnel_ipv4 = Some(address.clone()),
+ IpAddr::V6(address) => tunnel_ipv6 = Some(address.clone()),
+ }
+ }
+
+ // Identify IP address that gives us Internet access
+ let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4)
+ .map_err(BoxedError::new)?
+ .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV4, route.interface_luid))
+ .transpose()
+ .map_err(BoxedError::new)?
+ .flatten();
+ let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6)
+ .map_err(BoxedError::new)?
+ .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV6, route.interface_luid))
+ .transpose()
+ .map_err(BoxedError::new)?
+ .flatten();
+
+ let tunnel_ipv4 = tunnel_ipv4.unwrap_or(Ipv4Addr::new(0, 0, 0, 0));
+ let internet_ipv4 = Ipv4Addr::from(internet_ipv4.unwrap_or_default());
+ let internet_ipv6 = internet_ipv6.map(|addr| Ipv6Addr::from(addr));
+
+ let context = SplitTunnelDefaultRouteChangeHandlerContext::new(
+ shared_values.split_tunnel.clone(),
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ );
+
+ shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex")
+ .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6)
+ .map_err(BoxedError::new)?;
+
+ #[cfg(target_os = "windows")]
+ shared_values.route_manager.add_default_route_callback(
+ Some(Self::split_tunnel_default_route_change_handler),
+ context,
+ );
+
+ Ok(())
+ }
+
fn set_dns(&self, shared_values: &mut SharedTunnelStateValues) -> Result<(), BoxedError> {
let dns_ips = self.get_dns_servers(shared_values);
shared_values
@@ -150,6 +292,18 @@ impl ConnectedState {
}
}
+ #[cfg(windows)]
+ fn apply_split_tunnel_config<T: AsRef<OsStr>>(
+ shared_values: &SharedTunnelStateValues,
+ paths: &[T],
+ ) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.set_paths(paths)
+ }
+
fn disconnect(
self,
shared_values: &mut SharedTunnelStateValues,
@@ -158,6 +312,24 @@ impl ConnectedState {
Self::reset_dns(shared_values);
Self::reset_routes(shared_values);
+ #[cfg(windows)]
+ if let Err(error) = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex")
+ .register_ips(
+ Ipv4Addr::new(0, 0, 0, 0),
+ None,
+ Ipv4Addr::new(0, 0, 0, 0),
+ None,
+ )
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to unregister IP addresses")
+ );
+ }
+
EventConsequence::NewState(DisconnectingState::enter(
shared_values,
(self.close_handle, self.tunnel_close_event, after_disconnect),
@@ -257,6 +429,11 @@ impl ConnectedState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ SameState(self.into())
+ }
}
}
@@ -326,6 +503,19 @@ impl TunnelState for ConnectedState {
),
)
} else {
+ #[cfg(windows)]
+ if let Err(error) = connected_state.update_split_tunnel_addresses(shared_values) {
+ log::error!("{}", error.display_chain());
+ return DisconnectingState::enter(
+ shared_values,
+ (
+ connected_state.close_handle,
+ connected_state.tunnel_close_event,
+ AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
+ ),
+ );
+ }
+
(
TunnelStateWrapper::from(connected_state),
TunnelStateTransition::Connected(tunnel_endpoint),
@@ -360,3 +550,44 @@ impl TunnelState for ConnectedState {
}
}
}
+
+#[cfg(target_os = "windows")]
+struct SplitTunnelDefaultRouteChangeHandlerContext {
+ split_tunnel: Arc<Mutex<SplitTunnel>>,
+ pub tunnel_ipv4: Ipv4Addr,
+ pub tunnel_ipv6: Option<Ipv6Addr>,
+ pub internet_ipv4: Ipv4Addr,
+ pub internet_ipv6: Option<Ipv6Addr>,
+}
+
+#[cfg(target_os = "windows")]
+impl SplitTunnelDefaultRouteChangeHandlerContext {
+ pub fn new(
+ split_tunnel: Arc<Mutex<SplitTunnel>>,
+ tunnel_ipv4: Ipv4Addr,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Ipv4Addr,
+ internet_ipv6: Option<Ipv6Addr>,
+ ) -> Self {
+ SplitTunnelDefaultRouteChangeHandlerContext {
+ split_tunnel,
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ }
+ }
+
+ pub fn register_ips(&self) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = self
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.register_ips(
+ self.tunnel_ipv4,
+ self.tunnel_ipv6,
+ self.internet_ipv4,
+ self.internet_ipv6,
+ )
+ }
+}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index b0c87acdb4..34e9eeb2be 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -3,6 +3,8 @@ use super::{
EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
+#[cfg(windows)]
+use crate::split_tunnel;
use crate::{
firewall::FirewallPolicy,
routing::RouteManager,
@@ -17,6 +19,8 @@ use futures::{
FutureExt, StreamExt,
};
use log::{debug, error, info, trace, warn};
+#[cfg(windows)]
+use std::ffi::OsStr;
use std::{
path::{Path, PathBuf},
thread,
@@ -89,6 +93,18 @@ impl ConnectingState {
})
}
+ #[cfg(windows)]
+ fn apply_split_tunnel_config<T: AsRef<OsStr>>(
+ shared_values: &SharedTunnelStateValues,
+ paths: &[T],
+ ) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.set_paths(paths)
+ }
+
fn start_tunnel(
runtime: tokio::runtime::Handle,
parameters: TunnelParameters,
@@ -314,6 +330,11 @@ impl ConnectingState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ SameState(self.into())
+ }
}
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 8d2c9bc0fa..cfa7794af7 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -3,7 +3,11 @@ use super::{
TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::firewall::FirewallPolicy;
+#[cfg(windows)]
+use crate::split_tunnel;
use futures::StreamExt;
+#[cfg(windows)]
+use std::ffi::OsStr;
use talpid_types::ErrorExt;
/// No tunnel is running.
@@ -36,6 +40,18 @@ impl DisconnectedState {
log::error!("{}", error_chain);
}
}
+
+ #[cfg(windows)]
+ fn apply_split_tunnel_config<T: AsRef<OsStr>>(
+ shared_values: &SharedTunnelStateValues,
+ paths: &[T],
+ ) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.set_paths(paths)
+ }
}
impl TunnelState for DisconnectedState {
@@ -115,6 +131,11 @@ impl TunnelState for DisconnectedState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ SameState(self.into())
+ }
Some(_) => SameState(self.into()),
None => Finished,
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 7d308d5971..e488896042 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -3,8 +3,12 @@ use super::{
EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver,
TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
+#[cfg(windows)]
+use crate::split_tunnel;
use crate::tunnel::CloseHandle;
use futures::{future::FusedFuture, StreamExt};
+#[cfg(windows)]
+use std::ffi::OsStr;
use std::thread;
use talpid_types::{
tunnel::{ActionAfterDisconnect, ErrorStateCause},
@@ -59,6 +63,11 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Nothing
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ AfterDisconnect::Nothing
+ }
},
AfterDisconnect::Block(reason) => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
@@ -96,6 +105,11 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Block(reason)
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ AfterDisconnect::Block(reason)
+ }
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
@@ -134,12 +148,29 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Reconnect(retry_attempt)
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ AfterDisconnect::Reconnect(retry_attempt)
+ }
},
};
EventConsequence::SameState(self.into())
}
+ #[cfg(windows)]
+ fn apply_split_tunnel_config<T: AsRef<OsStr>>(
+ shared_values: &SharedTunnelStateValues,
+ paths: &[T],
+ ) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.set_paths(paths)
+ }
+
fn after_disconnect(
self,
block_reason: Option<ErrorStateCause>,
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 5e647c8201..6d772a62b8 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -3,7 +3,11 @@ use super::{
TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper,
};
use crate::firewall::FirewallPolicy;
+#[cfg(windows)]
+use crate::split_tunnel;
use futures::StreamExt;
+#[cfg(windows)]
+use std::ffi::OsStr;
use talpid_types::{
tunnel::{self as talpid_tunnel, ErrorStateCause, FirewallPolicyError},
ErrorExt,
@@ -61,6 +65,18 @@ impl ErrorState {
}
}
}
+
+ #[cfg(windows)]
+ fn apply_split_tunnel_config<T: AsRef<OsStr>>(
+ shared_values: &SharedTunnelStateValues,
+ paths: &[T],
+ ) -> Result<(), split_tunnel::Error> {
+ let split_tunnel = shared_values
+ .split_tunnel
+ .lock()
+ .expect("Thread unexpectedly panicked while holding the mutex");
+ split_tunnel.set_paths(paths)
+ }
}
impl TunnelState for ErrorState {
@@ -151,12 +167,17 @@ impl TunnelState for ErrorState {
Some(TunnelCommand::Block(reason)) => {
NewState(ErrorState::enter(shared_values, reason))
}
-
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ // TODO: Do nothing here?
+ let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths));
+ SameState(self.into())
+ }
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 38496865a4..969c2116a1 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -11,6 +11,8 @@ use self::{
disconnecting_state::{AfterDisconnect, DisconnectingState},
error_state::ErrorState,
};
+#[cfg(windows)]
+use crate::split_tunnel;
use crate::{
dns::DnsMonitor,
firewall::{Firewall, FirewallArguments},
@@ -19,6 +21,11 @@ use crate::{
routing::RouteManager,
tunnel::{tun_provider::TunProvider, TunnelEvent},
};
+#[cfg(windows)]
+use std::ffi::OsString;
+#[cfg(windows)]
+use std::sync::Mutex;
+
use futures::{
channel::{mpsc, oneshot},
stream, StreamExt,
@@ -47,9 +54,9 @@ pub enum Error {
OfflineMonitorError(#[error(source)] crate::offline::Error),
/// Unable to set up split tunneling
- #[cfg(target_os = "linux")]
+ #[cfg(target_os = "windows")]
#[error(display = "Failed to initialize split tunneling")]
- InitSplitTunneling(#[error(source)] crate::split_tunnel::Error),
+ InitSplitTunneling(#[error(source)] split_tunnel::Error),
/// Failed to initialize the system firewall integration.
#[error(display = "Failed to initialize the system firewall integration")]
@@ -86,6 +93,7 @@ pub async fn spawn(
shutdown_tx: oneshot::Sender<()>,
reset_firewall: bool,
#[cfg(target_os = "android")] android_context: AndroidContext,
+ #[cfg(windows)] exclude_paths: Vec<OsString>,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
@@ -122,6 +130,8 @@ pub async fn spawn(
reset_firewall,
#[cfg(target_os = "android")]
android_context,
+ #[cfg(windows)]
+ exclude_paths,
));
let state_machine = match state_machine {
Ok(state_machine) => {
@@ -169,6 +179,12 @@ pub enum TunnelCommand {
/// Bypass a socket, allowing traffic to flow through outside the tunnel.
#[cfg(target_os = "android")]
BypassSocket(RawFd, oneshot::Sender<()>),
+ /// Set applications that are allowed to send and receive traffic outside of the tunnel.
+ #[cfg(windows)]
+ SetExcludedApps(
+ oneshot::Sender<Result<(), split_tunnel::Error>>,
+ Vec<OsString>,
+ ),
}
type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>;
@@ -207,6 +223,7 @@ impl TunnelStateMachine {
commands: mpsc::UnboundedReceiver<TunnelCommand>,
reset_firewall: bool,
#[cfg(target_os = "android")] android_context: AndroidContext,
+ #[cfg(windows)] exclude_paths: Vec<OsString>,
) -> Result<Self, Error> {
let args = FirewallArguments {
initialize_blocked: block_when_disconnected || !reset_firewall,
@@ -239,6 +256,14 @@ impl TunnelStateMachine {
.await
.map_err(Error::OfflineMonitorError)?;
let is_offline = offline_monitor.is_offline().await;
+
+ #[cfg(windows)]
+ let split_tunnel = split_tunnel::SplitTunnel::new().map_err(Error::InitSplitTunneling)?;
+ #[cfg(windows)]
+ split_tunnel
+ .set_paths(&exclude_paths)
+ .map_err(Error::InitSplitTunneling)?;
+
let mut shared_values = SharedTunnelStateValues {
runtime,
firewall,
@@ -256,6 +281,8 @@ impl TunnelStateMachine {
resource_dir,
#[cfg(target_os = "linux")]
connectivity_check_was_enabled: None,
+ #[cfg(windows)]
+ split_tunnel: Arc::new(Mutex::new(split_tunnel)),
};
let (initial_state, _) = DisconnectedState::enter(&mut shared_values, reset_firewall);
@@ -337,6 +364,9 @@ struct SharedTunnelStateValues {
/// NetworkManager's connecitivity check state.
#[cfg(target_os = "linux")]
connectivity_check_was_enabled: Option<bool>,
+ /// Management of excluded apps.
+ #[cfg(windows)]
+ split_tunnel: Arc<Mutex<split_tunnel::SplitTunnel>>,
}
impl SharedTunnelStateValues {
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 79008f8cbc..8b3d503462 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -85,6 +85,7 @@ pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, E
}
}
+#[derive(Debug, Clone)]
#[allow(dead_code)]
#[repr(u32)]
pub enum WinNetAddrFamily {
@@ -121,15 +122,33 @@ pub struct WinNetDefaultRoute {
pub gateway: WinNetIp,
}
-impl From<WinNetIp> for IpAddr {
- fn from(addr: WinNetIp) -> IpAddr {
+impl From<WinNetIp> for Ipv4Addr {
+ fn from(addr: WinNetIp) -> Ipv4Addr {
match addr.addr_family {
WinNetAddrFamily::IPV4 => {
let mut bytes: [u8; 4] = Default::default();
bytes.clone_from_slice(&addr.ip_bytes[..4]);
- IpAddr::V4(Ipv4Addr::from(bytes))
+ Ipv4Addr::from(bytes)
}
- WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)),
+ WinNetAddrFamily::IPV6 => panic!("address family mismatch"),
+ }
+ }
+}
+
+impl From<WinNetIp> for Ipv6Addr {
+ fn from(addr: WinNetIp) -> Ipv6Addr {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => panic!("address family mismatch"),
+ WinNetAddrFamily::IPV6 => Ipv6Addr::from(addr.ip_bytes),
+ }
+ }
+}
+
+impl From<WinNetIp> for IpAddr {
+ fn from(addr: WinNetIp) -> IpAddr {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::from(addr)),
+ WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr)),
}
}
}