summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJonathan <jonathan@mullvad.net>2022-09-06 10:06:17 +0200
committerJonathan <jonathan@mullvad.net>2022-10-18 14:42:25 +0200
commitd3f7ed493ecc182e582697f7ac03b1b30a2c2f52 (patch)
treecf4b0e155279b2dad9f54468393e060584950ecb
parent7bb08d42243ec05f04bc7691d9d2bb3156f10fed (diff)
downloadmullvadvpn-d3f7ed493ecc182e582697f7ac03b1b30a2c2f52.tar.xz
mullvadvpn-d3f7ed493ecc182e582697f7ac03b1b30a2c2f52.zip
Port winnet from C++ to Rust
Remove all of the C++ code in the winnet module and write an almost equivalent route manager in rust.
-rw-r--r--Cargo.lock16
-rw-r--r--talpid-core/Cargo.toml17
-rw-r--r--talpid-core/src/lib.rs4
-rw-r--r--talpid-core/src/offline/mod.rs5
-rw-r--r--talpid-core/src/offline/windows.rs60
-rw-r--r--talpid-core/src/routing/mod.rs4
-rw-r--r--talpid-core/src/routing/windows.rs221
-rw-r--r--talpid-core/src/routing/windows/default_route_monitor.rs451
-rw-r--r--talpid-core/src/routing/windows/get_best_default_route.rs190
-rw-r--r--talpid-core/src/routing/windows/mod.rs303
-rw-r--r--talpid-core/src/routing/windows/route_manager.rs885
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs94
-rw-r--r--talpid-core/src/tunnel/mod.rs2
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs7
-rw-r--r--talpid-core/src/tunnel/wireguard/wireguard_go.rs49
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs9
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs12
-rw-r--r--talpid-core/src/windows/mod.rs11
-rw-r--r--talpid-core/src/winnet.rs416
19 files changed, 1990 insertions, 766 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 77a447ec8b..16933b2567 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3161,8 +3161,9 @@ dependencies = [
"tunnel-obfuscation",
"uuid",
"which",
- "widestring 0.5.1",
+ "widestring 1.0.2",
"winapi",
+ "windows",
"windows-service",
"windows-sys 0.42.0",
"winreg",
@@ -3973,6 +3974,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
+name = "windows"
+version = "0.36.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e53b97a83176b369b0eb2fd8158d4ae215357d02df9d40c1e1bf1879c5482c80"
+dependencies = [
+ "windows_aarch64_msvc 0.36.1",
+ "windows_i686_gnu 0.36.1",
+ "windows_i686_msvc 0.36.1",
+ "windows_x86_64_gnu 0.36.1",
+ "windows_x86_64_msvc 0.36.1",
+]
+
+[[package]]
name = "windows-service"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index d1167e16fc..97bdf2db4b 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -79,13 +79,28 @@ subslice = "0.2"
[target.'cfg(windows)'.dependencies]
-widestring = "0.5"
+widestring = "1.0"
winreg = { version = "0.7", features = ["transactions"] }
winapi = { version = "0.3.6", features = ["ws2def"] }
talpid-platform-metadata = { path = "../talpid-platform-metadata" }
memoffset = "0.6"
windows-service = "0.5.0"
+[target.'cfg(windows)'.dependencies.windows]
+version = "0.36.1"
+features = [
+ "Data_Xml_Dom",
+ "Win32_Foundation",
+ "Win32_Security",
+ "Win32_System_Threading",
+ "Win32_UI_WindowsAndMessaging",
+ "Win32_NetworkManagement",
+ "Win32_NetworkManagement_IpHelper",
+ "Win32_NetworkManagement_Ndis",
+ "Win32_Foundation",
+ "Win32_Networking_WinSock",
+]
+
[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.42.0"
features = [
diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs
index 73c3293fb4..46bb4c1169 100644
--- a/talpid-core/src/lib.rs
+++ b/talpid-core/src/lib.rs
@@ -9,10 +9,6 @@
#[macro_use]
mod ffi;
-/// Misc networking functions for Windows.
-#[cfg(windows)]
-mod winnet;
-
/// Windows API wrappers and utilities
#[cfg(target_os = "windows")]
pub mod windows;
diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs
index b07fb3d8c9..3c5448762b 100644
--- a/talpid-core/src/offline/mod.rs
+++ b/talpid-core/src/offline/mod.rs
@@ -1,4 +1,4 @@
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", target_os = "windows"))]
use crate::routing::RouteManagerHandle;
#[cfg(target_os = "windows")]
use crate::windows::window::PowerManagementListener;
@@ -46,6 +46,7 @@ pub async fn spawn_monitor(
sender: UnboundedSender<bool>,
#[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
#[cfg(target_os = "android")] android_context: AndroidContext,
+ #[cfg(target_os = "windows")] route_manager: RouteManagerHandle,
#[cfg(target_os = "windows")] power_mgmt_rx: PowerManagementListener,
) -> Result<MonitorHandle, Error> {
let monitor = if !*FORCE_DISABLE_OFFLINE_MONITOR {
@@ -57,6 +58,8 @@ pub async fn spawn_monitor(
#[cfg(target_os = "android")]
android_context,
#[cfg(target_os = "windows")]
+ route_manager,
+ #[cfg(target_os = "windows")]
power_mgmt_rx,
)
.await?,
diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs
index bbe9d951a9..9dc341d03a 100644
--- a/talpid-core/src/offline/windows.rs
+++ b/talpid-core/src/offline/windows.rs
@@ -1,11 +1,13 @@
use crate::{
- windows::window::{PowerManagementEvent, PowerManagementListener},
- winnet,
+ routing::{get_best_default_route, CallbackHandle, EventType, RouteManagerHandle},
+ windows::{
+ window::{PowerManagementEvent, PowerManagementListener},
+ AddressFamily,
+ },
};
use futures::channel::mpsc::UnboundedSender;
use parking_lot::Mutex;
use std::{
- ffi::c_void,
io,
sync::{Arc, Weak},
time::Duration,
@@ -17,20 +19,21 @@ pub enum Error {
#[error(display = "Unable to create listener thread")]
ThreadCreationError(#[error(source)] io::Error),
#[error(display = "Failed to start connectivity monitor")]
- ConnectivityMonitorError(#[error(source)] winnet::DefaultRouteCallbackError),
+ ConnectivityMonitorError(#[error(source)] crate::routing::Error),
}
pub struct BroadcastListener {
system_state: Arc<Mutex<SystemState>>,
- _callback_handle: winnet::WinNetCallbackHandle,
+ _callback_handle: CallbackHandle,
_notify_tx: Arc<UnboundedSender<bool>>,
}
unsafe impl Send for BroadcastListener {}
impl BroadcastListener {
- pub fn start(
+ pub async fn start(
notify_tx: UnboundedSender<bool>,
+ route_manager_handle: RouteManagerHandle,
mut power_mgmt_rx: PowerManagementListener,
) -> Result<Self, Error> {
let notify_tx = Arc::new(notify_tx);
@@ -66,7 +69,8 @@ impl BroadcastListener {
});
let callback_handle =
- unsafe { Self::setup_network_connectivity_listener(system_state.clone())? };
+ Self::setup_network_connectivity_listener(system_state.clone(), route_manager_handle)
+ .await?;
Ok(BroadcastListener {
system_state,
@@ -76,7 +80,7 @@ impl BroadcastListener {
}
fn check_initial_connectivity() -> (bool, bool) {
- let v4_connectivity = winnet::get_best_default_route(winnet::WinNetAddrFamily::IPV4)
+ let v4_connectivity = get_best_default_route(AddressFamily::Ipv4)
.map(|route| route.is_some())
.unwrap_or_else(|error| {
log::error!(
@@ -85,7 +89,7 @@ impl BroadcastListener {
);
true
});
- let v6_connectivity = winnet::get_best_default_route(winnet::WinNetAddrFamily::IPV6)
+ let v6_connectivity = get_best_default_route(AddressFamily::Ipv6)
.map(|route| route.is_some())
.unwrap_or_else(|error| {
log::error!(
@@ -103,34 +107,35 @@ impl BroadcastListener {
/// The caller must make sure the `system_state` reference is valid
/// until after `WinNet_DeactivateConnectivityMonitor` has been called.
- unsafe fn setup_network_connectivity_listener(
+ async fn setup_network_connectivity_listener(
system_state: Arc<Mutex<SystemState>>,
- ) -> Result<winnet::WinNetCallbackHandle, Error> {
- let change_handle = winnet::add_default_route_change_callback(
- Some(Self::connectivity_callback),
- system_state,
- )?;
+ route_manager_handle: RouteManagerHandle,
+ ) -> Result<CallbackHandle, Error> {
+ let change_handle = route_manager_handle
+ .add_default_route_change_callback(Box::new(move |event, addr_family| {
+ Self::connectivity_callback(event, addr_family, &system_state)
+ }))
+ .await
+ .map_err(|e| Error::ConnectivityMonitorError(e))?;
Ok(change_handle)
}
- unsafe extern "system" fn connectivity_callback(
- event_type: winnet::WinNetDefaultRouteChangeEventType,
- family: winnet::WinNetAddrFamily,
- _default_route: winnet::WinNetDefaultRoute,
- ctx: *mut c_void,
+ fn connectivity_callback<'a>(
+ event_type: EventType<'a>,
+ family: AddressFamily,
+ state_lock: &Arc<Mutex<SystemState>>,
) {
- use winnet::WinNetDefaultRouteChangeEventType::*;
+ use crate::routing::EventType::*;
- if event_type == DefaultRouteUpdatedDetails {
+ if matches!(event_type, UpdatedDetails(_)) {
// ignore changes that don't affect the route
return;
}
- let state_lock: &mut Arc<Mutex<SystemState>> = &mut *(ctx as *mut _);
- let connectivity = event_type != DefaultRouteRemoved;
+ let connectivity = event_type != Removed;
let change = match family {
- winnet::WinNetAddrFamily::IPV4 => StateChange::NetworkV4Connectivity(connectivity),
- winnet::WinNetAddrFamily::IPV6 => StateChange::NetworkV6Connectivity(connectivity),
+ AddressFamily::Ipv4 => StateChange::NetworkV4Connectivity(connectivity),
+ AddressFamily::Ipv6 => StateChange::NetworkV6Connectivity(connectivity),
};
let mut state = state_lock.lock();
state.apply_change(change);
@@ -202,9 +207,10 @@ pub type MonitorHandle = BroadcastListener;
pub async fn spawn_monitor(
sender: UnboundedSender<bool>,
+ route_manager_handle: RouteManagerHandle,
power_mgmt_rx: PowerManagementListener,
) -> Result<MonitorHandle, Error> {
- BroadcastListener::start(sender, power_mgmt_rx)
+ BroadcastListener::start(sender, route_manager_handle, power_mgmt_rx).await
}
fn apply_system_state_change(state: Arc<Mutex<SystemState>>, change: StateChange) {
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs
index 1eb02a206b..5d1247618e 100644
--- a/talpid-core/src/routing/mod.rs
+++ b/talpid-core/src/routing/mod.rs
@@ -5,8 +5,10 @@ use ipnetwork::IpNetwork;
use std::{fmt, net::IpAddr};
#[cfg(target_os = "windows")]
-#[path = "windows.rs"]
+#[path = "windows/mod.rs"]
mod imp;
+#[cfg(target_os = "windows")]
+pub use imp::{get_best_default_route, CallbackHandle, EventType, InterfaceAndGateway};
#[cfg(not(target_os = "windows"))]
#[path = "unix.rs"]
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
deleted file mode 100644
index fad8540ecd..0000000000
--- a/talpid-core/src/routing/windows.rs
+++ /dev/null
@@ -1,221 +0,0 @@
-use super::NetNode;
-use crate::{routing::RequiredRoute, winnet};
-use futures::{
- channel::{
- mpsc::{self, UnboundedReceiver, UnboundedSender},
- oneshot,
- },
- StreamExt,
-};
-use std::{collections::HashSet, net::IpAddr};
-use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
-use winnet::WinNetAddrFamily;
-
-/// Windows routing errors.
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- /// The sender was dropped unexpectedly -- possible panic
- #[error(display = "The channel sender was dropped")]
- ManagerChannelDown,
- /// Failure to initialize route manager
- #[error(display = "Failed to start route manager")]
- FailedToStartManager,
- /// Failure to add routes
- #[error(display = "Failed to add routes")]
- AddRoutesFailed(#[error(source)] winnet::Error),
- /// Failure to clear routes
- #[error(display = "Failed to clear applied routes")]
- ClearRoutesFailed,
- /// WinNet returned an error while adding default route callback
- #[error(display = "Failed to set callback for default route")]
- FailedToAddDefaultRouteCallback,
- /// Attempt to use route manager that has been dropped
- #[error(display = "Cannot send message to route manager since it is down")]
- RouteManagerDown,
- /// Something went wrong when getting the mtu of the interface
- #[error(display = "Could not get the mtu of the interface")]
- GetMtu,
-}
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-/// Manages routes by calling into WinNet
-pub struct RouteManager {
- manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
-}
-
-/// Handle to a route manager.
-#[derive(Clone)]
-pub struct RouteManagerHandle {
- tx: UnboundedSender<RouteManagerCommand>,
-}
-
-impl RouteManagerHandle {
- /// Applies the given routes while the route manager is running.
- pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
- let (response_tx, response_rx) = oneshot::channel();
- self.tx
- .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
- .map_err(|_| Error::RouteManagerDown)?;
- response_rx.await.map_err(|_| Error::ManagerChannelDown)?
- }
-
- /// Applies the given routes while the route manager is running.
- pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
- let (response_tx, response_rx) = oneshot::channel();
- self.tx
- .unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
- .map_err(|_| Error::RouteManagerDown)?;
- response_rx.await.map_err(|_| Error::ManagerChannelDown)?
- }
-}
-
-#[derive(Debug)]
-pub enum RouteManagerCommand {
- AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
- GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>),
- Shutdown,
-}
-
-impl RouteManager {
- /// Creates a new route manager that will apply the provided routes and ensure they exist until
- /// it's stopped.
- pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
- if !winnet::activate_routing_manager() {
- return Err(Error::FailedToStartManager);
- }
- let (manage_tx, manage_rx) = mpsc::unbounded();
- let manager = Self {
- manage_tx: Some(manage_tx),
- };
- tokio::spawn(RouteManager::listen(manage_rx));
- manager.add_routes(required_routes).await?;
-
- Ok(manager)
- }
-
- /// Retrieve a sender directly to the command channel.
- pub fn handle(&self) -> Result<RouteManagerHandle> {
- if let Some(tx) = &self.manage_tx {
- Ok(RouteManagerHandle { tx: tx.clone() })
- } else {
- Err(Error::RouteManagerDown)
- }
- }
-
- async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) {
- while let Some(command) = manage_rx.next().await {
- match command {
- RouteManagerCommand::AddRoutes(routes, tx) => {
- let routes: Vec<_> = routes
- .iter()
- .map(|route| {
- let destination = winnet::WinNetIpNetwork::from(route.prefix);
- match &route.node {
- NetNode::DefaultNode => {
- winnet::WinNetRoute::through_default_node(destination)
- }
- NetNode::RealNode(node) => winnet::WinNetRoute::new(
- winnet::WinNetNode::from(node),
- destination,
- ),
- }
- })
- .collect();
-
- let _ = tx.send(
- winnet::routing_manager_add_routes(&routes).map_err(Error::AddRoutesFailed),
- );
- }
- RouteManagerCommand::GetMtuForRoute(ip, tx) => {
- let addr_family = if ip.is_ipv4() {
- winnet::WinNetAddrFamily::IPV4
- } else {
- winnet::WinNetAddrFamily::IPV6
- };
- let res = match get_mtu_for_route(addr_family) {
- Ok(Some(mtu)) => Ok(mtu),
- Ok(None) => Err(Error::GetMtu),
- Err(e) => Err(e),
- };
- let _ = tx.send(res);
- }
- RouteManagerCommand::Shutdown => {
- break;
- }
- }
- }
- }
-
- /// Stops the routing manager and invalidates the route manager - no new default route callbacks
- /// can be added
- pub fn stop(&mut self) {
- if let Some(tx) = self.manage_tx.take() {
- if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() {
- log::error!("RouteManager channel already down or thread panicked");
- }
-
- winnet::deactivate_routing_manager();
- }
- }
-
- /// Applies the given routes until [`RouteManager::stop`] is called.
- pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
- if let Some(tx) = &self.manage_tx {
- let (result_tx, result_rx) = oneshot::channel();
- if tx
- .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
- .is_err()
- {
- return Err(Error::RouteManagerDown);
- }
- result_rx.await.map_err(|_| Error::ManagerChannelDown)?
- } else {
- Err(Error::RouteManagerDown)
- }
- }
-
- /// Removes all routes previously applied in [`RouteManager::new`] or
- /// [`RouteManager::add_routes`].
- pub fn clear_routes(&self) -> Result<()> {
- if winnet::routing_manager_delete_applied_routes() {
- Ok(())
- } else {
- Err(Error::ClearRoutesFailed)
- }
- }
-}
-
-fn get_mtu_for_route(addr_family: WinNetAddrFamily) -> Result<Option<u16>> {
- use crate::windows::AddressFamily;
- match winnet::get_best_default_route(addr_family) {
- Ok(Some(route)) => {
- let addr_family = match addr_family {
- WinNetAddrFamily::IPV4 => AddressFamily::Ipv4,
- WinNetAddrFamily::IPV6 => AddressFamily::Ipv6,
- };
- let luid = NET_LUID_LH {
- Value: route.interface_luid,
- };
- let interface_row = crate::windows::get_ip_interface_entry(addr_family, &luid)
- .map_err(|e| {
- log::error!("Could not get ip interface entry: {}", e);
- Error::GetMtu
- })?;
- let mtu = interface_row.NlMtu;
- let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?;
- Ok(Some(mtu))
- }
- Ok(None) => Ok(None),
- Err(e) => {
- log::error!("Could not get best default route: {}", e);
- Err(Error::GetMtu)
- }
- }
-}
-
-impl Drop for RouteManager {
- fn drop(&mut self) {
- self.stop();
- }
-}
diff --git a/talpid-core/src/routing/windows/default_route_monitor.rs b/talpid-core/src/routing/windows/default_route_monitor.rs
new file mode 100644
index 0000000000..3976903f11
--- /dev/null
+++ b/talpid-core/src/routing/windows/default_route_monitor.rs
@@ -0,0 +1,451 @@
+use super::{
+ get_best_default_route, get_best_default_route::route_has_gateway, AddressFamily, Error,
+ InterfaceAndGateway, Result,
+};
+
+use std::{
+ ffi::c_void,
+ io,
+ sync::{
+ mpsc::{channel, RecvTimeoutError, Sender},
+ Arc, Mutex,
+ },
+ time::{Duration, Instant},
+};
+use windows_sys::Win32::{
+ Foundation::{BOOLEAN, HANDLE, NO_ERROR},
+ NetworkManagement::{
+ IpHelper::{
+ CancelMibChangeNotify2, ConvertInterfaceLuidToIndex, NotifyIpInterfaceChange,
+ NotifyRouteChange2, NotifyUnicastIpAddressChange, MIB_IPFORWARD_ROW2,
+ MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE, MIB_UNICASTIPADDRESS_ROW,
+ },
+ Ndis::NET_LUID_LH,
+ },
+};
+
+const WIN_FALSE: BOOLEAN = 0;
+
+struct DefaultRouteMonitorContext {
+ callback: Box<dyn for<'a> Fn(EventType<'a>) + Send + 'static>,
+ refresh_current_route: bool,
+ family: AddressFamily,
+ best_route: Option<InterfaceAndGateway>,
+}
+
+impl DefaultRouteMonitorContext {
+ fn new(
+ callback: Box<dyn for<'a> Fn(EventType<'a>) + Send + 'static>,
+ family: AddressFamily,
+ ) -> Self {
+ Self {
+ callback,
+ best_route: None,
+ refresh_current_route: false,
+ family,
+ }
+ }
+
+ fn update_refresh_flag(&mut self, luid: &NET_LUID_LH, index: u32) {
+ if let Some(best_route) = &self.best_route {
+ // SAFETY: luid is a union but both fields are finally represented by u64, as such any
+ // access is valid
+ if unsafe { luid.Value } == unsafe { best_route.iface.Value } {
+ self.refresh_current_route = true;
+ return;
+ }
+ // SAFETY: luid is a union but both fields are finally represented by u64, as such any
+ // access is valid
+ if unsafe { luid.Value } != 0 {
+ return;
+ }
+
+ let mut default_interface_index = 0;
+ let route_luid = best_route.iface;
+ // SAFETY: No clear safety specifications
+ if NO_ERROR as i32
+ == unsafe { ConvertInterfaceLuidToIndex(&route_luid, &mut default_interface_index) }
+ {
+ self.refresh_current_route = index == default_interface_index;
+ } else {
+ self.refresh_current_route = true;
+ }
+ }
+ }
+
+ fn evaluate_routes(&mut self) {
+ let refresh_current = self.refresh_current_route;
+ self.refresh_current_route = false;
+
+ let current_best_route = get_best_default_route(self.family).ok().flatten();
+
+ match (&self.best_route, current_best_route) {
+ (None, None) => (),
+ (None, Some(current_best_route)) => {
+ self.best_route = Some(current_best_route);
+ (self.callback)(EventType::Updated(&self.best_route.as_ref().unwrap()));
+ }
+ (Some(_), None) => {
+ self.best_route = None;
+ (self.callback)(EventType::Removed);
+ }
+ (Some(best_route), Some(current_best_route)) => {
+ if best_route != &current_best_route {
+ self.best_route = Some(current_best_route);
+ (self.callback)(EventType::Updated(&self.best_route.as_ref().unwrap()));
+ } else if refresh_current {
+ (self.callback)(EventType::UpdatedDetails(
+ &self.best_route.as_ref().unwrap(),
+ ));
+ }
+ }
+ }
+ }
+}
+
+pub struct DefaultRouteMonitor {
+ // SAFETY: These handles must be dropped before the context. This will happen automatically if
+ // it is handled by DefaultRouteMonitors drop implementation
+ notify_change_handles: Option<(NotifyChangeHandle, NotifyChangeHandle, NotifyChangeHandle)>,
+ // SAFETY: Context must be dropped after all of the notifier handles have been dropped in order
+ // to guarantee none of them use its pointer. This will be dropped by DefaultRouteMonitors
+ // drop implementation. SAFETY: The content of this pointer is not allowed to be mutated at
+ // any point except for in the drop implementation
+ context: *const ContextAndBurstGuard,
+}
+
+/// SAFETY: DefaultRouteMonitor is `Send` since `NotifyChangeHandle` is `Send` and
+/// `ContextAndBurstGuard` is `Sync` as it holds Mutex<T> and Arc<Mutex<T>> fields.
+unsafe impl Send for DefaultRouteMonitor {}
+
+impl Drop for DefaultRouteMonitor {
+ fn drop(&mut self) {
+ drop(self.notify_change_handles.take());
+ // SAFETY: This pointer was created by Box::into_raw and is not modified since then.
+ // This drop function is also only called once
+ let context = unsafe { Box::from_raw(self.context as *mut ContextAndBurstGuard) };
+
+ // Stop the burst guard
+ context.burst_guard.lock().unwrap().stop();
+
+ // Drop the context now that we are guaranteed nothing might try to access the context
+ drop(context);
+ }
+}
+
+struct NotifyChangeHandle(HANDLE);
+
+/// SAFETY: NotifyChangeHandle is `Send` since it holds sole ownership of a pointer provided by C
+unsafe impl Send for NotifyChangeHandle {}
+
+impl Drop for NotifyChangeHandle {
+ fn drop(&mut self) {
+ // SAFETY: There is no clear safety specification on this function. However self.0 should
+ // point to a handle that has been allocated by windows and should be non-null. Even
+ // if it would be null that would cause a panic rather than UB.
+ unsafe {
+ if NO_ERROR as i32 != CancelMibChangeNotify2(self.0) {
+ // If this callback is called after we free the context that could result in UB, in
+ // order to avoid that we panic.
+ panic!("Could not cancel change notification callback")
+ }
+ }
+ }
+}
+
+#[derive(PartialEq, Clone, Copy)]
+/// The type of route update passed to the callback
+pub enum EventType<'a> {
+ /// New route
+ Updated(&'a InterfaceAndGateway),
+ /// Updated details of the same old route
+ UpdatedDetails(&'a InterfaceAndGateway),
+ /// Route removed
+ Removed,
+}
+
+// SAFETY: This struct must be `Sync` otherwise it is not allowed to be sent between threads.
+// Having only `Mutex<T>` or `Arc<Mutex<T>>` fields guarantees that it is `Sync`
+struct ContextAndBurstGuard {
+ context: Arc<Mutex<DefaultRouteMonitorContext>>,
+ burst_guard: Mutex<BurstGuard>,
+}
+
+impl DefaultRouteMonitor {
+ pub fn new<F: for<'a> Fn(EventType<'a>) + Send + 'static>(
+ family: AddressFamily,
+ callback: F,
+ ) -> Result<Self> {
+ let context = Arc::new(Mutex::new(DefaultRouteMonitorContext::new(
+ Box::new(callback),
+ family,
+ )));
+
+ let moved_context = context.clone();
+ let burst_guard = Mutex::new(BurstGuard::new(move || {
+ moved_context.lock().unwrap().evaluate_routes();
+ }));
+
+ // SAFETY: We need to send the ContextAndBurstGuard to the windows notification functions as
+ // a raw pointer. This imposes the requirement it is not mutated or dropped until
+ // after those notifications are guaranteed to not run. This happens when the
+ // DefaultRouteMonitor is dropped and not before then. It also imposes the requirement that
+ // ContextAndBurstGuard is `Sync` since we will send references to it to other
+ // threads. This requirement is fullfilled since all fields of `ContextAndBurstGuard` are
+ // wrapped in either a Arc<Mutex> or Mutex.
+ let context_and_burst = Box::into_raw(Box::new(ContextAndBurstGuard {
+ context,
+ burst_guard,
+ })) as *const _;
+
+ let handles = match Self::register_callbacks(family, context_and_burst) {
+ Ok(handles) => handles,
+ Err(e) => {
+ // Clean up the memory leak in case of error
+ // SAFETY: We created context_and_burst from `Box::into_raw()` and it has not been
+ // modified since. All of the handles have been freed at this point
+ // so there will be no risk of UAF.
+ drop(unsafe { Box::from_raw(context_and_burst as *mut ContextAndBurstGuard) });
+ return Err(e);
+ }
+ };
+
+ let monitor = Self {
+ context: context_and_burst,
+ notify_change_handles: Some(handles),
+ };
+
+ // We must set the best default route after we have registered listeners in order to avoid
+ // race conditions.
+ {
+ // SAFETY: `monitor.context` will be valid since monitor will handle dropping it. No
+ // mutation happens here since we are using a Mutex.
+ let context = &unsafe { &*(monitor.context) }.context;
+ let mut context = context.lock().unwrap();
+ context.best_route = get_best_default_route(context.family)?;
+ }
+
+ Ok(monitor)
+ }
+
+ fn register_callbacks(
+ family: AddressFamily,
+ context_and_burst: *const ContextAndBurstGuard,
+ ) -> Result<(NotifyChangeHandle, NotifyChangeHandle, NotifyChangeHandle)> {
+ let family = family.to_af_family();
+
+ // We must provide a raw pointer that points to the context that will be used in the
+ // callbacks. We provide a Mutex for the state turned into a Weak pointer turned
+ // into a raw pointer in order to not have to manually deallocate the memory after
+ // we cancel the callbacks. This will leak the weak pointer but the context state itself
+ // will be correctly dropped when DefaultRouteManager is dropped.
+ let context_ptr = context_and_burst;
+ let mut handle_ptr = 0;
+ // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle
+ // has not been dropped.
+ let status = unsafe {
+ NotifyRouteChange2(
+ family,
+ Some(route_change_callback),
+ context_ptr as *const _,
+ WIN_FALSE,
+ &mut handle_ptr,
+ )
+ };
+
+ if NO_ERROR as i32 != status {
+ return Err(Error::RegisterNotifyRouteCallback(
+ io::Error::from_raw_os_error(status),
+ ));
+ }
+ let notify_route_change_handle = NotifyChangeHandle(handle_ptr);
+
+ let mut handle_ptr = 0;
+ // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle
+ // has not been dropped.
+ let status = unsafe {
+ NotifyIpInterfaceChange(
+ family,
+ Some(interface_change_callback),
+ context_ptr as *const _,
+ WIN_FALSE,
+ &mut handle_ptr,
+ )
+ };
+ if NO_ERROR as i32 != status {
+ return Err(Error::RegisterNotifyIpInterfaceCallback(
+ io::Error::from_raw_os_error(status),
+ ));
+ }
+ let notify_interface_change_handle = NotifyChangeHandle(handle_ptr);
+
+ let mut handle_ptr = 0;
+ // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle
+ // has not been dropped.
+ let status = unsafe {
+ NotifyUnicastIpAddressChange(
+ family,
+ Some(ip_address_change_callback),
+ context_ptr as *const _,
+ WIN_FALSE,
+ &mut handle_ptr,
+ )
+ };
+ if NO_ERROR as i32 != status {
+ return Err(Error::RegisterNotifyUnicastIpAddressCallback(
+ io::Error::from_raw_os_error(status),
+ ));
+ }
+ let notify_address_change_handle = NotifyChangeHandle(handle_ptr);
+
+ Ok((
+ notify_route_change_handle,
+ notify_interface_change_handle,
+ notify_address_change_handle,
+ ))
+ }
+}
+
+// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference.
+// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed
+// to not be called again.
+unsafe extern "system" fn route_change_callback(
+ context: *const c_void,
+ row: *const MIB_IPFORWARD_ROW2,
+ _notification_type: MIB_NOTIFICATION_TYPE,
+) {
+ // SAFETY: We assume Windows provides this pointer correctly
+ let row = &*row;
+
+ if row.DestinationPrefix.PrefixLength != 0 || !route_has_gateway(row) {
+ return;
+ }
+
+ // SAFETY: context must not be dropped or modified until this callback has been cancelled.
+ let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard);
+ let mut context = context_and_burst.context.lock().unwrap();
+
+ context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex);
+ context_and_burst.burst_guard.lock().unwrap().trigger();
+}
+
+// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference.
+// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed
+// to not be called again.
+unsafe extern "system" fn interface_change_callback(
+ context: *const c_void,
+ row: *const MIB_IPINTERFACE_ROW,
+ _notification_type: MIB_NOTIFICATION_TYPE,
+) {
+ // SAFETY: We assume Windows provides this pointer correctly
+ let row = &*row;
+
+ // SAFETY: context must not be dropped or modified until this callback has been cancelled.
+ let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard);
+ let mut context = context_and_burst.context.lock().unwrap();
+
+ context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex);
+ context_and_burst.burst_guard.lock().unwrap().trigger();
+}
+
+// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference.
+// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed
+// to not be called again.
+unsafe extern "system" fn ip_address_change_callback(
+ context: *const c_void,
+ row: *const MIB_UNICASTIPADDRESS_ROW,
+ _notification_type: MIB_NOTIFICATION_TYPE,
+) {
+ // SAFETY: We assume Windows provides this pointer correctly
+ let row = &*row;
+
+ // SAFETY: context must not be dropped or modified until this callback has been cancelled.
+ let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard);
+ let mut context = context_and_burst.context.lock().unwrap();
+
+ context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex);
+ context_and_burst.burst_guard.lock().unwrap().trigger();
+}
+
+/// BurstGuard is a wrapper for a function that protects that function from being called too many
+/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point
+/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it
+/// will call the wrapped function. If another call to `trigger` is made during this wait then it
+/// will wait another `buffer_period`, this happens over and over until either
+/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in
+/// `buffer_period`. At which point the wrapped function will be called.
+struct BurstGuard {
+ sender: Sender<BurstGuardEvent>,
+}
+
+enum BurstGuardEvent {
+ Trigger,
+ Shutdown(Sender<()>),
+}
+
+impl BurstGuard {
+ fn new<F: Fn() + Send + 'static>(callback: F) -> Self {
+ /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent
+ /// before it calls the callback.
+ const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200);
+ /// This is the longest period that the `BurstGuard` will wait from the first trigger till
+ /// it calls the callback.
+ const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2);
+
+ let (sender, listener) = channel();
+ std::thread::spawn(move || {
+ // The `stop` implementation assumes that this thread will not call `callback` again
+ // if the listener has been dropped.
+ while let Ok(message) = listener.recv() {
+ match message {
+ BurstGuardEvent::Trigger => {
+ let start = Instant::now();
+ loop {
+ match listener.recv_timeout(BURST_BUFFER_PERIOD) {
+ Ok(BurstGuardEvent::Trigger) => {
+ if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD {
+ callback();
+ break;
+ }
+ }
+ Ok(BurstGuardEvent::Shutdown(tx)) => {
+ let _ = tx.send(());
+ return;
+ }
+ Err(RecvTimeoutError::Timeout) => {
+ callback();
+ break;
+ }
+ Err(RecvTimeoutError::Disconnected) => {
+ break;
+ }
+ }
+ }
+ }
+ BurstGuardEvent::Shutdown(tx) => {
+ let _ = tx.send(());
+ return;
+ }
+ }
+ }
+ });
+ Self { sender }
+ }
+
+ /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further
+ /// calls to `callback`.
+ fn stop(&self) {
+ let (sender, listener) = channel();
+ // If we could not send then it means the thread has already shut down and we can return
+ if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() {
+ // We do not care what the result is, if it is OK it means the thread shut down, if
+ // it is Err it also means it shut down.
+ let _ = listener.recv();
+ }
+ }
+
+ /// Non-blocking
+ fn trigger(&self) {
+ self.sender.send(BurstGuardEvent::Trigger).unwrap();
+ }
+}
diff --git a/talpid-core/src/routing/windows/get_best_default_route.rs b/talpid-core/src/routing/windows/get_best_default_route.rs
new file mode 100644
index 0000000000..4ec7395fff
--- /dev/null
+++ b/talpid-core/src/routing/windows/get_best_default_route.rs
@@ -0,0 +1,190 @@
+use super::{Error, Result};
+use crate::windows::{get_ip_interface_entry, try_socketaddr_from_inet_sockaddr, AddressFamily};
+use std::{convert::TryInto, io, net::SocketAddr};
+use widestring::{widecstr, WideCStr};
+use windows_sys::Win32::{
+ Foundation::NO_ERROR,
+ NetworkManagement::{
+ IpHelper::{
+ FreeMibTable, GetIfEntry2, GetIpForwardTable2, IF_TYPE_SOFTWARE_LOOPBACK,
+ IF_TYPE_TUNNEL, MIB_IF_ROW2, MIB_IPFORWARD_ROW2,
+ },
+ Ndis::NET_LUID_LH,
+ },
+};
+
+// Interface description substrings found for virtual adapters.
+const TUNNEL_INTERFACE_DESCS: [&WideCStr; 3] = [
+ widecstr!("WireGuard"),
+ widecstr!("Wintun"),
+ widecstr!("Tunnel"),
+];
+
+fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> {
+ let family = family.to_af_family();
+ let mut table_ptr = std::ptr::null_mut();
+
+ // SAFETY: GetIpForwardTable2 does not have clear safety specifications however what it does is
+ // heap allocate a IpForwardTable2 and then change table_ptr to point to that allocation.
+ let status = unsafe { GetIpForwardTable2(family, &mut table_ptr) };
+ if NO_ERROR as i32 != status {
+ return Err(Error::GetIpForwardTableFailed(
+ io::Error::from_raw_os_error(status),
+ ));
+ }
+
+ // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error
+ let num_entries = unsafe { *table_ptr }.NumEntries;
+ let mut vec = Vec::with_capacity(num_entries.try_into().unwrap_or_default());
+
+ for i in 0..num_entries {
+ assert!(
+ usize::try_from(i).unwrap() * std::mem::size_of::<MIB_IPFORWARD_ROW2>()
+ < usize::try_from(isize::MAX).unwrap()
+ );
+
+ // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error nor have we
+ // or will we modify the table
+ let ptr: *const MIB_IPFORWARD_ROW2 = unsafe { (*table_ptr).Table.as_ptr() };
+
+ // SAFETY: The assert guarantees that the amount of bytes we are jumping is not larger than
+ // isize::MAX. Win32 guarantees that the resulting pointer is aligned, non-null,
+ // init.
+ let row: &MIB_IPFORWARD_ROW2 =
+ unsafe { ptr.offset(i.try_into().unwrap()).as_ref() }.unwrap();
+ vec.push(row.clone());
+ }
+ // SAFETY: FreeMibTable does not have clear safety rules but it deallocates the
+ // MIB_IPFORWARD_TABLE2 This pointer is ONLY deallocated here so it is guaranteed to not
+ // have been already deallocated. We have cloned all MIB_IPFORWARD_ROW2s and the rows do not
+ // contain pointers to the table so they will not be dangling after this free.
+ unsafe { FreeMibTable(table_ptr as *const _) }
+ Ok(vec)
+}
+
+/// General type for passing interface and gateway
+pub struct InterfaceAndGateway {
+ /// Interface
+ pub iface: NET_LUID_LH,
+ /// Gateway
+ pub gateway: SocketAddr,
+}
+
+impl PartialEq for InterfaceAndGateway {
+ fn eq(&self, other: &InterfaceAndGateway) -> bool {
+ // SAFETY: Accessing Value is always valid in this union as both fields are the same type
+ (unsafe { self.iface.Value == other.iface.Value } && self.gateway == other.gateway)
+ }
+}
+
+/// Get the best default route for the given address family or None if none exists.
+pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceAndGateway>> {
+ let table = get_ipforward_rows(family)?;
+
+ // Remove all candidates without a gateway and which are not on a physical interface.
+ // Then get the annotated routes which are active.
+ let mut annotated: Vec<AnnotatedRoute<'_>> = table
+ .iter()
+ .filter(|row| {
+ 0 == row.DestinationPrefix.PrefixLength
+ && route_has_gateway(row)
+ && is_route_on_physical_interface(row).unwrap_or(false)
+ })
+ .filter_map(|row| annotate_route(row))
+ .collect();
+
+ if annotated.is_empty() {
+ return Ok(None);
+ }
+
+ // We previously filtered out all inactive routes so we only need to sort by acending
+ // effective_metric
+ annotated.sort_by(|lhs, rhs| lhs.effective_metric.cmp(&rhs.effective_metric));
+
+ Ok(Some(InterfaceAndGateway {
+ iface: annotated[0].route.InterfaceLuid,
+ gateway: try_socketaddr_from_inet_sockaddr(annotated[0].route.NextHop)
+ .map_err(|_| Error::InvalidSiFamily)?,
+ }))
+}
+
+pub fn route_has_gateway(route: &MIB_IPFORWARD_ROW2) -> bool {
+ match try_socketaddr_from_inet_sockaddr(route.NextHop) {
+ Ok(sock) => !sock.ip().is_unspecified(),
+ Err(_) => false,
+ }
+}
+
+// TODO(Jon): It would be more correct to filter for devices that match the known LUID of the tunnel
+// interface
+fn is_route_on_physical_interface(route: &MIB_IPFORWARD_ROW2) -> Result<bool> {
+ // The last 16 bits of _bitfield represent the interface type. For that reason we mask it with
+ // 0xFFFF. SAFETY: route.InterfaceLuid is a union. Both variants of this union are always
+ // valid since one is a u64 and the other is a wrapped u64. Access to the _bitfield as such
+ // is safe since it does not reinterpret the u64 as anything it is not.
+ let if_type = u32::try_from(unsafe { route.InterfaceLuid.Info._bitfield } & 0xFFFF).unwrap();
+ if if_type == IF_TYPE_SOFTWARE_LOOPBACK || if_type == IF_TYPE_TUNNEL {
+ return Ok(false);
+ }
+
+ // OpenVPN uses interface type IF_TYPE_PROP_VIRTUAL,
+ // but tethering etc. may rely on virtual adapters too,
+ // so we have to filter out the TAP adapter specifically.
+
+ // SAFETY: We are allowed to initialize MIB_IF_ROW2 with zeroed because it is made up entirely
+ // of types for which the zero pattern (all zeros) is valid.
+ let mut row: MIB_IF_ROW2 = unsafe { std::mem::zeroed() };
+ row.InterfaceLuid = route.InterfaceLuid;
+ row.InterfaceIndex = route.InterfaceIndex;
+
+ // SAFETY: GetIfEntry2 does not have clear safety rules however it will read the
+ // row.InterfaceLuid or row.InterfaceIndex and use that information to populate the struct.
+ // We guarantee here that these fields are valid since they are set.
+ let status = unsafe { GetIfEntry2(&mut row) };
+ if NO_ERROR as i32 != status {
+ return Err(Error::GetIfEntryFailed(io::Error::from_raw_os_error(
+ status,
+ )));
+ }
+
+ let row_description = WideCStr::from_slice_truncate(&row.Description)
+ .expect("Windows provided incorrectly formatted utf16 string");
+
+ for tunnel_interface_desc in TUNNEL_INTERFACE_DESCS {
+ if contains_subslice(row_description.as_slice(), tunnel_interface_desc.as_slice()) {
+ return Ok(false);
+ }
+ }
+
+ return Ok(true);
+}
+
+fn contains_subslice<T: PartialEq>(slice: &[T], subslice: &[T]) -> bool {
+ slice
+ .windows(subslice.len())
+ .any(|window| window == subslice)
+}
+
+struct AnnotatedRoute<'a> {
+ route: &'a MIB_IPFORWARD_ROW2,
+ effective_metric: u32,
+}
+
+fn annotate_route<'a>(route: &'a MIB_IPFORWARD_ROW2) -> Option<AnnotatedRoute<'a>> {
+ // SAFETY: `si_family` is valid in both `Ipv4` and `Ipv6` so we can safely access `si_family`.
+ let iface = get_ip_interface_entry(
+ AddressFamily::try_from_af_family(unsafe { route.DestinationPrefix.Prefix.si_family })
+ .ok()?,
+ &route.InterfaceLuid,
+ )
+ .ok()?;
+
+ if iface.Connected == 0 {
+ None
+ } else {
+ Some(AnnotatedRoute {
+ route,
+ effective_metric: route.Metric + iface.Metric,
+ })
+ }
+}
diff --git a/talpid-core/src/routing/windows/mod.rs b/talpid-core/src/routing/windows/mod.rs
new file mode 100644
index 0000000000..06d23368ca
--- /dev/null
+++ b/talpid-core/src/routing/windows/mod.rs
@@ -0,0 +1,303 @@
+use crate::{routing::RequiredRoute, windows::AddressFamily};
+use futures::channel::oneshot;
+use std::{collections::HashSet, io, net::IpAddr};
+use talpid_types::ErrorExt;
+use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
+
+pub use default_route_monitor::EventType;
+pub use get_best_default_route::{get_best_default_route, route_has_gateway, InterfaceAndGateway};
+pub use route_manager::{Callback, CallbackHandle, Route, RouteManagerInternal};
+
+mod default_route_monitor;
+mod get_best_default_route;
+mod route_manager;
+
+/// Windows routing errors.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// The sender was dropped unexpectedly -- possible panic
+ #[error(display = "The channel sender was dropped")]
+ ManagerChannelDown,
+ /// Failure to initialize route manager
+ #[error(display = "Failed to start route manager")]
+ FailedToStartManager,
+ /// Attempt to use route manager that has been dropped
+ #[error(display = "Cannot send message to route manager since it is down")]
+ RouteManagerDown,
+ /// Low level error caused by a failure to add to route table
+ #[error(display = "Could not add route to route table")]
+ AddToRouteTable(io::Error),
+ /// Low level error caused by failure to delete route from route table
+ #[error(display = "Failed to delete applied routes")]
+ DeleteFromRouteTable(io::Error),
+ /// GetIpForwardTable2 windows API call failed
+ #[error(display = "Failed to retrieve the routing table")]
+ GetIpForwardTableFailed(io::Error),
+ /// GetIfEntry2 windows API call failed
+ #[error(display = "Failed to retrieve network interface entry")]
+ GetIfEntryFailed(io::Error),
+ /// Low level error caused by failing to register the route callback
+ #[error(display = "Attempt to register notify route change callback failed")]
+ RegisterNotifyRouteCallback(io::Error),
+ /// Low level error caused by failing to register the ip interface callback
+ #[error(display = "Attempt to register notify ip interface change callback failed")]
+ RegisterNotifyIpInterfaceCallback(io::Error),
+ /// Low level error caused by failing to register the unicast ip address callback
+ #[error(display = "Attempt to register notify unicast ip address change callback failed")]
+ RegisterNotifyUnicastIpAddressCallback(io::Error),
+ /// Low level error caused by windows Adapters API
+ #[error(display = "Windows adapter error")]
+ Adapter(io::Error),
+ /// High level error caused by a failure to clear the routes in the route manager.
+ /// Contains the lower error
+ #[error(display = "Failed to clear applied routes")]
+ ClearRoutesFailed(Box<Error>),
+ /// High level error caused by a failure to add routes in the route manager.
+ /// Contains the lower error
+ #[error(display = "Failed to add routes")]
+ AddRoutesFailed(Box<Error>),
+ /// Something went wrong when getting the mtu of the interface
+ #[error(display = "Could not get the mtu of the interface")]
+ GetMtu,
+ /// The SI family was of an unexpected value
+ #[error(display = "The SI family was of an unexpected value")]
+ InvalidSiFamily,
+ /// Device name not found
+ #[error(display = "The device name was not found")]
+ DeviceNameNotFound,
+ /// No default route
+ #[error(display = "No default route found")]
+ NoDefaultRoute,
+ /// Conversion error between types
+ #[error(display = "Conversion error")]
+ Conversion,
+ /// Could not find device gateway
+ #[error(display = "Could not find device gateway")]
+ DeviceGatewayNotFound,
+ /// Could not get default route
+ #[error(display = "Could not get default route")]
+ GetDefaultRoute,
+ /// Could not find device by name
+ #[error(display = "Could not find device by name")]
+ GetDeviceByName,
+ /// Could not find device by gateway
+ #[error(display = "Could not find device by gateway")]
+ GetDeviceByGateway,
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Manages routes by calling into WinNet
+pub struct RouteManager {
+ manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
+}
+
+/// Handle to a route manager.
+#[derive(Clone)]
+pub struct RouteManagerHandle {
+ tx: UnboundedSender<RouteManagerCommand>,
+}
+
+impl RouteManagerHandle {
+ /// Add a callback which will be called if the default route changes.
+ pub async fn add_default_route_change_callback(
+ &self,
+ callback: Callback,
+ ) -> Result<CallbackHandle> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .send(RouteManagerCommand::RegisterDefaultRouteChangeCallback(
+ callback,
+ response_tx,
+ ))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ }
+
+ /// Applies the given routes while the route manager is running.
+ pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .send(RouteManagerCommand::AddRoutes(routes, response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ }
+
+ /// Applies the given routes while the route manager is running.
+ pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ }
+}
+
+pub enum RouteManagerCommand {
+ AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
+ GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>),
+ ClearRoutes,
+ RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<Result<CallbackHandle>>),
+ Shutdown,
+}
+
+impl RouteManager {
+ /// Creates a new route manager that will apply the provided routes and ensure they exist until
+ /// it's stopped.
+ pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
+ let internal = match RouteManagerInternal::new() {
+ Ok(internal) => internal,
+ Err(_) => return Err(Error::FailedToStartManager),
+ };
+ let (manage_tx, manage_rx) = mpsc::unbounded_channel();
+ let manager = Self {
+ manage_tx: Some(manage_tx),
+ };
+ tokio::spawn(RouteManager::listen(manage_rx, internal));
+ manager.add_routes(required_routes).await?;
+
+ Ok(manager)
+ }
+
+ /// Add a callback which will be called if the default route changes.
+ pub async fn add_default_route_change_callback(
+ &self,
+ callback: Callback,
+ ) -> Result<CallbackHandle> {
+ if let Some(tx) = &self.manage_tx {
+ let (result_tx, result_rx) = oneshot::channel();
+ if tx
+ .send(RouteManagerCommand::RegisterDefaultRouteChangeCallback(
+ callback, result_tx,
+ ))
+ .is_err()
+ {
+ return Err(Error::RouteManagerDown);
+ }
+ result_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ /// Retrieve a sender directly to the command channel.
+ pub fn handle(&self) -> Result<RouteManagerHandle> {
+ if let Some(tx) = &self.manage_tx {
+ Ok(RouteManagerHandle { tx: tx.clone() })
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ async fn listen(
+ mut manage_rx: UnboundedReceiver<RouteManagerCommand>,
+ mut internal: RouteManagerInternal,
+ ) {
+ while let Some(command) = manage_rx.recv().await {
+ match command {
+ RouteManagerCommand::AddRoutes(routes, tx) => {
+ let routes: Vec<_> = routes
+ .into_iter()
+ .map(|route| Route {
+ network: route.prefix,
+ node: route.node,
+ })
+ .collect();
+
+ let _ = tx.send(
+ internal
+ .add_routes(routes)
+ .map_err(|e| Error::AddRoutesFailed(Box::new(e))),
+ );
+ }
+ RouteManagerCommand::GetMtuForRoute(ip, tx) => {
+ let addr_family = if ip.is_ipv4() {
+ AddressFamily::Ipv4
+ } else {
+ AddressFamily::Ipv6
+ };
+ let res = match get_mtu_for_route(addr_family) {
+ Ok(Some(mtu)) => Ok(mtu),
+ Ok(None) => Err(Error::GetMtu),
+ Err(e) => Err(e),
+ };
+ let _ = tx.send(res);
+ }
+ RouteManagerCommand::ClearRoutes => {
+ if let Err(e) = internal.delete_applied_routes() {
+ log::error!("{}", e.display_chain_with_msg("Could not clear routes"));
+ }
+ }
+ RouteManagerCommand::RegisterDefaultRouteChangeCallback(callback, tx) => {
+ let _ = tx.send(internal.register_default_route_changed_callback(callback));
+ }
+ RouteManagerCommand::Shutdown => {
+ break;
+ }
+ }
+ }
+ }
+
+ /// Stops the routing manager and invalidates the route manager - no new default route callbacks
+ /// can be added
+ pub fn stop(&mut self) {
+ if let Some(tx) = self.manage_tx.take() {
+ if tx.send(RouteManagerCommand::Shutdown).is_err() {
+ log::error!("RouteManager channel already down or thread panicked");
+ }
+ }
+ }
+
+ /// Applies the given routes until [`RouteManager::stop`] is called.
+ pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ if let Some(tx) = &self.manage_tx {
+ let (result_tx, result_rx) = oneshot::channel();
+ if tx
+ .send(RouteManagerCommand::AddRoutes(routes, result_tx))
+ .is_err()
+ {
+ return Err(Error::RouteManagerDown);
+ }
+ result_rx.await.map_err(|_| Error::ManagerChannelDown)?
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ /// Removes all routes previously applied in [`RouteManager::new`] or
+ /// [`RouteManager::add_routes`].
+ pub fn clear_routes(&self) -> Result<()> {
+ if let Some(tx) = &self.manage_tx {
+ tx.send(RouteManagerCommand::ClearRoutes)
+ .map_err(|_| Error::RouteManagerDown)
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+}
+
+fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> {
+ match get_best_default_route(addr_family) {
+ Ok(Some(route)) => {
+ let interface_row = crate::windows::get_ip_interface_entry(addr_family, &route.iface)
+ .map_err(|e| {
+ log::error!("Could not get ip interface entry: {}", e);
+ Error::GetMtu
+ })?;
+ let mtu = interface_row.NlMtu;
+ let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?;
+ Ok(Some(mtu))
+ }
+ Ok(None) => Ok(None),
+ Err(e) => {
+ log::error!("Could not get best default route: {}", e);
+ Err(Error::GetMtu)
+ }
+ }
+}
+
+impl Drop for RouteManager {
+ fn drop(&mut self) {
+ self.stop();
+ }
+}
diff --git a/talpid-core/src/routing/windows/route_manager.rs b/talpid-core/src/routing/windows/route_manager.rs
new file mode 100644
index 0000000000..f1d878dd28
--- /dev/null
+++ b/talpid-core/src/routing/windows/route_manager.rs
@@ -0,0 +1,885 @@
+use super::{
+ default_route_monitor::{DefaultRouteMonitor, EventType as RouteMonitorEventType},
+ get_best_default_route, Error, InterfaceAndGateway, Result,
+};
+use crate::{
+ routing::NetNode,
+ windows::{inet_sockaddr_from_socketaddr, try_socketaddr_from_inet_sockaddr, AddressFamily},
+};
+use ipnetwork::IpNetwork;
+use std::{
+ collections::HashMap,
+ io,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
+ sync::{Arc, Mutex},
+};
+use widestring::{WideCStr, WideCString};
+use windows_sys::Win32::{
+ Foundation::{
+ ERROR_BUFFER_OVERFLOW, ERROR_NOT_FOUND, ERROR_NO_DATA, ERROR_OBJECT_ALREADY_EXISTS,
+ ERROR_SUCCESS, NO_ERROR,
+ },
+ NetworkManagement::{
+ IpHelper::{
+ ConvertInterfaceAliasToLuid, CreateIpForwardEntry2, DeleteIpForwardEntry2,
+ GetAdaptersAddresses, InitializeIpForwardEntry, SetIpForwardEntry2,
+ GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER,
+ GAA_FLAG_SKIP_FRIENDLY_NAME, GAA_FLAG_SKIP_MULTICAST, GET_ADAPTERS_ADDRESSES_FLAGS,
+ IP_ADAPTER_ADDRESSES_LH, IP_ADAPTER_GATEWAY_ADDRESS_LH, IP_ADAPTER_IPV4_ENABLED,
+ IP_ADAPTER_IPV6_ENABLED, IP_ADDRESS_PREFIX, MIB_IPFORWARD_ROW2,
+ },
+ Ndis::NET_LUID_LH,
+ },
+ Networking::WinSock::{
+ NlroManual, ADDRESS_FAMILY, AF_INET, AF_INET6, MIB_IPPROTO_NETMGMT, SOCKADDR_IN,
+ SOCKADDR_IN6, SOCKADDR_INET, SOCKET_ADDRESS,
+ },
+};
+
+type Network = IpNetwork;
+type NodeAddress = SOCKADDR_INET;
+
+/// Callback handle for the default route changed callback. Produced by the RouteManager.
+pub struct CallbackHandle {
+ nonce: i32,
+ callbacks: Arc<Mutex<(i32, HashMap<i32, Callback>)>>,
+}
+
+impl Drop for CallbackHandle {
+ fn drop(&mut self) {
+ let (_, callbacks) = &mut *self.callbacks.lock().unwrap();
+ match callbacks.remove(&self.nonce) {
+ Some(_) => (),
+ None => {
+ log::warn!("Could not un-register route manager callback due to it already being de-registered");
+ }
+ }
+ }
+}
+
+#[derive(Clone)]
+struct RegisteredRoute {
+ network: Network,
+ luid: NET_LUID_LH,
+ next_hop: SocketAddr,
+}
+
+impl std::fmt::Display for RegisteredRoute {
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ // SAFETY: luid.Value is always valid as the underlying type of both union fields is an u64
+ formatter.write_fmt(format_args!("RegisteredRoute {{ luid: {} }}", unsafe {
+ self.luid.Value
+ }))
+ }
+}
+
+impl PartialEq for RegisteredRoute {
+ fn eq(&self, other: &Self) -> bool {
+ // SAFETY: luid.Value is always valid as the underlying type of both union fields is an u64
+ (unsafe { self.luid.Value == other.luid.Value })
+ && (self.next_hop == other.next_hop)
+ && (self.network == other.network)
+ }
+}
+
+#[derive(Clone)]
+pub struct Node {
+ pub device_name: Option<widestring::U16CString>,
+ pub gateway: Option<NodeAddress>,
+}
+
+#[derive(Clone)]
+pub struct Route {
+ pub network: Network,
+ pub node: NetNode,
+}
+
+#[derive(Clone)]
+struct RouteRecord {
+ route: Route,
+ registered_route: RegisteredRoute,
+}
+
+struct EventEntry {
+ record: RouteRecord,
+ event_type: RecordEventType,
+}
+
+enum RecordEventType {
+ AddRoute,
+ DeleteRoute,
+}
+
+pub type Callback = Box<dyn for<'a> Fn(RouteMonitorEventType<'a>, AddressFamily) + Send>;
+
+pub struct RouteManagerInternal {
+ route_monitor_v4: Option<DefaultRouteMonitor>,
+ route_monitor_v6: Option<DefaultRouteMonitor>,
+ routes: Arc<Mutex<Vec<RouteRecord>>>,
+ /// Lock for a nonce and a HashMap of callbacks and their id which is used as a handle to
+ /// unregister them. The nonce is used to create new ids and then incrementing.
+ callbacks: Arc<Mutex<(i32, HashMap<i32, Callback>)>>,
+}
+
+impl RouteManagerInternal {
+ pub fn new() -> Result<Self> {
+ let routes = Arc::new(Mutex::new(Vec::new()));
+ let callbacks = Arc::new(Mutex::new((0, HashMap::new())));
+
+ let callbacks_ipv4 = callbacks.clone();
+ let routes_ipv4 = routes.clone();
+ let callbacks_ipv6 = callbacks.clone();
+ let routes_ipv6 = routes.clone();
+
+ Ok(Self {
+ route_monitor_v4: Some(DefaultRouteMonitor::new(
+ AddressFamily::Ipv4,
+ move |event_type| {
+ Self::default_route_change(&callbacks_ipv4, &routes_ipv4, AF_INET, event_type);
+ },
+ )?),
+ route_monitor_v6: Some(DefaultRouteMonitor::new(
+ AddressFamily::Ipv6,
+ move |event_type| {
+ Self::default_route_change(&callbacks_ipv6, &routes_ipv6, AF_INET6, event_type);
+ },
+ )?),
+ routes,
+ callbacks,
+ })
+ }
+
+ pub fn add_routes(&self, new_routes: Vec<Route>) -> Result<()> {
+ let mut route_manager_routes = self.routes.lock().unwrap();
+
+ let mut event_log = vec![];
+
+ for route in new_routes {
+ let registered_route = Self::add_into_routing_table(&route).map_err(|error| {
+ if let Err(error) = Self::undo_events(&event_log, &mut route_manager_routes) {
+ error
+ } else {
+ error
+ }
+ })?;
+
+ let new_record = RouteRecord {
+ route,
+ registered_route,
+ };
+
+ event_log.push(EventEntry {
+ event_type: RecordEventType::AddRoute,
+ record: new_record.clone(),
+ });
+
+ let existing_record_idx =
+ Self::find_route_record(&mut route_manager_routes, &new_record.registered_route);
+
+ match existing_record_idx {
+ None => route_manager_routes.push(new_record),
+ Some(idx) => route_manager_routes[idx] = new_record,
+ }
+ }
+ Ok(())
+ }
+
+ fn add_into_routing_table(route: &Route) -> Result<RegisteredRoute> {
+ let node = Self::resolve_node(ipnetwork_to_address_family(route.network), &route.node)?;
+
+ // SAFETY: MIB_IPFORWARD_ROW2 contains no references or pointers only number primitives and
+ // as such it is safe to zero it.
+ let mut spec: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() };
+
+ // SAFETY: This function must be used to initialize MIB_IPFORWARD_ROW2 structs if it is to
+ // be used later by CreateIpForwardEntry2.
+ unsafe { InitializeIpForwardEntry(&mut spec) };
+
+ spec.InterfaceLuid = node.iface;
+ spec.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network);
+ spec.NextHop = inet_sockaddr_from_socketaddr(node.gateway);
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ // SAFETY: DestinationPrefix must be initialized to a valid prefix. NextHop must have a
+ // valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be set
+ // to the interface.
+ let mut status = unsafe { CreateIpForwardEntry2(&spec) };
+
+ // The return code ERROR_OBJECT_ALREADY_EXISTS means there is already an existing route
+ // on the same interface, with the same DestinationPrefix and NextHop.
+ //
+ // However, all the other properties of the route may be different. And the properties may
+ // not have the exact same values as when the route was registered, because windows
+ // will adjust route properties at time of route insertion as well as later.
+ //
+ // The simplest thing in this case is to just overwrite the route.
+ //
+
+ if ERROR_OBJECT_ALREADY_EXISTS as i32 == status {
+ // SAFETY: DestinationPrefix must be initialzed to a valid prefix. NextHop must have
+ // a valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must
+ // be set to the interface.
+ status = unsafe { SetIpForwardEntry2(&spec) };
+ }
+
+ if NO_ERROR as i32 != status {
+ log::error!("Could not register route in routing table");
+ return Err(Error::AddToRouteTable(io::Error::from_raw_os_error(status)));
+ }
+
+ Ok(RegisteredRoute {
+ network: route.network,
+ luid: node.iface,
+ next_hop: node.gateway,
+ })
+ }
+
+ fn resolve_node(family: AddressFamily, optional_node: &NetNode) -> Result<InterfaceAndGateway> {
+ // There are four cases:
+ //
+ // Unspecified node (use interface and gateway of default route).
+ // Node is specified by name.
+ // Node is specified by name and gateway.
+ // Node is specified by gateway.
+ //
+
+ match optional_node {
+ NetNode::DefaultNode => {
+ let default_route = get_best_default_route(family)?;
+ match default_route {
+ None => {
+ log::error!("Unable to determine details of default route");
+ return Err(Error::NoDefaultRoute);
+ }
+ Some(default_route) => return Ok(default_route),
+ }
+ }
+ NetNode::RealNode(node) => {
+ if let Some(device_name) = &node.get_device() {
+ let device_name = WideCString::from_str(device_name)
+ .expect("Failed to convert UTF-8 string to null terminated UCS string");
+ let luid = match Self::parse_string_encoded_luid(device_name.as_ucstr())? {
+ None => {
+ let mut luid = NET_LUID_LH { Value: 0 };
+ // SAFETY: No specific safety requirement
+ if NO_ERROR as i32
+ != unsafe {
+ ConvertInterfaceAliasToLuid(device_name.as_ptr(), &mut luid)
+ }
+ {
+ log::error!(
+ "Unable to derive interface LUID from interface alias: {:?}",
+ device_name
+ );
+ return Err(Error::DeviceNameNotFound);
+ } else {
+ luid
+ }
+ }
+ Some(luid) => luid,
+ };
+
+ return Ok(InterfaceAndGateway {
+ iface: luid,
+ gateway: match node.get_address() {
+ Some(ip) => SocketAddr::new(ip, 0),
+ None => match family {
+ AddressFamily::Ipv4 => {
+ SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
+ }
+ AddressFamily::Ipv6 => {
+ SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
+ }
+ },
+ },
+ });
+ }
+
+ // The node is specified only by gateway.
+ //
+
+ // Unwrapping is fine because the node must have an address since no device name was
+ // found.
+ let gateway = node.get_address().map(inet_sockaddr_from_ipaddr).unwrap();
+ Ok(InterfaceAndGateway {
+ iface: interface_luid_from_gateway(&gateway)?,
+ gateway: try_socketaddr_from_inet_sockaddr(gateway)
+ .map_err(|_| Error::InvalidSiFamily)?,
+ })
+ }
+ }
+ }
+
+ fn find_route_record(records: &mut Vec<RouteRecord>, route: &RegisteredRoute) -> Option<usize> {
+ records
+ .iter()
+ .position(|record| route == &record.registered_route)
+ }
+
+ fn undo_events(event_log: &Vec<EventEntry>, records: &mut Vec<RouteRecord>) -> Result<()> {
+ // Rewind state by processing events in the reverse order.
+ //
+
+ let mut result = Ok(());
+
+ for event in event_log.iter().rev() {
+ match event.event_type {
+ RecordEventType::AddRoute => {
+ let record_idx = Self::find_route_record(records, &event.record.registered_route)
+ .expect("Internal state inconsistency in route manager, could not find route record");
+ let record = records.get(record_idx)
+ .expect("Internal state inconsistency in route manager, route record index pointing at nothing");
+
+ if let Err(e) = Self::delete_from_routing_table(&record.registered_route) {
+ result = result.and(Err(e));
+ continue;
+ }
+ records.remove(record_idx);
+ }
+ RecordEventType::DeleteRoute => {
+ if let Err(e) = Self::restore_into_routing_table(&event.record.registered_route)
+ {
+ result = result.and(Err(e));
+ continue;
+ }
+ records.push(event.record.clone());
+ }
+ }
+ }
+
+ result
+ }
+
+ fn delete_from_routing_table(route: &RegisteredRoute) -> Result<()> {
+ // SAFETY: There are no pointers or references inside of MIB_IPFORWARD_ROW2, only primitive
+ // numbers as such it is safe to zero it.
+ let mut r: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() };
+
+ r.InterfaceLuid = route.luid;
+ r.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network);
+ r.NextHop = inet_sockaddr_from_socketaddr(route.next_hop);
+
+ // SAFETY: DestinationPrefix must be initialzed to a valid prefix. NextHop must have
+ // a valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be
+ // set to the interface.
+ let status = unsafe { DeleteIpForwardEntry2(&r) };
+
+ match u32::try_from(status) {
+ Ok(ERROR_NOT_FOUND) => {
+ log::warn!("Attempting to delete route which was not present in routing table, ignoring and proceeding. Route: {}", route);
+ }
+ Ok(NO_ERROR) => (),
+ _ => {
+ log::error!(
+ "Failed to delete route in routing table. Route: {}, Status: {}",
+ route,
+ status
+ );
+ return Err(Error::DeleteFromRouteTable(io::Error::from_raw_os_error(
+ status,
+ )));
+ }
+ }
+
+ Ok(())
+ }
+
+ fn restore_into_routing_table(route: &RegisteredRoute) -> Result<()> {
+ // SAFETY: There are no pointers or references inside of MIB_IPFORWARD_ROW2, only primitive
+ // numbers as such it is safe to zero it.
+ let mut spec: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() };
+
+ // SAFETY: This function must be used to initialize MIB_IPFORWARD_ROW2 structs if it is to
+ // be used later by CreateIpForwardEntry2.
+ unsafe { InitializeIpForwardEntry(&mut spec) };
+
+ spec.InterfaceLuid = route.luid;
+ spec.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network);
+ spec.NextHop = inet_sockaddr_from_socketaddr(route.next_hop);
+ spec.Metric = 0;
+ spec.Protocol = MIB_IPPROTO_NETMGMT;
+ spec.Origin = NlroManual;
+
+ // SAFETY: DestinationPrefix must be initialized to a valid prefix. NextHop must have a
+ // valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be set
+ // to the interface.
+ let status = unsafe { CreateIpForwardEntry2(&spec) };
+
+ if NO_ERROR as i32 != status {
+ log::error!(
+ "Could not register route in routing table. Route: {}, Status: {}",
+ route,
+ status
+ );
+ return Err(Error::AddToRouteTable(io::Error::from_raw_os_error(status)));
+ }
+ Ok(())
+ }
+
+ fn parse_string_encoded_luid(encoded_luid: &WideCStr) -> Result<Option<NET_LUID_LH>> {
+ // The `#` is a valid character in adapter names so we use `?` instead.
+ // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes.
+ // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`.
+ //
+
+ const STRING_ENCODED_LUID_LENGTH: usize = 17;
+
+ if encoded_luid.len() != STRING_ENCODED_LUID_LENGTH
+ || Some(Ok('?')) != encoded_luid.chars().next()
+ {
+ return Ok(None);
+ }
+
+ let luid = NET_LUID_LH {
+ Value: u64::from_str_radix(
+ &encoded_luid.to_string().map_err(|_| {
+ log::error!("Failed to parse string encoded LUID: {:?}", encoded_luid);
+ Error::Conversion
+ })?[1..],
+ 16,
+ )
+ .map_err(|_| {
+ log::error!("Failed to parse string encoded LUID: {:?}", encoded_luid);
+ Error::Conversion
+ })?,
+ };
+
+ return Ok(Some(luid));
+ }
+
+ pub fn delete_applied_routes(&mut self) -> Result<()> {
+ let mut routes = self.routes.lock().unwrap();
+ // Delete all routes owned by us.
+ //
+
+ for record in (*routes).iter() {
+ if let Err(_) = Self::delete_from_routing_table(&record.registered_route) {
+ log::error!(
+ "Failed to delete route while clearing applied routes, {}",
+ record.registered_route
+ );
+ }
+ }
+
+ routes.clear();
+ Ok(())
+ }
+
+ pub fn register_default_route_changed_callback(
+ &self,
+ callback: Callback,
+ ) -> Result<CallbackHandle> {
+ let (nonce, callbacks) = &mut *self.callbacks.lock().unwrap();
+ let old_nonce = *nonce;
+ callbacks.insert(old_nonce, callback);
+ *nonce = nonce.wrapping_add(1);
+ Ok(CallbackHandle {
+ nonce: old_nonce,
+ callbacks: self.callbacks.clone(),
+ })
+ }
+
+ fn default_route_change<'a>(
+ callbacks: &Arc<Mutex<(i32, HashMap<i32, Callback>)>>,
+ records: &Arc<Mutex<Vec<RouteRecord>>>,
+ family: ADDRESS_FAMILY,
+ event_type: RouteMonitorEventType<'a>,
+ ) {
+ // Forward event to all registered listeners.
+ //
+
+ {
+ let (_, callbacks) = &mut *callbacks.lock().unwrap();
+ for callback in callbacks.values() {
+ let family =
+ AddressFamily::try_from_af_family(u16::try_from(family).unwrap()).unwrap();
+ callback(event_type, family);
+ }
+ }
+
+ // Examine event to determine if best default route has changed.
+ //
+
+ let route = if let RouteMonitorEventType::Updated(route) = event_type {
+ route
+ } else {
+ return;
+ };
+
+ // Examine our routes to see if any of them are policy bound to the best default route.
+ //
+
+ let mut records = records.lock().unwrap();
+ let mut affected_routes: Vec<&mut RouteRecord> = vec![];
+
+ for record in (*records).iter_mut() {
+ if matches!(record.route.node, NetNode::DefaultNode)
+ && family
+ == u32::from(ipnetwork_to_address_family(record.route.network).to_af_family())
+ {
+ affected_routes.push(record);
+ }
+ }
+
+ if affected_routes.is_empty() {
+ return;
+ }
+
+ // Update all affected routes.
+ //
+
+ log::info!("Best default route has changed. Refreshing dependent routes");
+
+ for affected_route in affected_routes {
+ // We can't update the existing route because defining characteristics are being
+ // changed. So removing and adding again is the only option.
+ //
+
+ match Self::delete_from_routing_table(&affected_route.registered_route) {
+ Ok(()) => (),
+ Err(e) => {
+ log::error!(
+ "Failed to delete route when refreshing existing routes: {}",
+ e
+ );
+ continue;
+ }
+ }
+
+ affected_route.registered_route.luid = route.iface;
+ affected_route.registered_route.next_hop = route.gateway;
+
+ match Self::restore_into_routing_table(&affected_route.registered_route) {
+ Ok(()) => (),
+ Err(e) => {
+ log::error!("Failed to add route when refreshing existing routes: {}", e);
+ continue;
+ }
+ }
+ }
+ }
+}
+
+impl Drop for RouteManagerInternal {
+ fn drop(&mut self) {
+ drop(self.route_monitor_v4.take());
+ drop(self.route_monitor_v6.take());
+
+ match self.delete_applied_routes() {
+ Ok(()) => (),
+ Err(e) => {
+ log::error!("Failed to correctly drop RouteManagerInternal {}", e)
+ }
+ }
+ }
+}
+
+fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> {
+ const ADAPTER_FLAGS: GET_ADAPTERS_ADDRESSES_FLAGS = GAA_FLAG_SKIP_ANYCAST
+ | GAA_FLAG_SKIP_MULTICAST
+ | GAA_FLAG_SKIP_DNS_SERVER
+ | GAA_FLAG_SKIP_FRIENDLY_NAME
+ | GAA_FLAG_INCLUDE_GATEWAYS;
+
+ // SAFETY: The si_family field is always valid to access.
+ let family: u32 = u32::from(unsafe { gateway.si_family });
+ let adapters = Adapters::new(family, ADAPTER_FLAGS)?;
+
+ // Process adapters to find matching ones.
+ //
+
+ let mut matches: Vec<_> = adapters
+ // SAFETY: We are not allowed to dereference adapter.Head if it has been aquired in a previous iteration of the iterator
+ // we ensure this is upheld by not saving any references to adapter.Head between iterations.
+ .iter()
+ .filter(|adapter| {
+ if !adapter_interface_enabled(adapter, family).unwrap_or(false) {
+ return false;
+ }
+ let gateways = if adapter.FirstGatewayAddress.is_null() {
+ vec![]
+ } else {
+ // SAFETY: adapter.FirstGatewayAddress is not null and all elements in the linked list live
+ // in the same buffer and as such have the same lifetime.
+ unsafe { isolate_gateway_address(get_first_gateway_address_reference(adapter), family) }
+ };
+
+ address_present(gateways, &gateway).unwrap_or(false)
+ })
+ .collect();
+
+ if matches.is_empty() {
+ log::error!("Unable to find network adapter with specified gateway");
+ return Err(Error::DeviceGatewayNotFound);
+ }
+
+ // Sort matching interfaces ascending by metric.
+ //
+
+ let target_v4 = AF_INET == family;
+
+ matches.sort_by(|lhs, rhs| {
+ if target_v4 {
+ lhs.Ipv4Metric.cmp(&rhs.Ipv4Metric)
+ } else {
+ lhs.Ipv6Metric.cmp(&rhs.Ipv6Metric)
+ }
+ });
+
+ // Select the interface with the best (lowest) metric.
+ //
+
+ Ok(matches[0].Luid)
+}
+
+/// SAFETY: adapter.FirstGatewayAddress must be dereferencable and must live as long as adapter
+unsafe fn get_first_gateway_address_reference(
+ adapter: &IP_ADAPTER_ADDRESSES_LH,
+) -> &IP_ADAPTER_GATEWAY_ADDRESS_LH {
+ &*adapter.FirstGatewayAddress
+}
+
+fn adapter_interface_enabled(
+ adapter: &IP_ADAPTER_ADDRESSES_LH,
+ family: ADDRESS_FAMILY,
+) -> Result<bool> {
+ match family {
+ // SAFETY: All fields in the Anonymous2 union are at represented by a u32 so dereferencing
+ // them is safe
+ AF_INET => Ok(0 != unsafe { adapter.Anonymous2.Flags } & IP_ADAPTER_IPV4_ENABLED),
+ AF_INET6 => Ok(0 != unsafe { adapter.Anonymous2.Flags } & IP_ADAPTER_IPV6_ENABLED),
+ _ => Err(Error::InvalidSiFamily),
+ }
+}
+
+/// SAFETY: `head` must be a linked list where each `head.Next` is either null or
+/// the it and all of its fields has lifetime 'a and are dereferencable.
+unsafe fn isolate_gateway_address<'a>(
+ head: &'a IP_ADAPTER_GATEWAY_ADDRESS_LH,
+ family: ADDRESS_FAMILY,
+) -> Vec<&'a SOCKET_ADDRESS> {
+ let mut matches = vec![];
+
+ let mut gateway = head;
+ loop {
+ // SAFETY: The contract states that Address.lpSockaddr is dereferencable if the element is
+ // non-null
+ if family == u32::from((*gateway.Address.lpSockaddr).sa_family) {
+ // SAFETY: The contract states that this field must have lifetime 'a
+ matches.push(&gateway.Address);
+ }
+
+ if gateway.Next.is_null() {
+ break;
+ }
+
+ // SAFETY: Gateway.Next is not null here and the contract states it must be dereferencable
+ // if non-null
+ gateway = &*gateway.Next;
+ }
+
+ matches
+}
+
+fn address_present(hay: Vec<&'_ SOCKET_ADDRESS>, needle: &'_ SOCKADDR_INET) -> Result<bool> {
+ for candidate in hay {
+ // SAFETY: Contract states that needle is dereferencable
+ if equal_address(needle, candidate)? {
+ return Ok(true);
+ }
+ }
+
+ Ok(false)
+}
+
+fn equal_address(lhs: &'_ SOCKADDR_INET, rhs: &'_ SOCKET_ADDRESS) -> Result<bool> {
+ let rhs = &*rhs;
+ // SAFETY: The si_family field is always valid
+ if unsafe { lhs.si_family != (*rhs.lpSockaddr).sa_family } {
+ return Ok(false);
+ }
+
+ match unsafe { lhs.si_family } as u32 {
+ AF_INET => {
+ let typed_rhs = rhs.lpSockaddr as *mut SOCKADDR_IN;
+ // SAFETY: If rhs.lpSockaddr.sa_family is IPv4 then lpSockaddr is a SOCKADDR_IN
+ Ok(unsafe { lhs.Ipv4.sin_addr.S_un.S_addr == (*typed_rhs).sin_addr.S_un.S_addr })
+ }
+ AF_INET6 => {
+ let typed_rhs = rhs.lpSockaddr as *mut SOCKADDR_IN6;
+ // SAFETY: If rhs.lpSockaddr.sa_family is IPv6 then lpSockaddr is a SOCKADDR_IN6
+ Ok(unsafe { lhs.Ipv6.sin6_addr.u.Byte == (*typed_rhs).sin6_addr.u.Byte })
+ }
+ _ => {
+ log::error!("Missing case handler in match");
+ Err(Error::InvalidSiFamily)
+ }
+ }
+}
+
+/// Linked list containing `IP_ADAPTER_ADDRESSES_LH` queried from the windows API.
+/// Consume by using the iterator produced by `iter_mut()`
+struct Adapters {
+ // SAFETY: This vector is not allowed to be resized since all of the data inside of it would be
+ // dangling
+ buffer: Vec<u8>,
+}
+
+impl Adapters {
+ /// Create a new linked list of adapters from the windows API
+ fn new(family: ADDRESS_FAMILY, flags: GET_ADAPTERS_ADDRESSES_FLAGS) -> Result<Self> {
+ const MSDN_RECOMMENDED_STARTING_BUFFER_SIZE: usize = 1024 * 15;
+ let mut buffer: Vec<u8> = Vec::with_capacity(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE);
+ buffer.resize(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE, 0);
+
+ let mut buffer_size = u32::try_from(buffer.len()).unwrap();
+ let mut buffer_pointer = buffer.as_mut_ptr();
+
+ // Acquire interfaces.
+ //
+
+ loop {
+ // SAFETY: buffer_size must point to the correct amount of bytes in the buffer which it
+ // does. buffer_pointer must point to the start of a mutable buffer which it
+ // does. After this call buffer_size might have changed and as such the
+ // buffer must be resized to reflect this if this function is going to be
+ // called again.
+ let status = unsafe {
+ GetAdaptersAddresses(
+ family,
+ flags,
+ std::ptr::null_mut() as *mut _,
+ buffer_pointer as *mut IP_ADAPTER_ADDRESSES_LH,
+ &mut buffer_size,
+ )
+ };
+
+ if ERROR_SUCCESS == status {
+ // SAFETY: We truncate the buffer to avoid having a bunch of zero:ed objects at the
+ // end of it truncate will not change capacity and will therefore
+ // never reallocate the vector which means it can not cause the
+ // pointers in the buffer to dangle.
+ buffer.truncate(usize::try_from(buffer_size).unwrap());
+ break;
+ }
+
+ if ERROR_NO_DATA == status {
+ return Ok(Self { buffer: Vec::new() });
+ }
+
+ if ERROR_BUFFER_OVERFLOW != status {
+ log::error!("Probe required buffer size for GetAdaptersAddresses");
+ return Err(Error::Adapter(io::Error::from_raw_os_error(
+ i32::try_from(status).unwrap(),
+ )));
+ }
+
+ // The needed length is returned in the buffer_size pointer
+ buffer.resize(usize::try_from(buffer_size).unwrap(), 0);
+ buffer_pointer = buffer.as_mut_ptr();
+ }
+
+ // Verify structure compatibility.
+ // The structure has been extended many times.
+ //
+
+ // Unwrapping is fine because we previously would return if we got a ERROR_NO_DATA status.
+ // As such the buffer is not empty. SAFETY: Casting the buffers first element to an
+ // IP_ADAPTER_ADDRESSES_LH is safe as that is the underlying data structure. SAFETY:
+ // This union field is always valid to read from
+ let system_size = unsafe {
+ (*(buffer.get(0).unwrap() as *const u8 as *const IP_ADAPTER_ADDRESSES_LH))
+ .Anonymous1
+ .Anonymous
+ .Length
+ };
+ let code_size = u32::try_from(std::mem::size_of::<IP_ADAPTER_ADDRESSES_LH>()).unwrap();
+
+ if system_size < code_size {
+ log::error!("Expecting IP_ADAPTER_ADDRESSES to have size {code_size} bytes. Found structure with size {system_size} bytes.");
+ return Err(Error::Adapter(io::Error::new(io::ErrorKind::Other,
+ format!("Expecting IP_ADAPTER_ADDRESSES to have size {code_size} bytes. Found structure with size {system_size} bytes."))));
+ }
+
+ // Initialize members.
+ //
+
+ Ok(Self { buffer })
+ }
+
+ /// Produces a iterator for the linked list in `Adapters` see
+ /// [AdaptersIterator](struct.AdaptersIterator.html) SAFETY: See the documentation on
+ /// `AdaptersIterator`
+ fn iter<'a>(&'a self) -> AdaptersIterator<'a> {
+ let cur = if self.buffer.is_empty() {
+ std::ptr::null()
+ } else {
+ &self.buffer[0] as *const u8 as *const IP_ADAPTER_ADDRESSES_LH
+ };
+ AdaptersIterator {
+ _adapters: self,
+ cur,
+ }
+ }
+}
+
+/// SAFETY: You are only allowed to dereference `IP_ADAPTER_ADDRESSES_LH.Next` or any following
+/// `Next` items in the linked list if they were produced by the latest call to `next()`. Any raw
+/// pointers that were aquired before the call to `next()` are not valid to dereference.
+struct AdaptersIterator<'a> {
+ _adapters: &'a Adapters,
+ cur: *const IP_ADAPTER_ADDRESSES_LH,
+}
+
+impl<'a> Iterator for AdaptersIterator<'a> {
+ type Item = &'a IP_ADAPTER_ADDRESSES_LH;
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.cur.is_null() {
+ None
+ } else {
+ let ret = self.cur;
+ // SAFETY: self.cur is guaranteed to not be null, we are also holding a &Adapters which
+ // guarantees no other reference of self could be held right now which has
+ // mutably dereferenced the same address that self.cur is pointing to.
+ //
+ // It is possible that someone has copied the previous returned items `Next` pointer
+ // which points to the same as address as self.cur, however dereferencing
+ // that is unsafe and that code is responsible for not dereferencing
+ // `Next` on a reference returned by this function after that reference has been
+ // dropped.
+ self.cur = unsafe { (*self.cur).Next };
+ // SAFETY: ret is guaranteed to be non-null and valid since self.adapters owns the
+ // memory.
+ Some(unsafe { &*ret })
+ }
+ }
+}
+
+/// Convert to a windows defined `IP_ADDRESS_PREFIX` from a `ipnetwork::IpNetwork` but set the port
+/// to 0
+pub fn win_ip_address_prefix_from_ipnetwork_port_zero(from: IpNetwork) -> IP_ADDRESS_PREFIX {
+ // Port should not matter so we set it to 0
+ let prefix =
+ crate::windows::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from.ip(), 0));
+ IP_ADDRESS_PREFIX {
+ Prefix: prefix,
+ PrefixLength: from.prefix(),
+ }
+}
+
+/// Convert to a windows defined `SOCKADDR_INET` from a `IpAddr` but set the port to 0
+pub fn inet_sockaddr_from_ipaddr(from: IpAddr) -> SOCKADDR_INET {
+ // Port should not matter so we set it to 0
+ crate::windows::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from, 0))
+}
+
+/// Convert to a `AddressFamily` from a `ipnetwork::IpNetwork`
+pub fn ipnetwork_to_address_family(from: IpNetwork) -> AddressFamily {
+ if from.is_ipv4() {
+ AddressFamily::Ipv4
+ } else {
+ AddressFamily::Ipv6
+ }
+}
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index 0bde6ac435..49028319a0 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -5,6 +5,7 @@ mod volume_monitor;
mod windows;
use crate::{
+ routing::{self, get_best_default_route, CallbackHandle, EventType, RouteManagerHandle},
tunnel::TunnelMetadata,
tunnel_state_machine::TunnelCommand,
windows::{
@@ -12,7 +13,6 @@ use crate::{
window::{PowerManagementEvent, PowerManagementListener},
AddressFamily,
},
- winnet::{self, get_best_default_route, WinNetAddrFamily, WinNetCallbackHandle},
};
use futures::channel::{mpsc, oneshot};
use std::{
@@ -29,9 +29,7 @@ use std::{
time::Duration,
};
use talpid_types::{tunnel::ErrorStateCause, ErrorExt};
-use windows_sys::Win32::{
- Foundation::ERROR_OPERATION_ABORTED, NetworkManagement::Ndis::NET_LUID_LH,
-};
+use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED;
const DRIVER_EVENT_BUFFER_SIZE: usize = 2048;
const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123);
@@ -74,7 +72,7 @@ pub enum Error {
/// Failed to obtain default route
#[error(display = "Failed to obtain the default route")]
- ObtainDefaultRoute(#[error(source)] winnet::Error),
+ ObtainDefaultRoute(#[error(source)] routing::Error),
/// Failed to obtain an IP address given a network interface LUID
#[error(display = "Failed to obtain IP address for interface LUID")]
@@ -116,10 +114,11 @@ pub struct SplitTunnel {
event_thread: Option<std::thread::JoinHandle<()>>,
quit_event: Arc<windows::Event>,
excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
- _route_change_callback: Option<WinNetCallbackHandle>,
+ _route_change_callback: Option<CallbackHandle>,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
async_path_update_in_progress: Arc<AtomicBool>,
power_mgmt_handle: tokio::task::JoinHandle<()>,
+ route_manager: RouteManagerHandle,
}
enum Request {
@@ -187,6 +186,7 @@ impl SplitTunnel {
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
volume_update_rx: mpsc::UnboundedReceiver<()>,
power_mgmt_rx: PowerManagementListener,
+ route_manager: RouteManagerHandle,
) -> Result<Self, Error> {
let excluded_processes = Arc::new(RwLock::new(HashMap::new()));
@@ -209,6 +209,7 @@ impl SplitTunnel {
async_path_update_in_progress: Arc::new(AtomicBool::new(false)),
excluded_processes,
power_mgmt_handle,
+ route_manager,
})
}
@@ -715,13 +716,22 @@ impl SplitTunnel {
));
self._route_change_callback = None;
+ let moved_context_mutex = context_mutex.clone();
let mut context = context_mutex.lock().unwrap();
- let callback = winnet::add_default_route_change_callback(
- Some(split_tunnel_default_route_change_handler),
- context_mutex.clone(),
- )
- .map(Some)
- .map_err(|_| Error::RegisterRouteChangeCallback)?;
+ let callback = self
+ .runtime
+ .block_on(
+ self.route_manager
+ .add_default_route_change_callback(Box::new(move |event, addr_family| {
+ split_tunnel_default_route_change_handler(
+ event,
+ addr_family,
+ &moved_context_mutex,
+ )
+ })),
+ )
+ .map(Some)
+ .map_err(|_| Error::RegisterRouteChangeCallback)?;
context.initialize_internet_addresses()?;
context.register_ips()?;
@@ -801,16 +811,10 @@ impl SplitTunnelDefaultRouteChangeHandlerContext {
pub fn initialize_internet_addresses(&mut self) -> Result<(), Error> {
// Identify IP address that gives us Internet access
- let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4)
+ let internet_ipv4 = get_best_default_route(AddressFamily::Ipv4)
.map_err(Error::ObtainDefaultRoute)?
.map(|route| {
- get_ip_address_for_interface(
- AddressFamily::Ipv4,
- NET_LUID_LH {
- Value: route.interface_luid,
- },
- )
- .map(|ip| match ip {
+ get_ip_address_for_interface(AddressFamily::Ipv4, route.iface).map(|ip| match ip {
Some(IpAddr::V4(addr)) => Some(addr),
Some(_) => unreachable!("wrong address family (expected IPv4)"),
None => {
@@ -822,16 +826,10 @@ impl SplitTunnelDefaultRouteChangeHandlerContext {
.transpose()
.map_err(Error::LuidToIp)?
.flatten();
- let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6)
+ let internet_ipv6 = get_best_default_route(AddressFamily::Ipv6)
.map_err(Error::ObtainDefaultRoute)?
.map(|route| {
- get_ip_address_for_interface(
- AddressFamily::Ipv6,
- NET_LUID_LH {
- Value: route.interface_luid,
- },
- )
- .map(|ip| match ip {
+ get_ip_address_for_interface(AddressFamily::Ipv6, route.iface).map(|ip| match ip {
Some(IpAddr::V6(addr)) => Some(addr),
Some(_) => unreachable!("wrong address family (expected IPv6)"),
None => {
@@ -851,16 +849,14 @@ impl SplitTunnelDefaultRouteChangeHandlerContext {
}
}
-unsafe extern "system" fn split_tunnel_default_route_change_handler(
- event_type: winnet::WinNetDefaultRouteChangeEventType,
- address_family: WinNetAddrFamily,
- default_route: winnet::WinNetDefaultRoute,
- ctx: *mut libc::c_void,
+fn split_tunnel_default_route_change_handler<'a>(
+ event_type: EventType<'a>,
+ address_family: AddressFamily,
+ ctx_mutex: &Arc<Mutex<SplitTunnelDefaultRouteChangeHandlerContext>>,
) {
- use winnet::WinNetDefaultRouteChangeEventType::*;
+ use crate::routing::EventType::*;
// Update the "internet interface" IP when best default route changes
- let ctx_mutex = &mut *(ctx as *mut Arc<Mutex<SplitTunnelDefaultRouteChangeHandlerContext>>);
let mut ctx = ctx_mutex.lock().expect("ST route handler mutex poisoned");
let daemon_tx = ctx.daemon_tx.upgrade();
@@ -870,16 +866,9 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
}
};
- let translated_family = winnet_to_talpid_family(address_family);
-
let result = match event_type {
- DefaultRouteChanged | DefaultRouteUpdatedDetails => {
- match get_ip_address_for_interface(
- translated_family,
- NET_LUID_LH {
- Value: default_route.interface_luid,
- },
- ) {
+ Updated(default_route) | UpdatedDetails(default_route) => {
+ match get_ip_address_for_interface(address_family, default_route.iface) {
Ok(Some(ip)) => match IpAddr::from(ip) {
IpAddr::V4(addr) => ctx.addresses.internet_ipv4 = Some(addr),
IpAddr::V6(addr) => ctx.addresses.internet_ipv6 = Some(addr),
@@ -887,10 +876,10 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
Ok(None) => {
log::warn!("Failed to obtain default route interface address");
match address_family {
- WinNetAddrFamily::IPV4 => {
+ AddressFamily::Ipv4 => {
ctx.addresses.internet_ipv4 = None;
}
- WinNetAddrFamily::IPV6 => {
+ AddressFamily::Ipv6 => {
ctx.addresses.internet_ipv6 = None;
}
}
@@ -910,12 +899,12 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
ctx.register_ips()
}
// no default route
- DefaultRouteRemoved => {
+ Removed => {
match address_family {
- WinNetAddrFamily::IPV4 => {
+ AddressFamily::Ipv4 => {
ctx.addresses.internet_ipv4 = None;
}
- WinNetAddrFamily::IPV6 => {
+ AddressFamily::Ipv6 => {
ctx.addresses.internet_ipv6 = None;
}
}
@@ -931,10 +920,3 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler(
maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError));
}
}
-
-fn winnet_to_talpid_family(address_family: WinNetAddrFamily) -> AddressFamily {
- match address_family {
- WinNetAddrFamily::IPV4 => AddressFamily::Ipv4,
- WinNetAddrFamily::IPV6 => AddressFamily::Ipv6,
- }
-}
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 302c8003c9..4425c3c69d 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -42,7 +42,7 @@ pub enum Error {
/// Failure in Windows syscall.
#[cfg(windows)]
#[error(display = "Failure in Windows syscall")]
- WinnetError(#[error(source)] crate::winnet::Error),
+ WinnetError(#[error(source)] crate::routing::Error),
/// Running on an operating system which is not supported yet.
#[error(display = "Tunnel type not supported on this operating system")]
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 5e8c0ede49..b982dc148d 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -226,6 +226,8 @@ impl WireguardMonitor {
args.resource_dir,
args.tun_provider,
#[cfg(target_os = "windows")]
+ args.route_manager.clone(),
+ #[cfg(target_os = "windows")]
setup_done_tx,
)?;
let iface_name = tunnel.get_interface_name();
@@ -507,6 +509,7 @@ impl WireguardMonitor {
log_path: Option<&Path>,
resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
+ #[cfg(windows)] route_manager_handle: crate::routing::RouteManagerHandle,
#[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Box<dyn Tunnel>> {
#[cfg(target_os = "linux")]
@@ -576,7 +579,11 @@ impl WireguardMonitor {
#[cfg(not(windows))]
Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes),
#[cfg(windows)]
+ route_manager_handle,
+ #[cfg(windows)]
setup_done_tx,
+ #[cfg(windows)]
+ &runtime,
)
.map_err(Error::TunnelError)?,
))
diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
index 60705d324f..a1ca8be6ba 100644
--- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs
+++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs
@@ -39,9 +39,6 @@ use {
type Result<T> = std::result::Result<T, TunnelError>;
-#[cfg(target_os = "windows")]
-use crate::winnet;
-
#[cfg(not(target_os = "windows"))]
use std::sync::{Arc, Mutex};
@@ -66,7 +63,7 @@ pub struct WgGoTunnel {
// context that maps to fs::File instance, used with logging callback
_logging_context: LoggingContext,
#[cfg(target_os = "windows")]
- _route_callback_handle: Option<crate::winnet::WinNetCallbackHandle>,
+ _route_callback_handle: Option<crate::routing::CallbackHandle>,
#[cfg(target_os = "windows")]
setup_handle: tokio::task::JoinHandle<()>,
}
@@ -117,15 +114,19 @@ impl WgGoTunnel {
pub fn start_tunnel(
config: &Config,
log_path: Option<&Path>,
+ route_manager_handle: crate::tunnel::RouteManagerHandle,
mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>,
+ runtime: &tokio::runtime::Handle,
) -> Result<Self> {
use talpid_types::ErrorExt;
- let route_callback_handle = winnet::add_default_route_change_callback(
- Some(WgGoTunnel::default_route_changed_callback),
- (),
- )
- .ok();
+ let route_callback_handle = runtime
+ .block_on(
+ route_manager_handle.add_default_route_change_callback(Box::new(
+ WgGoTunnel::default_route_changed_callback,
+ )),
+ )
+ .ok();
if route_callback_handle.is_none() {
log::warn!("Failed to register default route callback");
}
@@ -208,25 +209,21 @@ impl WgGoTunnel {
// Callback to be used to rebind the tunnel sockets when the default route changes
#[cfg(target_os = "windows")]
- pub unsafe extern "system" fn default_route_changed_callback(
- event_type: winnet::WinNetDefaultRouteChangeEventType,
- address_family: winnet::WinNetAddrFamily,
- default_route: winnet::WinNetDefaultRoute,
- _ctx: *mut libc::c_void,
+ pub fn default_route_changed_callback<'a>(
+ event_type: crate::routing::EventType<'a>,
+ address_family: crate::windows::AddressFamily,
) {
- use windows_sys::Win32::NetworkManagement::{
- IpHelper::ConvertInterfaceLuidToIndex, Ndis::NET_LUID_LH,
- };
- use winnet::WinNetDefaultRouteChangeEventType::*;
+ use crate::routing::EventType::*;
+ use windows_sys::Win32::NetworkManagement::IpHelper::ConvertInterfaceLuidToIndex;
let iface_idx: u32 = match event_type {
- DefaultRouteChanged => {
+ Updated(default_route) => {
let mut iface_idx = 0u32;
- let iface_luid = NET_LUID_LH {
- Value: default_route.interface_luid,
+ // TODO: Make sure unwrap is fine
+ let iface_luid = default_route.iface;
+ let status = unsafe {
+ ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _)
};
- let status =
- ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _);
if status != 0 {
log::error!(
"Failed to convert interface LUID to interface index: {}: {}",
@@ -238,12 +235,12 @@ impl WgGoTunnel {
iface_idx
}
// if there is no new default route, specify 0 as the interface index
- DefaultRouteRemoved => 0,
+ Removed => 0,
// ignore interface updates that don't affect the interface to use
- DefaultRouteUpdatedDetails => return,
+ UpdatedDetails(_) => return,
};
- wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx);
+ unsafe { wgRebindTunnelSocket(address_family.to_af_family(), iface_idx) };
}
#[cfg(not(target_os = "windows"))]
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 5a83bd6b76..964963d46e 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -29,7 +29,7 @@ use talpid_types::{
};
#[cfg(windows)]
-use crate::{routing, winnet};
+use crate::routing;
#[cfg(target_os = "android")]
use crate::tunnel::tun_provider;
@@ -524,12 +524,7 @@ fn should_retry(error: &tunnel::Error, retry_attempt: u32) -> bool {
#[cfg(windows)]
fn is_recoverable_routing_error(error: &crate::routing::Error) -> bool {
match error {
- routing::Error::AddRoutesFailed(route_error) => match route_error {
- winnet::Error::GetDefaultRoute
- | winnet::Error::GetDeviceByName
- | winnet::Error::GetDeviceByGateway => true,
- _ => false,
- },
+ routing::Error::AddRoutesFailed(_) => true,
_ => false,
}
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index c1b52278f0..5d13b7d1f2 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -258,6 +258,10 @@ impl TunnelStateMachine {
#[cfg(target_os = "windows")]
let power_mgmt_rx = crate::windows::window::PowerManagementListener::new();
+ let route_manager = RouteManager::new(HashSet::new())
+ .await
+ .map_err(Error::InitRouteManagerError)?;
+
#[cfg(windows)]
let split_tunnel = split_tunnel::SplitTunnel::new(
runtime.clone(),
@@ -265,6 +269,9 @@ impl TunnelStateMachine {
args.command_tx.clone(),
volume_update_rx,
power_mgmt_rx.clone(),
+ route_manager
+ .handle()
+ .map_err(Error::InitRouteManagerError)?,
)
.map_err(Error::InitSplitTunneling)?;
@@ -279,9 +286,6 @@ impl TunnelStateMachine {
};
let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?;
- let route_manager = RouteManager::new(HashSet::new())
- .await
- .map_err(Error::InitRouteManagerError)?;
let dns_monitor = DnsMonitor::new(
#[cfg(target_os = "linux")]
runtime.clone(),
@@ -315,6 +319,8 @@ impl TunnelStateMachine {
#[cfg(target_os = "android")]
android_context,
#[cfg(target_os = "windows")]
+ route_manager.handle()?,
+ #[cfg(target_os = "windows")]
power_mgmt_rx,
)
.await
diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs
index d853991707..5504e11d93 100644
--- a/talpid-core/src/windows/mod.rs
+++ b/talpid-core/src/windows/mod.rs
@@ -109,7 +109,7 @@ impl fmt::Display for AddressFamily {
}
impl AddressFamily {
- /// Convert an [`AddressFamily`] to one of the `AF_*` constants.
+ /// Convert one of the `AF_*` constants to an [`AddressFamily`].
pub fn try_from_af_family(family: u16) -> Result<AddressFamily> {
match u32::from(family) {
AF_INET => Ok(AddressFamily::Ipv4),
@@ -117,6 +117,15 @@ impl AddressFamily {
family => Err(Error::UnknownAddressFamily(family)),
}
}
+
+ /// Convert an [`AddressFamily`] to one of the `AF_*` constants.
+ pub fn to_af_family(&self) -> u16 {
+ match self {
+ // These values are both small enough to fit in a u16
+ Self::Ipv4 => u16::try_from(AF_INET).unwrap(),
+ Self::Ipv6 => u16::try_from(AF_INET6).unwrap(),
+ }
+ }
}
/// Context for [`notify_ip_interface_change`]. When it is dropped,
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
deleted file mode 100644
index 9843d873aa..0000000000
--- a/talpid-core/src/winnet.rs
+++ /dev/null
@@ -1,416 +0,0 @@
-use self::api::*;
-use crate::{logging::windows::log_sink, routing::Node};
-use ipnetwork::IpNetwork;
-use libc::c_void;
-use std::{
- convert::TryFrom,
- net::{IpAddr, Ipv4Addr, Ipv6Addr},
- ptr,
-};
-use widestring::WideCString;
-
-/// Errors that this module may produce.
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- /// Supplied interface alias is invalid.
- #[error(display = "Supplied interface alias is invalid")]
- InvalidInterfaceAlias(#[error(source)] widestring::NulError<u16>),
-
- /// Failed to enable IPv6 on the network interface.
- #[error(display = "Failed to enable IPv6 on the network interface")]
- EnableIpv6,
-
- /// Failed to get the current default route.
- #[error(display = "Failed to obtain default route")]
- GetDefaultRoute,
-
- /// Failed to get a network device.
- #[error(display = "Failed to obtain network interface by name")]
- GetDeviceByName,
-
- /// Failed to get a network device.
- #[error(display = "Failed to obtain network interface by gateway")]
- GetDeviceByGateway,
-
- /// Unexpected error while adding routes
- #[error(display = "Winnet returned an error while adding routes")]
- GeneralAddRoutesError,
-
- /// Failed to obtain an IP address given a LUID.
- #[error(display = "Failed to obtain IP address for the given interface")]
- GetIpAddressFromLuid,
-
- /// Failed to read IPv6 status on the TAP network interface.
- #[error(display = "Failed to read IPv6 status on the TAP network interface")]
- GetIpv6Status,
-}
-
-fn logging_context() -> *const u8 {
- b"WinNet\0".as_ptr()
-}
-
-#[derive(Debug, Default, Clone, Copy)]
-#[allow(dead_code)]
-#[repr(u32)]
-pub enum WinNetAddrFamily {
- #[default]
- IPV4 = 0,
- IPV6 = 1,
-}
-
-impl WinNetAddrFamily {
- pub fn to_windows_proto_enum(&self) -> u16 {
- match self {
- Self::IPV4 => 2,
- Self::IPV6 => 23,
- }
- }
-}
-
-#[repr(C)]
-#[derive(Default)]
-pub struct WinNetIp {
- pub addr_family: WinNetAddrFamily,
- pub ip_bytes: [u8; 16],
-}
-
-#[repr(C)]
-#[derive(Default)]
-pub struct WinNetDefaultRoute {
- pub interface_luid: u64,
- pub gateway: WinNetIp,
-}
-
-#[derive(Debug)]
-pub struct WrongIpFamilyError;
-
-impl TryFrom<WinNetIp> for Ipv4Addr {
- type Error = WrongIpFamilyError;
-
- fn try_from(addr: WinNetIp) -> Result<Ipv4Addr, WrongIpFamilyError> {
- match addr.addr_family {
- WinNetAddrFamily::IPV4 => {
- let mut bytes: [u8; 4] = Default::default();
- bytes.clone_from_slice(&addr.ip_bytes[..4]);
- Ok(Ipv4Addr::from(bytes))
- }
- WinNetAddrFamily::IPV6 => Err(WrongIpFamilyError),
- }
- }
-}
-
-impl TryFrom<WinNetIp> for Ipv6Addr {
- type Error = WrongIpFamilyError;
-
- fn try_from(addr: WinNetIp) -> Result<Ipv6Addr, WrongIpFamilyError> {
- match addr.addr_family {
- WinNetAddrFamily::IPV4 => Err(WrongIpFamilyError),
- WinNetAddrFamily::IPV6 => Ok(Ipv6Addr::from(addr.ip_bytes)),
- }
- }
-}
-
-impl From<WinNetIp> for IpAddr {
- fn from(addr: WinNetIp) -> IpAddr {
- match addr.addr_family {
- WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::try_from(addr).unwrap()),
- WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::try_from(addr).unwrap()),
- }
- }
-}
-
-impl From<IpAddr> for WinNetIp {
- fn from(addr: IpAddr) -> WinNetIp {
- let mut bytes = [0u8; 16];
- match addr {
- IpAddr::V4(v4_addr) => {
- bytes[..4].copy_from_slice(&v4_addr.octets());
- WinNetIp {
- addr_family: WinNetAddrFamily::IPV4,
- ip_bytes: bytes,
- }
- }
- IpAddr::V6(v6_addr) => {
- bytes.copy_from_slice(&v6_addr.octets());
-
- WinNetIp {
- addr_family: WinNetAddrFamily::IPV6,
- ip_bytes: bytes,
- }
- }
- }
- }
-}
-
-#[repr(C)]
-pub struct WinNetIpNetwork {
- prefix: u8,
- ip: WinNetIp,
-}
-
-impl From<IpNetwork> for WinNetIpNetwork {
- fn from(network: IpNetwork) -> WinNetIpNetwork {
- WinNetIpNetwork {
- prefix: network.prefix(),
- ip: WinNetIp::from(network.ip()),
- }
- }
-}
-
-#[repr(C)]
-pub struct WinNetNode {
- gateway: *mut WinNetIp,
- device_name: *mut u16,
-}
-
-impl WinNetNode {
- fn new(name: &str, ip: WinNetIp) -> Self {
- let device_name = WideCString::from_str(name)
- .expect("Failed to convert UTF-8 string to null terminated UCS string")
- .into_raw();
- let gateway = Box::into_raw(Box::new(ip));
- Self {
- gateway,
- device_name,
- }
- }
-
- fn from_gateway(ip: WinNetIp) -> Self {
- let gateway = Box::into_raw(Box::new(ip));
- Self {
- gateway,
- device_name: ptr::null_mut(),
- }
- }
-
- fn from_device(name: &str) -> Self {
- let device_name = WideCString::from_str(name)
- .expect("Failed to convert UTF-8 string to null terminated UCS string")
- .into_raw();
- Self {
- gateway: ptr::null_mut(),
- device_name,
- }
- }
-}
-
-impl From<&Node> for WinNetNode {
- fn from(node: &Node) -> Self {
- match (node.get_address(), node.get_device()) {
- (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()),
- (None, Some(device)) => WinNetNode::from_device(device),
- (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()),
- _ => unreachable!(),
- }
- }
-}
-
-impl Drop for WinNetNode {
- fn drop(&mut self) {
- if !self.gateway.is_null() {
- unsafe {
- let _ = Box::from_raw(self.gateway);
- }
- }
- if !self.device_name.is_null() {
- unsafe {
- let _ = WideCString::from_ptr_str(self.device_name);
- }
- }
- }
-}
-
-#[repr(C)]
-pub struct WinNetRoute {
- gateway: WinNetIpNetwork,
- node: *mut WinNetNode,
-}
-
-impl WinNetRoute {
- pub fn through_default_node(gateway: WinNetIpNetwork) -> Self {
- Self {
- gateway,
- node: ptr::null_mut(),
- }
- }
-
- pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self {
- let node = Box::into_raw(Box::new(node));
- Self { gateway, node }
- }
-}
-
-impl Drop for WinNetRoute {
- fn drop(&mut self) {
- if !self.node.is_null() {
- unsafe {
- let _ = Box::from_raw(self.node);
- }
- self.node = ptr::null_mut();
- }
- }
-}
-
-pub fn activate_routing_manager() -> bool {
- unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) }
-}
-
-pub struct WinNetCallbackHandle {
- handle: *mut libc::c_void,
- // Allows us to keep the context pointer alive.
- _context: Box<dyn std::any::Any>,
-}
-
-unsafe impl Send for WinNetCallbackHandle {}
-
-impl Drop for WinNetCallbackHandle {
- fn drop(&mut self) {
- unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) };
- }
-}
-
-#[derive(Debug, Clone, Copy, PartialEq)]
-#[allow(dead_code)]
-#[repr(u16)]
-pub enum WinNetDefaultRouteChangeEventType {
- DefaultRouteChanged = 0,
- DefaultRouteUpdatedDetails = 1,
- DefaultRouteRemoved = 2,
-}
-
-pub type DefaultRouteChangedCallback = unsafe extern "system" fn(
- event_type: WinNetDefaultRouteChangeEventType,
- family: WinNetAddrFamily,
- default_route: WinNetDefaultRoute,
- ctx: *mut c_void,
-);
-
-#[derive(err_derive::Error, Debug)]
-#[error(display = "Failed to set callback for default route")]
-pub struct DefaultRouteCallbackError;
-
-pub fn add_default_route_change_callback<T: 'static>(
- callback: Option<DefaultRouteChangedCallback>,
- context: T,
-) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> {
- let mut handle_ptr = ptr::null_mut();
- let mut context = Box::new(context);
- let ctx_ptr = &mut *context as *mut T as *mut libc::c_void;
- unsafe {
- if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _)
- {
- return Err(DefaultRouteCallbackError);
- }
-
- Ok(WinNetCallbackHandle {
- handle: handle_ptr,
- _context: context,
- })
- }
-}
-
-pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> Result<(), Error> {
- let ptr = routes.as_ptr();
- let length: u32 = routes.len() as u32;
- match unsafe { WinNet_AddRoutes(ptr, length) } {
- WinNetAddRouteStatus::Success => Ok(()),
- WinNetAddRouteStatus::GeneralError => Err(Error::GeneralAddRoutesError),
- WinNetAddRouteStatus::NoDefaultRoute => Err(Error::GetDefaultRoute),
- WinNetAddRouteStatus::NameNotFound => Err(Error::GetDeviceByName),
- WinNetAddRouteStatus::GatewayNotFound => Err(Error::GetDeviceByGateway),
- }
-}
-
-pub fn routing_manager_delete_applied_routes() -> bool {
- unsafe { WinNet_DeleteAppliedRoutes() }
-}
-
-pub fn deactivate_routing_manager() {
- unsafe { WinNet_DeactivateRouteManager() }
-}
-
-pub fn get_best_default_route(
- family: WinNetAddrFamily,
-) -> Result<Option<WinNetDefaultRoute>, Error> {
- let mut default_route = WinNetDefaultRoute::default();
- match unsafe {
- WinNet_GetBestDefaultRoute(
- family,
- &mut default_route as *mut _,
- Some(log_sink),
- logging_context(),
- )
- } {
- WinNetStatus::Success => Ok(Some(default_route)),
- WinNetStatus::NotFound => Ok(None),
- WinNetStatus::Failure => Err(Error::GetDefaultRoute),
- }
-}
-
-#[allow(non_snake_case)]
-mod api {
- use super::DefaultRouteChangedCallback;
- use crate::logging::windows::LogSink;
-
- #[allow(dead_code)]
- #[repr(u32)]
- pub enum WinNetStatus {
- Success = 0,
- NotFound = 1,
- Failure = 2,
- }
-
- #[allow(dead_code)]
- #[repr(u32)]
- pub enum WinNetAddRouteStatus {
- Success = 0,
- GeneralError = 1,
- NoDefaultRoute = 2,
- NameNotFound = 3,
- GatewayNotFound = 4,
- }
-
- extern "system" {
- #[link_name = "WinNet_ActivateRouteManager"]
- pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *const u8) -> bool;
-
- #[link_name = "WinNet_AddRoutes"]
- pub fn WinNet_AddRoutes(
- routes: *const super::WinNetRoute,
- num_routes: u32,
- ) -> WinNetAddRouteStatus;
-
- // #[link_name = "WinNet_AddRoute"]
- // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> WinNetAddRouteStatus;
-
- // #[link_name = "WinNet_DeleteRoutes"]
- // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool;
-
- // #[link_name = "WinNet_DeleteRoute"]
- // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool;
-
- #[link_name = "WinNet_DeleteAppliedRoutes"]
- pub fn WinNet_DeleteAppliedRoutes() -> bool;
-
- #[link_name = "WinNet_DeactivateRouteManager"]
- pub fn WinNet_DeactivateRouteManager();
-
- #[link_name = "WinNet_GetBestDefaultRoute"]
- pub fn WinNet_GetBestDefaultRoute(
- family: super::WinNetAddrFamily,
- default_route: *mut super::WinNetDefaultRoute,
- sink: Option<LogSink>,
- sink_context: *const u8,
- ) -> WinNetStatus;
-
- #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"]
- pub fn WinNet_RegisterDefaultRouteChangedCallback(
- callback: Option<DefaultRouteChangedCallback>,
- callbackContext: *mut libc::c_void,
- registrationHandle: *mut *mut libc::c_void,
- ) -> bool;
-
- #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"]
- pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void);
- }
-}