diff options
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 59 | ||||
| -rw-r--r-- | talpid-wireguard/src/logging.rs | 16 | ||||
| -rw-r--r-- | talpid-wireguard/src/stats.rs | 3 | ||||
| -rw-r--r-- | talpid-wireguard/src/wireguard_go.rs | 185 |
4 files changed, 42 insertions, 221 deletions
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 784cafa75b..2a98ad3fb0 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -50,12 +50,14 @@ mod connectivity_check; mod logging; mod ping_monitor; mod stats; +#[cfg(unix)] mod wireguard_go; #[cfg(target_os = "linux")] pub(crate) mod wireguard_kernel; #[cfg(windows)] mod wireguard_nt; +#[cfg(unix)] use self::wireguard_go::WgGoTunnel; type Result<T> = std::result::Result<T, Error>; @@ -764,44 +766,33 @@ impl WireguardMonitor { } #[cfg(target_os = "windows")] - if config.use_wireguard_nt { - log::debug!("Using WireGuardNT"); - return wireguard_nt::WgNtTunnel::start_tunnel( - config, - log_path, - resource_dir, - setup_done_tx, - ) - .map(|tun| Box::new(tun) as Box<dyn Tunnel + 'static>) - .map_err(Error::TunnelError); + { + wireguard_nt::WgNtTunnel::start_tunnel(config, log_path, resource_dir, setup_done_tx) + .map(|tun| Box::new(tun) as Box<dyn Tunnel + 'static>) + .map_err(Error::TunnelError) } - #[cfg(not(windows))] - let routes = Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes); + #[cfg(not(target_os = "windows"))] + { + let routes = + Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes); - #[cfg(target_os = "android")] - let config = Self::patch_allowed_ips(config, psk_negotiation); + #[cfg(target_os = "android")] + let config = Self::patch_allowed_ips(config, psk_negotiation); - #[cfg(any(target_os = "linux", windows))] - log::debug!("Using userspace WireGuard implementation"); - Ok(Box::new( - WgGoTunnel::start_tunnel( - #[allow(clippy::needless_borrow)] - &config, - log_path, - #[cfg(not(windows))] - tun_provider, - #[cfg(not(windows))] - routes, - #[cfg(windows)] - route_manager_handle, - #[cfg(windows)] - setup_done_tx, - #[cfg(windows)] - &runtime, - ) - .map_err(Error::TunnelError)?, - )) + #[cfg(target_os = "linux")] + log::debug!("Using userspace WireGuard implementation"); + Ok(Box::new( + WgGoTunnel::start_tunnel( + #[allow(clippy::needless_borrow)] + &config, + log_path, + tun_provider, + routes, + ) + .map_err(Error::TunnelError)?, + )) + } } /// Blocks the current thread until tunnel disconnects diff --git a/talpid-wireguard/src/logging.rs b/talpid-wireguard/src/logging.rs index 5f006d418c..99cf405d23 100644 --- a/talpid-wireguard/src/logging.rs +++ b/talpid-wireguard/src/logging.rs @@ -1,5 +1,6 @@ use once_cell::sync::Lazy; use parking_lot::Mutex; +#[cfg(unix)] use std::ffi::{c_char, c_void}; use std::{collections::HashMap, fmt, fs, io::Write, path::Path}; @@ -90,6 +91,7 @@ fn log_inner(logfile: &mut fs::File, level: LogLevel, tag: &str, msg: &str) { } // Callback that receives messages from WireGuard +#[cfg(unix)] pub unsafe extern "system" fn wg_go_logging_callback( level: WgLogLevel, msg: *const c_char, @@ -98,14 +100,7 @@ pub unsafe extern "system" fn wg_go_logging_callback( let mut map = LOG_MUTEX.lock(); if let Some(logfile) = map.get_mut(&(context as u32)) { let managed_msg = if !msg.is_null() { - #[cfg(not(target_os = "windows"))] - let m = std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string(); - #[cfg(target_os = "windows")] - let m = std::ffi::CStr::from_ptr(msg) - .to_string_lossy() - .to_string() - .replace('\n', "\r\n"); - m + std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string() } else { "Logging message from WireGuard is NULL".to_string() }; @@ -118,8 +113,11 @@ pub unsafe extern "system" fn wg_go_logging_callback( } } -pub type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose // const WG_GO_LOG_SILENT: WgLogLevel = 0; // const WG_GO_LOG_ERROR: WgLogLevel = 1; +#[cfg(unix)] const WG_GO_LOG_VERBOSE: WgLogLevel = 2; + +#[cfg(unix)] +pub type WgLogLevel = u32; diff --git a/talpid-wireguard/src/stats.rs b/talpid-wireguard/src/stats.rs index bd7b578545..79db5937a0 100644 --- a/talpid-wireguard/src/stats.rs +++ b/talpid-wireguard/src/stats.rs @@ -27,6 +27,7 @@ pub struct Stats { pub type StatsMap = std::collections::HashMap<[u8; 32], Stats>; impl Stats { + #[cfg(unix)] pub fn parse_config_str(config: &str) -> Result<StatsMap, Error> { let mut map = StatsMap::new(); @@ -124,6 +125,7 @@ impl Stats { mod test { use super::{Error, Stats}; + #[cfg(unix)] #[test] fn test_parsing() { let valid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=2740\nrx_bytes=2396\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; @@ -137,6 +139,7 @@ mod test { assert_eq!(stats[&pubkey].tx_bytes, 2740); } + #[cfg(unix)] #[test] fn test_parsing_invalid_input() { let invalid_input = "private_key=0000000000000000000000000000000000000000000000000000000000000000\npublic_key=0000000000000000000000000000000000000000000000000000000000000000\npreshared_key=0000000000000000000000000000000000000000000000000000000000000000\nprotocol_version=1\nendpoint=000.000.000.000:00000\nlast_handshake_time_sec=1578420649\nlast_handshake_time_nsec=369416131\ntx_bytes=27error40\npersistent_keepalive_interval=0\nallowed_ip=0.0.0.0/0\n"; diff --git a/talpid-wireguard/src/wireguard_go.rs b/talpid-wireguard/src/wireguard_go.rs index e848e32679..24ad613659 100644 --- a/talpid-wireguard/src/wireguard_go.rs +++ b/talpid-wireguard/src/wireguard_go.rs @@ -3,41 +3,29 @@ use super::{ Config, Tunnel, TunnelError, }; use crate::logging::{clean_up_logging, initialize_logging, wg_go_logging_callback, WgLogLevel}; -#[cfg(windows)] -use futures::SinkExt; +use ipnetwork::IpNetwork; use std::{ ffi::{c_char, c_void, CStr}, future::Future, path::Path, pin::Pin, }; -#[cfg(windows)] -use talpid_types::BoxedError; +use talpid_tunnel::tun_provider::TunProvider; use zeroize::Zeroize; -#[cfg(not(windows))] -use {ipnetwork::IpNetwork, talpid_tunnel::tun_provider::TunProvider}; - -#[cfg(target_os = "windows")] -use std::ffi::CString; #[cfg(target_os = "android")] use talpid_tunnel::tun_provider; -#[cfg(not(target_os = "windows"))] -use { - std::{ - net::IpAddr, - os::unix::io::{AsRawFd, RawFd}, - }, - talpid_tunnel::tun_provider::{Tun, TunConfig}, +use std::{ + net::IpAddr, + os::unix::io::{AsRawFd, RawFd}, }; +use talpid_tunnel::tun_provider::{Tun, TunConfig}; type Result<T> = std::result::Result<T, TunnelError>; -#[cfg(not(target_os = "windows"))] use std::sync::{Arc, Mutex}; -#[cfg(not(target_os = "windows"))] const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; struct LoggingContext(u32); @@ -53,20 +41,14 @@ pub struct WgGoTunnel { handle: Option<i32>, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped - #[cfg(not(target_os = "windows"))] _tunnel_device: Tun, // context that maps to fs::File instance, used with logging callback _logging_context: LoggingContext, - #[cfg(target_os = "windows")] - _route_callback_handle: Option<talpid_routing::CallbackHandle>, - #[cfg(target_os = "windows")] - setup_handle: tokio::task::JoinHandle<()>, #[cfg(target_os = "android")] tun_provider: Arc<Mutex<TunProvider>>, } impl WgGoTunnel { - #[cfg(not(target_os = "windows"))] pub fn start_tunnel( config: &Config, log_path: Option<&Path>, @@ -113,139 +95,6 @@ impl WgGoTunnel { }) } - #[cfg(target_os = "windows")] - pub fn start_tunnel( - config: &Config, - log_path: Option<&Path>, - route_manager_handle: talpid_routing::RouteManagerHandle, - mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, - runtime: &tokio::runtime::Handle, - ) -> Result<Self> { - use talpid_types::ErrorExt; - - let route_callback_handle = runtime - .block_on( - route_manager_handle.add_default_route_change_callback(Box::new( - WgGoTunnel::default_route_changed_callback, - )), - ) - .ok(); - if route_callback_handle.is_none() { - log::warn!("Failed to register default route callback"); - } - - let wg_config_str = config.to_userspace_format(); - let iface_name: String = "Mullvad".to_string(); - let cstr_iface_name = - CString::new(iface_name.as_bytes()).map_err(TunnelError::InterfaceNameError)?; - let logging_context = initialize_logging(log_path) - .map(LoggingContext) - .map_err(TunnelError::LoggingError)?; - - let mut alias_ptr = std::ptr::null_mut(); - let mut interface_luid = 0u64; - - let handle = unsafe { - wgTurnOn( - cstr_iface_name.as_ptr(), - config.mtu as i64, - wg_config_str.as_ptr(), - &mut alias_ptr, - &mut interface_luid, - Some(wg_go_logging_callback), - logging_context.0 as *mut c_void, - ) - }; - check_wg_status(handle)?; - - let actual_iface_name = { - let actual_iface_name_c = unsafe { CStr::from_ptr(alias_ptr) }; - let actual_iface_name = actual_iface_name_c - .to_str() - .map_err(|_| TunnelError::InvalidAlias)? - .to_string(); - unsafe { wgFreePtr(alias_ptr as *mut c_void) }; - actual_iface_name - }; - - log::debug!("Adapter alias: {}", actual_iface_name); - - let has_ipv6 = config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()); - let setup_handle = tokio::spawn(async move { - use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH; - let luid = NET_LUID_LH { - Value: interface_luid, - }; - log::debug!("Waiting for tunnel IP interfaces to arrive"); - - let prepare_interfaces = async move { - talpid_windows_net::wait_for_interfaces(luid, true, has_ipv6).await?; - - if let Err(error) = - talpid_tunnel::network_interface::initialize_interfaces(luid, None) - { - log::error!( - "{}", - error.display_chain_with_msg("Failed to set tunnel interface metric"), - ); - } - - Ok(()) - }; - - let _ = done_tx - .send( - prepare_interfaces - .await - .map_err(|error| BoxedError::new(TunnelError::SetupIpInterfaces(error))), - ) - .await; - log::debug!("Waiting for tunnel IP interfaces: Done"); - }); - - Ok(WgGoTunnel { - interface_name: actual_iface_name, - handle: Some(handle), - setup_handle, - _logging_context: logging_context, - _route_callback_handle: route_callback_handle, - }) - } - - // Callback to be used to rebind the tunnel sockets when the default route changes - #[cfg(target_os = "windows")] - pub fn default_route_changed_callback( - event_type: crate::routing::EventType<'_>, - address_family: talpid_windows_net::AddressFamily, - ) { - use crate::routing::EventType::*; - use windows_sys::Win32::NetworkManagement::IpHelper::ConvertInterfaceLuidToIndex; - - let iface_idx: u32 = match event_type { - Updated(default_route) => { - let mut iface_idx = 0u32; - let iface_luid = default_route.iface; - let status = unsafe { ConvertInterfaceLuidToIndex(&iface_luid, &mut iface_idx) }; - if status != 0 { - log::error!( - "Failed to convert interface LUID to interface index: {}: {}", - status, - std::io::Error::last_os_error() - ); - return; - } - iface_idx - } - // if there is no new default route, specify 0 as the interface index - Removed => 0, - // ignore interface updates that don't affect the interface to use - UpdatedDetails(_) => return, - }; - - unsafe { wgRebindTunnelSocket(address_family.to_af_family(), iface_idx) }; - } - - #[cfg(not(target_os = "windows"))] fn create_tunnel_config(config: &Config, routes: impl Iterator<Item = IpNetwork>) -> TunConfig { let mut dns_servers = vec![IpAddr::V4(config.ipv4_gateway)]; dns_servers.extend(config.ipv6_gateway.map(IpAddr::V6)); @@ -287,8 +136,6 @@ impl WgGoTunnel { } fn stop_tunnel(&mut self) -> Result<()> { - #[cfg(windows)] - self.setup_handle.abort(); if let Some(handle) = self.handle.take() { let status = unsafe { wgTurnOff(handle) }; if status < 0 { @@ -298,7 +145,6 @@ impl WgGoTunnel { Ok(()) } - #[cfg(not(target_os = "windows"))] fn get_tunnel( tun_provider: Arc<Mutex<TunProvider>>, config: &Config, @@ -420,7 +266,6 @@ fn check_wg_status(wg_code: i32) -> Result<()> { } } -#[cfg(unix)] pub type Fd = std::os::unix::io::RawFd; pub type LoggingCallback = @@ -435,7 +280,7 @@ extern "C" { /// /// Positive return values are tunnel handles for this specific wireguard tunnel instance. /// Negative return values signify errors. All error codes are opaque. - #[cfg(not(any(target_os = "android", target_os = "windows")))] + #[cfg(not(target_os = "android"))] fn wgTurnOn( mtu: isize, settings: *const i8, @@ -453,18 +298,6 @@ extern "C" { logging_context: *mut c_void, ) -> i32; - // Windows - #[cfg(target_os = "windows")] - fn wgTurnOn( - iface_name: *const i8, - mtu: i64, - settings: *const i8, - iface_name_out: *mut *mut c_char, - iface_luid_out: *mut u64, - logging_callback: Option<LoggingCallback>, - logging_context: *mut c_void, - ) -> i32; - // Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. fn wgTurnOff(handle: i32) -> i32; @@ -484,8 +317,4 @@ extern "C" { // Returns the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] fn wgGetSocketV6(handle: i32) -> Fd; - - // Rebind tunnel socket when network interfaces change - #[cfg(target_os = "windows")] - fn wgRebindTunnelSocket(family: u16, interfaceIndex: u32); } |
