diff options
| -rw-r--r-- | CHANGELOG.md | 1 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/connectivity_check.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 84 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 44 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_nt.rs | 85 | ||||
| -rw-r--r-- | talpid-types/src/lib.rs | 4 | ||||
| -rw-r--r-- | wireguard/libwg/libwg_windows.go | 2 |
7 files changed, 138 insertions, 87 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index bfb38b4bbb..f48c909157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,7 @@ Line wrap the file at 100 chars. Th - Remove tray icon of current running app version when upgrading. - Allow Mullvad wireguard-nt tunnels to work simultaneously with other wg-nt tunnels. - Fix notifications on Windows not showing if window is unpinned and hidden. +- Wait for IP interfaces to arrive before trying to configure them when using wireguard-nt. #### Android - Fix Quick Settings tile showing wrong state in certain scenarios. diff --git a/talpid-core/src/tunnel/wireguard/connectivity_check.rs b/talpid-core/src/tunnel/wireguard/connectivity_check.rs index ad90cff807..553943b576 100644 --- a/talpid-core/src/tunnel/wireguard/connectivity_check.rs +++ b/talpid-core/src/tunnel/wireguard/connectivity_check.rs @@ -589,11 +589,6 @@ mod test { "mock-tunnel".to_string() } - #[cfg(windows)] - fn get_interface_luid(&self) -> u64 { - 0 - } - fn stop(self: Box<Self>) -> Result<(), TunnelError> { Ok(()) } diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 0e14ea9269..6cc2d85fdb 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -3,7 +3,11 @@ use self::config::Config; use super::tun_provider; use super::{tun_provider::TunProvider, TunnelEvent, TunnelMetadata}; use crate::routing::{self, RequiredRoute}; +#[cfg(windows)] +use futures::channel::{mpsc, oneshot}; use futures::future::abortable; +#[cfg(windows)] +use futures::{FutureExt, StreamExt}; #[cfg(target_os = "linux")] use lazy_static::lazy_static; #[cfg(target_os = "linux")] @@ -13,8 +17,10 @@ use std::io; use std::{ net::{IpAddr, SocketAddr}, path::Path, - sync::{mpsc, Arc, Mutex}, + sync::{mpsc as sync_mpsc, Arc, Mutex}, }; +#[cfg(windows)] +use talpid_types::BoxedError; use talpid_types::{net::TransportProtocol, ErrorExt}; use udp_over_tcp::{TcpOptions, Udp2Tcp}; @@ -63,8 +69,8 @@ pub enum Error { /// Failed to set up IP interfaces. #[cfg(windows)] - #[error(display = "Failed while waiting on IP interfaces")] - IpInterfacesError(#[error(source)] io::Error), + #[error(display = "Failed to set up IP interfaces")] + IpInterfacesError, /// Failed to set IP addresses on WireGuard interface #[cfg(target_os = "windows")] @@ -84,11 +90,11 @@ pub struct WireguardMonitor { + Sync + 'static, >, - close_msg_sender: mpsc::Sender<CloseMsg>, - close_msg_receiver: mpsc::Receiver<CloseMsg>, + close_msg_sender: sync_mpsc::Sender<CloseMsg>, + close_msg_receiver: sync_mpsc::Receiver<CloseMsg>, #[cfg(target_os = "windows")] - stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>, - pinger_stop_sender: mpsc::Sender<()>, + stop_setup_tx: Option<oneshot::Sender<()>>, + pinger_stop_sender: sync_mpsc::Sender<()>, _tcp_proxies: Vec<TcpProxy>, } @@ -184,6 +190,8 @@ impl WireguardMonitor { } } + #[cfg(target_os = "windows")] + let (setup_done_tx, setup_done_rx) = mpsc::channel(0); let tunnel = Self::open_tunnel( runtime.clone(), &config, @@ -191,16 +199,16 @@ impl WireguardMonitor { resource_dir, tun_provider, route_manager, + #[cfg(target_os = "windows")] + setup_done_tx, )?; let iface_name = tunnel.get_interface_name().to_string(); - #[cfg(windows)] - let iface_luid = tunnel.get_interface_luid(); let event_callback = Box::new(on_event.clone()); - let (close_msg_sender, close_msg_receiver) = mpsc::channel(); - let (pinger_tx, pinger_rx) = mpsc::channel(); + let (close_msg_sender, close_msg_receiver) = sync_mpsc::channel(); + let (pinger_tx, pinger_rx) = sync_mpsc::channel(); #[cfg(target_os = "windows")] - let (stop_setup_tx, stop_setup_rx) = futures::channel::oneshot::channel(); + let (stop_setup_tx, stop_setup_rx) = oneshot::channel(); let monitor = WireguardMonitor { runtime: runtime.clone(), tunnel: Arc::new(Mutex::new(Some(tunnel))), @@ -229,37 +237,36 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); std::thread::spawn(move || { - runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone()))); - #[cfg(windows)] { + let mut done_rx = setup_done_rx.fuse(); let iface_close_sender = close_sender.clone(); - let enable_ipv6 = config.ipv6_gateway.is_some(); - let result = runtime.block_on(async move { - use futures::future::FutureExt; - use winapi::shared::ifdef::NET_LUID; - let luid = NET_LUID { Value: iface_luid }; - let setup_future = crate::windows::wait_for_interfaces(luid, true, enable_ipv6); - futures::select! { - result = setup_future.fuse() => { - result.map_err(|error| - iface_close_sender.send(CloseMsg::SetupError( - Error::IpInterfacesError(error) - )) - .unwrap_or(()) - ) + result = done_rx.next() => { + match result { + Some(result) => { + result.map_err(|error| { + log::error!("{}", error.display_chain_with_msg("Failed to configure tunnel interface")); + iface_close_sender.send(CloseMsg::SetupError( + Error::IpInterfacesError + )) + .unwrap_or(()) + }) + } + None => Err(()), + } } _ = stop_setup_rx.fuse() => Err(()), } }); - if result.is_err() { return; } } + runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone()))); + let setup_iface_routes = || -> Result<()> { #[cfg(target_os = "windows")] if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { @@ -322,6 +329,7 @@ impl WireguardMonitor { resource_dir: &Path, tun_provider: &mut TunProvider, route_manager: &mut routing::RouteManager, + #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { #[cfg(target_os = "linux")] if !*FORCE_USERSPACE_WIREGUARD { @@ -360,7 +368,12 @@ impl WireguardMonitor { #[cfg(target_os = "windows")] if config.use_wireguard_nt { - match wireguard_nt::WgNtTunnel::start_tunnel(config, log_path, resource_dir) { + match wireguard_nt::WgNtTunnel::start_tunnel( + config, + log_path, + resource_dir, + setup_done_tx.clone(), + ) { Ok(tunnel) => { log::debug!("Using WireGuardNT"); return Ok(Box::new(tunnel)); @@ -386,6 +399,8 @@ impl WireguardMonitor { Self::get_tunnel_destinations(config), #[cfg(windows)] route_manager, + #[cfg(windows)] + setup_done_tx, ) .map_err(Error::TunnelError)?, )) @@ -560,7 +575,7 @@ enum CloseMsg { /// Close handle for a WireGuard tunnel. #[derive(Clone, Debug)] pub struct CloseHandle { - chan: mpsc::Sender<CloseMsg>, + chan: sync_mpsc::Sender<CloseMsg>, } impl CloseHandle { @@ -574,8 +589,6 @@ impl CloseHandle { pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; - #[cfg(target_os = "windows")] - fn get_interface_luid(&self) -> u64; fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>; } @@ -624,6 +637,11 @@ pub enum TunnelError { #[error(display = "Failed to create tunnel device")] SetupTunnelDeviceError(#[error(source)] tun_provider::Error), + /// Failed to setup a tunnel device. + #[cfg(windows)] + #[error(display = "Failed to config IP interfaces on tunnel device")] + SetupIpInterfaces(#[error(source)] io::Error), + /// Failed to configure Wireguard sockets to bypass the tunnel. #[cfg(target_os = "android")] #[error(display = "Failed to configure Wireguard sockets to bypass the tunnel")] diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 8e98a7fbdb..e53b0cfbb3 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -9,6 +9,8 @@ use crate::tunnel::tun_provider::TunProvider; use crate::tunnel::wireguard::logging::{ clean_up_logging, initialize_logging, wg_go_logging_callback, WgLogLevel, }; +#[cfg(windows)] +use futures::SinkExt; #[cfg(not(windows))] use ipnetwork::IpNetwork; use std::{ @@ -16,6 +18,8 @@ use std::{ os::raw::c_char, path::Path, }; +#[cfg(windows)] +use talpid_types::BoxedError; use zeroize::Zeroize; #[cfg(target_os = "windows")] @@ -51,8 +55,6 @@ impl Drop for LoggingContext { pub struct WgGoTunnel { interface_name: String, - #[cfg(windows)] - interface_luid: u64, handle: Option<i32>, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped @@ -62,6 +64,8 @@ pub struct WgGoTunnel { _logging_context: LoggingContext, #[cfg(target_os = "windows")] _route_callback_handle: Option<crate::winnet::WinNetCallbackHandle>, + #[cfg(target_os = "windows")] + setup_handle: tokio::task::JoinHandle<()>, } impl WgGoTunnel { @@ -111,6 +115,7 @@ impl WgGoTunnel { config: &Config, log_path: Option<&Path>, route_manager: &mut routing::RouteManager, + mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Self> { let route_callback_handle = route_manager .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()) @@ -127,13 +132,6 @@ impl WgGoTunnel { .map(LoggingContext) .map_err(TunnelError::LoggingError)?; - let wait_on_ipv6 = config.ipv6_gateway.is_some() - || config.tunnel.addresses.iter().any(|ip| ip.is_ipv6()) - || config - .peers - .iter() - .any(|config| config.allowed_ips.iter().any(|ip| ip.is_ipv6())); - let mut alias_ptr = std::ptr::null_mut(); let mut interface_luid = 0u64; @@ -141,7 +139,6 @@ impl WgGoTunnel { wgTurnOn( cstr_iface_name.as_ptr(), config.mtu as i64, - wait_on_ipv6 as u8, wg_config_str.as_ptr(), &mut alias_ptr, &mut interface_luid, @@ -163,10 +160,27 @@ impl WgGoTunnel { log::debug!("Adapter alias: {}", actual_iface_name); + let has_ipv6 = config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()); + let setup_handle = tokio::spawn(async move { + use winapi::shared::ifdef::NET_LUID; + let luid = NET_LUID { + Value: interface_luid, + }; + log::debug!("Waiting for tunnel IP interfaces to arrive"); + let _ = done_tx + .send( + crate::windows::wait_for_interfaces(luid, true, has_ipv6) + .await + .map_err(|error| BoxedError::new(TunnelError::SetupIpInterfaces(error))), + ) + .await; + log::debug!("Waiting for tunnel IP interfaces: Done"); + }); + Ok(WgGoTunnel { interface_name: actual_iface_name, - interface_luid, handle: Some(handle), + setup_handle, _logging_context: logging_context, _route_callback_handle: route_callback_handle, }) @@ -248,6 +262,8 @@ impl WgGoTunnel { } fn stop_tunnel(&mut self) -> Result<()> { + #[cfg(windows)] + self.setup_handle.abort(); if let Some(handle) = self.handle.take() { let status = unsafe { wgTurnOff(handle) }; if status < 0 { @@ -299,11 +315,6 @@ impl Tunnel for WgGoTunnel { self.interface_name.clone() } - #[cfg(target_os = "windows")] - fn get_interface_luid(&self) -> u64 { - self.interface_luid - } - fn get_tunnel_stats(&self) -> Result<StatsMap> { let config_str = unsafe { let ptr = wgGetConfig(self.handle.unwrap()); @@ -390,7 +401,6 @@ extern "C" { fn wgTurnOn( iface_name: *const i8, mtu: i64, - wait_on_ipv6: u8, settings: *const i8, iface_name_out: *const *mut std::os::raw::c_char, iface_luid_out: *mut u64, diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs index dde1a991fd..705d09892d 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::windows; use bitflags::bitflags; +use futures::SinkExt; use ipnetwork::IpNetwork; use lazy_static::lazy_static; use std::{ @@ -18,7 +19,7 @@ use std::{ ptr, sync::{Arc, Mutex}, }; -use talpid_types::ErrorExt; +use talpid_types::{BoxedError, ErrorExt}; use widestring::{U16CStr, U16CString}; use winapi::{ shared::{ @@ -131,6 +132,10 @@ pub enum Error { #[error(display = "Failed to set tunnel WireGuard config")] SetWireGuardConfigError(#[error(source)] io::Error), + /// Error listening to tunnel IP interfaces + #[error(display = "Failed to wait on tunnel IP interfaces")] + IpInterfacesError(#[error(source)] io::Error), + /// Failed to set MTU on tunnel device #[error(display = "Failed to set tunnel IPv4 interface MTU")] SetTunnelIpv4MtuError(#[error(source)] io::Error), @@ -165,9 +170,9 @@ pub enum Error { } pub struct WgNtTunnel { - device: Option<WgNtAdapter>, - interface_luid: NET_LUID, + device: Arc<Mutex<Option<WgNtAdapter>>>, interface_name: String, + setup_handle: tokio::task::JoinHandle<()>, _logger_handle: LoggerHandle, } @@ -410,6 +415,7 @@ impl WgNtTunnel { config: &Config, log_path: Option<&Path>, resource_dir: &Path, + mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Self> { let dll = load_wg_nt_dll(resource_dir)?; let logger_handle = LoggerHandle::new(dll.clone(), log_path)?; @@ -421,44 +427,69 @@ impl WgNtTunnel { ) .map_err(Error::CreateTunnelDeviceError)?; - let interface_luid = device.luid(); let interface_name = device .name() .map_err(Error::ObtainAliasError)? .to_string_lossy(); + if let Err(error) = device.set_logging(WireGuardAdapterLogState::On) { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set log state on WireGuard interface") + ); + } + device.set_config(config)?; + let device = Arc::new(Mutex::new(Some(device))); + + let setup_future = setup_ip_listener( + device.clone(), + u32::from(config.mtu), + config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()), + ); + let setup_handle = tokio::spawn(async move { + let _ = done_tx + .send(setup_future.await.map_err(BoxedError::new)) + .await; + }); + let tunnel = WgNtTunnel { - device: Some(device), - interface_luid, + device, interface_name, + setup_handle, _logger_handle: logger_handle, }; - tunnel.configure(config)?; Ok(tunnel) } fn stop_tunnel(&mut self) { - let _ = self.device.take(); + self.setup_handle.abort(); + let _ = self.device.lock().unwrap().take(); } +} - fn configure(&self, config: &Config) -> Result<()> { - let device = self.device.as_ref().unwrap(); - if let Err(error) = device.set_logging(WireGuardAdapterLogState::On) { - log::error!( - "{}", - error.display_chain_with_msg("Failed to set log state on WireGuard interface") - ); - } - device.set_config(config)?; - prepare_interface(&device.luid(), AF_INET as u16, u32::from(config.mtu)) - .map_err(Error::SetTunnelIpv4MtuError)?; - if config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()) { - prepare_interface(&device.luid(), AF_INET6 as u16, u32::from(config.mtu)) - .map_err(Error::SetTunnelIpv6MtuError)?; - } +async fn setup_ip_listener( + device: Arc<Mutex<Option<WgNtAdapter>>>, + mtu: u32, + has_ipv6: bool, +) -> Result<()> { + let luid = { device.lock().unwrap().as_ref().unwrap().luid() }; + + log::debug!("Waiting for tunnel IP interfaces to arrive"); + windows::wait_for_interfaces(luid.clone(), true, has_ipv6) + .await + .map_err(Error::IpInterfacesError)?; + log::debug!("Waiting for tunnel IP interfaces: Done"); + + prepare_interface(&luid, AF_INET as u16, mtu).map_err(Error::SetTunnelIpv4MtuError)?; + if has_ipv6 { + prepare_interface(&luid, AF_INET6 as u16, mtu).map_err(Error::SetTunnelIpv6MtuError)?; + } + + if let Some(device) = &*device.lock().unwrap() { device .set_state(WgAdapterState::Up) - .map_err(Error::EnableTunnelError)?; + .map_err(Error::EnableTunnelError) + } else { Ok(()) } } @@ -914,12 +945,8 @@ impl Tunnel for WgNtTunnel { self.interface_name.clone() } - fn get_interface_luid(&self) -> u64 { - self.interface_luid.Value - } - fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> { - if let Some(ref device) = self.device { + if let Some(ref device) = &*self.device.lock().unwrap() { let mut map = StatsMap::new(); let (_interface, peers) = device.get_config().map_err(|error| { log::error!( diff --git a/talpid-types/src/lib.rs b/talpid-types/src/lib.rs index 835077855b..04b83fec72 100644 --- a/talpid-types/src/lib.rs +++ b/talpid-types/src/lib.rs @@ -42,7 +42,7 @@ impl<E: Error> ErrorExt for E { } #[derive(Debug)] -pub struct BoxedError(Box<dyn Error>); +pub struct BoxedError(Box<dyn Error + 'static + Send>); impl fmt::Display for BoxedError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -57,7 +57,7 @@ impl Error for BoxedError { } impl BoxedError { - pub fn new(error: impl Error + 'static) -> Self { + pub fn new(error: impl Error + 'static + Send) -> Self { BoxedError(Box::new(error)) } } diff --git a/wireguard/libwg/libwg_windows.go b/wireguard/libwg/libwg_windows.go index 2f7dcc7ecc..af7d7f6488 100644 --- a/wireguard/libwg/libwg_windows.go +++ b/wireguard/libwg/libwg_windows.go @@ -42,7 +42,7 @@ func init() { } //export wgTurnOn -func wgTurnOn(cIfaceName *C.char, mtu int, waitOnIpv6 bool, cSettings *C.char, cIfaceNameOut **C.char, cLuidOut *uint64, logSink LogSink, logContext LogContext) int32 { +func wgTurnOn(cIfaceName *C.char, mtu int, cSettings *C.char, cIfaceNameOut **C.char, cLuidOut *uint64, logSink LogSink, logContext LogContext) int32 { logger := logging.NewLogger(logSink, logContext) if cIfaceNameOut != nil { *cIfaceNameOut = nil |
