diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-09-28 12:44:12 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-09-28 12:44:12 +0200 |
| commit | 827d95c831f9ef8de4b419a6f7913377a20e8cf9 (patch) | |
| tree | 2ffa5de8ee4ba2f777b4bfe9001dca44696fad6e | |
| parent | 31e62ea07e957b2e1d285c8eb85605ce8cba5e69 (diff) | |
| parent | a247e6220fe924c89923beae638dfc182797ba18 (diff) | |
| download | mullvadvpn-827d95c831f9ef8de4b419a6f7913377a20e8cf9.tar.xz mullvadvpn-827d95c831f9ef8de4b419a6f7913377a20e8cf9.zip | |
Merge branch 'wg-nt'
27 files changed, 1731 insertions, 50 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 58dcef1163..c1b1c29f52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Line wrap the file at 100 chars. Th #### Windows - Resolve symbolic links and junctions for excluded apps. +- Add opt-in support for NT kernel WireGuard driver. It can be enabled in the CLI. ### Changed - Only use the account history file to store the last used account. diff --git a/Cargo.lock b/Cargo.lock index 54e805dd57..3b2733a019 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2485,6 +2485,7 @@ version = "0.1.0" dependencies = [ "async-trait", "atty", + "bitflags", "byteorder", "cfg-if 1.0.0", "chrono", @@ -436,7 +436,7 @@ echo "org.gradle.jvmargs=-Xmx4608M" >> ~/.gradle/gradle.properties * `"network-manager"`: use `NetworkManager` service through DBus * `TALPID_FORCE_USERSPACE_WIREGUARD` - Forces the daemon to use the userspace implementation of - WireGuard on Linux. + WireGuard on Linux and Windows. * `TALPID_DNS_CACHE_POLICY` - On Windows, this changes how DNS is configured: * `1`: The default. This sets a global list of DNS servers that `dnscache` will use instead of diff --git a/dist-assets/binaries b/dist-assets/binaries -Subproject c77da1b6ca952289acb668dc8a53a03367805ff +Subproject 19a97997b188855d0ba5aedb7419683df45d93b diff --git a/dist-assets/windows/installer.nsh b/dist-assets/windows/installer.nsh index 5f0151f009..c1dac59bf8 100644 --- a/dist-assets/windows/installer.nsh +++ b/dist-assets/windows/installer.nsh @@ -13,6 +13,7 @@ # !define WINTUN_POOL "Mullvad" +!define WG_NT_POOL "Mullvad" # "sc" exit code !define SERVICE_STARTED 0 @@ -59,19 +60,20 @@ !define PERSISTENT_BLOCK_OUTBOUND_IPV4_FILTER_GUID "{79860c64-9a5e-48a3-b5f3-d64b41659aa5}" # -# ExtractWintun +# ExtractWireGuard # -# Extract Wintun installer into $TEMP +# Extract Wintun and WireGuardNT installer into $TEMP # -!macro ExtractWintun +!macro ExtractWireGuard SetOutPath "$TEMP" File "${BUILD_RESOURCES_DIR}\binaries\x86_64-pc-windows-msvc\wintun\wintun.dll" + File "${BUILD_RESOURCES_DIR}\binaries\x86_64-pc-windows-msvc\wireguard-nt\wireguard.dll" File "${BUILD_RESOURCES_DIR}\..\windows\driverlogic\bin\x64-Release\driverlogic.exe" !macroend -!define ExtractWintun '!insertmacro "ExtractWintun"' +!define ExtractWireGuard '!insertmacro "ExtractWireGuard"' # # ExtractMullvadSetup @@ -222,6 +224,41 @@ !define RemoveWintun '!insertmacro "RemoveWintun"' # +# RemoveWireGuardNt +# +# Try to remove WireGuardNT +# +!macro RemoveWireGuardNt + Push $0 + Push $1 + + log::Log "RemoveWireGuardNt()" + + nsExec::ExecToStack '"$TEMP\driverlogic.exe" wg-nt-cleanup ${WG_NT_POOL}' + Pop $0 + Pop $1 + + ${If} $0 != ${DL_GENERAL_SUCCESS} + IntFmt $0 "0x%X" $0 + StrCpy $R0 "Failed to remove WireGuardNT pool: error $0" + log::LogWithDetails $R0 $1 + Goto RemoveWireGuardNt_return_only + ${EndIf} + + log::Log "RemoveWireGuardNt() completed successfully" + + Push 0 + Pop $R0 + + RemoveWireGuardNt_return_only: + + Pop $1 + Pop $0 + +!macroend + +!define RemoveWireGuardNt '!insertmacro "RemoveWireGuardNt"' +# # RemoveAbandonedWintunAdapter # # Removes old Wintun interface, even if it belongs to a different pool. @@ -1244,8 +1281,9 @@ ${ClearFirewallRules} ${RemoveWireGuardKey} - ${ExtractWintun} + ${ExtractWireGuard} ${RemoveWintun} + ${RemoveWireGuardNt} ${ExtractSplitTunnelDriver} ${RemoveSplitTunnelDriver} diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js index 452baad71f..6f74d65eda 100644 --- a/gui/tasks/distribution.js +++ b/gui/tasks/distribution.js @@ -114,6 +114,7 @@ const config = { { from: distAssets('binaries/x86_64-pc-windows-msvc/sslocal.exe'), to: '.' }, { from: root('build/lib/x86_64-pc-windows-msvc/libwg.dll'), to: '.' }, { from: distAssets('binaries/x86_64-pc-windows-msvc/wintun/wintun.dll'), to: '.' }, + { from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard-nt/wireguard.dll'), to: '.' }, ], }, diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs index 08306b70eb..e01d81a9da 100644 --- a/mullvad-cli/src/cmds/tunnel.rs +++ b/mullvad-cli/src/cmds/tunnel.rs @@ -34,11 +34,19 @@ impl Command for Tunnel { } fn create_wireguard_subcommand() -> clap::App<'static, 'static> { - clap::SubCommand::with_name("wireguard") + let subcmd = clap::SubCommand::with_name("wireguard") .about("Manage options for Wireguard tunnels") .setting(clap::AppSettings::SubcommandRequiredElseHelp) .subcommand(create_wireguard_mtu_subcommand()) - .subcommand(create_wireguard_keys_subcommand()) + .subcommand(create_wireguard_keys_subcommand()); + #[cfg(windows)] + { + subcmd.subcommand(create_wireguard_use_wg_nt_subcommand()) + } + #[cfg(not(windows))] + { + subcmd + } } fn create_wireguard_mtu_subcommand() -> clap::App<'static, 'static> { @@ -61,6 +69,22 @@ fn create_wireguard_keys_subcommand() -> clap::App<'static, 'static> { .subcommand(create_wireguard_keys_rotation_interval_subcommand()) } +#[cfg(windows)] +fn create_wireguard_use_wg_nt_subcommand() -> clap::App<'static, 'static> { + clap::SubCommand::with_name("use-wireguard-nt") + .about("Enable or disable wireguard-nt") + .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::SubCommand::with_name("get")) + .subcommand( + clap::SubCommand::with_name("set").arg( + clap::Arg::with_name("policy") + .required(true) + .takes_value(true) + .possible_values(&["on", "off"]), + ), + ) +} + fn create_wireguard_keys_rotation_interval_subcommand() -> clap::App<'static, 'static> { clap::SubCommand::with_name("rotation-interval") .about("Manage automatic key rotation (given in hours)") @@ -147,6 +171,13 @@ impl Tunnel { _ => unreachable!("unhandled command"), }, + #[cfg(windows)] + ("use-wireguard-nt", Some(matches)) => match matches.subcommand() { + ("get", _) => Self::process_wireguard_use_wg_nt_get().await, + ("set", Some(matches)) => Self::process_wireguard_use_wg_nt_set(matches).await, + _ => unreachable!("unhandled command"), + }, + _ => unreachable!("unhandled command"), } } @@ -180,6 +211,26 @@ impl Tunnel { Ok(()) } + #[cfg(windows)] + async fn process_wireguard_use_wg_nt_get() -> Result<()> { + let tunnel_options = Self::get_tunnel_options().await?; + if tunnel_options.wireguard.unwrap().use_wireguard_nt { + println!("enabled"); + } else { + println!("disabled"); + } + Ok(()) + } + + #[cfg(windows)] + async fn process_wireguard_use_wg_nt_set(matches: &clap::ArgMatches<'_>) -> Result<()> { + let new_state = matches.value_of("policy").unwrap() == "on"; + let mut rpc = new_rpc_client().await?; + rpc.set_use_wireguard_nt(new_state).await?; + println!("Updated wireguard-nt setting"); + Ok(()) + } + async fn process_wireguard_key_check() -> Result<()> { let mut rpc = new_rpc_client().await?; let key = rpc.get_wireguard_key(()).await; diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 19692d57fa..bde37e85fb 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -281,6 +281,9 @@ pub enum DaemonCommand { /// Disable split tunnel #[cfg(windows)] SetSplitTunnelState(ResponseTx<(), Error>, bool), + /// Toggle wireguard-nt on or off + #[cfg(target_os = "windows")] + UseWireGuardNt(ResponseTx<(), Error>, bool), /// Makes the daemon exit the main loop and quit. Shutdown, /// Saves the target tunnel state and enters a blocking state. The state is restored @@ -1230,6 +1233,8 @@ where ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await, #[cfg(windows)] SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await, + #[cfg(target_os = "windows")] + UseWireGuardNt(tx, state) => self.on_use_wireguard_nt(tx, state).await, Shutdown => self.trigger_shutdown_event(), PrepareRestart => self.on_prepare_restart(), #[cfg(target_os = "android")] @@ -1937,6 +1942,35 @@ where } } + #[cfg(windows)] + async fn on_use_wireguard_nt(&mut self, tx: ResponseTx<(), Error>, state: bool) { + let save_result = self + .settings + .set_use_wireguard_nt(state) + .await + .map_err(Error::SettingsError); + match save_result { + Ok(settings_changed) => { + Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response"); + if settings_changed { + self.event_listener + .notify_settings(self.settings.to_settings()); + if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { + info!("Initiating tunnel restart"); + self.reconnect_tunnel(); + } + } + } + Err(error) => { + error!( + "{}", + error.display_chain_with_msg("Unable to save settings") + ); + Self::oneshot_send(tx, Err(error), "use_wireguard_nt response"); + } + } + } + async fn on_update_relay_settings( &mut self, tx: ResponseTx<(), settings::Error>, diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 0b2f6b463f..cef8d42f78 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -688,6 +688,22 @@ impl ManagementService for ManagementServiceImpl { async fn set_split_tunnel_state(&self, _: Request<bool>) -> ServiceResult<()> { Ok(Response::new(())) } + + #[cfg(windows)] + async fn set_use_wireguard_nt(&self, request: Request<bool>) -> ServiceResult<()> { + log::debug!("set_use_wireguard_nt"); + let state = request.into_inner(); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::UseWireGuardNt(tx, state))?; + self.wait_for_result(rx) + .await? + .map_err(map_daemon_error) + .map(Response::new) + } + #[cfg(not(windows))] + async fn set_use_wireguard_nt(&self, _: Request<bool>) -> ServiceResult<()> { + Ok(Response::new(())) + } } impl ManagementServiceImpl { diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index 02568e3226..e3cfeb8ddc 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -330,6 +330,20 @@ impl SettingsPersister { self.update(should_save).await } + #[cfg(windows)] + pub async fn set_use_wireguard_nt(&mut self, state: bool) -> Result<bool, Error> { + let should_save = Self::update_field( + &mut self + .settings + .tunnel_options + .wireguard + .options + .use_wireguard_nt, + state, + ); + self.update(should_save).await + } + fn update_field<T: Eq>(field: &mut T, new_value: T) -> bool { if *field != new_value { *field = new_value; diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 8711ba7b1b..a6d88bb0e0 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -69,6 +69,8 @@ service ManagementService { rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {} rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} + + rpc SetUseWireguardNt(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} } message RelaySettingsUpdate { @@ -379,6 +381,7 @@ message TunnelOptions { message WireguardOptions { uint32 mtu = 1; google.protobuf.Duration rotation_interval = 2; + bool use_wireguard_nt = 3; } message GenericOptions { bool enable_ipv6 = 1; diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 81fdbb0e87..4e3a7217eb 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -562,6 +562,10 @@ impl From<&mullvad_types::settings::TunnelOptions> for TunnelOptions { .wireguard .rotation_interval .map(|ivl| Duration::from(std::time::Duration::from(ivl))), + #[cfg(windows)] + use_wireguard_nt: options.wireguard.options.use_wireguard_nt, + #[cfg(not(windows))] + use_wireguard_nt: false, }), generic: Some(tunnel_options::GenericOptions { enable_ipv6: options.generic.enable_ipv6, @@ -1199,6 +1203,8 @@ impl TryFrom<TunnelOptions> for mullvad_types::settings::TunnelOptions { } else { None }, + #[cfg(windows)] + use_wireguard_nt: wireguard_options.use_wireguard_nt, }, rotation_interval: wireguard_options .rotation_interval diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 69b2f1e42c..f8ebc64564 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" publish = false [dependencies] +bitflags = "1.2" async-trait = "0.1" atty = "0.2" cfg-if = "1.0" diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 6501162616..f4f34be00b 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -130,6 +130,7 @@ impl TunnelMonitor { runtime, &config, log_file, + resource_dir, on_event, tun_provider, route_manager, @@ -161,6 +162,7 @@ impl TunnelMonitor { runtime: tokio::runtime::Handle, params: &wireguard_types::TunnelParameters, log: Option<PathBuf>, + resource_dir: &Path, on_event: L, tun_provider: &mut TunProvider, route_manager: &mut RouteManager, @@ -177,6 +179,7 @@ impl TunnelMonitor { runtime, config, log.as_ref().map(|p| p.as_path()), + resource_dir, on_event, tun_provider, route_manager, diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs index ae82483c66..252bf8418f 100644 --- a/talpid-core/src/tunnel/wireguard/config.rs +++ b/talpid-core/src/tunnel/wireguard/config.rs @@ -23,6 +23,9 @@ pub struct Config { /// Enable IPv6 routing rules #[cfg(target_os = "linux")] pub enable_ipv6: bool, + /// Temporary switch for wireguard-nt + #[cfg(target_os = "windows")] + pub use_wireguard_nt: bool, } const DEFAULT_MTU: u16 = 1380; @@ -109,6 +112,8 @@ impl Config { fwmark: crate::linux::TUNNEL_FW_MARK, #[cfg(target_os = "linux")] enable_ipv6: generic_options.enable_ipv6, + #[cfg(target_os = "windows")] + use_wireguard_nt: wg_options.use_wireguard_nt, }) } diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs index 64f2f71bf5..e795326854 100644 --- a/talpid-core/src/tunnel/wireguard/logging.rs +++ b/talpid-core/src/tunnel/wireguard/logging.rs @@ -1,5 +1,5 @@ use parking_lot::Mutex; -use std::{collections::HashMap, fs, io::Write, path::Path}; +use std::{collections::HashMap, fmt, fs, io::Write, path::Path}; lazy_static::lazy_static! { static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new()); @@ -44,14 +44,58 @@ pub fn clean_up_logging(ordinal: u32) { map.remove(&ordinal); } +#[allow(dead_code)] +pub enum LogLevel { + Verbose, + Info, + Warning, + Error, +} + +impl fmt::Display for LogLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl AsRef<str> for LogLevel { + fn as_ref(&self) -> &str { + match self { + LogLevel::Verbose => "VERBOSE", + LogLevel::Info => "INFO", + LogLevel::Warning => "WARNING", + LogLevel::Error => "ERROR", + } + } +} + +#[cfg(windows)] +pub fn log(context: u32, level: LogLevel, tag: &str, msg: &str) { + let mut map = LOG_MUTEX.lock(); + if let Some(logfile) = map.get_mut(&(context as u32)) { + log_inner(logfile, level, tag, msg); + } +} + +fn log_inner(logfile: &mut fs::File, level: LogLevel, tag: &str, msg: &str) { + let _ = write!( + logfile, + "{}[{}][{}] {}", + chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"), + tag, + level, + msg, + ); +} + // Callback that receives messages from WireGuard -pub unsafe extern "system" fn logging_callback( +pub unsafe extern "system" fn wg_go_logging_callback( level: WgLogLevel, msg: *const libc::c_char, context: *mut libc::c_void, ) { - let map = LOG_MUTEX.lock(); - if let Some(mut logfile) = map.get(&(context as u32)) { + let mut map = LOG_MUTEX.lock(); + if let Some(logfile) = map.get_mut(&(context as u32)) { let managed_msg = if !msg.is_null() { #[cfg(not(target_os = "windows"))] let m = std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string(); @@ -65,24 +109,14 @@ pub unsafe extern "system" fn logging_callback( "Logging message from WireGuard is NULL".to_string() }; - let level_str = match level { - WG_GO_LOG_VERBOSE => "VERBOSE", - WG_GO_LOG_ERROR | _ => "ERROR", + let level = match level { + WG_GO_LOG_VERBOSE => LogLevel::Verbose, + WG_GO_LOG_ERROR | _ => LogLevel::Error, }; - - let _ = write!( - logfile, - "{}[{}][{}] {}", - chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"), - "wireguard-go", - level_str, - managed_msg - ); + log_inner(logfile, level, "wireguard-go", &managed_msg); } } -// unsafe fn - pub type WgLogLevel = u32; // wireguard-go supports log levels 0 through 3 with 3 being the most verbose // const WG_GO_LOG_SILENT: WgLogLevel = 0; diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index df3c7bb8b3..a3fa426ca2 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -26,6 +26,8 @@ mod stats; mod wireguard_go; #[cfg(target_os = "linux")] pub(crate) mod wireguard_kernel; +#[cfg(windows)] +mod wireguard_nt; use self::wireguard_go::WgGoTunnel; @@ -89,8 +91,6 @@ pub struct WireguardMonitor { stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>, pinger_stop_sender: mpsc::Sender<()>, _tcp_proxies: Vec<TcpProxy>, - #[cfg(target_os = "windows")] - _callback_handle: Option<crate::winnet::WinNetCallbackHandle>, } #[cfg(target_os = "linux")] @@ -165,6 +165,7 @@ impl WireguardMonitor { runtime: tokio::runtime::Handle, mut config: Config, log_path: Option<&Path>, + resource_dir: &Path, on_event: F, tun_provider: &mut TunProvider, route_manager: &mut routing::RouteManager, @@ -183,20 +184,12 @@ impl WireguardMonitor { } } - let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?; + let tunnel = + Self::open_tunnel(&config, log_path, resource_dir, tun_provider, route_manager)?; let iface_name = tunnel.get_interface_name().to_string(); #[cfg(windows)] let iface_luid = tunnel.get_interface_luid(); - #[cfg(target_os = "windows")] - let callback_handle = route_manager - .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()) - .ok(); - #[cfg(target_os = "windows")] - if callback_handle.is_none() { - log::warn!("Failed to register default route callback"); - } - let event_callback = Box::new(on_event.clone()); let (close_msg_sender, close_msg_receiver) = mpsc::channel(); let (pinger_tx, pinger_rx) = mpsc::channel(); @@ -212,8 +205,6 @@ impl WireguardMonitor { stop_setup_tx: Some(stop_setup_tx), pinger_stop_sender: pinger_tx, _tcp_proxies: tcp_proxies, - #[cfg(target_os = "windows")] - _callback_handle: callback_handle, }; let gateway = config.ipv4_gateway; @@ -317,10 +308,11 @@ impl WireguardMonitor { Ok(monitor) } - #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] + #[allow(unused_variables)] fn open_tunnel( config: &Config, log_path: Option<&Path>, + resource_dir: &Path, tun_provider: &mut TunProvider, route_manager: &mut routing::RouteManager, ) -> Result<Box<dyn Tunnel>> { @@ -362,14 +354,34 @@ impl WireguardMonitor { } } - #[cfg(target_os = "linux")] + #[cfg(target_os = "windows")] + if config.use_wireguard_nt { + match wireguard_nt::WgNtTunnel::start_tunnel(config, log_path, resource_dir) { + Ok(tunnel) => { + log::debug!("Using WireGuardNT"); + return Ok(Box::new(tunnel)); + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to setup WireGuardNT tunnel") + ); + } + } + } + + #[cfg(any(target_os = "linux", windows))] log::debug!("Using userspace WireGuard implementation"); Ok(Box::new( WgGoTunnel::start_tunnel( &config, log_path, + #[cfg(not(windows))] tun_provider, + #[cfg(not(windows))] Self::get_tunnel_destinations(config), + #[cfg(windows)] + route_manager, ) .map_err(Error::TunnelError)?, )) diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs index ff076fd714..f565988267 100644 --- a/talpid-core/src/tunnel/wireguard/stats.rs +++ b/talpid-core/src/tunnel/wireguard/stats.rs @@ -12,6 +12,9 @@ pub enum Error { #[error(display = "Device no longer exists")] NoTunnelDevice, + + #[error(display = "Failed to obtain tunnel config")] + NoTunnelConfig, } /// Contains bytes sent and received through a tunnel diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index e83a74ab47..bc12fa676f 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -2,10 +2,14 @@ use super::{ stats::{Stats, StatsMap}, Config, Tunnel, TunnelError, }; -use crate::tunnel::{ - tun_provider::TunProvider, - wireguard::logging::{clean_up_logging, initialize_logging, logging_callback, WgLogLevel}, +#[cfg(windows)] +use crate::routing; +#[cfg(not(windows))] +use crate::tunnel::tun_provider::TunProvider; +use crate::tunnel::wireguard::logging::{ + clean_up_logging, initialize_logging, wg_go_logging_callback, WgLogLevel, }; +#[cfg(not(windows))] use ipnetwork::IpNetwork; use std::{ ffi::{c_void, CStr}, @@ -56,6 +60,8 @@ pub struct WgGoTunnel { _tunnel_device: Tun, // context that maps to fs::File instance, used with logging callback _logging_context: LoggingContext, + #[cfg(target_os = "windows")] + _route_callback_handle: Option<crate::winnet::WinNetCallbackHandle>, } impl WgGoTunnel { @@ -82,7 +88,7 @@ impl WgGoTunnel { mtu, wg_config_str.as_ptr() as *const i8, tunnel_fd, - Some(logging_callback), + Some(wg_go_logging_callback), logging_context.0 as *mut libc::c_void, ) }; @@ -104,9 +110,15 @@ impl WgGoTunnel { pub fn start_tunnel( config: &Config, log_path: Option<&Path>, - _tun_provider: &mut TunProvider, - _routes: impl Iterator<Item = IpNetwork>, + route_manager: &mut routing::RouteManager, ) -> Result<Self> { + let route_callback_handle = route_manager + .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()) + .ok(); + if route_callback_handle.is_none() { + log::warn!("Failed to register default route callback"); + } + let wg_config_str = config.to_userspace_format(); let iface_name: String = "Mullvad".to_string(); let cstr_iface_name = @@ -133,7 +145,7 @@ impl WgGoTunnel { wg_config_str.as_ptr(), &mut alias_ptr, &mut interface_luid, - Some(logging_callback), + Some(wg_go_logging_callback), logging_context.0 as *mut libc::c_void, ) }; @@ -156,6 +168,7 @@ impl WgGoTunnel { interface_luid, handle: Some(handle), _logging_context: logging_context, + _route_callback_handle: route_callback_handle, }) } diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs new file mode 100644 index 0000000000..2bd14f644d --- /dev/null +++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs @@ -0,0 +1,1215 @@ +use super::{ + config::Config, + logging, + stats::{Stats, StatsMap}, + Tunnel, +}; +use crate::windows; +use bitflags::bitflags; +use ipnetwork::IpNetwork; +use lazy_static::lazy_static; +use std::{ + ffi::CStr, + fmt, io, iter, mem, + os::windows::{ffi::OsStrExt, io::RawHandle}, + path::Path, + ptr, + sync::{Arc, Mutex}, +}; +use talpid_types::ErrorExt; +use widestring::{U16CStr, U16CString}; +use winapi::{ + shared::{ + guiddef::GUID, + ifdef::NET_LUID, + in6addr::IN6_ADDR, + inaddr::IN_ADDR, + minwindef::{BOOL, FARPROC, HINSTANCE, HMODULE}, + nldef::RouterDiscoveryDisabled, + ntdef::FALSE, + winerror::ERROR_MORE_DATA, + ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6}, + ws2ipdef::SOCKADDR_INET, + }, + um::libloaderapi::{ + FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH, + }, +}; + + +lazy_static! { + static ref WG_NT_DLL: Mutex<Option<Arc<WgNtDll>>> = Mutex::new(None); + static ref ADAPTER_POOL: U16CString = U16CString::from_str("Mullvad").unwrap(); + static ref ADAPTER_ALIAS: U16CString = U16CString::from_str("Mullvad").unwrap(); +} + +const ADAPTER_GUID: GUID = GUID { + Data1: 0x514a3988, + Data2: 0x9716, + Data3: 0x43d5, + Data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9], +}; + +/// Longest possible adapter name (in characters), including null terminator +const MAX_ADAPTER_NAME: usize = 128; + +type WireGuardOpenAdapterFn = + unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle; +type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn( + pool: *const u16, + name: *const u16, + requested_guid: *const GUID, + reboot_required: *mut BOOL, +) -> RawHandle; +type WireGuardFreeAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle); +type WireGuardDeleteAdapterFn = + unsafe extern "stdcall" fn(adapter: RawHandle, reboot_required: *mut BOOL) -> BOOL; +type WireGuardGetAdapterLuidFn = + unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID); +type WireGuardGetAdapterNameFn = + unsafe extern "stdcall" fn(adapter: RawHandle, name: *mut u16) -> BOOL; +type WireGuardSetConfigurationFn = + unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: u32) -> BOOL; +type WireGuardGetConfigurationFn = + unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: *mut u32) -> BOOL; +type WireGuardSetStateFn = + unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL; + +#[cfg(windows)] +#[repr(C)] +#[allow(dead_code)] +enum LogLevel { + Info = 0, + Warn = 1, + Err = 2, +} + +#[cfg(windows)] +impl From<LogLevel> for logging::LogLevel { + fn from(level: LogLevel) -> Self { + match level { + LogLevel::Info => Self::Info, + LogLevel::Warn => Self::Warning, + LogLevel::Err => Self::Error, + } + } +} + +type WireGuardLoggerCb = extern "stdcall" fn(LogLevel, timestamp: u64, *const u16); +type WireGuardSetLoggerFn = extern "stdcall" fn(Option<WireGuardLoggerCb>); + +#[repr(C)] +#[allow(dead_code)] +enum WireGuardAdapterLogState { + Off = 0, + On = 1, + OnWithPrefix = 2, +} + +type WireGuardSetAdapterLoggingFn = + unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> BOOL; + +type RebootRequired = bool; + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to load WireGuardNT + #[error(display = "Failed to load wireguard.dll")] + DllError(#[error(source)] io::Error), + + /// Failed to remove tunnel interface + #[error(display = "Failed to remove residual tunnel device")] + DeleteExistingTunnelError(#[error(source)] io::Error), + + /// Failed to create tunnel interface + #[error(display = "Failed to create WireGuard device")] + CreateTunnelDeviceError(#[error(source)] io::Error), + + /// Failed to delete tunnel interface + #[error(display = "Failed to delete WireGuard device")] + DeleteTunnelDeviceError(#[error(source)] io::Error), + + /// Failed to obtain tunnel interface alias + #[error(display = "Failed to obtain interface name")] + ObtainAliasError(#[error(source)] io::Error), + + /// Failed to get WireGuard tunnel config for device + #[error(display = "Failed to get tunnel WireGuard config")] + GetWireGuardConfigError(#[error(source)] io::Error), + + /// Failed to set WireGuard tunnel config on device + #[error(display = "Failed to set tunnel WireGuard config")] + SetWireGuardConfigError(#[error(source)] io::Error), + + /// Failed to set MTU on tunnel device + #[error(display = "Failed to set tunnel IPv4 interface MTU")] + SetTunnelIpv4MtuError(#[error(source)] io::Error), + + /// Failed to set MTU on tunnel device + #[error(display = "Failed to set tunnel IPv6 interface MTU")] + SetTunnelIpv6MtuError(#[error(source)] io::Error), + + /// Failed to set the tunnel state to up + #[error(display = "Failed to enable the tunnel adapter")] + EnableTunnelError(#[error(source)] io::Error), + + /// Unknown address family + #[error(display = "Unknown address family: {}", _0)] + UnknownAddressFamily(i32), + + /// Failure to set up logging + #[error(display = "Failed to set up logging")] + InitLoggingError(#[error(source)] logging::Error), + + /// Invalid allowed IP + #[error(display = "Invalid CIDR prefix")] + InvalidAllowedIpCidr, + + /// Allowed IP contains non-zero host bits + #[error(display = "Allowed IP contains non-zero host bits")] + InvalidAllowedIpBits, + + /// Failed to parse data returned by the driver + #[error(display = "Failed to parse data returned by wireguard-nt")] + InvalidConfigData, +} + +pub struct WgNtTunnel { + device: Option<WgNtAdapter>, + interface_luid: NET_LUID, + interface_name: String, + _logger_handle: LoggerHandle, +} + +const WIREGUARD_KEY_LENGTH: usize = 32; + +/// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. +#[derive(Clone, Copy)] +#[repr(C, align(8))] +union WgIpAddr { + v4: IN_ADDR, + v6: IN6_ADDR, +} + +/// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. +#[derive(Clone, Copy)] +#[repr(C, align(8))] +struct WgAllowedIp { + address: WgIpAddr, + address_family: ADDRESS_FAMILY, + cidr: u8, +} + +impl WgAllowedIp { + fn new(address: WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<Self> { + Self::validate(&address, address_family, cidr)?; + Ok(Self { + address, + address_family, + cidr, + }) + } + + fn validate(address: &WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<()> { + match address_family as i32 { + AF_INET => { + if cidr > 32 { + return Err(Error::InvalidAllowedIpCidr); + } + let host_mask = u32::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); + if host_mask & (unsafe { *(address.v4.S_un.S_addr()) }.to_be()) != 0 { + return Err(Error::InvalidAllowedIpBits); + } + } + AF_INET6 => { + if cidr > 128 { + return Err(Error::InvalidAllowedIpCidr); + } + let mut host_mask = u128::MAX.checked_shr(u32::from(cidr)).unwrap_or(0); + let bytes = unsafe { address.v6.u.Byte() }; + for byte in bytes.iter().rev() { + if byte & ((host_mask & 0xff) as u8) != 0 { + return Err(Error::InvalidAllowedIpBits); + } + host_mask = host_mask >> 8; + } + } + family => return Err(Error::UnknownAddressFamily(family)), + } + Ok(()) + } +} + +impl PartialEq for WgAllowedIp { + fn eq(&self, other: &Self) -> bool { + if self.cidr != other.cidr { + return false; + } + match self.address_family as i32 { + AF_INET => { + windows::ipaddr_from_inaddr(unsafe { self.address.v4 }) + == windows::ipaddr_from_inaddr(unsafe { other.address.v4 }) + } + AF_INET6 => { + windows::ipaddr_from_in6addr(unsafe { self.address.v6 }) + == windows::ipaddr_from_in6addr(unsafe { other.address.v6 }) + } + _ => { + log::error!("Allowed IP uses unknown address family"); + true + } + } + } +} +impl Eq for WgAllowedIp {} + +impl fmt::Debug for WgAllowedIp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("WgAllowedIp"); + match self.address_family as i32 { + AF_INET => s.field( + "address", + &windows::ipaddr_from_inaddr(unsafe { self.address.v4 }), + ), + AF_INET6 => s.field( + "address", + &windows::ipaddr_from_in6addr(unsafe { self.address.v6 }), + ), + _ => s.field("address", &"<unknown>"), + }; + s.field("address_family", &self.address_family) + .field("cidr", &self.cidr) + .finish() + } +} + +bitflags! { + /// See `WIREGUARD_PEER_FLAG` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. + struct WgPeerFlag: u32 { + const HAS_PUBLIC_KEY = 0b00000001; + const HAS_PRESHARED_KEY = 0b00000010; + const HAS_PERSISTENT_KEEPALIVE = 0b00000100; + const HAS_ENDPOINT = 0b00001000; + const REPLACE_ALLOWED_IPS = 0b00100000; + const REMOVE = 0b01000000; + const UPDATE = 0b10000000; + } +} + +/// See `WIREGUARD_PEER` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +#[repr(C, align(8))] +struct WgPeer { + flags: WgPeerFlag, + reserved: u32, + public_key: [u8; WIREGUARD_KEY_LENGTH], + preshared_key: [u8; WIREGUARD_KEY_LENGTH], + persistent_keepalive: u16, + endpoint: SockAddrInet, + tx_bytes: u64, + rx_bytes: u64, + last_handshake: u64, + allowed_ips_count: u32, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct SockAddrInet { + addr: SOCKADDR_INET, +} + +impl From<SOCKADDR_INET> for SockAddrInet { + fn from(addr: SOCKADDR_INET) -> Self { + Self { addr } + } +} +impl PartialEq for SockAddrInet { + fn eq(&self, other: &Self) -> bool { + let self_addr = match windows::try_socketaddr_from_inet_sockaddr(self.addr) { + Ok(addr) => addr, + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to convert socket address") + ); + return true; + } + }; + let other_addr = match windows::try_socketaddr_from_inet_sockaddr(other.addr) { + Ok(addr) => addr, + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to convert socket address") + ); + return true; + } + }; + self_addr == other_addr + } +} +impl Eq for SockAddrInet {} + +impl fmt::Debug for SockAddrInet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("SockAddrInet"); + let self_addr = windows::try_socketaddr_from_inet_sockaddr(self.addr) + .map(|addr| addr.to_string()) + .unwrap_or("<unknown>".to_string()); + s.field("addr", &self_addr).finish() + } +} + +bitflags! { + /// See `WIREGUARD_INTERFACE_FLAG` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. + struct WgInterfaceFlag: u32 { + const HAS_PUBLIC_KEY = 0b00000001; + const HAS_PRIVATE_KEY = 0b00000010; + const HAS_LISTEN_PORT = 0b00000100; + const REPLACE_PEERS = 0b00001000; + } +} + +/// See `WIREGUARD_INTERFACE` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +#[repr(C, align(8))] +struct WgInterface { + flags: WgInterfaceFlag, + listen_port: u16, + private_key: [u8; WIREGUARD_KEY_LENGTH], + public_key: [u8; WIREGUARD_KEY_LENGTH], + peers_count: u32, +} + +/// See `WIREGUARD_ADAPTER_LOG_STATE` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +#[repr(C)] +#[allow(dead_code)] +enum WgAdapterState { + Down = 0, + Up = 1, +} + + +impl WgNtTunnel { + pub fn start_tunnel( + config: &Config, + log_path: Option<&Path>, + resource_dir: &Path, + ) -> Result<Self> { + let dll = load_wg_nt_dll(resource_dir)?; + + let logger_handle = LoggerHandle::new(dll.clone(), log_path)?; + + { + if let Ok(device) = WgNtAdapter::open(dll.clone(), &*ADAPTER_POOL, &*ADAPTER_ALIAS) { + device.delete().map_err(Error::DeleteExistingTunnelError)?; + } + } + + let (device, reboot_required) = WgNtAdapter::create( + dll.clone(), + &*ADAPTER_POOL, + &*ADAPTER_ALIAS, + Some(ADAPTER_GUID.clone()), + ) + .map_err(Error::CreateTunnelDeviceError)?; + + if reboot_required { + log::warn!("You may need to reboot to finish installing WireGuardNT"); + } + + let interface_luid = device.luid(); + let interface_name = match device.name() { + Ok(name) => name.to_string_lossy(), + Err(error) => { + if let Err(error) = device.delete() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to delete tunnel device") + ); + } + return Err(Error::ObtainAliasError(error)); + } + }; + + let tunnel = WgNtTunnel { + device: Some(device), + interface_luid, + interface_name, + _logger_handle: logger_handle, + }; + tunnel.configure(config)?; + Ok(tunnel) + } + + fn stop_tunnel(&mut self) -> Result<()> { + if let Some(device) = self.device.take() { + if let Err(error) = device.delete() { + return Err(Error::DeleteTunnelDeviceError(error)); + } + } + Ok(()) + } + + fn configure(&self, config: &Config) -> Result<()> { + let device = self.device.as_ref().unwrap(); + if let Err(error) = device.set_logging(WireGuardAdapterLogState::On) { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set log state on WireGuard interface") + ); + } + device.set_config(config)?; + prepare_interface(&device.luid(), AF_INET as u16, u32::from(config.mtu)) + .map_err(Error::SetTunnelIpv4MtuError)?; + if config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()) { + prepare_interface(&device.luid(), AF_INET6 as u16, u32::from(config.mtu)) + .map_err(Error::SetTunnelIpv6MtuError)?; + } + device + .set_state(WgAdapterState::Up) + .map_err(Error::EnableTunnelError)?; + Ok(()) + } +} + +impl Drop for WgNtTunnel { + fn drop(&mut self) { + if let Err(error) = self.stop_tunnel() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to stop WireGuardNT tunnel") + ); + } + } +} + +lazy_static! { + static ref LOG_CONTEXT: Mutex<Option<u32>> = Mutex::new(None); +} + +struct LoggerHandle { + dll: Arc<WgNtDll>, + context: u32, +} + +impl LoggerHandle { + fn new(dll: Arc<WgNtDll>, log_path: Option<&Path>) -> Result<Self> { + let context = logging::initialize_logging(log_path).map_err(Error::InitLoggingError)?; + { + *(LOG_CONTEXT.lock().unwrap()) = Some(context); + } + dll.set_logger(Some(Self::logging_callback)); + Ok(Self { dll, context }) + } + + extern "stdcall" fn logging_callback(level: LogLevel, _timestamp: u64, message: *const u16) { + if message.is_null() { + return; + } + let mut message = unsafe { U16CStr::from_ptr_str(message) }.to_string_lossy(); + message.push_str("\r\n"); + + if let Some(context) = &*LOG_CONTEXT.lock().unwrap() { + // Horribly broken, because callback does not provide a context + logging::log(*context, level.into(), "wireguard-nt", &message); + } + } +} + +impl Drop for LoggerHandle { + fn drop(&mut self) { + let mut ctx = LOG_CONTEXT.lock().unwrap(); + if *ctx == Some(self.context) { + *ctx = None; + self.dll.set_logger(None); + } + logging::clean_up_logging(self.context); + } +} + + +struct WgNtAdapter { + dll_handle: Arc<WgNtDll>, + handle: RawHandle, +} + +impl fmt::Debug for WgNtAdapter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WgNtAdapter") + .field("handle", &self.handle) + .finish() + } +} + +unsafe impl Send for WgNtAdapter {} +unsafe impl Sync for WgNtAdapter {} + +impl WgNtAdapter { + fn open(dll_handle: Arc<WgNtDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> { + let handle = dll_handle.open_adapter(pool, name)?; + Ok(Self { dll_handle, handle }) + } + + fn create( + dll_handle: Arc<WgNtDll>, + pool: &U16CStr, + name: &U16CStr, + requested_guid: Option<GUID>, + ) -> io::Result<(Self, RebootRequired)> { + let (handle, restart_required) = dll_handle.create_adapter(pool, name, requested_guid)?; + Ok((Self { dll_handle, handle }, restart_required)) + } + + fn delete(self) -> io::Result<RebootRequired> { + unsafe { self.dll_handle.delete_adapter(self.handle) } + } + + fn name(&self) -> io::Result<U16CString> { + unsafe { self.dll_handle.get_adapter_name(self.handle) } + } + + fn luid(&self) -> NET_LUID { + unsafe { self.dll_handle.get_adapter_luid(self.handle) } + } + + fn set_config(&self, config: &Config) -> Result<()> { + let config_buffer = serialize_config(config)?; + unsafe { + self.dll_handle + .set_config(self.handle, config_buffer.as_ptr(), config_buffer.len()) + .map_err(Error::SetWireGuardConfigError) + } + } + + fn get_config(&self) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { + unsafe { + deserialize_config( + &self + .dll_handle + .get_config(self.handle) + .map_err(Error::GetWireGuardConfigError)?, + ) + } + } + + fn set_state(&self, state: WgAdapterState) -> io::Result<()> { + unsafe { self.dll_handle.set_adapter_state(self.handle, state) } + } + + fn set_logging(&self, state: WireGuardAdapterLogState) -> io::Result<()> { + unsafe { self.dll_handle.set_adapter_logging(self.handle, state) } + } +} + +impl Drop for WgNtAdapter { + fn drop(&mut self) { + unsafe { self.dll_handle.free_adapter(self.handle) }; + } +} + +struct WgNtDll { + handle: HINSTANCE, + func_open: WireGuardOpenAdapterFn, + func_create: WireGuardCreateAdapterFn, + func_delete: WireGuardDeleteAdapterFn, + func_free: WireGuardFreeAdapterFn, + func_get_adapter_luid: WireGuardGetAdapterLuidFn, + func_get_adapter_name: WireGuardGetAdapterNameFn, + func_set_configuration: WireGuardSetConfigurationFn, + func_get_configuration: WireGuardGetConfigurationFn, + func_set_adapter_state: WireGuardSetStateFn, + func_set_logger: WireGuardSetLoggerFn, + func_set_adapter_logging: WireGuardSetAdapterLoggingFn, +} + +unsafe impl Send for WgNtDll {} +unsafe impl Sync for WgNtDll {} + +impl WgNtDll { + pub fn new(resource_dir: &Path) -> io::Result<Self> { + let wg_nt_dll: Vec<u16> = resource_dir + .join("wireguard.dll") + .as_os_str() + .encode_wide() + .chain(iter::once(0u16)) + .collect(); + + let handle = unsafe { + LoadLibraryExW( + wg_nt_dll.as_ptr(), + ptr::null_mut(), + LOAD_WITH_ALTERED_SEARCH_PATH, + ) + }; + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Self::new_inner(handle, Self::get_proc_address) + } + + fn new_inner( + handle: HMODULE, + get_proc_fn: unsafe fn(HMODULE, &CStr) -> io::Result<FARPROC>, + ) -> io::Result<Self> { + Ok(WgNtDll { + handle, + func_open: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardOpenAdapter\0").unwrap(), + )?) as *const _ as *const _) + }, + func_create: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardCreateAdapter\0").unwrap(), + )?) as *const _ as *const _) + }, + func_delete: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardDeleteAdapter\0").unwrap(), + )?) as *const _ as *const _) + }, + func_free: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardFreeAdapter\0").unwrap(), + )?) as *const _ as *const _) + }, + func_get_adapter_luid: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardGetAdapterLUID\0").unwrap(), + )?) as *const _ as *const _) + }, + func_get_adapter_name: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardGetAdapterName\0").unwrap(), + )?) as *const _ as *const _) + }, + func_set_configuration: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardSetConfiguration\0").unwrap(), + )?) as *const _ as *const _) + }, + func_get_configuration: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardGetConfiguration\0").unwrap(), + )?) as *const _ as *const _) + }, + func_set_adapter_state: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardSetAdapterState\0").unwrap(), + )?) as *const _ as *const _) + }, + func_set_logger: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardSetLogger\0").unwrap(), + )?) as *const _ as *const _) + }, + func_set_adapter_logging: unsafe { + *((&get_proc_fn( + handle, + CStr::from_bytes_with_nul(b"WireGuardSetAdapterLogging\0").unwrap(), + )?) as *const _ as *const _) + }, + }) + } + + unsafe fn get_proc_address(handle: HMODULE, name: &CStr) -> io::Result<FARPROC> { + let handle = GetProcAddress(handle, name.as_ptr()); + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(handle) + } + + pub fn open_adapter(&self, pool: &U16CStr, name: &U16CStr) -> io::Result<RawHandle> { + let handle = unsafe { (self.func_open)(pool.as_ptr(), name.as_ptr()) }; + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(handle) + } + + pub fn create_adapter( + &self, + pool: &U16CStr, + name: &U16CStr, + requested_guid: Option<GUID>, + ) -> io::Result<(RawHandle, RebootRequired)> { + let guid_ptr = match requested_guid.as_ref() { + Some(guid) => guid as *const _, + None => ptr::null_mut(), + }; + let mut reboot_required = 0; + let handle = unsafe { + (self.func_create)(pool.as_ptr(), name.as_ptr(), guid_ptr, &mut reboot_required) + }; + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok((handle, reboot_required != 0)) + } + + pub unsafe fn delete_adapter(&self, adapter: RawHandle) -> io::Result<RebootRequired> { + let mut reboot_required = 0; + let result = (self.func_delete)(adapter, &mut reboot_required); + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(reboot_required != 0) + } + + pub unsafe fn free_adapter(&self, adapter: RawHandle) { + (self.func_free)(adapter); + } + + pub unsafe fn get_adapter_name(&self, adapter: RawHandle) -> io::Result<U16CString> { + let mut alias_buffer = vec![0u16; MAX_ADAPTER_NAME]; + let result = (self.func_get_adapter_name)(adapter, alias_buffer.as_mut_ptr()); + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(U16CString::from_vec_with_nul(alias_buffer) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "missing null terminator"))?) + } + + pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID { + let mut luid = mem::MaybeUninit::<NET_LUID>::zeroed(); + (self.func_get_adapter_luid)(adapter, luid.as_mut_ptr()); + luid.assume_init() + } + + pub unsafe fn set_config( + &self, + adapter: RawHandle, + config: *const u8, + config_size: usize, + ) -> io::Result<()> { + let result = (self.func_set_configuration)(adapter, config, config_size as u32); + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } + + pub unsafe fn get_config(&self, adapter: RawHandle) -> io::Result<Vec<u8>> { + let mut config_size = 0; + let mut config = vec![]; + loop { + let result = + (self.func_get_configuration)(adapter, config.as_mut_ptr(), &mut config_size); + if result == 0 { + let last_error = io::Error::last_os_error(); + if last_error.raw_os_error() != Some(ERROR_MORE_DATA as i32) { + break Err(last_error); + } + config.resize(config_size as usize, 0); + } else { + break Ok(config); + } + } + } + + pub unsafe fn set_adapter_state( + &self, + adapter: RawHandle, + state: WgAdapterState, + ) -> io::Result<()> { + let result = (self.func_set_adapter_state)(adapter, state); + if result == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } + + pub fn set_logger(&self, cb: Option<WireGuardLoggerCb>) { + (self.func_set_logger)(cb); + } + + pub unsafe fn set_adapter_logging( + &self, + adapter: RawHandle, + state: WireGuardAdapterLogState, + ) -> io::Result<()> { + if (self.func_set_adapter_logging)(adapter, state) == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } +} + +impl Drop for WgNtDll { + fn drop(&mut self) { + unsafe { FreeLibrary(self.handle) }; + } +} + +fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> { + let mut dll = (*WG_NT_DLL).lock().expect("WireGuardNT mutex poisoned"); + match &*dll { + Some(dll) => Ok(dll.clone()), + None => { + let new_dll = Arc::new(WgNtDll::new(resource_dir).map_err(Error::DllError)?); + *dll = Some(new_dll.clone()); + Ok(new_dll) + } + } +} + +fn serialize_config(config: &Config) -> Result<Vec<u8>> { + let mut buffer = vec![]; + + let header = WgInterface { + flags: WgInterfaceFlag::HAS_PRIVATE_KEY | WgInterfaceFlag::REPLACE_PEERS, + listen_port: 0, + private_key: config.tunnel.private_key.to_bytes(), + public_key: [0u8; WIREGUARD_KEY_LENGTH], + peers_count: config.peers.len() as u32, + }; + + buffer.extend_from_slice(unsafe { as_u8_slice(&header) }); + + for peer in &config.peers { + let wg_peer = WgPeer { + flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, + reserved: 0, + public_key: peer.public_key.as_bytes().clone(), + preshared_key: [0u8; WIREGUARD_KEY_LENGTH], + persistent_keepalive: 0, + endpoint: windows::inet_sockaddr_from_socketaddr(peer.endpoint).into(), + tx_bytes: 0, + rx_bytes: 0, + last_handshake: 0, + allowed_ips_count: peer.allowed_ips.len() as u32, + }; + + buffer.extend_from_slice(unsafe { as_u8_slice(&wg_peer) }); + + for allowed_ip in &peer.allowed_ips { + let address_family = match allowed_ip { + IpNetwork::V4(_) => AF_INET as u16, + IpNetwork::V6(_) => AF_INET6 as u16, + }; + let address = match allowed_ip { + IpNetwork::V4(v4_network) => WgIpAddr { + v4: windows::inaddr_from_ipaddr(v4_network.ip()), + }, + IpNetwork::V6(v6_network) => WgIpAddr { + v6: windows::in6addr_from_ipaddr(v6_network.ip()), + }, + }; + + let wg_allowed_ip = + WgAllowedIp::new(address, address_family, allowed_ip.prefix() as u8)?; + + buffer.extend_from_slice(unsafe { as_u8_slice(&wg_allowed_ip) }); + } + } + + Ok(buffer) +} + +unsafe fn deserialize_config( + config: &[u8], +) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> { + if config.len() < mem::size_of::<WgInterface>() { + return Err(Error::InvalidConfigData); + } + let (head, mut tail) = config.split_at(mem::size_of::<WgInterface>()); + let interface: WgInterface = *(head.as_ptr() as *const WgInterface); + + let mut peers = vec![]; + for _ in 0..interface.peers_count { + if tail.len() < mem::size_of::<WgPeer>() { + return Err(Error::InvalidConfigData); + } + let (peer_data, new_tail) = tail.split_at(mem::size_of::<WgPeer>()); + let peer: WgPeer = *(peer_data.as_ptr() as *const WgPeer); + tail = new_tail; + + if let Err(error) = windows::try_socketaddr_from_inet_sockaddr(peer.endpoint.addr) { + log::error!( + "{}", + error.display_chain_with_msg("Received invalid endpoint address") + ); + return Err(Error::InvalidConfigData); + } + + let mut allowed_ips = vec![]; + + for _ in 0..peer.allowed_ips_count { + if tail.len() < mem::size_of::<WgAllowedIp>() { + return Err(Error::InvalidConfigData); + } + let (allowed_ip_data, new_tail) = tail.split_at(mem::size_of::<WgAllowedIp>()); + let allowed_ip: WgAllowedIp = *(allowed_ip_data.as_ptr() as *const WgAllowedIp); + if let Err(error) = WgAllowedIp::validate( + &allowed_ip.address, + allowed_ip.address_family, + allowed_ip.cidr, + ) { + log::error!( + "{}", + error.display_chain_with_msg("Received invalid allowed IP") + ); + return Err(Error::InvalidConfigData); + } + tail = new_tail; + allowed_ips.push(allowed_ip); + } + + peers.push((peer, allowed_ips)); + } + + if tail.len() > 0 { + return Err(Error::InvalidConfigData); + } + + Ok((interface, peers)) +} + +fn prepare_interface(luid: &NET_LUID, family: u16, mtu: u32) -> io::Result<()> { + let family = windows::AddressFamily::try_from_af_family(family) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; + let mut iface = windows::get_ip_interface_entry(family, luid)?; + iface.SitePrefixLength = 0; + iface.NlMtu = mtu; + iface.RouterDiscoveryBehavior = RouterDiscoveryDisabled; + iface.DadTransmits = 0; + iface.ManagedAddressConfigurationSupported = FALSE; + iface.OtherStatefulConfigurationSupported = FALSE; + windows::set_ip_interface_entry(&iface) +} + +impl Tunnel for WgNtTunnel { + fn get_interface_name(&self) -> String { + self.interface_name.clone() + } + + fn get_interface_luid(&self) -> u64 { + self.interface_luid.Value + } + + fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> { + if let Some(ref device) = self.device { + let mut map = StatsMap::new(); + let (_interface, peers) = device.get_config().map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain wg-nt tunnel config") + ); + super::TunnelError::StatsError(super::stats::Error::NoTunnelConfig) + })?; + for (peer, _allowed_ips) in &peers { + map.insert( + peer.public_key, + Stats { + tx_bytes: peer.tx_bytes, + rx_bytes: peer.rx_bytes, + }, + ); + } + Ok(map) + } else { + Err(super::TunnelError::StatsError( + super::stats::Error::NoTunnelDevice, + )) + } + } + + fn stop(mut self: Box<Self>) -> std::result::Result<(), super::TunnelError> { + if let Err(error) = self.stop_tunnel() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to stop WireGuardNT tunnel") + ); + Err(super::TunnelError::StopWireguardError { status: 0 }) + } else { + Ok(()) + } + } +} + +unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] { + std::slice::from_raw_parts(object as *const _ as *const _, mem::size_of::<T>()) +} + +#[cfg(test)] +mod tests { + use super::*; + use lazy_static::lazy_static; + use talpid_types::net::{wireguard, TransportProtocol}; + + #[derive(Debug, Eq, PartialEq, Clone, Copy)] + #[repr(C)] + struct Interface { + interface: WgInterface, + p0: WgPeer, + p0_allowed_ip_0: WgAllowedIp, + } + + lazy_static! { + static ref WG_PRIVATE_KEY: wireguard::PrivateKey = wireguard::PrivateKey::new_from_random(); + static ref WG_PUBLIC_KEY: wireguard::PublicKey = + wireguard::PrivateKey::new_from_random().public_key(); + static ref WG_CONFIG: Config = { + Config { + tunnel: wireguard::TunnelConfig { + private_key: WG_PRIVATE_KEY.clone(), + addresses: vec![], + }, + peers: vec![wireguard::PeerConfig { + public_key: WG_PUBLIC_KEY.clone(), + allowed_ips: vec!["1.3.3.0/24".parse().unwrap()], + endpoint: "1.2.3.4:1234".parse().unwrap(), + protocol: TransportProtocol::Udp, + }], + ipv4_gateway: "0.0.0.0".parse().unwrap(), + ipv6_gateway: None, + mtu: 0, + use_wireguard_nt: true, + } + }; + static ref WG_STRUCT_CONFIG: Interface = Interface { + interface: WgInterface { + flags: WgInterfaceFlag::HAS_PRIVATE_KEY | WgInterfaceFlag::REPLACE_PEERS, + listen_port: 0, + private_key: WG_PRIVATE_KEY.to_bytes(), + public_key: [0; WIREGUARD_KEY_LENGTH], + peers_count: 1, + }, + p0: WgPeer { + flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT, + reserved: 0, + public_key: WG_PUBLIC_KEY.as_bytes().clone(), + preshared_key: [0; WIREGUARD_KEY_LENGTH], + persistent_keepalive: 0, + endpoint: windows::inet_sockaddr_from_socketaddr("1.2.3.4:1234".parse().unwrap()) + .into(), + tx_bytes: 0, + rx_bytes: 0, + last_handshake: 0, + allowed_ips_count: 1, + }, + p0_allowed_ip_0: WgAllowedIp { + address: WgIpAddr { + v4: windows::inaddr_from_ipaddr("1.3.3.0".parse().unwrap()), + }, + address_family: AF_INET as u16, + cidr: 24, + }, + }; + } + + fn get_proc_fn(_handle: HMODULE, _symbol: &CStr) -> io::Result<FARPROC> { + Ok(std::ptr::null_mut()) + } + + #[test] + fn test_dll_imports() { + WgNtDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap(); + } + + #[test] + fn test_config_serialization() { + let serialized_data = serialize_config(&*WG_CONFIG).unwrap(); + assert_eq!(mem::size_of::<Interface>(), serialized_data.len()); + let serialized_iface = &unsafe { *(serialized_data.as_ptr() as *const Interface) }; + assert_eq!(&*WG_STRUCT_CONFIG, serialized_iface); + } + + #[test] + fn test_config_deserialization() { + let (iface, peers) = + unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }.unwrap(); + assert_eq!(iface, WG_STRUCT_CONFIG.interface); + assert_eq!(peers.len(), 1); + let (peer, allowed_ips) = &peers[0]; + assert_eq!(peer, &WG_STRUCT_CONFIG.p0); + assert_eq!(allowed_ips.len(), 1); + assert_eq!(allowed_ips[0], WG_STRUCT_CONFIG.p0_allowed_ip_0); + } + + #[test] + fn test_wg_allowed_ip_v4() { + // Valid: /32 prefix + let address_family = AF_INET as u16; + let address = WgIpAddr { + v4: windows::inaddr_from_ipaddr("127.0.0.1".parse().unwrap()), + }; + let cidr = 32; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid host bits + let cidr = 24; + let address = WgIpAddr { + v4: windows::inaddr_from_ipaddr("0.0.0.1".parse().unwrap()), + }; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + + // Valid host bits + let cidr = 24; + let address = WgIpAddr { + v4: windows::inaddr_from_ipaddr("255.255.255.0".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // 0.0.0.0/0 + let cidr = 0; + let address = WgIpAddr { + v4: windows::inaddr_from_ipaddr("0.0.0.0".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid CIDR + let cidr = 33; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + } + + #[test] + fn test_wg_allowed_ip_v6() { + // Valid: /128 prefix + let address_family = AF_INET6 as u16; + let address = WgIpAddr { + v6: windows::in6addr_from_ipaddr("::1".parse().unwrap()), + }; + let cidr = 128; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid host bits + let cidr = 127; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + + // Valid host bits + let address = WgIpAddr { + v6: windows::in6addr_from_ipaddr( + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap(), + ), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // ::/0 + let cidr = 0; + let address = WgIpAddr { + v6: windows::in6addr_from_ipaddr("::".parse().unwrap()), + }; + WgAllowedIp::new(address, address_family, cidr).unwrap(); + + // Invalid CIDR + let cidr = 129; + assert!(WgAllowedIp::new(address, address_family, cidr).is_err()); + } +} diff --git a/talpid-core/src/windows.rs b/talpid-core/src/windows.rs index 8151165680..03cd8a9c74 100644 --- a/talpid-core/src/windows.rs +++ b/talpid-core/src/windows.rs @@ -1,12 +1,16 @@ use std::{ ffi::OsStr, fmt, io, mem, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::windows::{ffi::OsStrExt, io::RawHandle}, + ptr, sync::Mutex, time::{Duration, Instant}, }; use winapi::shared::{ ifdef::NET_LUID, + in6addr::IN6_ADDR, + inaddr::IN_ADDR, netioapi::{ CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, FreeMibTable, GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, MibAddInstance, @@ -17,6 +21,7 @@ use winapi::shared::{ ntdef::FALSE, winerror::{ERROR_NOT_FOUND, NO_ERROR}, ws2def::{AF_INET, AF_INET6, AF_UNSPEC}, + ws2ipdef::SOCKADDR_INET, }; /// Result type for this module. @@ -58,6 +63,10 @@ pub enum Error { #[cfg(windows)] #[error(display = "Unicast channel sender was unexpectedly dropped")] UnicastSenderDropped, + + /// Unknown address family + #[error(display = "Unknown address family: {}", _0)] + UnknownAddressFamily(i32), } /// Address family. These correspond to the `AF_*` constants. @@ -78,6 +87,17 @@ impl fmt::Display for AddressFamily { } } +impl AddressFamily { + /// Convert an [`AddressFamily`] to one of the `AF_*` constants. + pub fn try_from_af_family(family: u16) -> Result<AddressFamily> { + match family as i32 { + AF_INET => Ok(AddressFamily::Ipv4), + AF_INET6 => Ok(AddressFamily::Ipv6), + family => Err(Error::UnknownAddressFamily(family)), + } + } +} + /// Context for [`notify_ip_interface_change`]. When it is dropped, /// the callback is unregistered. pub struct IpNotifierHandle<'a> { @@ -341,3 +361,120 @@ fn af_family_from_family(family: Option<AddressFamily>) -> u16 { .map(|family| family as u16) .unwrap_or(AF_UNSPEC as u16) } + +/// Converts an `Ipv4Addr` to `IN_ADDR` +pub fn inaddr_from_ipaddr(addr: Ipv4Addr) -> IN_ADDR { + let mut in_addr: IN_ADDR = unsafe { mem::zeroed() }; + let addr_octets = addr.octets(); + unsafe { + ptr::copy_nonoverlapping( + &addr_octets as *const _, + in_addr.S_un.S_addr_mut() as *mut _ as *mut u8, + addr_octets.len(), + ); + } + in_addr +} + +/// Converts an `Ipv6Addr` to `IN6_ADDR` +pub fn in6addr_from_ipaddr(addr: Ipv6Addr) -> IN6_ADDR { + let mut in_addr: IN6_ADDR = unsafe { mem::zeroed() }; + let addr_octets = addr.octets(); + unsafe { + ptr::copy_nonoverlapping( + &addr_octets as *const _, + in_addr.u.Byte_mut() as *mut _, + addr_octets.len(), + ); + } + in_addr +} + +/// Converts an `IN_ADDR` to `Ipv4Addr` +pub fn ipaddr_from_inaddr(addr: IN_ADDR) -> Ipv4Addr { + Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_be()) +} + +/// Converts an `IN6_ADDR` to `Ipv6Addr` +pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr { + Ipv6Addr::from(*unsafe { addr.u.Byte() }) +} + +/// Converts a `SocketAddr` to `SOCKADDR_INET` +pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET { + let mut sockaddr: SOCKADDR_INET = unsafe { mem::zeroed() }; + + match addr { + SocketAddr::V4(v4_addr) => { + unsafe { + *sockaddr.si_family_mut() = AF_INET as u16; + } + + let mut v4sockaddr = unsafe { sockaddr.Ipv4_mut() }; + v4sockaddr.sin_family = AF_INET as u16; + v4sockaddr.sin_port = v4_addr.port().to_be(); + v4sockaddr.sin_addr = inaddr_from_ipaddr(*v4_addr.ip()); + } + SocketAddr::V6(v6_addr) => { + unsafe { + *sockaddr.si_family_mut() = AF_INET6 as u16; + } + + let mut v6sockaddr = unsafe { sockaddr.Ipv6_mut() }; + v6sockaddr.sin6_family = AF_INET6 as u16; + v6sockaddr.sin6_port = v6_addr.port().to_be(); + v6sockaddr.sin6_addr = in6addr_from_ipaddr(*v6_addr.ip()); + v6sockaddr.sin6_flowinfo = v6_addr.flowinfo(); + *unsafe { v6sockaddr.u.sin6_scope_id_mut() } = v6_addr.scope_id(); + } + } + + sockaddr +} + +/// Converts a `SOCKADDR_INET` to `SocketAddr`. Returns an error if the address family is invalid. +pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> { + unsafe { + match *addr.si_family() as i32 { + AF_INET => Ok(SocketAddr::V4(SocketAddrV4::new( + ipaddr_from_inaddr(addr.Ipv4().sin_addr), + u16::from_be(addr.Ipv4().sin_port), + ))), + AF_INET6 => Ok(SocketAddr::V6(SocketAddrV6::new( + ipaddr_from_in6addr(addr.Ipv6().sin6_addr), + u16::from_be(addr.Ipv6().sin6_port), + addr.Ipv6().sin6_flowinfo, + *addr.Ipv6().u.sin6_scope_id(), + ))), + family => Err(Error::UnknownAddressFamily(family)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sockaddr_v4() { + let addr_v4 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 1234)); + assert_eq!( + addr_v4, + try_socketaddr_from_inet_sockaddr(inet_sockaddr_from_socketaddr(addr_v4)).unwrap() + ); + } + + #[test] + fn test_sockaddr_v6() { + let addr_v6 = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), + 1234, + 0xa, + 0xb, + )); + assert_eq!( + addr_v6, + try_socketaddr_from_inet_sockaddr(inet_sockaddr_from_socketaddr(addr_v6)).unwrap() + ); + } +} diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs index c5bde3547e..beced8357d 100644 --- a/talpid-types/src/net/wireguard.rs +++ b/talpid-types/src/net/wireguard.rs @@ -86,6 +86,10 @@ pub struct TunnelOptions { jnix(map = "|maybe_mtu| maybe_mtu.map(|mtu| mtu as i32)") )] pub mtu: Option<u16>, + /// Temporary switch for wireguard-nt + #[cfg(windows)] + #[serde(default)] + pub use_wireguard_nt: bool, } /// Wireguard x25519 private key diff --git a/windows/driverlogic/driverlogic.vcxproj b/windows/driverlogic/driverlogic.vcxproj index b91e97d86c..cc46c1ac72 100644 --- a/windows/driverlogic/driverlogic.vcxproj +++ b/windows/driverlogic/driverlogic.vcxproj @@ -117,6 +117,7 @@ <ClInclude Include="src\util.h" /> <ClInclude Include="src\version.h" /> <ClInclude Include="src\wintun.h" /> + <ClInclude Include="src\wireguard.h" /> </ItemGroup> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <ImportGroup Label="ExtensionTargets"> diff --git a/windows/driverlogic/driverlogic.vcxproj.filters b/windows/driverlogic/driverlogic.vcxproj.filters index 91ed2267da..9665231376 100644 --- a/windows/driverlogic/driverlogic.vcxproj.filters +++ b/windows/driverlogic/driverlogic.vcxproj.filters @@ -28,5 +28,6 @@ <ClInclude Include="src\util.h" /> <ClInclude Include="src\wintun.h" /> <ClInclude Include="src\devenum.h" /> + <ClInclude Include="src\wireguard.h" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp index 93af20b400..3cb1739e21 100644 --- a/windows/driverlogic/src/driverlogic.cpp +++ b/windows/driverlogic/src/driverlogic.cpp @@ -5,6 +5,7 @@ #include "log.h" #include "version.h" #include "wintun.h" +#include "wireguard.h" #include "devenum.h" #include <string> #include <libcommon/error.h> @@ -278,6 +279,32 @@ ReturnCode CommandWintunDeleteAbandonedDevice(const std::vector<std::wstring> &a return GENERAL_SUCCESS; } +ReturnCode CommandWireGuardNtCleanup(const std::vector<std::wstring> &args) +{ + ArgumentContext argsContext(args); + + argsContext.ensureExactArgumentCount(1); + + const auto poolName = argsContext.next(); + + WireGuardNtDll wgNt; + + BOOL rebootRequired; + + if (FALSE == wgNt.deletePoolDriver(poolName.c_str(), &rebootRequired)) + { + throw std::runtime_error("Failed to delete WireGuardNT pool"); + } + + std::wstringstream ss; + + ss << L"Successfully deleted WireGuardNT pool. Reboot required: " << rebootRequired; + + Log(ss.str()); + + return ReturnCode::GENERAL_SUCCESS; +} + } // anonymous namespace int wmain(int argc, const wchar_t *argv[]) @@ -325,7 +352,8 @@ int wmain(int argc, const wchar_t *argv[]) { L"st-force-install", CommandSplitTunnelForceInstall }, { L"st-remove", CommandSplitTunnelRemove }, { L"wintun-delete-pool-driver", CommandWintunDeletePool }, - { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice } + { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice }, + { L"wg-nt-cleanup", CommandWireGuardNtCleanup } }; // diff --git a/windows/driverlogic/src/wireguard.h b/windows/driverlogic/src/wireguard.h new file mode 100644 index 0000000000..5892b248f1 --- /dev/null +++ b/windows/driverlogic/src/wireguard.h @@ -0,0 +1,58 @@ +#pragma once + +#include <wireguard-nt/wireguard.h> +#include <libcommon/error.h> +#include "util.h" + +class WireGuardNtDll +{ +public: + + WireGuardNtDll() : dllHandle(nullptr) + { + auto path = GetProcessModulePath().replace_filename(L"wireguard.dll"); + dllHandle = LoadLibraryExW(path.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); + + if (nullptr == dllHandle) + { + THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW"); + } + + try + { + deletePoolDriver = getProcAddressOrThrow<WIREGUARD_DELETE_POOL_DRIVER_FUNC*>("WireGuardDeletePoolDriver"); + } + catch (...) + { + FreeLibrary(dllHandle); + throw; + } + } + + ~WireGuardNtDll() + { + if (nullptr != dllHandle) + { + FreeLibrary(dllHandle); + } + } + + WIREGUARD_DELETE_POOL_DRIVER_FUNC *deletePoolDriver; + +private: + + template<typename T> + T getProcAddressOrThrow(const char *procName) + { + const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName)); + + if (nullptr == result) + { + THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress"); + } + + return result; + } + + HMODULE dllHandle; +}; diff --git a/windows/libshared/src/libshared/network/interfaceutils.cpp b/windows/libshared/src/libshared/network/interfaceutils.cpp index 28263f383b..fba4d71ba0 100644 --- a/windows/libshared/src/libshared/network/interfaceutils.cpp +++ b/windows/libshared/src/libshared/network/interfaceutils.cpp @@ -99,6 +99,7 @@ void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOC row.InterfaceLuid = device; row.Address = address; + row.DadState = IpDadStatePreferred; const auto status = CreateUnicastIpAddressEntry(&row); |
