diff options
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 3 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 244 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/windows.rs | 132 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/windows.rs | 133 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/connectivity_check.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 100 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 18 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 9 | ||||
| -rw-r--r-- | wireguard/libwg/libwg_windows.go | 48 |
12 files changed, 435 insertions, 264 deletions
diff --git a/Cargo.lock b/Cargo.lock index 781f984822..7fb4f90439 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2524,6 +2524,7 @@ dependencies = [ name = "talpid-core" version = "0.1.0" dependencies = [ + "async-trait", "atty", "byteorder", "cfg-if 1.0.0", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 6f0c60c08d..f84e4a5bd4 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" publish = false [dependencies] +async-trait = "0.1" atty = "0.2" cfg-if = "1.0" duct = "0.13" diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index 25c80899e5..ca5fbb8ea2 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -98,6 +98,11 @@ impl RouteManager { } } + /// Retrieve handle for the tokio runtime. + pub fn runtime_handle(&self) -> tokio::runtime::Handle { + self.runtime.clone() + } + async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) { while let Some(command) = manage_rx.next().await { match command { diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 063feacc3d..69cd842138 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -24,6 +24,9 @@ pub mod wireguard; /// A module for low level platform specific tunnel device management. pub(crate) mod tun_provider; +#[cfg(target_os = "windows")] +mod windows; + const OPENVPN_LOG_FILENAME: &str = "openvpn.log"; const WIREGUARD_LOG_FILENAME: &str = "wireguard.log"; diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index 8d1aa03607..44efa96668 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -22,7 +22,7 @@ use std::{ process::ExitStatus, sync::{ atomic::{AtomicBool, Ordering}, - mpsc, Arc, + mpsc, Arc, Mutex, }, thread, time::Duration, @@ -33,12 +33,9 @@ use std::{collections::HashSet, net::IpAddr}; use std::{ ffi::{OsStr, OsString}, os::windows::ffi::OsStrExt, - sync::Mutex, time::Instant, }; -use talpid_types::net::openvpn; -#[cfg(any(windows, target_os = "linux"))] -use talpid_types::ErrorExt; +use talpid_types::{net::openvpn, ErrorExt}; use tokio::task; #[cfg(target_os = "linux")] use which; @@ -180,6 +177,10 @@ pub enum Error { #[error(display = "OpenVPN process died unexpectedly")] ChildProcessDied, + /// Failed before OpenVPN started + #[error(display = "Failed to start OpenVPN")] + StartProcessError, + /// The IP routing program was not found. #[cfg(target_os = "linux")] #[error(display = "The IP routing program `ip` was not found")] @@ -260,9 +261,15 @@ const OPENVPN_BIN_FILENAME: &str = "openvpn.exe"; /// Struct for monitoring an OpenVPN process. #[derive(Debug)] pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> { - child: Arc<C::ProcessHandle>, + spawn_task: Option< + tokio::task::JoinHandle< + std::result::Result<io::Result<C::ProcessHandle>, futures::future::Aborted>, + >, + >, + abort_spawn: futures::future::AbortHandle, + + child: Arc<Mutex<Option<Arc<C::ProcessHandle>>>>, proxy_monitor: Option<Box<dyn ProxyMonitor>>, - log_path: Option<PathBuf>, closed: Arc<AtomicBool>, /// Keep the `TempFile` for the user-pass file in the struct, so it's removed on drop. _user_pass_file: mktemp::TempFile, @@ -274,9 +281,52 @@ pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> { server_join_handle: Option<task::JoinHandle<std::result::Result<(), event_server::Error>>>, #[cfg(windows)] - wintun_adapter: Option<windows::TemporaryWintunAdapter>, - #[cfg(windows)] - _wintun_logger: Option<windows::WintunLoggerHandle>, + wintun: Arc<Box<dyn WintunContext>>, +} + +#[cfg(windows)] +#[async_trait::async_trait] +trait WintunContext: Send + Sync { + fn luid(&self) -> NET_LUID; + fn ipv6(&self) -> bool; + async fn wait_for_interfaces(&self) -> io::Result<()>; +} + +#[cfg(windows)] +impl std::fmt::Debug for dyn WintunContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "WintunContext {{ luid: {}, ipv6: {} }}", + self.luid().Value, + self.ipv6() + ) + } +} + +#[cfg(windows)] +#[derive(Debug)] +struct WintunContextImpl { + adapter: windows::TemporaryWintunAdapter, + wait_v6_interface: bool, + _logger: windows::WintunLoggerHandle, +} + +#[cfg(windows)] +#[async_trait::async_trait] +impl WintunContext for WintunContextImpl { + fn luid(&self) -> NET_LUID { + self.adapter.adapter().luid() + } + + fn ipv6(&self) -> bool { + self.wait_v6_interface + } + + async fn wait_for_interfaces(&self) -> io::Result<()> { + let luid = self.adapter.adapter().luid(); + super::windows::wait_for_interfaces(luid, true, self.wait_v6_interface).await + } } @@ -399,14 +449,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { log::warn!("You may need to restart Windows to complete the install of Wintun"); } - log::debug!("Wait for IP interfaces"); - windows::wait_for_interfaces( - &adapter.adapter().luid(), - true, - params.generic_options.enable_ipv6, - ) - .map_err(Error::IpInterfacesError)?; - let assigned_guid = adapter.adapter().guid(); let assigned_guid = assigned_guid.as_ref().unwrap_or_else(|error| { log::error!( @@ -475,15 +517,17 @@ impl OpenVpnMonitor<OpenVpnCommand> { Self::new_internal( cmd, on_openvpn_event, - &plugin_path, + plugin_path, log_path, user_pass_file, proxy_auth_file, proxy_monitor, #[cfg(windows)] - Some(wintun_adapter), - #[cfg(windows)] - Some(wintun_logger), + Box::new(WintunContextImpl { + adapter: wintun_adapter, + wait_v6_interface: params.generic_options.enable_ipv6, + _logger: wintun_logger, + }), ) } } @@ -527,17 +571,16 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute Ok(routes) } -impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { +impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { fn new_internal<L>( mut cmd: C, on_event: L, - plugin_path: impl AsRef<Path>, + plugin_path: PathBuf, log_path: Option<PathBuf>, user_pass_file: mktemp::TempFile, proxy_auth_file: Option<mktemp::TempFile>, proxy_monitor: Option<Box<dyn ProxyMonitor>>, - #[cfg(windows)] wintun_adapter: Option<windows::TemporaryWintunAdapter>, - #[cfg(windows)] wintun_logger: Option<windows::WintunLoggerHandle>, + #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static, @@ -574,16 +617,23 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { .unwrap_err()); } - let child = cmd - .plugin(plugin_path, vec![ipc_path]) - .log(log_path.as_ref().map(|p| p.as_path())) - .start() - .map_err(|e| Error::ChildProcessError("Failed to start", e))?; + #[cfg(windows)] + let wintun = Arc::new(wintun); + + cmd.plugin(plugin_path, vec![ipc_path]) + .log(log_path.as_ref().map(|p| p.as_path())); + let (spawn_task, abort_spawn) = futures::future::abortable(Self::prepare_process( + cmd, + #[cfg(windows)] + wintun.clone(), + )); + let spawn_task = runtime.spawn(spawn_task); Ok(OpenVpnMonitor { - child: Arc::new(child), + spawn_task: Some(spawn_task), + abort_spawn, + child: Arc::new(Mutex::new(None)), proxy_monitor, - log_path, closed: Arc::new(AtomicBool::new(false)), _user_pass_file: user_pass_file, _proxy_auth_file: proxy_auth_file, @@ -593,17 +643,28 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { server_join_handle: Some(server_join_handle), #[cfg(windows)] - wintun_adapter, - #[cfg(windows)] - _wintun_logger: wintun_logger, + wintun, }) } + async fn prepare_process( + cmd: C, + #[cfg(windows)] wintun: Arc<Box<dyn WintunContext>>, + ) -> io::Result<C::ProcessHandle> { + #[cfg(windows)] + { + log::debug!("Wait for IP interfaces"); + wintun.wait_for_interfaces().await?; + } + cmd.start() + } + /// Creates a handle to this monitor, allowing the tunnel to be closed while some other /// thread is blocked in `wait`. pub fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> { OpenVpnCloseHandle { child: self.child.clone(), + abort_spawn: self.abort_spawn.clone(), closed: self.closed.clone(), } } @@ -656,9 +717,19 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { } /// Supplement `inner_wait_tunnel()` with logging and error handling. - fn wait_tunnel(&mut self) -> Result<()> { + fn wait_tunnel(self) -> Result<()> { let result = self.inner_wait_tunnel(); match result { + WaitResult::Preparation(result) => match result { + Err(error) => { + log::debug!( + "{}", + error.display_chain_with_msg("Failed to start OpenVPN") + ); + Err(Error::StartProcessError) + } + _ => Ok(()), + }, WaitResult::Child(Ok(exit_status), closed) => { if exit_status.success() || closed { log::debug!( @@ -684,8 +755,28 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { /// Waits for both the child process and the event dispatcher in parallel. After both have /// returned this returns the earliest result. - fn inner_wait_tunnel(&mut self) -> WaitResult { - let child_wait_handle = self.child.clone(); + fn inner_wait_tunnel(mut self) -> WaitResult { + let child = match self + .runtime + .block_on(self.spawn_task.take().unwrap()) + .expect("spawn task panicked") + { + Ok(Ok(child)) => Arc::new(child), + Ok(Err(error)) => { + self.closed.swap(true, Ordering::SeqCst); + return WaitResult::Preparation(Err(error)); + } + Err(_) => return WaitResult::Preparation(Ok(())), + }; + + if self.closed.load(Ordering::SeqCst) { + return WaitResult::Preparation(Ok(())); + } + + { + self.child.lock().unwrap().replace(child.clone()); + } + let closed_handle = self.closed.clone(); let child_close_handle = self.close_handle(); @@ -695,7 +786,7 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { let event_server_abort_tx = self.event_server_abort_tx.clone(); thread::spawn(move || { - let result = child_wait_handle.wait(); + let result = child.wait(); let closed = closed_handle.load(Ordering::SeqCst); child_tx.send(WaitResult::Child(result, closed)).unwrap(); event_server_abort_tx.trigger(); @@ -835,7 +926,8 @@ impl<C: OpenVpnBuilder + 'static> OpenVpnMonitor<C> { /// A handle to an `OpenVpnMonitor` for closing it. #[derive(Debug, Clone)] pub struct OpenVpnCloseHandle<H: ProcessHandle = OpenVpnProcHandle> { - child: Arc<H>, + child: Arc<Mutex<Option<Arc<H>>>>, + abort_spawn: futures::future::AbortHandle, closed: Arc<AtomicBool>, } @@ -843,7 +935,12 @@ impl<H: ProcessHandle> OpenVpnCloseHandle<H> { /// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return. pub fn close(self) -> io::Result<()> { if !self.closed.swap(true, Ordering::SeqCst) { - self.child.kill() + self.abort_spawn.abort(); + if let Some(child) = self.child.lock().unwrap().as_ref() { + child.kill() + } else { + Ok(()) + } } else { Ok(()) } @@ -853,6 +950,7 @@ impl<H: ProcessHandle> OpenVpnCloseHandle<H> { /// Internal enum to differentiate between if the child process or the event dispatcher died first. #[derive(Debug)] enum WaitResult { + Preparation(io::Result<()>), Child(io::Result<ExitStatus>, bool), EventDispatcher, } @@ -1152,6 +1250,24 @@ mod tests { sync::Arc, }; + #[cfg(windows)] + #[derive(Debug)] + struct TestWintunContext {} + + #[cfg(windows)] + #[async_trait::async_trait] + impl WintunContext for TestWintunContext { + fn luid(&self) -> NET_LUID { + NET_LUID { Value: 0u64 } + } + fn ipv6(&self) -> bool { + false + } + async fn wait_for_interfaces(&self) -> io::Result<()> { + Ok(()) + } + } + #[derive(Debug, Default, Clone)] struct TestOpenVpnBuilder { pub plugin: Arc<Mutex<Option<PathBuf>>>, @@ -1205,15 +1321,13 @@ mod tests { let _ = OpenVpnMonitor::new_internal( builder.clone(), |_, _| {}, - "./my_test_plugin", + "./my_test_plugin".into(), None, TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ); assert_eq!( Some(PathBuf::from("./my_test_plugin")), @@ -1227,15 +1341,13 @@ mod tests { let _ = OpenVpnMonitor::new_internal( builder.clone(), |_, _| {}, - "", + "".into(), Some(PathBuf::from("./my_test_log_file")), TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ); assert_eq!( Some(PathBuf::from("./my_test_log_file")), @@ -1250,15 +1362,13 @@ mod tests { let testee = OpenVpnMonitor::new_internal( builder, |_, _| {}, - "", + "".into(), None, TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ) .unwrap(); assert!(testee.wait().is_ok()); @@ -1271,15 +1381,13 @@ mod tests { let testee = OpenVpnMonitor::new_internal( builder, |_, _| {}, - "", + "".into(), None, TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ) .unwrap(); assert!(testee.wait().is_err()); @@ -1292,15 +1400,13 @@ mod tests { let testee = OpenVpnMonitor::new_internal( builder, |_, _| {}, - "", + "".into(), None, TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ) .unwrap(); testee.close_handle().close().unwrap(); @@ -1310,22 +1416,20 @@ mod tests { #[test] fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); - let error = OpenVpnMonitor::new_internal( + let result = OpenVpnMonitor::new_internal( builder, |_, _| {}, - "", + "".into(), None, TempFile::new(), None, None, #[cfg(windows)] - None, - #[cfg(windows)] - None, + Box::new(TestWintunContext {}), ) - .unwrap_err(); - match error { - Error::ChildProcessError(..) => (), + .unwrap(); + match result.wait() { + Err(Error::StartProcessError) => (), _ => panic!("Wrong error"), } } diff --git a/talpid-core/src/tunnel/openvpn/windows.rs b/talpid-core/src/tunnel/openvpn/windows.rs index a88c6c756f..9b907e101a 100644 --- a/talpid-core/src/tunnel/openvpn/windows.rs +++ b/talpid-core/src/tunnel/openvpn/windows.rs @@ -4,8 +4,7 @@ use std::{ os::windows::{ffi::OsStrExt, io::RawHandle}, path::Path, ptr, - sync::{Arc, Mutex}, - time::Duration, + sync::Arc, }; use talpid_types::ErrorExt; use widestring::{U16CStr, U16CString}; @@ -14,13 +13,8 @@ use winapi::{ guiddef::GUID, ifdef::NET_LUID, minwindef::{BOOL, FARPROC, HINSTANCE, HMODULE}, - netioapi::{ - CancelMibChangeNotify2, ConvertInterfaceLuidToGuid, GetIpInterfaceEntry, - MibAddInstance, NotifyIpInterfaceChange, MIB_IPINTERFACE_ROW, - }, - ntdef::FALSE, + netioapi::ConvertInterfaceLuidToGuid, winerror::NO_ERROR, - ws2def::{AF_INET, AF_INET6, AF_UNSPEC}, }, um::{ libloaderapi::{ @@ -35,8 +29,6 @@ use winreg::{enums::HKEY_LOCAL_MACHINE, RegKey}; /// Longest possible adapter name (in characters), including null terminator const MAX_ADAPTER_NAME: usize = 128; -const INTERFACE_WAIT_TIMEOUT: Duration = Duration::from_secs(5); - type WintunOpenAdapterFn = unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle; @@ -142,6 +134,7 @@ impl fmt::Debug for WintunAdapter { } unsafe impl Send for WintunAdapter {} +unsafe impl Sync for WintunAdapter {} impl WintunAdapter { pub fn open(dll_handle: Arc<WintunDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> { @@ -406,6 +399,7 @@ pub fn string_from_guid(guid: &GUID) -> String { } } +/// Returns the registry key for a network device identified by its GUID. pub fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result<RegKey> { let net_devs = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_with_flags( r"SYSTEM\CurrentControlSet\Control\Class\{4d36e972-e325-11ce-bfc1-08002be10318}", @@ -433,124 +427,6 @@ pub fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Re Err(io::Error::new(io::ErrorKind::NotFound, "device not found")) } -pub struct IpNotifierHandle<'a> { - mutex: Mutex<()>, - callback: Option<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>>, - handle: RawHandle, -} - -impl<'a> Drop for IpNotifierHandle<'a> { - fn drop(&mut self) { - // Inner callback may be called while destructing - unsafe { CancelMibChangeNotify2(self.handle as *mut _) }; - - let _ = self - .mutex - .lock() - .expect("NotifyIpInterfaceChange mutex poisoned"); - let _ = self.callback.take(); - } -} - -unsafe extern "system" fn inner_callback( - context: *mut winapi::ctypes::c_void, - row: *mut MIB_IPINTERFACE_ROW, - notify_type: u32, -) { - let context = &mut *(context as *mut IpNotifierHandle<'_>); - let _ = context - .mutex - .lock() - .expect("NotifyIpInterfaceChange mutex poisoned"); - - if let Some(ref mut callback) = context.callback { - callback(&*row, notify_type); - } -} - -pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>( - callback: T, - family: u16, -) -> io::Result<Box<IpNotifierHandle<'a>>> { - let mut context = Box::new(IpNotifierHandle { - mutex: Mutex::default(), - callback: Some(Box::new(callback)), - handle: std::ptr::null_mut(), - }); - - let status = unsafe { - NotifyIpInterfaceChange( - family, - Some(inner_callback), - &mut *context as *mut _ as *mut _, - FALSE, - (&mut context.handle) as *mut _, - ) - }; - - if status != NO_ERROR { - return Err(io::Error::last_os_error()); - } - - Ok(context) -} - -pub fn get_ip_interface_entry(family: u16, luid: &NET_LUID) -> io::Result<MIB_IPINTERFACE_ROW> { - let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() }; - row.Family = family; - row.InterfaceLuid = *luid; - - let result = unsafe { GetIpInterfaceEntry(&mut row as *mut _) }; - if result != NO_ERROR { - return Err(io::Error::last_os_error()); - } - - Ok(row) -} - -pub fn wait_for_interfaces(luid: &NET_LUID, ipv4: bool, ipv6: bool) -> io::Result<()> { - let (tx, rx) = std::sync::mpsc::channel(); - - let mut found_ipv4 = if ipv4 { false } else { true }; - let mut found_ipv6 = if ipv6 { false } else { true }; - - let _handle = notify_ip_interface_change( - move |row, notification_type| { - if found_ipv4 && found_ipv6 { - return; - } - if notification_type != MibAddInstance { - return; - } - if row.InterfaceLuid.Value != luid.Value { - return; - } - match row.Family as i32 { - AF_INET => found_ipv4 = true, - AF_INET6 => found_ipv6 = true, - _ => (), - } - if found_ipv4 && found_ipv6 { - let _ = tx.send(()); - } - }, - AF_UNSPEC as u16, - )?; - - // Make sure they don't already exist - if (!ipv4 || get_ip_interface_entry(AF_INET as u16, luid).is_ok()) - && (!ipv6 || get_ip_interface_entry(AF_INET6 as u16, luid).is_ok()) - { - return Ok(()); - } - - let _ = rx - .recv_timeout(INTERFACE_WAIT_TIMEOUT) - .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out waiting on interfaces"))?; - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/talpid-core/src/tunnel/windows.rs b/talpid-core/src/tunnel/windows.rs new file mode 100644 index 0000000000..2d4fcf85e5 --- /dev/null +++ b/talpid-core/src/tunnel/windows.rs @@ -0,0 +1,133 @@ +use std::{io, mem, os::windows::io::RawHandle, sync::Mutex}; +use winapi::shared::{ + ifdef::NET_LUID, + netioapi::{ + CancelMibChangeNotify2, GetIpInterfaceEntry, MibAddInstance, NotifyIpInterfaceChange, + MIB_IPINTERFACE_ROW, + }, + ntdef::FALSE, + winerror::{ERROR_NOT_FOUND, NO_ERROR}, + ws2def::{AF_INET, AF_INET6, AF_UNSPEC}, +}; + +/// Context for [`notify_ip_interface_change`]. When it is dropped, +/// the callback is unregistered. +pub struct IpNotifierHandle<'a> { + callback: Mutex<Box<dyn FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>>, + handle: RawHandle, +} + +unsafe impl Send for IpNotifierHandle<'_> {} + +impl<'a> Drop for IpNotifierHandle<'a> { + fn drop(&mut self) { + unsafe { CancelMibChangeNotify2(self.handle as *mut _) }; + } +} + +unsafe extern "system" fn inner_callback( + context: *mut winapi::ctypes::c_void, + row: *mut MIB_IPINTERFACE_ROW, + notify_type: u32, +) { + let context = &mut *(context as *mut IpNotifierHandle<'_>); + context + .callback + .lock() + .expect("NotifyIpInterfaceChange mutex poisoned")(&*row, notify_type); +} + +/// Registers a callback function that is invoked when an interface is added, removed, +/// or changed. +pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, u32) + Send + 'a>( + callback: T, + family: u16, +) -> io::Result<Box<IpNotifierHandle<'a>>> { + let mut context = Box::new(IpNotifierHandle { + callback: Mutex::new(Box::new(callback)), + handle: std::ptr::null_mut(), + }); + + let status = unsafe { + NotifyIpInterfaceChange( + family, + Some(inner_callback), + &mut *context as *mut _ as *mut _, + FALSE, + (&mut context.handle) as *mut _, + ) + }; + + if status == NO_ERROR { + Ok(context) + } else { + Err(io::Error::from_raw_os_error(status as i32)) + } +} + +/// Returns information about a network IP interface. +pub fn get_ip_interface_entry(family: u16, luid: &NET_LUID) -> io::Result<MIB_IPINTERFACE_ROW> { + let mut row: MIB_IPINTERFACE_ROW = unsafe { mem::zeroed() }; + row.Family = family; + row.InterfaceLuid = *luid; + + let result = unsafe { GetIpInterfaceEntry(&mut row as *mut _) }; + if result == NO_ERROR { + Ok(row) + } else { + Err(io::Error::from_raw_os_error(result as i32)) + } +} + +fn ip_interface_entry_exists(family: u16, luid: &NET_LUID) -> io::Result<bool> { + match get_ip_interface_entry(family, luid) { + Ok(_) => Ok(true), + Err(error) if error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => Ok(false), + Err(error) => Err(error), + } +} + +/// Waits until the specified IP interfaces have attached to a given network interface. +pub async fn wait_for_interfaces(luid: NET_LUID, ipv4: bool, ipv6: bool) -> io::Result<()> { + let (tx, rx) = futures::channel::oneshot::channel(); + + let mut found_ipv4 = if ipv4 { false } else { true }; + let mut found_ipv6 = if ipv6 { false } else { true }; + + let mut tx = Some(tx); + + let _handle = notify_ip_interface_change( + move |row, notification_type| { + if found_ipv4 && found_ipv6 { + return; + } + if notification_type != MibAddInstance { + return; + } + if row.InterfaceLuid.Value != luid.Value { + return; + } + match row.Family as i32 { + AF_INET => found_ipv4 = true, + AF_INET6 => found_ipv6 = true, + _ => (), + } + if found_ipv4 && found_ipv6 { + if let Some(tx) = tx.take() { + let _ = tx.send(()); + } + } + }, + AF_UNSPEC as u16, + )?; + + // Make sure they don't already exist + if (!ipv4 || ip_interface_entry_exists(AF_INET as u16, &luid)?) + && (!ipv6 || ip_interface_entry_exists(AF_INET6 as u16, &luid)?) + { + return Ok(()); + } + + let _ = rx.await; + Ok(()) +} diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index 9b62fc47e8..60318c071b 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -524,6 +524,11 @@ mod test { "mock-tunnel".to_string() } + #[cfg(windows)] + fn get_interface_luid(&self) -> u64 { + 0 + } + fn stop(self: Box<Self>) -> Result<(), TunnelError> { Ok(()) } diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 199fe3e6ca..0703633b4e 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -8,6 +8,8 @@ use futures::future::abortable; use lazy_static::lazy_static; #[cfg(target_os = "linux")] use std::env; +#[cfg(windows)] +use std::io; use std::{ collections::HashSet, net::SocketAddr, @@ -54,9 +56,19 @@ pub enum Error { #[error(display = "Failed obtain local address for the UDP socket in Udp2Tcp")] GetLocalUdpAddress(#[error(source)] std::io::Error), - /// Failed to setup connectivity monitor + /// Failed to set up connectivity monitor #[error(display = "Connectivity monitor failed")] ConnectivityMonitorError(#[error(source)] connectivity_check::Error), + + /// Failed to set up IP interfaces. + #[cfg(windows)] + #[error(display = "Failed while waiting on IP interfaces")] + IpInterfacesError(#[error(source)] io::Error), + + /// Failed to set IP addresses on WireGuard interface + #[cfg(target_os = "windows")] + #[error(display = "Failed to set IP addresses on WireGuard interface")] + SetIpAddressesError, } @@ -68,6 +80,8 @@ pub struct WireguardMonitor { event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>, close_msg_sender: mpsc::Sender<CloseMsg>, close_msg_receiver: mpsc::Receiver<CloseMsg>, + #[cfg(target_os = "windows")] + stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>, pinger_stop_sender: mpsc::Sender<()>, _tcp_proxies: Vec<TcpProxy>, } @@ -158,18 +172,11 @@ impl WireguardMonitor { let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?; let iface_name = tunnel.get_interface_name().to_string(); + #[cfg(windows)] + let iface_luid = tunnel.get_interface_luid(); (on_event)(TunnelEvent::InterfaceUp(iface_name.clone())); - #[cfg(target_os = "linux")] - route_manager - .create_routing_rules(config.enable_ipv6) - .map_err(Error::SetupRoutingError)?; - - route_manager - .add_routes(Self::get_routes(&iface_name, &config)) - .map_err(Error::SetupRoutingError)?; - #[cfg(target_os = "windows")] route_manager .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()); @@ -177,11 +184,15 @@ impl WireguardMonitor { let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = mpsc::channel(); let (pinger_tx, pinger_rx) = mpsc::channel(); + #[cfg(target_os = "windows")] + let (stop_setup_tx, stop_setup_rx) = futures::channel::oneshot::channel(); let monitor = WireguardMonitor { tunnel: Arc::new(Mutex::new(Some(tunnel))), event_callback, close_msg_sender, close_msg_receiver, + #[cfg(target_os = "windows")] + stop_setup_tx: Some(stop_setup_tx), pinger_stop_sender: pinger_tx, _tcp_proxies: tcp_proxies, }; @@ -191,13 +202,67 @@ impl WireguardMonitor { let close_sender = monitor.close_msg_sender.clone(); let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( gateway, - iface_name.to_string(), + iface_name.clone(), Arc::downgrade(&monitor.tunnel), pinger_rx, ) .map_err(Error::ConnectivityMonitorError)?; + let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; + #[cfg(windows)] + let runtime = route_manager.runtime_handle(); + std::thread::spawn(move || { + #[cfg(windows)] + { + let iface_close_sender = close_sender.clone(); + let enable_ipv6 = config.ipv6_gateway.is_some(); + + let result = runtime.block_on(async move { + use futures::future::FutureExt; + use winapi::shared::ifdef::NET_LUID; + let luid = NET_LUID { Value: iface_luid }; + let setup_future = super::windows::wait_for_interfaces(luid, true, enable_ipv6); + + futures::select! { + result = setup_future.fuse() => { + result.map_err(|error| + iface_close_sender.send(CloseMsg::SetupError( + Error::IpInterfacesError(error) + )) + .unwrap_or(()) + ) + } + _ = stop_setup_rx.fuse() => Err(()), + } + }); + + if result.is_err() { + return; + } + } + + let setup_iface_routes = move || -> Result<()> { + #[cfg(target_os = "windows")] + if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { + return Err(Error::SetIpAddressesError); + } + + #[cfg(target_os = "linux")] + route_handle + .create_routing_rules(config.enable_ipv6) + .map_err(Error::SetupRoutingError)?; + + route_handle + .add_routes(Self::get_routes(&iface_name, &config)) + .map_err(Error::SetupRoutingError) + }; + + if let Err(error) = setup_iface_routes() { + let _ = close_sender.send(CloseMsg::SetupError(error)); + return; + } + match connectivity_monitor.establish_connectivity() { Ok(true) => { (on_event)(TunnelEvent::Up(metadata)); @@ -291,9 +356,14 @@ impl WireguardMonitor { let wait_result = match self.close_msg_receiver.recv() { Ok(CloseMsg::PingErr) => Err(Error::TimeoutError), Ok(CloseMsg::Stop) => Ok(()), + Ok(CloseMsg::SetupError(error)) => Err(error), Err(_) => Ok(()), }; + #[cfg(windows)] + if let Some(stop_tx) = self.stop_setup_tx.take() { + let _ = stop_tx.send(()); + } let _ = self.pinger_stop_sender.send(()); self.stop_tunnel(); @@ -439,6 +509,7 @@ impl WireguardMonitor { enum CloseMsg { Stop, PingErr, + SetupError(Error), } /// Close handle for a WireGuard tunnel. @@ -458,6 +529,8 @@ impl CloseHandle { pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; + #[cfg(target_os = "windows")] + fn get_interface_luid(&self) -> u64; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::Stats, TunnelError>; #[cfg(target_os = "linux")] @@ -522,11 +595,6 @@ pub enum TunnelError { #[error(display = "Failed to convert adapter alias")] InvalidAlias, - /// Failed to set ip addresses on tunnel interface. - #[cfg(target_os = "windows")] - #[error(display = "Failed to set IP addresses on WireGuard interface")] - SetIpAddressesError, - /// Failure to set up logging #[error(display = "Failed to set up logging")] LoggingError(#[error(source)] logging::Error), diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 2751fbdfbf..2145bb96d6 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -29,7 +29,7 @@ use { type Result<T> = std::result::Result<T, TunnelError>; #[cfg(target_os = "windows")] -use crate::winnet::{self, add_device_ip_addresses}; +use crate::winnet; #[cfg(not(target_os = "windows"))] const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; @@ -44,6 +44,8 @@ impl Drop for LoggingContext { pub struct WgGoTunnel { interface_name: String, + #[cfg(windows)] + interface_luid: u64, handle: Option<i32>, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped @@ -134,6 +136,7 @@ impl WgGoTunnel { .any(|config| config.allowed_ips.iter().any(|ip| ip.is_ipv6())); let mut alias_ptr = std::ptr::null_mut(); + let mut interface_luid = 0u64; let handle = unsafe { wgTurnOn( @@ -142,6 +145,7 @@ impl WgGoTunnel { wait_on_ipv6 as u8, wg_config_str.as_ptr(), &mut alias_ptr, + &mut interface_luid, Some(logging_callback), logging_context.0 as *mut libc::c_void, ) @@ -163,13 +167,9 @@ impl WgGoTunnel { log::debug!("Adapter alias: {}", actual_iface_name); - if !add_device_ip_addresses(&actual_iface_name, &config.tunnel.addresses) { - // Todo: what kind of clean-up is required? - return Err(TunnelError::SetIpAddressesError); - } - Ok(WgGoTunnel { interface_name: actual_iface_name, + interface_luid, handle: Some(handle), _logging_context: logging_context, }) @@ -302,6 +302,11 @@ impl Tunnel for WgGoTunnel { self.interface_name.clone() } + #[cfg(target_os = "windows")] + fn get_interface_luid(&self) -> u64 { + self.interface_luid + } + fn get_tunnel_stats(&self) -> Result<Stats> { let config_str = unsafe { let ptr = wgGetConfig(self.handle.unwrap()); @@ -376,6 +381,7 @@ extern "C" { wait_on_ipv6: u8, settings: *const i8, iface_name_out: *const *mut std::os::raw::c_char, + iface_luid_out: *mut u64, logging_callback: Option<LoggingCallback>, logging_context: *mut libc::c_void, ) -> i32; diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 553cf9377d..fe87d00f0f 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -167,6 +167,15 @@ impl ConnectingState { log::debug!("WireGuard tunnel timed out"); None } + error @ tunnel::Error::WireguardTunnelMonitoringError(..) + if !should_retry(&error) => + { + error!( + "{}", + error.display_chain_with_msg("Tunnel has stopped unexpectedly") + ); + Some(ErrorStateCause::StartTunnelError) + } error => { warn!( "{}", diff --git a/wireguard/libwg/libwg_windows.go b/wireguard/libwg/libwg_windows.go index 42f27148b3..2cf04ed140 100644 --- a/wireguard/libwg/libwg_windows.go +++ b/wireguard/libwg/libwg_windows.go @@ -21,9 +21,7 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/wintun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/mullvad/mullvadvpn-app/wireguard/libwg/interfacewatcher" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/logging" "github.com/mullvad/mullvadvpn-app/wireguard/libwg/tunnelcontainer" ) @@ -43,30 +41,8 @@ func init() { } } -func createInterfaceWatcherEvents(waitOnIpv6 bool, tunLuid uint64) []interfacewatcher.Event { - if waitOnIpv6 { - return []interfacewatcher.Event{ - { - Luid: winipcfg.LUID(tunLuid), - Family: windows.AF_INET, - }, - interfacewatcher.Event { - Luid: winipcfg.LUID(tunLuid), - Family: windows.AF_INET6, - }, - } - } else { - return []interfacewatcher.Event{ - { - Luid: winipcfg.LUID(tunLuid), - Family: windows.AF_INET, - }, - } - } -} - //export wgTurnOn -func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, cIfaceNameOut **C.char, logSink LogSink, logContext LogContext) int32 { +func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, cIfaceNameOut **C.char, cLuidOut *uint64, logSink LogSink, logContext LogContext) int32 { logger := logging.NewLogger(logSink, logContext) if cIfaceNameOut != nil { *cIfaceNameOut = nil @@ -88,13 +64,6 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c // {AFE43773-E1F8-4EBB-8536-576AB86AFE9A} networkId := windows.GUID{0xafe43773, 0xe1f8, 0x4ebb, [8]byte{0x85, 0x36, 0x57, 0x6a, 0xb8, 0x6a, 0xfe, 0x9a}} - watcher, err := interfacewatcher.NewWatcher() - if err != nil { - logger.Errorf("%s\n", err) - return ERROR_GENERAL_FAILURE - } - defer watcher.Destroy() - if tun.WintunPool != MullvadPool { tun.WintunPool = MullvadPool } @@ -132,18 +101,6 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c device.Up() - interfaces := createInterfaceWatcherEvents(waitOnIpv6, nativeTun.LUID()) - - logger.Verbosef("Waiting for interfaces to attach\n") - - if !watcher.Join(interfaces, 5) { - logger.Errorf("Failed to wait for IP interfaces to become available\n") - device.Close() - return ERROR_GENERAL_FAILURE - } - - logger.Verbosef("Interfaces OK\n") - context := tunnelcontainer.Context{ Device: device, Logger: logger, @@ -159,6 +116,9 @@ func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, c if cIfaceNameOut != nil { *cIfaceNameOut = C.CString(actualInterfaceName) } + if cLuidOut != nil { + *cLuidOut = nativeTun.LUID() + } return handle } |
