summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-09-28 12:44:12 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-09-28 12:44:12 +0200
commit827d95c831f9ef8de4b419a6f7913377a20e8cf9 (patch)
tree2ffa5de8ee4ba2f777b4bfe9001dca44696fad6e
parent31e62ea07e957b2e1d285c8eb85605ce8cba5e69 (diff)
parenta247e6220fe924c89923beae638dfc182797ba18 (diff)
downloadmullvadvpn-827d95c831f9ef8de4b419a6f7913377a20e8cf9.tar.xz
mullvadvpn-827d95c831f9ef8de4b419a6f7913377a20e8cf9.zip
Merge branch 'wg-nt'
-rw-r--r--CHANGELOG.md1
-rw-r--r--Cargo.lock1
-rw-r--r--README.md2
m---------dist-assets/binaries0
-rw-r--r--dist-assets/windows/installer.nsh48
-rw-r--r--gui/tasks/distribution.js1
-rw-r--r--mullvad-cli/src/cmds/tunnel.rs55
-rw-r--r--mullvad-daemon/src/lib.rs34
-rw-r--r--mullvad-daemon/src/management_interface.rs16
-rw-r--r--mullvad-daemon/src/settings.rs14
-rw-r--r--mullvad-management-interface/proto/management_interface.proto3
-rw-r--r--mullvad-management-interface/src/types.rs6
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/tunnel/mod.rs3
-rw-r--r--talpid-core/src/tunnel/wireguard/config.rs5
-rw-r--r--talpid-core/src/tunnel/wireguard/logging.rs70
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs44
-rw-r--r--talpid-core/src/tunnel/wireguard/stats.rs3
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs27
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_nt.rs1215
-rw-r--r--talpid-core/src/windows.rs137
-rw-r--r--talpid-types/src/net/wireguard.rs4
-rw-r--r--windows/driverlogic/driverlogic.vcxproj1
-rw-r--r--windows/driverlogic/driverlogic.vcxproj.filters1
-rw-r--r--windows/driverlogic/src/driverlogic.cpp30
-rw-r--r--windows/driverlogic/src/wireguard.h58
-rw-r--r--windows/libshared/src/libshared/network/interfaceutils.cpp1
27 files changed, 1731 insertions, 50 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 58dcef1163..c1b1c29f52 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -36,6 +36,7 @@ Line wrap the file at 100 chars. Th
#### Windows
- Resolve symbolic links and junctions for excluded apps.
+- Add opt-in support for NT kernel WireGuard driver. It can be enabled in the CLI.
### Changed
- Only use the account history file to store the last used account.
diff --git a/Cargo.lock b/Cargo.lock
index 54e805dd57..3b2733a019 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2485,6 +2485,7 @@ version = "0.1.0"
dependencies = [
"async-trait",
"atty",
+ "bitflags",
"byteorder",
"cfg-if 1.0.0",
"chrono",
diff --git a/README.md b/README.md
index c562d4cd4a..bf91197235 100644
--- a/README.md
+++ b/README.md
@@ -436,7 +436,7 @@ echo "org.gradle.jvmargs=-Xmx4608M" >> ~/.gradle/gradle.properties
* `"network-manager"`: use `NetworkManager` service through DBus
* `TALPID_FORCE_USERSPACE_WIREGUARD` - Forces the daemon to use the userspace implementation of
- WireGuard on Linux.
+ WireGuard on Linux and Windows.
* `TALPID_DNS_CACHE_POLICY` - On Windows, this changes how DNS is configured:
* `1`: The default. This sets a global list of DNS servers that `dnscache` will use instead of
diff --git a/dist-assets/binaries b/dist-assets/binaries
-Subproject c77da1b6ca952289acb668dc8a53a03367805ff
+Subproject 19a97997b188855d0ba5aedb7419683df45d93b
diff --git a/dist-assets/windows/installer.nsh b/dist-assets/windows/installer.nsh
index 5f0151f009..c1dac59bf8 100644
--- a/dist-assets/windows/installer.nsh
+++ b/dist-assets/windows/installer.nsh
@@ -13,6 +13,7 @@
#
!define WINTUN_POOL "Mullvad"
+!define WG_NT_POOL "Mullvad"
# "sc" exit code
!define SERVICE_STARTED 0
@@ -59,19 +60,20 @@
!define PERSISTENT_BLOCK_OUTBOUND_IPV4_FILTER_GUID "{79860c64-9a5e-48a3-b5f3-d64b41659aa5}"
#
-# ExtractWintun
+# ExtractWireGuard
#
-# Extract Wintun installer into $TEMP
+# Extract Wintun and WireGuardNT installer into $TEMP
#
-!macro ExtractWintun
+!macro ExtractWireGuard
SetOutPath "$TEMP"
File "${BUILD_RESOURCES_DIR}\binaries\x86_64-pc-windows-msvc\wintun\wintun.dll"
+ File "${BUILD_RESOURCES_DIR}\binaries\x86_64-pc-windows-msvc\wireguard-nt\wireguard.dll"
File "${BUILD_RESOURCES_DIR}\..\windows\driverlogic\bin\x64-Release\driverlogic.exe"
!macroend
-!define ExtractWintun '!insertmacro "ExtractWintun"'
+!define ExtractWireGuard '!insertmacro "ExtractWireGuard"'
#
# ExtractMullvadSetup
@@ -222,6 +224,41 @@
!define RemoveWintun '!insertmacro "RemoveWintun"'
#
+# RemoveWireGuardNt
+#
+# Try to remove WireGuardNT
+#
+!macro RemoveWireGuardNt
+ Push $0
+ Push $1
+
+ log::Log "RemoveWireGuardNt()"
+
+ nsExec::ExecToStack '"$TEMP\driverlogic.exe" wg-nt-cleanup ${WG_NT_POOL}'
+ Pop $0
+ Pop $1
+
+ ${If} $0 != ${DL_GENERAL_SUCCESS}
+ IntFmt $0 "0x%X" $0
+ StrCpy $R0 "Failed to remove WireGuardNT pool: error $0"
+ log::LogWithDetails $R0 $1
+ Goto RemoveWireGuardNt_return_only
+ ${EndIf}
+
+ log::Log "RemoveWireGuardNt() completed successfully"
+
+ Push 0
+ Pop $R0
+
+ RemoveWireGuardNt_return_only:
+
+ Pop $1
+ Pop $0
+
+!macroend
+
+!define RemoveWireGuardNt '!insertmacro "RemoveWireGuardNt"'
+#
# RemoveAbandonedWintunAdapter
#
# Removes old Wintun interface, even if it belongs to a different pool.
@@ -1244,8 +1281,9 @@
${ClearFirewallRules}
${RemoveWireGuardKey}
- ${ExtractWintun}
+ ${ExtractWireGuard}
${RemoveWintun}
+ ${RemoveWireGuardNt}
${ExtractSplitTunnelDriver}
${RemoveSplitTunnelDriver}
diff --git a/gui/tasks/distribution.js b/gui/tasks/distribution.js
index 452baad71f..6f74d65eda 100644
--- a/gui/tasks/distribution.js
+++ b/gui/tasks/distribution.js
@@ -114,6 +114,7 @@ const config = {
{ from: distAssets('binaries/x86_64-pc-windows-msvc/sslocal.exe'), to: '.' },
{ from: root('build/lib/x86_64-pc-windows-msvc/libwg.dll'), to: '.' },
{ from: distAssets('binaries/x86_64-pc-windows-msvc/wintun/wintun.dll'), to: '.' },
+ { from: distAssets('binaries/x86_64-pc-windows-msvc/wireguard-nt/wireguard.dll'), to: '.' },
],
},
diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs
index 08306b70eb..e01d81a9da 100644
--- a/mullvad-cli/src/cmds/tunnel.rs
+++ b/mullvad-cli/src/cmds/tunnel.rs
@@ -34,11 +34,19 @@ impl Command for Tunnel {
}
fn create_wireguard_subcommand() -> clap::App<'static, 'static> {
- clap::SubCommand::with_name("wireguard")
+ let subcmd = clap::SubCommand::with_name("wireguard")
.about("Manage options for Wireguard tunnels")
.setting(clap::AppSettings::SubcommandRequiredElseHelp)
.subcommand(create_wireguard_mtu_subcommand())
- .subcommand(create_wireguard_keys_subcommand())
+ .subcommand(create_wireguard_keys_subcommand());
+ #[cfg(windows)]
+ {
+ subcmd.subcommand(create_wireguard_use_wg_nt_subcommand())
+ }
+ #[cfg(not(windows))]
+ {
+ subcmd
+ }
}
fn create_wireguard_mtu_subcommand() -> clap::App<'static, 'static> {
@@ -61,6 +69,22 @@ fn create_wireguard_keys_subcommand() -> clap::App<'static, 'static> {
.subcommand(create_wireguard_keys_rotation_interval_subcommand())
}
+#[cfg(windows)]
+fn create_wireguard_use_wg_nt_subcommand() -> clap::App<'static, 'static> {
+ clap::SubCommand::with_name("use-wireguard-nt")
+ .about("Enable or disable wireguard-nt")
+ .setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::SubCommand::with_name("get"))
+ .subcommand(
+ clap::SubCommand::with_name("set").arg(
+ clap::Arg::with_name("policy")
+ .required(true)
+ .takes_value(true)
+ .possible_values(&["on", "off"]),
+ ),
+ )
+}
+
fn create_wireguard_keys_rotation_interval_subcommand() -> clap::App<'static, 'static> {
clap::SubCommand::with_name("rotation-interval")
.about("Manage automatic key rotation (given in hours)")
@@ -147,6 +171,13 @@ impl Tunnel {
_ => unreachable!("unhandled command"),
},
+ #[cfg(windows)]
+ ("use-wireguard-nt", Some(matches)) => match matches.subcommand() {
+ ("get", _) => Self::process_wireguard_use_wg_nt_get().await,
+ ("set", Some(matches)) => Self::process_wireguard_use_wg_nt_set(matches).await,
+ _ => unreachable!("unhandled command"),
+ },
+
_ => unreachable!("unhandled command"),
}
}
@@ -180,6 +211,26 @@ impl Tunnel {
Ok(())
}
+ #[cfg(windows)]
+ async fn process_wireguard_use_wg_nt_get() -> Result<()> {
+ let tunnel_options = Self::get_tunnel_options().await?;
+ if tunnel_options.wireguard.unwrap().use_wireguard_nt {
+ println!("enabled");
+ } else {
+ println!("disabled");
+ }
+ Ok(())
+ }
+
+ #[cfg(windows)]
+ async fn process_wireguard_use_wg_nt_set(matches: &clap::ArgMatches<'_>) -> Result<()> {
+ let new_state = matches.value_of("policy").unwrap() == "on";
+ let mut rpc = new_rpc_client().await?;
+ rpc.set_use_wireguard_nt(new_state).await?;
+ println!("Updated wireguard-nt setting");
+ Ok(())
+ }
+
async fn process_wireguard_key_check() -> Result<()> {
let mut rpc = new_rpc_client().await?;
let key = rpc.get_wireguard_key(()).await;
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 19692d57fa..bde37e85fb 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -281,6 +281,9 @@ pub enum DaemonCommand {
/// Disable split tunnel
#[cfg(windows)]
SetSplitTunnelState(ResponseTx<(), Error>, bool),
+ /// Toggle wireguard-nt on or off
+ #[cfg(target_os = "windows")]
+ UseWireGuardNt(ResponseTx<(), Error>, bool),
/// Makes the daemon exit the main loop and quit.
Shutdown,
/// Saves the target tunnel state and enters a blocking state. The state is restored
@@ -1230,6 +1233,8 @@ where
ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await,
#[cfg(windows)]
SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await,
+ #[cfg(target_os = "windows")]
+ UseWireGuardNt(tx, state) => self.on_use_wireguard_nt(tx, state).await,
Shutdown => self.trigger_shutdown_event(),
PrepareRestart => self.on_prepare_restart(),
#[cfg(target_os = "android")]
@@ -1937,6 +1942,35 @@ where
}
}
+ #[cfg(windows)]
+ async fn on_use_wireguard_nt(&mut self, tx: ResponseTx<(), Error>, state: bool) {
+ let save_result = self
+ .settings
+ .set_use_wireguard_nt(state)
+ .await
+ .map_err(Error::SettingsError);
+ match save_result {
+ Ok(settings_changed) => {
+ Self::oneshot_send(tx, Ok(()), "use_wireguard_nt response");
+ if settings_changed {
+ self.event_listener
+ .notify_settings(self.settings.to_settings());
+ if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
+ info!("Initiating tunnel restart");
+ self.reconnect_tunnel();
+ }
+ }
+ }
+ Err(error) => {
+ error!(
+ "{}",
+ error.display_chain_with_msg("Unable to save settings")
+ );
+ Self::oneshot_send(tx, Err(error), "use_wireguard_nt response");
+ }
+ }
+ }
+
async fn on_update_relay_settings(
&mut self,
tx: ResponseTx<(), settings::Error>,
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 0b2f6b463f..cef8d42f78 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -688,6 +688,22 @@ impl ManagementService for ManagementServiceImpl {
async fn set_split_tunnel_state(&self, _: Request<bool>) -> ServiceResult<()> {
Ok(Response::new(()))
}
+
+ #[cfg(windows)]
+ async fn set_use_wireguard_nt(&self, request: Request<bool>) -> ServiceResult<()> {
+ log::debug!("set_use_wireguard_nt");
+ let state = request.into_inner();
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::UseWireGuardNt(tx, state))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+ #[cfg(not(windows))]
+ async fn set_use_wireguard_nt(&self, _: Request<bool>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
}
impl ManagementServiceImpl {
diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs
index 02568e3226..e3cfeb8ddc 100644
--- a/mullvad-daemon/src/settings.rs
+++ b/mullvad-daemon/src/settings.rs
@@ -330,6 +330,20 @@ impl SettingsPersister {
self.update(should_save).await
}
+ #[cfg(windows)]
+ pub async fn set_use_wireguard_nt(&mut self, state: bool) -> Result<bool, Error> {
+ let should_save = Self::update_field(
+ &mut self
+ .settings
+ .tunnel_options
+ .wireguard
+ .options
+ .use_wireguard_nt,
+ state,
+ );
+ self.update(should_save).await
+ }
+
fn update_field<T: Eq>(field: &mut T, new_value: T) -> bool {
if *field != new_value {
*field = new_value;
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 8711ba7b1b..a6d88bb0e0 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -69,6 +69,8 @@ service ManagementService {
rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
+
+ rpc SetUseWireguardNt(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
}
message RelaySettingsUpdate {
@@ -379,6 +381,7 @@ message TunnelOptions {
message WireguardOptions {
uint32 mtu = 1;
google.protobuf.Duration rotation_interval = 2;
+ bool use_wireguard_nt = 3;
}
message GenericOptions {
bool enable_ipv6 = 1;
diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs
index 81fdbb0e87..4e3a7217eb 100644
--- a/mullvad-management-interface/src/types.rs
+++ b/mullvad-management-interface/src/types.rs
@@ -562,6 +562,10 @@ impl From<&mullvad_types::settings::TunnelOptions> for TunnelOptions {
.wireguard
.rotation_interval
.map(|ivl| Duration::from(std::time::Duration::from(ivl))),
+ #[cfg(windows)]
+ use_wireguard_nt: options.wireguard.options.use_wireguard_nt,
+ #[cfg(not(windows))]
+ use_wireguard_nt: false,
}),
generic: Some(tunnel_options::GenericOptions {
enable_ipv6: options.generic.enable_ipv6,
@@ -1199,6 +1203,8 @@ impl TryFrom<TunnelOptions> for mullvad_types::settings::TunnelOptions {
} else {
None
},
+ #[cfg(windows)]
+ use_wireguard_nt: wireguard_options.use_wireguard_nt,
},
rotation_interval: wireguard_options
.rotation_interval
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 69b2f1e42c..f8ebc64564 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -8,6 +8,7 @@ edition = "2018"
publish = false
[dependencies]
+bitflags = "1.2"
async-trait = "0.1"
atty = "0.2"
cfg-if = "1.0"
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 6501162616..f4f34be00b 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -130,6 +130,7 @@ impl TunnelMonitor {
runtime,
&config,
log_file,
+ resource_dir,
on_event,
tun_provider,
route_manager,
@@ -161,6 +162,7 @@ impl TunnelMonitor {
runtime: tokio::runtime::Handle,
params: &wireguard_types::TunnelParameters,
log: Option<PathBuf>,
+ resource_dir: &Path,
on_event: L,
tun_provider: &mut TunProvider,
route_manager: &mut RouteManager,
@@ -177,6 +179,7 @@ impl TunnelMonitor {
runtime,
config,
log.as_ref().map(|p| p.as_path()),
+ resource_dir,
on_event,
tun_provider,
route_manager,
diff --git a/talpid-core/src/tunnel/wireguard/config.rs b/talpid-core/src/tunnel/wireguard/config.rs
index ae82483c66..252bf8418f 100644
--- a/talpid-core/src/tunnel/wireguard/config.rs
+++ b/talpid-core/src/tunnel/wireguard/config.rs
@@ -23,6 +23,9 @@ pub struct Config {
/// Enable IPv6 routing rules
#[cfg(target_os = "linux")]
pub enable_ipv6: bool,
+ /// Temporary switch for wireguard-nt
+ #[cfg(target_os = "windows")]
+ pub use_wireguard_nt: bool,
}
const DEFAULT_MTU: u16 = 1380;
@@ -109,6 +112,8 @@ impl Config {
fwmark: crate::linux::TUNNEL_FW_MARK,
#[cfg(target_os = "linux")]
enable_ipv6: generic_options.enable_ipv6,
+ #[cfg(target_os = "windows")]
+ use_wireguard_nt: wg_options.use_wireguard_nt,
})
}
diff --git a/talpid-core/src/tunnel/wireguard/logging.rs b/talpid-core/src/tunnel/wireguard/logging.rs
index 64f2f71bf5..e795326854 100644
--- a/talpid-core/src/tunnel/wireguard/logging.rs
+++ b/talpid-core/src/tunnel/wireguard/logging.rs
@@ -1,5 +1,5 @@
use parking_lot::Mutex;
-use std::{collections::HashMap, fs, io::Write, path::Path};
+use std::{collections::HashMap, fmt, fs, io::Write, path::Path};
lazy_static::lazy_static! {
static ref LOG_MUTEX: Mutex<HashMap<u32, fs::File>> = Mutex::new(HashMap::new());
@@ -44,14 +44,58 @@ pub fn clean_up_logging(ordinal: u32) {
map.remove(&ordinal);
}
+#[allow(dead_code)]
+pub enum LogLevel {
+ Verbose,
+ Info,
+ Warning,
+ Error,
+}
+
+impl fmt::Display for LogLevel {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str(self.as_ref())
+ }
+}
+
+impl AsRef<str> for LogLevel {
+ fn as_ref(&self) -> &str {
+ match self {
+ LogLevel::Verbose => "VERBOSE",
+ LogLevel::Info => "INFO",
+ LogLevel::Warning => "WARNING",
+ LogLevel::Error => "ERROR",
+ }
+ }
+}
+
+#[cfg(windows)]
+pub fn log(context: u32, level: LogLevel, tag: &str, msg: &str) {
+ let mut map = LOG_MUTEX.lock();
+ if let Some(logfile) = map.get_mut(&(context as u32)) {
+ log_inner(logfile, level, tag, msg);
+ }
+}
+
+fn log_inner(logfile: &mut fs::File, level: LogLevel, tag: &str, msg: &str) {
+ let _ = write!(
+ logfile,
+ "{}[{}][{}] {}",
+ chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"),
+ tag,
+ level,
+ msg,
+ );
+}
+
// Callback that receives messages from WireGuard
-pub unsafe extern "system" fn logging_callback(
+pub unsafe extern "system" fn wg_go_logging_callback(
level: WgLogLevel,
msg: *const libc::c_char,
context: *mut libc::c_void,
) {
- let map = LOG_MUTEX.lock();
- if let Some(mut logfile) = map.get(&(context as u32)) {
+ let mut map = LOG_MUTEX.lock();
+ if let Some(logfile) = map.get_mut(&(context as u32)) {
let managed_msg = if !msg.is_null() {
#[cfg(not(target_os = "windows"))]
let m = std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string();
@@ -65,24 +109,14 @@ pub unsafe extern "system" fn logging_callback(
"Logging message from WireGuard is NULL".to_string()
};
- let level_str = match level {
- WG_GO_LOG_VERBOSE => "VERBOSE",
- WG_GO_LOG_ERROR | _ => "ERROR",
+ let level = match level {
+ WG_GO_LOG_VERBOSE => LogLevel::Verbose,
+ WG_GO_LOG_ERROR | _ => LogLevel::Error,
};
-
- let _ = write!(
- logfile,
- "{}[{}][{}] {}",
- chrono::Local::now().format("[%Y-%m-%d %H:%M:%S%.3f]"),
- "wireguard-go",
- level_str,
- managed_msg
- );
+ log_inner(logfile, level, "wireguard-go", &managed_msg);
}
}
-// unsafe fn
-
pub type WgLogLevel = u32;
// wireguard-go supports log levels 0 through 3 with 3 being the most verbose
// const WG_GO_LOG_SILENT: WgLogLevel = 0;
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index df3c7bb8b3..a3fa426ca2 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -26,6 +26,8 @@ mod stats;
mod wireguard_go;
#[cfg(target_os = "linux")]
pub(crate) mod wireguard_kernel;
+#[cfg(windows)]
+mod wireguard_nt;
use self::wireguard_go::WgGoTunnel;
@@ -89,8 +91,6 @@ pub struct WireguardMonitor {
stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>,
pinger_stop_sender: mpsc::Sender<()>,
_tcp_proxies: Vec<TcpProxy>,
- #[cfg(target_os = "windows")]
- _callback_handle: Option<crate::winnet::WinNetCallbackHandle>,
}
#[cfg(target_os = "linux")]
@@ -165,6 +165,7 @@ impl WireguardMonitor {
runtime: tokio::runtime::Handle,
mut config: Config,
log_path: Option<&Path>,
+ resource_dir: &Path,
on_event: F,
tun_provider: &mut TunProvider,
route_manager: &mut routing::RouteManager,
@@ -183,20 +184,12 @@ impl WireguardMonitor {
}
}
- let tunnel = Self::open_tunnel(&config, log_path, tun_provider, route_manager)?;
+ let tunnel =
+ Self::open_tunnel(&config, log_path, resource_dir, tun_provider, route_manager)?;
let iface_name = tunnel.get_interface_name().to_string();
#[cfg(windows)]
let iface_luid = tunnel.get_interface_luid();
- #[cfg(target_os = "windows")]
- let callback_handle = route_manager
- .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ())
- .ok();
- #[cfg(target_os = "windows")]
- if callback_handle.is_none() {
- log::warn!("Failed to register default route callback");
- }
-
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
let (pinger_tx, pinger_rx) = mpsc::channel();
@@ -212,8 +205,6 @@ impl WireguardMonitor {
stop_setup_tx: Some(stop_setup_tx),
pinger_stop_sender: pinger_tx,
_tcp_proxies: tcp_proxies,
- #[cfg(target_os = "windows")]
- _callback_handle: callback_handle,
};
let gateway = config.ipv4_gateway;
@@ -317,10 +308,11 @@ impl WireguardMonitor {
Ok(monitor)
}
- #[cfg_attr(not(target_os = "linux"), allow(unused_variables))]
+ #[allow(unused_variables)]
fn open_tunnel(
config: &Config,
log_path: Option<&Path>,
+ resource_dir: &Path,
tun_provider: &mut TunProvider,
route_manager: &mut routing::RouteManager,
) -> Result<Box<dyn Tunnel>> {
@@ -362,14 +354,34 @@ impl WireguardMonitor {
}
}
- #[cfg(target_os = "linux")]
+ #[cfg(target_os = "windows")]
+ if config.use_wireguard_nt {
+ match wireguard_nt::WgNtTunnel::start_tunnel(config, log_path, resource_dir) {
+ Ok(tunnel) => {
+ log::debug!("Using WireGuardNT");
+ return Ok(Box::new(tunnel));
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to setup WireGuardNT tunnel")
+ );
+ }
+ }
+ }
+
+ #[cfg(any(target_os = "linux", windows))]
log::debug!("Using userspace WireGuard implementation");
Ok(Box::new(
WgGoTunnel::start_tunnel(
&config,
log_path,
+ #[cfg(not(windows))]
tun_provider,
+ #[cfg(not(windows))]
Self::get_tunnel_destinations(config),
+ #[cfg(windows)]
+ route_manager,
)
.map_err(Error::TunnelError)?,
))
diff --git a/talpid-core/src/tunnel/wireguard/stats.rs b/talpid-core/src/tunnel/wireguard/stats.rs
index ff076fd714..f565988267 100644
--- a/talpid-core/src/tunnel/wireguard/stats.rs
+++ b/talpid-core/src/tunnel/wireguard/stats.rs
@@ -12,6 +12,9 @@ pub enum Error {
#[error(display = "Device no longer exists")]
NoTunnelDevice,
+
+ #[error(display = "Failed to obtain tunnel config")]
+ NoTunnelConfig,
}
/// Contains bytes sent and received through a tunnel
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index e83a74ab47..bc12fa676f 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -2,10 +2,14 @@ use super::{
stats::{Stats, StatsMap},
Config, Tunnel, TunnelError,
};
-use crate::tunnel::{
- tun_provider::TunProvider,
- wireguard::logging::{clean_up_logging, initialize_logging, logging_callback, WgLogLevel},
+#[cfg(windows)]
+use crate::routing;
+#[cfg(not(windows))]
+use crate::tunnel::tun_provider::TunProvider;
+use crate::tunnel::wireguard::logging::{
+ clean_up_logging, initialize_logging, wg_go_logging_callback, WgLogLevel,
};
+#[cfg(not(windows))]
use ipnetwork::IpNetwork;
use std::{
ffi::{c_void, CStr},
@@ -56,6 +60,8 @@ pub struct WgGoTunnel {
_tunnel_device: Tun,
// context that maps to fs::File instance, used with logging callback
_logging_context: LoggingContext,
+ #[cfg(target_os = "windows")]
+ _route_callback_handle: Option<crate::winnet::WinNetCallbackHandle>,
}
impl WgGoTunnel {
@@ -82,7 +88,7 @@ impl WgGoTunnel {
mtu,
wg_config_str.as_ptr() as *const i8,
tunnel_fd,
- Some(logging_callback),
+ Some(wg_go_logging_callback),
logging_context.0 as *mut libc::c_void,
)
};
@@ -104,9 +110,15 @@ impl WgGoTunnel {
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
- _tun_provider: &mut TunProvider,
- _routes: impl Iterator<Item = IpNetwork>,
+ route_manager: &mut routing::RouteManager,
) -> Result<Self> {
+ let route_callback_handle = route_manager
+ .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ())
+ .ok();
+ if route_callback_handle.is_none() {
+ log::warn!("Failed to register default route callback");
+ }
+
let wg_config_str = config.to_userspace_format();
let iface_name: String = "Mullvad".to_string();
let cstr_iface_name =
@@ -133,7 +145,7 @@ impl WgGoTunnel {
wg_config_str.as_ptr(),
&mut alias_ptr,
&mut interface_luid,
- Some(logging_callback),
+ Some(wg_go_logging_callback),
logging_context.0 as *mut libc::c_void,
)
};
@@ -156,6 +168,7 @@ impl WgGoTunnel {
interface_luid,
handle: Some(handle),
_logging_context: logging_context,
+ _route_callback_handle: route_callback_handle,
})
}
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_nt.rs b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
new file mode 100644
index 0000000000..2bd14f644d
--- /dev/null
+++ b/talpid-core/src/tunnel/wireguard/wireguard_nt.rs
@@ -0,0 +1,1215 @@
+use super::{
+ config::Config,
+ logging,
+ stats::{Stats, StatsMap},
+ Tunnel,
+};
+use crate::windows;
+use bitflags::bitflags;
+use ipnetwork::IpNetwork;
+use lazy_static::lazy_static;
+use std::{
+ ffi::CStr,
+ fmt, io, iter, mem,
+ os::windows::{ffi::OsStrExt, io::RawHandle},
+ path::Path,
+ ptr,
+ sync::{Arc, Mutex},
+};
+use talpid_types::ErrorExt;
+use widestring::{U16CStr, U16CString};
+use winapi::{
+ shared::{
+ guiddef::GUID,
+ ifdef::NET_LUID,
+ in6addr::IN6_ADDR,
+ inaddr::IN_ADDR,
+ minwindef::{BOOL, FARPROC, HINSTANCE, HMODULE},
+ nldef::RouterDiscoveryDisabled,
+ ntdef::FALSE,
+ winerror::ERROR_MORE_DATA,
+ ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6},
+ ws2ipdef::SOCKADDR_INET,
+ },
+ um::libloaderapi::{
+ FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_WITH_ALTERED_SEARCH_PATH,
+ },
+};
+
+
+lazy_static! {
+ static ref WG_NT_DLL: Mutex<Option<Arc<WgNtDll>>> = Mutex::new(None);
+ static ref ADAPTER_POOL: U16CString = U16CString::from_str("Mullvad").unwrap();
+ static ref ADAPTER_ALIAS: U16CString = U16CString::from_str("Mullvad").unwrap();
+}
+
+const ADAPTER_GUID: GUID = GUID {
+ Data1: 0x514a3988,
+ Data2: 0x9716,
+ Data3: 0x43d5,
+ Data4: [0x8b, 0x05, 0x31, 0xda, 0x25, 0xa0, 0x44, 0xa9],
+};
+
+/// Longest possible adapter name (in characters), including null terminator
+const MAX_ADAPTER_NAME: usize = 128;
+
+type WireGuardOpenAdapterFn =
+ unsafe extern "stdcall" fn(pool: *const u16, name: *const u16) -> RawHandle;
+type WireGuardCreateAdapterFn = unsafe extern "stdcall" fn(
+ pool: *const u16,
+ name: *const u16,
+ requested_guid: *const GUID,
+ reboot_required: *mut BOOL,
+) -> RawHandle;
+type WireGuardFreeAdapterFn = unsafe extern "stdcall" fn(adapter: RawHandle);
+type WireGuardDeleteAdapterFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, reboot_required: *mut BOOL) -> BOOL;
+type WireGuardGetAdapterLuidFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, luid: *mut NET_LUID);
+type WireGuardGetAdapterNameFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, name: *mut u16) -> BOOL;
+type WireGuardSetConfigurationFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: u32) -> BOOL;
+type WireGuardGetConfigurationFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, config: *const u8, bytes: *mut u32) -> BOOL;
+type WireGuardSetStateFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, state: WgAdapterState) -> BOOL;
+
+#[cfg(windows)]
+#[repr(C)]
+#[allow(dead_code)]
+enum LogLevel {
+ Info = 0,
+ Warn = 1,
+ Err = 2,
+}
+
+#[cfg(windows)]
+impl From<LogLevel> for logging::LogLevel {
+ fn from(level: LogLevel) -> Self {
+ match level {
+ LogLevel::Info => Self::Info,
+ LogLevel::Warn => Self::Warning,
+ LogLevel::Err => Self::Error,
+ }
+ }
+}
+
+type WireGuardLoggerCb = extern "stdcall" fn(LogLevel, timestamp: u64, *const u16);
+type WireGuardSetLoggerFn = extern "stdcall" fn(Option<WireGuardLoggerCb>);
+
+#[repr(C)]
+#[allow(dead_code)]
+enum WireGuardAdapterLogState {
+ Off = 0,
+ On = 1,
+ OnWithPrefix = 2,
+}
+
+type WireGuardSetAdapterLoggingFn =
+ unsafe extern "stdcall" fn(adapter: RawHandle, state: WireGuardAdapterLogState) -> BOOL;
+
+type RebootRequired = bool;
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to load WireGuardNT
+ #[error(display = "Failed to load wireguard.dll")]
+ DllError(#[error(source)] io::Error),
+
+ /// Failed to remove tunnel interface
+ #[error(display = "Failed to remove residual tunnel device")]
+ DeleteExistingTunnelError(#[error(source)] io::Error),
+
+ /// Failed to create tunnel interface
+ #[error(display = "Failed to create WireGuard device")]
+ CreateTunnelDeviceError(#[error(source)] io::Error),
+
+ /// Failed to delete tunnel interface
+ #[error(display = "Failed to delete WireGuard device")]
+ DeleteTunnelDeviceError(#[error(source)] io::Error),
+
+ /// Failed to obtain tunnel interface alias
+ #[error(display = "Failed to obtain interface name")]
+ ObtainAliasError(#[error(source)] io::Error),
+
+ /// Failed to get WireGuard tunnel config for device
+ #[error(display = "Failed to get tunnel WireGuard config")]
+ GetWireGuardConfigError(#[error(source)] io::Error),
+
+ /// Failed to set WireGuard tunnel config on device
+ #[error(display = "Failed to set tunnel WireGuard config")]
+ SetWireGuardConfigError(#[error(source)] io::Error),
+
+ /// Failed to set MTU on tunnel device
+ #[error(display = "Failed to set tunnel IPv4 interface MTU")]
+ SetTunnelIpv4MtuError(#[error(source)] io::Error),
+
+ /// Failed to set MTU on tunnel device
+ #[error(display = "Failed to set tunnel IPv6 interface MTU")]
+ SetTunnelIpv6MtuError(#[error(source)] io::Error),
+
+ /// Failed to set the tunnel state to up
+ #[error(display = "Failed to enable the tunnel adapter")]
+ EnableTunnelError(#[error(source)] io::Error),
+
+ /// Unknown address family
+ #[error(display = "Unknown address family: {}", _0)]
+ UnknownAddressFamily(i32),
+
+ /// Failure to set up logging
+ #[error(display = "Failed to set up logging")]
+ InitLoggingError(#[error(source)] logging::Error),
+
+ /// Invalid allowed IP
+ #[error(display = "Invalid CIDR prefix")]
+ InvalidAllowedIpCidr,
+
+ /// Allowed IP contains non-zero host bits
+ #[error(display = "Allowed IP contains non-zero host bits")]
+ InvalidAllowedIpBits,
+
+ /// Failed to parse data returned by the driver
+ #[error(display = "Failed to parse data returned by wireguard-nt")]
+ InvalidConfigData,
+}
+
+pub struct WgNtTunnel {
+ device: Option<WgNtAdapter>,
+ interface_luid: NET_LUID,
+ interface_name: String,
+ _logger_handle: LoggerHandle,
+}
+
+const WIREGUARD_KEY_LENGTH: usize = 32;
+
+/// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+#[derive(Clone, Copy)]
+#[repr(C, align(8))]
+union WgIpAddr {
+ v4: IN_ADDR,
+ v6: IN6_ADDR,
+}
+
+/// See `WIREGUARD_ALLOWED_IP` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+#[derive(Clone, Copy)]
+#[repr(C, align(8))]
+struct WgAllowedIp {
+ address: WgIpAddr,
+ address_family: ADDRESS_FAMILY,
+ cidr: u8,
+}
+
+impl WgAllowedIp {
+ fn new(address: WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<Self> {
+ Self::validate(&address, address_family, cidr)?;
+ Ok(Self {
+ address,
+ address_family,
+ cidr,
+ })
+ }
+
+ fn validate(address: &WgIpAddr, address_family: ADDRESS_FAMILY, cidr: u8) -> Result<()> {
+ match address_family as i32 {
+ AF_INET => {
+ if cidr > 32 {
+ return Err(Error::InvalidAllowedIpCidr);
+ }
+ let host_mask = u32::MAX.checked_shr(u32::from(cidr)).unwrap_or(0);
+ if host_mask & (unsafe { *(address.v4.S_un.S_addr()) }.to_be()) != 0 {
+ return Err(Error::InvalidAllowedIpBits);
+ }
+ }
+ AF_INET6 => {
+ if cidr > 128 {
+ return Err(Error::InvalidAllowedIpCidr);
+ }
+ let mut host_mask = u128::MAX.checked_shr(u32::from(cidr)).unwrap_or(0);
+ let bytes = unsafe { address.v6.u.Byte() };
+ for byte in bytes.iter().rev() {
+ if byte & ((host_mask & 0xff) as u8) != 0 {
+ return Err(Error::InvalidAllowedIpBits);
+ }
+ host_mask = host_mask >> 8;
+ }
+ }
+ family => return Err(Error::UnknownAddressFamily(family)),
+ }
+ Ok(())
+ }
+}
+
+impl PartialEq for WgAllowedIp {
+ fn eq(&self, other: &Self) -> bool {
+ if self.cidr != other.cidr {
+ return false;
+ }
+ match self.address_family as i32 {
+ AF_INET => {
+ windows::ipaddr_from_inaddr(unsafe { self.address.v4 })
+ == windows::ipaddr_from_inaddr(unsafe { other.address.v4 })
+ }
+ AF_INET6 => {
+ windows::ipaddr_from_in6addr(unsafe { self.address.v6 })
+ == windows::ipaddr_from_in6addr(unsafe { other.address.v6 })
+ }
+ _ => {
+ log::error!("Allowed IP uses unknown address family");
+ true
+ }
+ }
+ }
+}
+impl Eq for WgAllowedIp {}
+
+impl fmt::Debug for WgAllowedIp {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let mut s = f.debug_struct("WgAllowedIp");
+ match self.address_family as i32 {
+ AF_INET => s.field(
+ "address",
+ &windows::ipaddr_from_inaddr(unsafe { self.address.v4 }),
+ ),
+ AF_INET6 => s.field(
+ "address",
+ &windows::ipaddr_from_in6addr(unsafe { self.address.v6 }),
+ ),
+ _ => s.field("address", &"<unknown>"),
+ };
+ s.field("address_family", &self.address_family)
+ .field("cidr", &self.cidr)
+ .finish()
+ }
+}
+
+bitflags! {
+ /// See `WIREGUARD_PEER_FLAG` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+ struct WgPeerFlag: u32 {
+ const HAS_PUBLIC_KEY = 0b00000001;
+ const HAS_PRESHARED_KEY = 0b00000010;
+ const HAS_PERSISTENT_KEEPALIVE = 0b00000100;
+ const HAS_ENDPOINT = 0b00001000;
+ const REPLACE_ALLOWED_IPS = 0b00100000;
+ const REMOVE = 0b01000000;
+ const UPDATE = 0b10000000;
+ }
+}
+
+/// See `WIREGUARD_PEER` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+#[repr(C, align(8))]
+struct WgPeer {
+ flags: WgPeerFlag,
+ reserved: u32,
+ public_key: [u8; WIREGUARD_KEY_LENGTH],
+ preshared_key: [u8; WIREGUARD_KEY_LENGTH],
+ persistent_keepalive: u16,
+ endpoint: SockAddrInet,
+ tx_bytes: u64,
+ rx_bytes: u64,
+ last_handshake: u64,
+ allowed_ips_count: u32,
+}
+
+#[derive(Clone, Copy)]
+#[repr(C)]
+struct SockAddrInet {
+ addr: SOCKADDR_INET,
+}
+
+impl From<SOCKADDR_INET> for SockAddrInet {
+ fn from(addr: SOCKADDR_INET) -> Self {
+ Self { addr }
+ }
+}
+impl PartialEq for SockAddrInet {
+ fn eq(&self, other: &Self) -> bool {
+ let self_addr = match windows::try_socketaddr_from_inet_sockaddr(self.addr) {
+ Ok(addr) => addr,
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to convert socket address")
+ );
+ return true;
+ }
+ };
+ let other_addr = match windows::try_socketaddr_from_inet_sockaddr(other.addr) {
+ Ok(addr) => addr,
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to convert socket address")
+ );
+ return true;
+ }
+ };
+ self_addr == other_addr
+ }
+}
+impl Eq for SockAddrInet {}
+
+impl fmt::Debug for SockAddrInet {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let mut s = f.debug_struct("SockAddrInet");
+ let self_addr = windows::try_socketaddr_from_inet_sockaddr(self.addr)
+ .map(|addr| addr.to_string())
+ .unwrap_or("<unknown>".to_string());
+ s.field("addr", &self_addr).finish()
+ }
+}
+
+bitflags! {
+ /// See `WIREGUARD_INTERFACE_FLAG` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+ struct WgInterfaceFlag: u32 {
+ const HAS_PUBLIC_KEY = 0b00000001;
+ const HAS_PRIVATE_KEY = 0b00000010;
+ const HAS_LISTEN_PORT = 0b00000100;
+ const REPLACE_PEERS = 0b00001000;
+ }
+}
+
+/// See `WIREGUARD_INTERFACE` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+#[repr(C, align(8))]
+struct WgInterface {
+ flags: WgInterfaceFlag,
+ listen_port: u16,
+ private_key: [u8; WIREGUARD_KEY_LENGTH],
+ public_key: [u8; WIREGUARD_KEY_LENGTH],
+ peers_count: u32,
+}
+
+/// See `WIREGUARD_ADAPTER_LOG_STATE` at https://git.zx2c4.com/wireguard-nt/tree/api/wireguard.h.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+#[repr(C)]
+#[allow(dead_code)]
+enum WgAdapterState {
+ Down = 0,
+ Up = 1,
+}
+
+
+impl WgNtTunnel {
+ pub fn start_tunnel(
+ config: &Config,
+ log_path: Option<&Path>,
+ resource_dir: &Path,
+ ) -> Result<Self> {
+ let dll = load_wg_nt_dll(resource_dir)?;
+
+ let logger_handle = LoggerHandle::new(dll.clone(), log_path)?;
+
+ {
+ if let Ok(device) = WgNtAdapter::open(dll.clone(), &*ADAPTER_POOL, &*ADAPTER_ALIAS) {
+ device.delete().map_err(Error::DeleteExistingTunnelError)?;
+ }
+ }
+
+ let (device, reboot_required) = WgNtAdapter::create(
+ dll.clone(),
+ &*ADAPTER_POOL,
+ &*ADAPTER_ALIAS,
+ Some(ADAPTER_GUID.clone()),
+ )
+ .map_err(Error::CreateTunnelDeviceError)?;
+
+ if reboot_required {
+ log::warn!("You may need to reboot to finish installing WireGuardNT");
+ }
+
+ let interface_luid = device.luid();
+ let interface_name = match device.name() {
+ Ok(name) => name.to_string_lossy(),
+ Err(error) => {
+ if let Err(error) = device.delete() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to delete tunnel device")
+ );
+ }
+ return Err(Error::ObtainAliasError(error));
+ }
+ };
+
+ let tunnel = WgNtTunnel {
+ device: Some(device),
+ interface_luid,
+ interface_name,
+ _logger_handle: logger_handle,
+ };
+ tunnel.configure(config)?;
+ Ok(tunnel)
+ }
+
+ fn stop_tunnel(&mut self) -> Result<()> {
+ if let Some(device) = self.device.take() {
+ if let Err(error) = device.delete() {
+ return Err(Error::DeleteTunnelDeviceError(error));
+ }
+ }
+ Ok(())
+ }
+
+ fn configure(&self, config: &Config) -> Result<()> {
+ let device = self.device.as_ref().unwrap();
+ if let Err(error) = device.set_logging(WireGuardAdapterLogState::On) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set log state on WireGuard interface")
+ );
+ }
+ device.set_config(config)?;
+ prepare_interface(&device.luid(), AF_INET as u16, u32::from(config.mtu))
+ .map_err(Error::SetTunnelIpv4MtuError)?;
+ if config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()) {
+ prepare_interface(&device.luid(), AF_INET6 as u16, u32::from(config.mtu))
+ .map_err(Error::SetTunnelIpv6MtuError)?;
+ }
+ device
+ .set_state(WgAdapterState::Up)
+ .map_err(Error::EnableTunnelError)?;
+ Ok(())
+ }
+}
+
+impl Drop for WgNtTunnel {
+ fn drop(&mut self) {
+ if let Err(error) = self.stop_tunnel() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to stop WireGuardNT tunnel")
+ );
+ }
+ }
+}
+
+lazy_static! {
+ static ref LOG_CONTEXT: Mutex<Option<u32>> = Mutex::new(None);
+}
+
+struct LoggerHandle {
+ dll: Arc<WgNtDll>,
+ context: u32,
+}
+
+impl LoggerHandle {
+ fn new(dll: Arc<WgNtDll>, log_path: Option<&Path>) -> Result<Self> {
+ let context = logging::initialize_logging(log_path).map_err(Error::InitLoggingError)?;
+ {
+ *(LOG_CONTEXT.lock().unwrap()) = Some(context);
+ }
+ dll.set_logger(Some(Self::logging_callback));
+ Ok(Self { dll, context })
+ }
+
+ extern "stdcall" fn logging_callback(level: LogLevel, _timestamp: u64, message: *const u16) {
+ if message.is_null() {
+ return;
+ }
+ let mut message = unsafe { U16CStr::from_ptr_str(message) }.to_string_lossy();
+ message.push_str("\r\n");
+
+ if let Some(context) = &*LOG_CONTEXT.lock().unwrap() {
+ // Horribly broken, because callback does not provide a context
+ logging::log(*context, level.into(), "wireguard-nt", &message);
+ }
+ }
+}
+
+impl Drop for LoggerHandle {
+ fn drop(&mut self) {
+ let mut ctx = LOG_CONTEXT.lock().unwrap();
+ if *ctx == Some(self.context) {
+ *ctx = None;
+ self.dll.set_logger(None);
+ }
+ logging::clean_up_logging(self.context);
+ }
+}
+
+
+struct WgNtAdapter {
+ dll_handle: Arc<WgNtDll>,
+ handle: RawHandle,
+}
+
+impl fmt::Debug for WgNtAdapter {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("WgNtAdapter")
+ .field("handle", &self.handle)
+ .finish()
+ }
+}
+
+unsafe impl Send for WgNtAdapter {}
+unsafe impl Sync for WgNtAdapter {}
+
+impl WgNtAdapter {
+ fn open(dll_handle: Arc<WgNtDll>, pool: &U16CStr, name: &U16CStr) -> io::Result<Self> {
+ let handle = dll_handle.open_adapter(pool, name)?;
+ Ok(Self { dll_handle, handle })
+ }
+
+ fn create(
+ dll_handle: Arc<WgNtDll>,
+ pool: &U16CStr,
+ name: &U16CStr,
+ requested_guid: Option<GUID>,
+ ) -> io::Result<(Self, RebootRequired)> {
+ let (handle, restart_required) = dll_handle.create_adapter(pool, name, requested_guid)?;
+ Ok((Self { dll_handle, handle }, restart_required))
+ }
+
+ fn delete(self) -> io::Result<RebootRequired> {
+ unsafe { self.dll_handle.delete_adapter(self.handle) }
+ }
+
+ fn name(&self) -> io::Result<U16CString> {
+ unsafe { self.dll_handle.get_adapter_name(self.handle) }
+ }
+
+ fn luid(&self) -> NET_LUID {
+ unsafe { self.dll_handle.get_adapter_luid(self.handle) }
+ }
+
+ fn set_config(&self, config: &Config) -> Result<()> {
+ let config_buffer = serialize_config(config)?;
+ unsafe {
+ self.dll_handle
+ .set_config(self.handle, config_buffer.as_ptr(), config_buffer.len())
+ .map_err(Error::SetWireGuardConfigError)
+ }
+ }
+
+ fn get_config(&self) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> {
+ unsafe {
+ deserialize_config(
+ &self
+ .dll_handle
+ .get_config(self.handle)
+ .map_err(Error::GetWireGuardConfigError)?,
+ )
+ }
+ }
+
+ fn set_state(&self, state: WgAdapterState) -> io::Result<()> {
+ unsafe { self.dll_handle.set_adapter_state(self.handle, state) }
+ }
+
+ fn set_logging(&self, state: WireGuardAdapterLogState) -> io::Result<()> {
+ unsafe { self.dll_handle.set_adapter_logging(self.handle, state) }
+ }
+}
+
+impl Drop for WgNtAdapter {
+ fn drop(&mut self) {
+ unsafe { self.dll_handle.free_adapter(self.handle) };
+ }
+}
+
+struct WgNtDll {
+ handle: HINSTANCE,
+ func_open: WireGuardOpenAdapterFn,
+ func_create: WireGuardCreateAdapterFn,
+ func_delete: WireGuardDeleteAdapterFn,
+ func_free: WireGuardFreeAdapterFn,
+ func_get_adapter_luid: WireGuardGetAdapterLuidFn,
+ func_get_adapter_name: WireGuardGetAdapterNameFn,
+ func_set_configuration: WireGuardSetConfigurationFn,
+ func_get_configuration: WireGuardGetConfigurationFn,
+ func_set_adapter_state: WireGuardSetStateFn,
+ func_set_logger: WireGuardSetLoggerFn,
+ func_set_adapter_logging: WireGuardSetAdapterLoggingFn,
+}
+
+unsafe impl Send for WgNtDll {}
+unsafe impl Sync for WgNtDll {}
+
+impl WgNtDll {
+ pub fn new(resource_dir: &Path) -> io::Result<Self> {
+ let wg_nt_dll: Vec<u16> = resource_dir
+ .join("wireguard.dll")
+ .as_os_str()
+ .encode_wide()
+ .chain(iter::once(0u16))
+ .collect();
+
+ let handle = unsafe {
+ LoadLibraryExW(
+ wg_nt_dll.as_ptr(),
+ ptr::null_mut(),
+ LOAD_WITH_ALTERED_SEARCH_PATH,
+ )
+ };
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Self::new_inner(handle, Self::get_proc_address)
+ }
+
+ fn new_inner(
+ handle: HMODULE,
+ get_proc_fn: unsafe fn(HMODULE, &CStr) -> io::Result<FARPROC>,
+ ) -> io::Result<Self> {
+ Ok(WgNtDll {
+ handle,
+ func_open: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardOpenAdapter\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_create: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardCreateAdapter\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_delete: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardDeleteAdapter\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_free: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardFreeAdapter\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_get_adapter_luid: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardGetAdapterLUID\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_get_adapter_name: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardGetAdapterName\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_set_configuration: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardSetConfiguration\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_get_configuration: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardGetConfiguration\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_set_adapter_state: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardSetAdapterState\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_set_logger: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardSetLogger\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ func_set_adapter_logging: unsafe {
+ *((&get_proc_fn(
+ handle,
+ CStr::from_bytes_with_nul(b"WireGuardSetAdapterLogging\0").unwrap(),
+ )?) as *const _ as *const _)
+ },
+ })
+ }
+
+ unsafe fn get_proc_address(handle: HMODULE, name: &CStr) -> io::Result<FARPROC> {
+ let handle = GetProcAddress(handle, name.as_ptr());
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(handle)
+ }
+
+ pub fn open_adapter(&self, pool: &U16CStr, name: &U16CStr) -> io::Result<RawHandle> {
+ let handle = unsafe { (self.func_open)(pool.as_ptr(), name.as_ptr()) };
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(handle)
+ }
+
+ pub fn create_adapter(
+ &self,
+ pool: &U16CStr,
+ name: &U16CStr,
+ requested_guid: Option<GUID>,
+ ) -> io::Result<(RawHandle, RebootRequired)> {
+ let guid_ptr = match requested_guid.as_ref() {
+ Some(guid) => guid as *const _,
+ None => ptr::null_mut(),
+ };
+ let mut reboot_required = 0;
+ let handle = unsafe {
+ (self.func_create)(pool.as_ptr(), name.as_ptr(), guid_ptr, &mut reboot_required)
+ };
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok((handle, reboot_required != 0))
+ }
+
+ pub unsafe fn delete_adapter(&self, adapter: RawHandle) -> io::Result<RebootRequired> {
+ let mut reboot_required = 0;
+ let result = (self.func_delete)(adapter, &mut reboot_required);
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(reboot_required != 0)
+ }
+
+ pub unsafe fn free_adapter(&self, adapter: RawHandle) {
+ (self.func_free)(adapter);
+ }
+
+ pub unsafe fn get_adapter_name(&self, adapter: RawHandle) -> io::Result<U16CString> {
+ let mut alias_buffer = vec![0u16; MAX_ADAPTER_NAME];
+ let result = (self.func_get_adapter_name)(adapter, alias_buffer.as_mut_ptr());
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(U16CString::from_vec_with_nul(alias_buffer)
+ .map_err(|_| io::Error::new(io::ErrorKind::Other, "missing null terminator"))?)
+ }
+
+ pub unsafe fn get_adapter_luid(&self, adapter: RawHandle) -> NET_LUID {
+ let mut luid = mem::MaybeUninit::<NET_LUID>::zeroed();
+ (self.func_get_adapter_luid)(adapter, luid.as_mut_ptr());
+ luid.assume_init()
+ }
+
+ pub unsafe fn set_config(
+ &self,
+ adapter: RawHandle,
+ config: *const u8,
+ config_size: usize,
+ ) -> io::Result<()> {
+ let result = (self.func_set_configuration)(adapter, config, config_size as u32);
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+
+ pub unsafe fn get_config(&self, adapter: RawHandle) -> io::Result<Vec<u8>> {
+ let mut config_size = 0;
+ let mut config = vec![];
+ loop {
+ let result =
+ (self.func_get_configuration)(adapter, config.as_mut_ptr(), &mut config_size);
+ if result == 0 {
+ let last_error = io::Error::last_os_error();
+ if last_error.raw_os_error() != Some(ERROR_MORE_DATA as i32) {
+ break Err(last_error);
+ }
+ config.resize(config_size as usize, 0);
+ } else {
+ break Ok(config);
+ }
+ }
+ }
+
+ pub unsafe fn set_adapter_state(
+ &self,
+ adapter: RawHandle,
+ state: WgAdapterState,
+ ) -> io::Result<()> {
+ let result = (self.func_set_adapter_state)(adapter, state);
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+
+ pub fn set_logger(&self, cb: Option<WireGuardLoggerCb>) {
+ (self.func_set_logger)(cb);
+ }
+
+ pub unsafe fn set_adapter_logging(
+ &self,
+ adapter: RawHandle,
+ state: WireGuardAdapterLogState,
+ ) -> io::Result<()> {
+ if (self.func_set_adapter_logging)(adapter, state) == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+ }
+}
+
+impl Drop for WgNtDll {
+ fn drop(&mut self) {
+ unsafe { FreeLibrary(self.handle) };
+ }
+}
+
+fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> {
+ let mut dll = (*WG_NT_DLL).lock().expect("WireGuardNT mutex poisoned");
+ match &*dll {
+ Some(dll) => Ok(dll.clone()),
+ None => {
+ let new_dll = Arc::new(WgNtDll::new(resource_dir).map_err(Error::DllError)?);
+ *dll = Some(new_dll.clone());
+ Ok(new_dll)
+ }
+ }
+}
+
+fn serialize_config(config: &Config) -> Result<Vec<u8>> {
+ let mut buffer = vec![];
+
+ let header = WgInterface {
+ flags: WgInterfaceFlag::HAS_PRIVATE_KEY | WgInterfaceFlag::REPLACE_PEERS,
+ listen_port: 0,
+ private_key: config.tunnel.private_key.to_bytes(),
+ public_key: [0u8; WIREGUARD_KEY_LENGTH],
+ peers_count: config.peers.len() as u32,
+ };
+
+ buffer.extend_from_slice(unsafe { as_u8_slice(&header) });
+
+ for peer in &config.peers {
+ let wg_peer = WgPeer {
+ flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT,
+ reserved: 0,
+ public_key: peer.public_key.as_bytes().clone(),
+ preshared_key: [0u8; WIREGUARD_KEY_LENGTH],
+ persistent_keepalive: 0,
+ endpoint: windows::inet_sockaddr_from_socketaddr(peer.endpoint).into(),
+ tx_bytes: 0,
+ rx_bytes: 0,
+ last_handshake: 0,
+ allowed_ips_count: peer.allowed_ips.len() as u32,
+ };
+
+ buffer.extend_from_slice(unsafe { as_u8_slice(&wg_peer) });
+
+ for allowed_ip in &peer.allowed_ips {
+ let address_family = match allowed_ip {
+ IpNetwork::V4(_) => AF_INET as u16,
+ IpNetwork::V6(_) => AF_INET6 as u16,
+ };
+ let address = match allowed_ip {
+ IpNetwork::V4(v4_network) => WgIpAddr {
+ v4: windows::inaddr_from_ipaddr(v4_network.ip()),
+ },
+ IpNetwork::V6(v6_network) => WgIpAddr {
+ v6: windows::in6addr_from_ipaddr(v6_network.ip()),
+ },
+ };
+
+ let wg_allowed_ip =
+ WgAllowedIp::new(address, address_family, allowed_ip.prefix() as u8)?;
+
+ buffer.extend_from_slice(unsafe { as_u8_slice(&wg_allowed_ip) });
+ }
+ }
+
+ Ok(buffer)
+}
+
+unsafe fn deserialize_config(
+ config: &[u8],
+) -> Result<(WgInterface, Vec<(WgPeer, Vec<WgAllowedIp>)>)> {
+ if config.len() < mem::size_of::<WgInterface>() {
+ return Err(Error::InvalidConfigData);
+ }
+ let (head, mut tail) = config.split_at(mem::size_of::<WgInterface>());
+ let interface: WgInterface = *(head.as_ptr() as *const WgInterface);
+
+ let mut peers = vec![];
+ for _ in 0..interface.peers_count {
+ if tail.len() < mem::size_of::<WgPeer>() {
+ return Err(Error::InvalidConfigData);
+ }
+ let (peer_data, new_tail) = tail.split_at(mem::size_of::<WgPeer>());
+ let peer: WgPeer = *(peer_data.as_ptr() as *const WgPeer);
+ tail = new_tail;
+
+ if let Err(error) = windows::try_socketaddr_from_inet_sockaddr(peer.endpoint.addr) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Received invalid endpoint address")
+ );
+ return Err(Error::InvalidConfigData);
+ }
+
+ let mut allowed_ips = vec![];
+
+ for _ in 0..peer.allowed_ips_count {
+ if tail.len() < mem::size_of::<WgAllowedIp>() {
+ return Err(Error::InvalidConfigData);
+ }
+ let (allowed_ip_data, new_tail) = tail.split_at(mem::size_of::<WgAllowedIp>());
+ let allowed_ip: WgAllowedIp = *(allowed_ip_data.as_ptr() as *const WgAllowedIp);
+ if let Err(error) = WgAllowedIp::validate(
+ &allowed_ip.address,
+ allowed_ip.address_family,
+ allowed_ip.cidr,
+ ) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Received invalid allowed IP")
+ );
+ return Err(Error::InvalidConfigData);
+ }
+ tail = new_tail;
+ allowed_ips.push(allowed_ip);
+ }
+
+ peers.push((peer, allowed_ips));
+ }
+
+ if tail.len() > 0 {
+ return Err(Error::InvalidConfigData);
+ }
+
+ Ok((interface, peers))
+}
+
+fn prepare_interface(luid: &NET_LUID, family: u16, mtu: u32) -> io::Result<()> {
+ let family = windows::AddressFamily::try_from_af_family(family)
+ .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?;
+ let mut iface = windows::get_ip_interface_entry(family, luid)?;
+ iface.SitePrefixLength = 0;
+ iface.NlMtu = mtu;
+ iface.RouterDiscoveryBehavior = RouterDiscoveryDisabled;
+ iface.DadTransmits = 0;
+ iface.ManagedAddressConfigurationSupported = FALSE;
+ iface.OtherStatefulConfigurationSupported = FALSE;
+ windows::set_ip_interface_entry(&iface)
+}
+
+impl Tunnel for WgNtTunnel {
+ fn get_interface_name(&self) -> String {
+ self.interface_name.clone()
+ }
+
+ fn get_interface_luid(&self) -> u64 {
+ self.interface_luid.Value
+ }
+
+ fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> {
+ if let Some(ref device) = self.device {
+ let mut map = StatsMap::new();
+ let (_interface, peers) = device.get_config().map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain wg-nt tunnel config")
+ );
+ super::TunnelError::StatsError(super::stats::Error::NoTunnelConfig)
+ })?;
+ for (peer, _allowed_ips) in &peers {
+ map.insert(
+ peer.public_key,
+ Stats {
+ tx_bytes: peer.tx_bytes,
+ rx_bytes: peer.rx_bytes,
+ },
+ );
+ }
+ Ok(map)
+ } else {
+ Err(super::TunnelError::StatsError(
+ super::stats::Error::NoTunnelDevice,
+ ))
+ }
+ }
+
+ fn stop(mut self: Box<Self>) -> std::result::Result<(), super::TunnelError> {
+ if let Err(error) = self.stop_tunnel() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to stop WireGuardNT tunnel")
+ );
+ Err(super::TunnelError::StopWireguardError { status: 0 })
+ } else {
+ Ok(())
+ }
+ }
+}
+
+unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] {
+ std::slice::from_raw_parts(object as *const _ as *const _, mem::size_of::<T>())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use lazy_static::lazy_static;
+ use talpid_types::net::{wireguard, TransportProtocol};
+
+ #[derive(Debug, Eq, PartialEq, Clone, Copy)]
+ #[repr(C)]
+ struct Interface {
+ interface: WgInterface,
+ p0: WgPeer,
+ p0_allowed_ip_0: WgAllowedIp,
+ }
+
+ lazy_static! {
+ static ref WG_PRIVATE_KEY: wireguard::PrivateKey = wireguard::PrivateKey::new_from_random();
+ static ref WG_PUBLIC_KEY: wireguard::PublicKey =
+ wireguard::PrivateKey::new_from_random().public_key();
+ static ref WG_CONFIG: Config = {
+ Config {
+ tunnel: wireguard::TunnelConfig {
+ private_key: WG_PRIVATE_KEY.clone(),
+ addresses: vec![],
+ },
+ peers: vec![wireguard::PeerConfig {
+ public_key: WG_PUBLIC_KEY.clone(),
+ allowed_ips: vec!["1.3.3.0/24".parse().unwrap()],
+ endpoint: "1.2.3.4:1234".parse().unwrap(),
+ protocol: TransportProtocol::Udp,
+ }],
+ ipv4_gateway: "0.0.0.0".parse().unwrap(),
+ ipv6_gateway: None,
+ mtu: 0,
+ use_wireguard_nt: true,
+ }
+ };
+ static ref WG_STRUCT_CONFIG: Interface = Interface {
+ interface: WgInterface {
+ flags: WgInterfaceFlag::HAS_PRIVATE_KEY | WgInterfaceFlag::REPLACE_PEERS,
+ listen_port: 0,
+ private_key: WG_PRIVATE_KEY.to_bytes(),
+ public_key: [0; WIREGUARD_KEY_LENGTH],
+ peers_count: 1,
+ },
+ p0: WgPeer {
+ flags: WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT,
+ reserved: 0,
+ public_key: WG_PUBLIC_KEY.as_bytes().clone(),
+ preshared_key: [0; WIREGUARD_KEY_LENGTH],
+ persistent_keepalive: 0,
+ endpoint: windows::inet_sockaddr_from_socketaddr("1.2.3.4:1234".parse().unwrap())
+ .into(),
+ tx_bytes: 0,
+ rx_bytes: 0,
+ last_handshake: 0,
+ allowed_ips_count: 1,
+ },
+ p0_allowed_ip_0: WgAllowedIp {
+ address: WgIpAddr {
+ v4: windows::inaddr_from_ipaddr("1.3.3.0".parse().unwrap()),
+ },
+ address_family: AF_INET as u16,
+ cidr: 24,
+ },
+ };
+ }
+
+ fn get_proc_fn(_handle: HMODULE, _symbol: &CStr) -> io::Result<FARPROC> {
+ Ok(std::ptr::null_mut())
+ }
+
+ #[test]
+ fn test_dll_imports() {
+ WgNtDll::new_inner(ptr::null_mut(), get_proc_fn).unwrap();
+ }
+
+ #[test]
+ fn test_config_serialization() {
+ let serialized_data = serialize_config(&*WG_CONFIG).unwrap();
+ assert_eq!(mem::size_of::<Interface>(), serialized_data.len());
+ let serialized_iface = &unsafe { *(serialized_data.as_ptr() as *const Interface) };
+ assert_eq!(&*WG_STRUCT_CONFIG, serialized_iface);
+ }
+
+ #[test]
+ fn test_config_deserialization() {
+ let (iface, peers) =
+ unsafe { deserialize_config(as_u8_slice(&*WG_STRUCT_CONFIG)) }.unwrap();
+ assert_eq!(iface, WG_STRUCT_CONFIG.interface);
+ assert_eq!(peers.len(), 1);
+ let (peer, allowed_ips) = &peers[0];
+ assert_eq!(peer, &WG_STRUCT_CONFIG.p0);
+ assert_eq!(allowed_ips.len(), 1);
+ assert_eq!(allowed_ips[0], WG_STRUCT_CONFIG.p0_allowed_ip_0);
+ }
+
+ #[test]
+ fn test_wg_allowed_ip_v4() {
+ // Valid: /32 prefix
+ let address_family = AF_INET as u16;
+ let address = WgIpAddr {
+ v4: windows::inaddr_from_ipaddr("127.0.0.1".parse().unwrap()),
+ };
+ let cidr = 32;
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid host bits
+ let cidr = 24;
+ let address = WgIpAddr {
+ v4: windows::inaddr_from_ipaddr("0.0.0.1".parse().unwrap()),
+ };
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+
+ // Valid host bits
+ let cidr = 24;
+ let address = WgIpAddr {
+ v4: windows::inaddr_from_ipaddr("255.255.255.0".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // 0.0.0.0/0
+ let cidr = 0;
+ let address = WgIpAddr {
+ v4: windows::inaddr_from_ipaddr("0.0.0.0".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid CIDR
+ let cidr = 33;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+ }
+
+ #[test]
+ fn test_wg_allowed_ip_v6() {
+ // Valid: /128 prefix
+ let address_family = AF_INET6 as u16;
+ let address = WgIpAddr {
+ v6: windows::in6addr_from_ipaddr("::1".parse().unwrap()),
+ };
+ let cidr = 128;
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid host bits
+ let cidr = 127;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+
+ // Valid host bits
+ let address = WgIpAddr {
+ v6: windows::in6addr_from_ipaddr(
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe".parse().unwrap(),
+ ),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // ::/0
+ let cidr = 0;
+ let address = WgIpAddr {
+ v6: windows::in6addr_from_ipaddr("::".parse().unwrap()),
+ };
+ WgAllowedIp::new(address, address_family, cidr).unwrap();
+
+ // Invalid CIDR
+ let cidr = 129;
+ assert!(WgAllowedIp::new(address, address_family, cidr).is_err());
+ }
+}
diff --git a/talpid-core/src/windows.rs b/talpid-core/src/windows.rs
index 8151165680..03cd8a9c74 100644
--- a/talpid-core/src/windows.rs
+++ b/talpid-core/src/windows.rs
@@ -1,12 +1,16 @@
use std::{
ffi::OsStr,
fmt, io, mem,
+ net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::windows::{ffi::OsStrExt, io::RawHandle},
+ ptr,
sync::Mutex,
time::{Duration, Instant},
};
use winapi::shared::{
ifdef::NET_LUID,
+ in6addr::IN6_ADDR,
+ inaddr::IN_ADDR,
netioapi::{
CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, FreeMibTable, GetIpInterfaceEntry,
GetUnicastIpAddressEntry, GetUnicastIpAddressTable, MibAddInstance,
@@ -17,6 +21,7 @@ use winapi::shared::{
ntdef::FALSE,
winerror::{ERROR_NOT_FOUND, NO_ERROR},
ws2def::{AF_INET, AF_INET6, AF_UNSPEC},
+ ws2ipdef::SOCKADDR_INET,
};
/// Result type for this module.
@@ -58,6 +63,10 @@ pub enum Error {
#[cfg(windows)]
#[error(display = "Unicast channel sender was unexpectedly dropped")]
UnicastSenderDropped,
+
+ /// Unknown address family
+ #[error(display = "Unknown address family: {}", _0)]
+ UnknownAddressFamily(i32),
}
/// Address family. These correspond to the `AF_*` constants.
@@ -78,6 +87,17 @@ impl fmt::Display for AddressFamily {
}
}
+impl AddressFamily {
+ /// Convert an [`AddressFamily`] to one of the `AF_*` constants.
+ pub fn try_from_af_family(family: u16) -> Result<AddressFamily> {
+ match family as i32 {
+ AF_INET => Ok(AddressFamily::Ipv4),
+ AF_INET6 => Ok(AddressFamily::Ipv6),
+ family => Err(Error::UnknownAddressFamily(family)),
+ }
+ }
+}
+
/// Context for [`notify_ip_interface_change`]. When it is dropped,
/// the callback is unregistered.
pub struct IpNotifierHandle<'a> {
@@ -341,3 +361,120 @@ fn af_family_from_family(family: Option<AddressFamily>) -> u16 {
.map(|family| family as u16)
.unwrap_or(AF_UNSPEC as u16)
}
+
+/// Converts an `Ipv4Addr` to `IN_ADDR`
+pub fn inaddr_from_ipaddr(addr: Ipv4Addr) -> IN_ADDR {
+ let mut in_addr: IN_ADDR = unsafe { mem::zeroed() };
+ let addr_octets = addr.octets();
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &addr_octets as *const _,
+ in_addr.S_un.S_addr_mut() as *mut _ as *mut u8,
+ addr_octets.len(),
+ );
+ }
+ in_addr
+}
+
+/// Converts an `Ipv6Addr` to `IN6_ADDR`
+pub fn in6addr_from_ipaddr(addr: Ipv6Addr) -> IN6_ADDR {
+ let mut in_addr: IN6_ADDR = unsafe { mem::zeroed() };
+ let addr_octets = addr.octets();
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &addr_octets as *const _,
+ in_addr.u.Byte_mut() as *mut _,
+ addr_octets.len(),
+ );
+ }
+ in_addr
+}
+
+/// Converts an `IN_ADDR` to `Ipv4Addr`
+pub fn ipaddr_from_inaddr(addr: IN_ADDR) -> Ipv4Addr {
+ Ipv4Addr::from(unsafe { *(addr.S_un.S_addr()) }.to_be())
+}
+
+/// Converts an `IN6_ADDR` to `Ipv6Addr`
+pub fn ipaddr_from_in6addr(addr: IN6_ADDR) -> Ipv6Addr {
+ Ipv6Addr::from(*unsafe { addr.u.Byte() })
+}
+
+/// Converts a `SocketAddr` to `SOCKADDR_INET`
+pub fn inet_sockaddr_from_socketaddr(addr: SocketAddr) -> SOCKADDR_INET {
+ let mut sockaddr: SOCKADDR_INET = unsafe { mem::zeroed() };
+
+ match addr {
+ SocketAddr::V4(v4_addr) => {
+ unsafe {
+ *sockaddr.si_family_mut() = AF_INET as u16;
+ }
+
+ let mut v4sockaddr = unsafe { sockaddr.Ipv4_mut() };
+ v4sockaddr.sin_family = AF_INET as u16;
+ v4sockaddr.sin_port = v4_addr.port().to_be();
+ v4sockaddr.sin_addr = inaddr_from_ipaddr(*v4_addr.ip());
+ }
+ SocketAddr::V6(v6_addr) => {
+ unsafe {
+ *sockaddr.si_family_mut() = AF_INET6 as u16;
+ }
+
+ let mut v6sockaddr = unsafe { sockaddr.Ipv6_mut() };
+ v6sockaddr.sin6_family = AF_INET6 as u16;
+ v6sockaddr.sin6_port = v6_addr.port().to_be();
+ v6sockaddr.sin6_addr = in6addr_from_ipaddr(*v6_addr.ip());
+ v6sockaddr.sin6_flowinfo = v6_addr.flowinfo();
+ *unsafe { v6sockaddr.u.sin6_scope_id_mut() } = v6_addr.scope_id();
+ }
+ }
+
+ sockaddr
+}
+
+/// Converts a `SOCKADDR_INET` to `SocketAddr`. Returns an error if the address family is invalid.
+pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAddr> {
+ unsafe {
+ match *addr.si_family() as i32 {
+ AF_INET => Ok(SocketAddr::V4(SocketAddrV4::new(
+ ipaddr_from_inaddr(addr.Ipv4().sin_addr),
+ u16::from_be(addr.Ipv4().sin_port),
+ ))),
+ AF_INET6 => Ok(SocketAddr::V6(SocketAddrV6::new(
+ ipaddr_from_in6addr(addr.Ipv6().sin6_addr),
+ u16::from_be(addr.Ipv6().sin6_port),
+ addr.Ipv6().sin6_flowinfo,
+ *addr.Ipv6().u.sin6_scope_id(),
+ ))),
+ family => Err(Error::UnknownAddressFamily(family)),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_sockaddr_v4() {
+ let addr_v4 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 1234));
+ assert_eq!(
+ addr_v4,
+ try_socketaddr_from_inet_sockaddr(inet_sockaddr_from_socketaddr(addr_v4)).unwrap()
+ );
+ }
+
+ #[test]
+ fn test_sockaddr_v6() {
+ let addr_v6 = SocketAddr::V6(SocketAddrV6::new(
+ Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
+ 1234,
+ 0xa,
+ 0xb,
+ ));
+ assert_eq!(
+ addr_v6,
+ try_socketaddr_from_inet_sockaddr(inet_sockaddr_from_socketaddr(addr_v6)).unwrap()
+ );
+ }
+}
diff --git a/talpid-types/src/net/wireguard.rs b/talpid-types/src/net/wireguard.rs
index c5bde3547e..beced8357d 100644
--- a/talpid-types/src/net/wireguard.rs
+++ b/talpid-types/src/net/wireguard.rs
@@ -86,6 +86,10 @@ pub struct TunnelOptions {
jnix(map = "|maybe_mtu| maybe_mtu.map(|mtu| mtu as i32)")
)]
pub mtu: Option<u16>,
+ /// Temporary switch for wireguard-nt
+ #[cfg(windows)]
+ #[serde(default)]
+ pub use_wireguard_nt: bool,
}
/// Wireguard x25519 private key
diff --git a/windows/driverlogic/driverlogic.vcxproj b/windows/driverlogic/driverlogic.vcxproj
index b91e97d86c..cc46c1ac72 100644
--- a/windows/driverlogic/driverlogic.vcxproj
+++ b/windows/driverlogic/driverlogic.vcxproj
@@ -117,6 +117,7 @@
<ClInclude Include="src\util.h" />
<ClInclude Include="src\version.h" />
<ClInclude Include="src\wintun.h" />
+ <ClInclude Include="src\wireguard.h" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
diff --git a/windows/driverlogic/driverlogic.vcxproj.filters b/windows/driverlogic/driverlogic.vcxproj.filters
index 91ed2267da..9665231376 100644
--- a/windows/driverlogic/driverlogic.vcxproj.filters
+++ b/windows/driverlogic/driverlogic.vcxproj.filters
@@ -28,5 +28,6 @@
<ClInclude Include="src\util.h" />
<ClInclude Include="src\wintun.h" />
<ClInclude Include="src\devenum.h" />
+ <ClInclude Include="src\wireguard.h" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/windows/driverlogic/src/driverlogic.cpp b/windows/driverlogic/src/driverlogic.cpp
index 93af20b400..3cb1739e21 100644
--- a/windows/driverlogic/src/driverlogic.cpp
+++ b/windows/driverlogic/src/driverlogic.cpp
@@ -5,6 +5,7 @@
#include "log.h"
#include "version.h"
#include "wintun.h"
+#include "wireguard.h"
#include "devenum.h"
#include <string>
#include <libcommon/error.h>
@@ -278,6 +279,32 @@ ReturnCode CommandWintunDeleteAbandonedDevice(const std::vector<std::wstring> &a
return GENERAL_SUCCESS;
}
+ReturnCode CommandWireGuardNtCleanup(const std::vector<std::wstring> &args)
+{
+ ArgumentContext argsContext(args);
+
+ argsContext.ensureExactArgumentCount(1);
+
+ const auto poolName = argsContext.next();
+
+ WireGuardNtDll wgNt;
+
+ BOOL rebootRequired;
+
+ if (FALSE == wgNt.deletePoolDriver(poolName.c_str(), &rebootRequired))
+ {
+ throw std::runtime_error("Failed to delete WireGuardNT pool");
+ }
+
+ std::wstringstream ss;
+
+ ss << L"Successfully deleted WireGuardNT pool. Reboot required: " << rebootRequired;
+
+ Log(ss.str());
+
+ return ReturnCode::GENERAL_SUCCESS;
+}
+
} // anonymous namespace
int wmain(int argc, const wchar_t *argv[])
@@ -325,7 +352,8 @@ int wmain(int argc, const wchar_t *argv[])
{ L"st-force-install", CommandSplitTunnelForceInstall },
{ L"st-remove", CommandSplitTunnelRemove },
{ L"wintun-delete-pool-driver", CommandWintunDeletePool },
- { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice }
+ { L"wintun-delete-abandoned-device", CommandWintunDeleteAbandonedDevice },
+ { L"wg-nt-cleanup", CommandWireGuardNtCleanup }
};
//
diff --git a/windows/driverlogic/src/wireguard.h b/windows/driverlogic/src/wireguard.h
new file mode 100644
index 0000000000..5892b248f1
--- /dev/null
+++ b/windows/driverlogic/src/wireguard.h
@@ -0,0 +1,58 @@
+#pragma once
+
+#include <wireguard-nt/wireguard.h>
+#include <libcommon/error.h>
+#include "util.h"
+
+class WireGuardNtDll
+{
+public:
+
+ WireGuardNtDll() : dllHandle(nullptr)
+ {
+ auto path = GetProcessModulePath().replace_filename(L"wireguard.dll");
+ dllHandle = LoadLibraryExW(path.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
+
+ if (nullptr == dllHandle)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "LoadLibraryExW");
+ }
+
+ try
+ {
+ deletePoolDriver = getProcAddressOrThrow<WIREGUARD_DELETE_POOL_DRIVER_FUNC*>("WireGuardDeletePoolDriver");
+ }
+ catch (...)
+ {
+ FreeLibrary(dllHandle);
+ throw;
+ }
+ }
+
+ ~WireGuardNtDll()
+ {
+ if (nullptr != dllHandle)
+ {
+ FreeLibrary(dllHandle);
+ }
+ }
+
+ WIREGUARD_DELETE_POOL_DRIVER_FUNC *deletePoolDriver;
+
+private:
+
+ template<typename T>
+ T getProcAddressOrThrow(const char *procName)
+ {
+ const T result = reinterpret_cast<T>(GetProcAddress(dllHandle, procName));
+
+ if (nullptr == result)
+ {
+ THROW_WINDOWS_ERROR(GetLastError(), "GetProcAddress");
+ }
+
+ return result;
+ }
+
+ HMODULE dllHandle;
+};
diff --git a/windows/libshared/src/libshared/network/interfaceutils.cpp b/windows/libshared/src/libshared/network/interfaceutils.cpp
index 28263f383b..fba4d71ba0 100644
--- a/windows/libshared/src/libshared/network/interfaceutils.cpp
+++ b/windows/libshared/src/libshared/network/interfaceutils.cpp
@@ -99,6 +99,7 @@ void InterfaceUtils::AddDeviceIpAddresses(NET_LUID device, const std::vector<SOC
row.InterfaceLuid = device;
row.Address = address;
+ row.DadState = IpDadStatePreferred;
const auto status = CreateUnicastIpAddressEntry(&row);