diff options
| author | Jonathan <jonathan@mullvad.net> | 2022-09-06 10:06:17 +0200 |
|---|---|---|
| committer | Jonathan <jonathan@mullvad.net> | 2022-10-18 14:42:25 +0200 |
| commit | d3f7ed493ecc182e582697f7ac03b1b30a2c2f52 (patch) | |
| tree | cf4b0e155279b2dad9f54468393e060584950ecb | |
| parent | 7bb08d42243ec05f04bc7691d9d2bb3156f10fed (diff) | |
| download | mullvadvpn-d3f7ed493ecc182e582697f7ac03b1b30a2c2f52.tar.xz mullvadvpn-d3f7ed493ecc182e582697f7ac03b1b30a2c2f52.zip | |
Port winnet from C++ to Rust
Remove all of the C++ code in the winnet module and write an almost
equivalent route manager in rust.
| -rw-r--r-- | Cargo.lock | 16 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 17 | ||||
| -rw-r--r-- | talpid-core/src/lib.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/offline/mod.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/offline/windows.rs | 60 | ||||
| -rw-r--r-- | talpid-core/src/routing/mod.rs | 4 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 221 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows/default_route_monitor.rs | 451 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows/get_best_default_route.rs | 190 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows/mod.rs | 303 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows/route_manager.rs | 885 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 94 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 7 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/wireguard_go.rs | 49 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 9 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/windows/mod.rs | 11 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 416 |
19 files changed, 1990 insertions, 766 deletions
diff --git a/Cargo.lock b/Cargo.lock index 77a447ec8b..16933b2567 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3161,8 +3161,9 @@ dependencies = [ "tunnel-obfuscation", "uuid", "which", - "widestring 0.5.1", + "widestring 1.0.2", "winapi", + "windows", "windows-service", "windows-sys 0.42.0", "winreg", @@ -3973,6 +3974,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] +name = "windows" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53b97a83176b369b0eb2fd8158d4ae215357d02df9d40c1e1bf1879c5482c80" +dependencies = [ + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", +] + +[[package]] name = "windows-service" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index d1167e16fc..97bdf2db4b 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -79,13 +79,28 @@ subslice = "0.2" [target.'cfg(windows)'.dependencies] -widestring = "0.5" +widestring = "1.0" winreg = { version = "0.7", features = ["transactions"] } winapi = { version = "0.3.6", features = ["ws2def"] } talpid-platform-metadata = { path = "../talpid-platform-metadata" } memoffset = "0.6" windows-service = "0.5.0" +[target.'cfg(windows)'.dependencies.windows] +version = "0.36.1" +features = [ + "Data_Xml_Dom", + "Win32_Foundation", + "Win32_Security", + "Win32_System_Threading", + "Win32_UI_WindowsAndMessaging", + "Win32_NetworkManagement", + "Win32_NetworkManagement_IpHelper", + "Win32_NetworkManagement_Ndis", + "Win32_Foundation", + "Win32_Networking_WinSock", +] + [target.'cfg(windows)'.dependencies.windows-sys] version = "0.42.0" features = [ diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs index 73c3293fb4..46bb4c1169 100644 --- a/talpid-core/src/lib.rs +++ b/talpid-core/src/lib.rs @@ -9,10 +9,6 @@ #[macro_use] mod ffi; -/// Misc networking functions for Windows. -#[cfg(windows)] -mod winnet; - /// Windows API wrappers and utilities #[cfg(target_os = "windows")] pub mod windows; diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs index b07fb3d8c9..3c5448762b 100644 --- a/talpid-core/src/offline/mod.rs +++ b/talpid-core/src/offline/mod.rs @@ -1,4 +1,4 @@ -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "windows"))] use crate::routing::RouteManagerHandle; #[cfg(target_os = "windows")] use crate::windows::window::PowerManagementListener; @@ -46,6 +46,7 @@ pub async fn spawn_monitor( sender: UnboundedSender<bool>, #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(target_os = "windows")] route_manager: RouteManagerHandle, #[cfg(target_os = "windows")] power_mgmt_rx: PowerManagementListener, ) -> Result<MonitorHandle, Error> { let monitor = if !*FORCE_DISABLE_OFFLINE_MONITOR { @@ -57,6 +58,8 @@ pub async fn spawn_monitor( #[cfg(target_os = "android")] android_context, #[cfg(target_os = "windows")] + route_manager, + #[cfg(target_os = "windows")] power_mgmt_rx, ) .await?, diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs index bbe9d951a9..9dc341d03a 100644 --- a/talpid-core/src/offline/windows.rs +++ b/talpid-core/src/offline/windows.rs @@ -1,11 +1,13 @@ use crate::{ - windows::window::{PowerManagementEvent, PowerManagementListener}, - winnet, + routing::{get_best_default_route, CallbackHandle, EventType, RouteManagerHandle}, + windows::{ + window::{PowerManagementEvent, PowerManagementListener}, + AddressFamily, + }, }; use futures::channel::mpsc::UnboundedSender; use parking_lot::Mutex; use std::{ - ffi::c_void, io, sync::{Arc, Weak}, time::Duration, @@ -17,20 +19,21 @@ pub enum Error { #[error(display = "Unable to create listener thread")] ThreadCreationError(#[error(source)] io::Error), #[error(display = "Failed to start connectivity monitor")] - ConnectivityMonitorError(#[error(source)] winnet::DefaultRouteCallbackError), + ConnectivityMonitorError(#[error(source)] crate::routing::Error), } pub struct BroadcastListener { system_state: Arc<Mutex<SystemState>>, - _callback_handle: winnet::WinNetCallbackHandle, + _callback_handle: CallbackHandle, _notify_tx: Arc<UnboundedSender<bool>>, } unsafe impl Send for BroadcastListener {} impl BroadcastListener { - pub fn start( + pub async fn start( notify_tx: UnboundedSender<bool>, + route_manager_handle: RouteManagerHandle, mut power_mgmt_rx: PowerManagementListener, ) -> Result<Self, Error> { let notify_tx = Arc::new(notify_tx); @@ -66,7 +69,8 @@ impl BroadcastListener { }); let callback_handle = - unsafe { Self::setup_network_connectivity_listener(system_state.clone())? }; + Self::setup_network_connectivity_listener(system_state.clone(), route_manager_handle) + .await?; Ok(BroadcastListener { system_state, @@ -76,7 +80,7 @@ impl BroadcastListener { } fn check_initial_connectivity() -> (bool, bool) { - let v4_connectivity = winnet::get_best_default_route(winnet::WinNetAddrFamily::IPV4) + let v4_connectivity = get_best_default_route(AddressFamily::Ipv4) .map(|route| route.is_some()) .unwrap_or_else(|error| { log::error!( @@ -85,7 +89,7 @@ impl BroadcastListener { ); true }); - let v6_connectivity = winnet::get_best_default_route(winnet::WinNetAddrFamily::IPV6) + let v6_connectivity = get_best_default_route(AddressFamily::Ipv6) .map(|route| route.is_some()) .unwrap_or_else(|error| { log::error!( @@ -103,34 +107,35 @@ impl BroadcastListener { /// The caller must make sure the `system_state` reference is valid /// until after `WinNet_DeactivateConnectivityMonitor` has been called. - unsafe fn setup_network_connectivity_listener( + async fn setup_network_connectivity_listener( system_state: Arc<Mutex<SystemState>>, - ) -> Result<winnet::WinNetCallbackHandle, Error> { - let change_handle = winnet::add_default_route_change_callback( - Some(Self::connectivity_callback), - system_state, - )?; + route_manager_handle: RouteManagerHandle, + ) -> Result<CallbackHandle, Error> { + let change_handle = route_manager_handle + .add_default_route_change_callback(Box::new(move |event, addr_family| { + Self::connectivity_callback(event, addr_family, &system_state) + })) + .await + .map_err(|e| Error::ConnectivityMonitorError(e))?; Ok(change_handle) } - unsafe extern "system" fn connectivity_callback( - event_type: winnet::WinNetDefaultRouteChangeEventType, - family: winnet::WinNetAddrFamily, - _default_route: winnet::WinNetDefaultRoute, - ctx: *mut c_void, + fn connectivity_callback<'a>( + event_type: EventType<'a>, + family: AddressFamily, + state_lock: &Arc<Mutex<SystemState>>, ) { - use winnet::WinNetDefaultRouteChangeEventType::*; + use crate::routing::EventType::*; - if event_type == DefaultRouteUpdatedDetails { + if matches!(event_type, UpdatedDetails(_)) { // ignore changes that don't affect the route return; } - let state_lock: &mut Arc<Mutex<SystemState>> = &mut *(ctx as *mut _); - let connectivity = event_type != DefaultRouteRemoved; + let connectivity = event_type != Removed; let change = match family { - winnet::WinNetAddrFamily::IPV4 => StateChange::NetworkV4Connectivity(connectivity), - winnet::WinNetAddrFamily::IPV6 => StateChange::NetworkV6Connectivity(connectivity), + AddressFamily::Ipv4 => StateChange::NetworkV4Connectivity(connectivity), + AddressFamily::Ipv6 => StateChange::NetworkV6Connectivity(connectivity), }; let mut state = state_lock.lock(); state.apply_change(change); @@ -202,9 +207,10 @@ pub type MonitorHandle = BroadcastListener; pub async fn spawn_monitor( sender: UnboundedSender<bool>, + route_manager_handle: RouteManagerHandle, power_mgmt_rx: PowerManagementListener, ) -> Result<MonitorHandle, Error> { - BroadcastListener::start(sender, power_mgmt_rx) + BroadcastListener::start(sender, route_manager_handle, power_mgmt_rx).await } fn apply_system_state_change(state: Arc<Mutex<SystemState>>, change: StateChange) { diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs index 1eb02a206b..5d1247618e 100644 --- a/talpid-core/src/routing/mod.rs +++ b/talpid-core/src/routing/mod.rs @@ -5,8 +5,10 @@ use ipnetwork::IpNetwork; use std::{fmt, net::IpAddr}; #[cfg(target_os = "windows")] -#[path = "windows.rs"] +#[path = "windows/mod.rs"] mod imp; +#[cfg(target_os = "windows")] +pub use imp::{get_best_default_route, CallbackHandle, EventType, InterfaceAndGateway}; #[cfg(not(target_os = "windows"))] #[path = "unix.rs"] diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs deleted file mode 100644 index fad8540ecd..0000000000 --- a/talpid-core/src/routing/windows.rs +++ /dev/null @@ -1,221 +0,0 @@ -use super::NetNode; -use crate::{routing::RequiredRoute, winnet}; -use futures::{ - channel::{ - mpsc::{self, UnboundedReceiver, UnboundedSender}, - oneshot, - }, - StreamExt, -}; -use std::{collections::HashSet, net::IpAddr}; -use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH; -use winnet::WinNetAddrFamily; - -/// Windows routing errors. -#[derive(err_derive::Error, Debug)] -pub enum Error { - /// The sender was dropped unexpectedly -- possible panic - #[error(display = "The channel sender was dropped")] - ManagerChannelDown, - /// Failure to initialize route manager - #[error(display = "Failed to start route manager")] - FailedToStartManager, - /// Failure to add routes - #[error(display = "Failed to add routes")] - AddRoutesFailed(#[error(source)] winnet::Error), - /// Failure to clear routes - #[error(display = "Failed to clear applied routes")] - ClearRoutesFailed, - /// WinNet returned an error while adding default route callback - #[error(display = "Failed to set callback for default route")] - FailedToAddDefaultRouteCallback, - /// Attempt to use route manager that has been dropped - #[error(display = "Cannot send message to route manager since it is down")] - RouteManagerDown, - /// Something went wrong when getting the mtu of the interface - #[error(display = "Could not get the mtu of the interface")] - GetMtu, -} - -pub type Result<T> = std::result::Result<T, Error>; - -/// Manages routes by calling into WinNet -pub struct RouteManager { - manage_tx: Option<UnboundedSender<RouteManagerCommand>>, -} - -/// Handle to a route manager. -#[derive(Clone)] -pub struct RouteManagerHandle { - tx: UnboundedSender<RouteManagerCommand>, -} - -impl RouteManagerHandle { - /// Applies the given routes while the route manager is running. - pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - self.tx - .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) - .map_err(|_| Error::RouteManagerDown)?; - response_rx.await.map_err(|_| Error::ManagerChannelDown)? - } - - /// Applies the given routes while the route manager is running. - pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> { - let (response_tx, response_rx) = oneshot::channel(); - self.tx - .unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx)) - .map_err(|_| Error::RouteManagerDown)?; - response_rx.await.map_err(|_| Error::ManagerChannelDown)? - } -} - -#[derive(Debug)] -pub enum RouteManagerCommand { - AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), - GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>), - Shutdown, -} - -impl RouteManager { - /// Creates a new route manager that will apply the provided routes and ensure they exist until - /// it's stopped. - pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { - if !winnet::activate_routing_manager() { - return Err(Error::FailedToStartManager); - } - let (manage_tx, manage_rx) = mpsc::unbounded(); - let manager = Self { - manage_tx: Some(manage_tx), - }; - tokio::spawn(RouteManager::listen(manage_rx)); - manager.add_routes(required_routes).await?; - - Ok(manager) - } - - /// Retrieve a sender directly to the command channel. - pub fn handle(&self) -> Result<RouteManagerHandle> { - if let Some(tx) = &self.manage_tx { - Ok(RouteManagerHandle { tx: tx.clone() }) - } else { - Err(Error::RouteManagerDown) - } - } - - async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) { - while let Some(command) = manage_rx.next().await { - match command { - RouteManagerCommand::AddRoutes(routes, tx) => { - let routes: Vec<_> = routes - .iter() - .map(|route| { - let destination = winnet::WinNetIpNetwork::from(route.prefix); - match &route.node { - NetNode::DefaultNode => { - winnet::WinNetRoute::through_default_node(destination) - } - NetNode::RealNode(node) => winnet::WinNetRoute::new( - winnet::WinNetNode::from(node), - destination, - ), - } - }) - .collect(); - - let _ = tx.send( - winnet::routing_manager_add_routes(&routes).map_err(Error::AddRoutesFailed), - ); - } - RouteManagerCommand::GetMtuForRoute(ip, tx) => { - let addr_family = if ip.is_ipv4() { - winnet::WinNetAddrFamily::IPV4 - } else { - winnet::WinNetAddrFamily::IPV6 - }; - let res = match get_mtu_for_route(addr_family) { - Ok(Some(mtu)) => Ok(mtu), - Ok(None) => Err(Error::GetMtu), - Err(e) => Err(e), - }; - let _ = tx.send(res); - } - RouteManagerCommand::Shutdown => { - break; - } - } - } - } - - /// Stops the routing manager and invalidates the route manager - no new default route callbacks - /// can be added - pub fn stop(&mut self) { - if let Some(tx) = self.manage_tx.take() { - if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() { - log::error!("RouteManager channel already down or thread panicked"); - } - - winnet::deactivate_routing_manager(); - } - } - - /// Applies the given routes until [`RouteManager::stop`] is called. - pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - if let Some(tx) = &self.manage_tx { - let (result_tx, result_rx) = oneshot::channel(); - if tx - .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) - .is_err() - { - return Err(Error::RouteManagerDown); - } - result_rx.await.map_err(|_| Error::ManagerChannelDown)? - } else { - Err(Error::RouteManagerDown) - } - } - - /// Removes all routes previously applied in [`RouteManager::new`] or - /// [`RouteManager::add_routes`]. - pub fn clear_routes(&self) -> Result<()> { - if winnet::routing_manager_delete_applied_routes() { - Ok(()) - } else { - Err(Error::ClearRoutesFailed) - } - } -} - -fn get_mtu_for_route(addr_family: WinNetAddrFamily) -> Result<Option<u16>> { - use crate::windows::AddressFamily; - match winnet::get_best_default_route(addr_family) { - Ok(Some(route)) => { - let addr_family = match addr_family { - WinNetAddrFamily::IPV4 => AddressFamily::Ipv4, - WinNetAddrFamily::IPV6 => AddressFamily::Ipv6, - }; - let luid = NET_LUID_LH { - Value: route.interface_luid, - }; - let interface_row = crate::windows::get_ip_interface_entry(addr_family, &luid) - .map_err(|e| { - log::error!("Could not get ip interface entry: {}", e); - Error::GetMtu - })?; - let mtu = interface_row.NlMtu; - let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?; - Ok(Some(mtu)) - } - Ok(None) => Ok(None), - Err(e) => { - log::error!("Could not get best default route: {}", e); - Err(Error::GetMtu) - } - } -} - -impl Drop for RouteManager { - fn drop(&mut self) { - self.stop(); - } -} diff --git a/talpid-core/src/routing/windows/default_route_monitor.rs b/talpid-core/src/routing/windows/default_route_monitor.rs new file mode 100644 index 0000000000..3976903f11 --- /dev/null +++ b/talpid-core/src/routing/windows/default_route_monitor.rs @@ -0,0 +1,451 @@ +use super::{ + get_best_default_route, get_best_default_route::route_has_gateway, AddressFamily, Error, + InterfaceAndGateway, Result, +}; + +use std::{ + ffi::c_void, + io, + sync::{ + mpsc::{channel, RecvTimeoutError, Sender}, + Arc, Mutex, + }, + time::{Duration, Instant}, +}; +use windows_sys::Win32::{ + Foundation::{BOOLEAN, HANDLE, NO_ERROR}, + NetworkManagement::{ + IpHelper::{ + CancelMibChangeNotify2, ConvertInterfaceLuidToIndex, NotifyIpInterfaceChange, + NotifyRouteChange2, NotifyUnicastIpAddressChange, MIB_IPFORWARD_ROW2, + MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE, MIB_UNICASTIPADDRESS_ROW, + }, + Ndis::NET_LUID_LH, + }, +}; + +const WIN_FALSE: BOOLEAN = 0; + +struct DefaultRouteMonitorContext { + callback: Box<dyn for<'a> Fn(EventType<'a>) + Send + 'static>, + refresh_current_route: bool, + family: AddressFamily, + best_route: Option<InterfaceAndGateway>, +} + +impl DefaultRouteMonitorContext { + fn new( + callback: Box<dyn for<'a> Fn(EventType<'a>) + Send + 'static>, + family: AddressFamily, + ) -> Self { + Self { + callback, + best_route: None, + refresh_current_route: false, + family, + } + } + + fn update_refresh_flag(&mut self, luid: &NET_LUID_LH, index: u32) { + if let Some(best_route) = &self.best_route { + // SAFETY: luid is a union but both fields are finally represented by u64, as such any + // access is valid + if unsafe { luid.Value } == unsafe { best_route.iface.Value } { + self.refresh_current_route = true; + return; + } + // SAFETY: luid is a union but both fields are finally represented by u64, as such any + // access is valid + if unsafe { luid.Value } != 0 { + return; + } + + let mut default_interface_index = 0; + let route_luid = best_route.iface; + // SAFETY: No clear safety specifications + if NO_ERROR as i32 + == unsafe { ConvertInterfaceLuidToIndex(&route_luid, &mut default_interface_index) } + { + self.refresh_current_route = index == default_interface_index; + } else { + self.refresh_current_route = true; + } + } + } + + fn evaluate_routes(&mut self) { + let refresh_current = self.refresh_current_route; + self.refresh_current_route = false; + + let current_best_route = get_best_default_route(self.family).ok().flatten(); + + match (&self.best_route, current_best_route) { + (None, None) => (), + (None, Some(current_best_route)) => { + self.best_route = Some(current_best_route); + (self.callback)(EventType::Updated(&self.best_route.as_ref().unwrap())); + } + (Some(_), None) => { + self.best_route = None; + (self.callback)(EventType::Removed); + } + (Some(best_route), Some(current_best_route)) => { + if best_route != ¤t_best_route { + self.best_route = Some(current_best_route); + (self.callback)(EventType::Updated(&self.best_route.as_ref().unwrap())); + } else if refresh_current { + (self.callback)(EventType::UpdatedDetails( + &self.best_route.as_ref().unwrap(), + )); + } + } + } + } +} + +pub struct DefaultRouteMonitor { + // SAFETY: These handles must be dropped before the context. This will happen automatically if + // it is handled by DefaultRouteMonitors drop implementation + notify_change_handles: Option<(NotifyChangeHandle, NotifyChangeHandle, NotifyChangeHandle)>, + // SAFETY: Context must be dropped after all of the notifier handles have been dropped in order + // to guarantee none of them use its pointer. This will be dropped by DefaultRouteMonitors + // drop implementation. SAFETY: The content of this pointer is not allowed to be mutated at + // any point except for in the drop implementation + context: *const ContextAndBurstGuard, +} + +/// SAFETY: DefaultRouteMonitor is `Send` since `NotifyChangeHandle` is `Send` and +/// `ContextAndBurstGuard` is `Sync` as it holds Mutex<T> and Arc<Mutex<T>> fields. +unsafe impl Send for DefaultRouteMonitor {} + +impl Drop for DefaultRouteMonitor { + fn drop(&mut self) { + drop(self.notify_change_handles.take()); + // SAFETY: This pointer was created by Box::into_raw and is not modified since then. + // This drop function is also only called once + let context = unsafe { Box::from_raw(self.context as *mut ContextAndBurstGuard) }; + + // Stop the burst guard + context.burst_guard.lock().unwrap().stop(); + + // Drop the context now that we are guaranteed nothing might try to access the context + drop(context); + } +} + +struct NotifyChangeHandle(HANDLE); + +/// SAFETY: NotifyChangeHandle is `Send` since it holds sole ownership of a pointer provided by C +unsafe impl Send for NotifyChangeHandle {} + +impl Drop for NotifyChangeHandle { + fn drop(&mut self) { + // SAFETY: There is no clear safety specification on this function. However self.0 should + // point to a handle that has been allocated by windows and should be non-null. Even + // if it would be null that would cause a panic rather than UB. + unsafe { + if NO_ERROR as i32 != CancelMibChangeNotify2(self.0) { + // If this callback is called after we free the context that could result in UB, in + // order to avoid that we panic. + panic!("Could not cancel change notification callback") + } + } + } +} + +#[derive(PartialEq, Clone, Copy)] +/// The type of route update passed to the callback +pub enum EventType<'a> { + /// New route + Updated(&'a InterfaceAndGateway), + /// Updated details of the same old route + UpdatedDetails(&'a InterfaceAndGateway), + /// Route removed + Removed, +} + +// SAFETY: This struct must be `Sync` otherwise it is not allowed to be sent between threads. +// Having only `Mutex<T>` or `Arc<Mutex<T>>` fields guarantees that it is `Sync` +struct ContextAndBurstGuard { + context: Arc<Mutex<DefaultRouteMonitorContext>>, + burst_guard: Mutex<BurstGuard>, +} + +impl DefaultRouteMonitor { + pub fn new<F: for<'a> Fn(EventType<'a>) + Send + 'static>( + family: AddressFamily, + callback: F, + ) -> Result<Self> { + let context = Arc::new(Mutex::new(DefaultRouteMonitorContext::new( + Box::new(callback), + family, + ))); + + let moved_context = context.clone(); + let burst_guard = Mutex::new(BurstGuard::new(move || { + moved_context.lock().unwrap().evaluate_routes(); + })); + + // SAFETY: We need to send the ContextAndBurstGuard to the windows notification functions as + // a raw pointer. This imposes the requirement it is not mutated or dropped until + // after those notifications are guaranteed to not run. This happens when the + // DefaultRouteMonitor is dropped and not before then. It also imposes the requirement that + // ContextAndBurstGuard is `Sync` since we will send references to it to other + // threads. This requirement is fullfilled since all fields of `ContextAndBurstGuard` are + // wrapped in either a Arc<Mutex> or Mutex. + let context_and_burst = Box::into_raw(Box::new(ContextAndBurstGuard { + context, + burst_guard, + })) as *const _; + + let handles = match Self::register_callbacks(family, context_and_burst) { + Ok(handles) => handles, + Err(e) => { + // Clean up the memory leak in case of error + // SAFETY: We created context_and_burst from `Box::into_raw()` and it has not been + // modified since. All of the handles have been freed at this point + // so there will be no risk of UAF. + drop(unsafe { Box::from_raw(context_and_burst as *mut ContextAndBurstGuard) }); + return Err(e); + } + }; + + let monitor = Self { + context: context_and_burst, + notify_change_handles: Some(handles), + }; + + // We must set the best default route after we have registered listeners in order to avoid + // race conditions. + { + // SAFETY: `monitor.context` will be valid since monitor will handle dropping it. No + // mutation happens here since we are using a Mutex. + let context = &unsafe { &*(monitor.context) }.context; + let mut context = context.lock().unwrap(); + context.best_route = get_best_default_route(context.family)?; + } + + Ok(monitor) + } + + fn register_callbacks( + family: AddressFamily, + context_and_burst: *const ContextAndBurstGuard, + ) -> Result<(NotifyChangeHandle, NotifyChangeHandle, NotifyChangeHandle)> { + let family = family.to_af_family(); + + // We must provide a raw pointer that points to the context that will be used in the + // callbacks. We provide a Mutex for the state turned into a Weak pointer turned + // into a raw pointer in order to not have to manually deallocate the memory after + // we cancel the callbacks. This will leak the weak pointer but the context state itself + // will be correctly dropped when DefaultRouteManager is dropped. + let context_ptr = context_and_burst; + let mut handle_ptr = 0; + // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle + // has not been dropped. + let status = unsafe { + NotifyRouteChange2( + family, + Some(route_change_callback), + context_ptr as *const _, + WIN_FALSE, + &mut handle_ptr, + ) + }; + + if NO_ERROR as i32 != status { + return Err(Error::RegisterNotifyRouteCallback( + io::Error::from_raw_os_error(status), + )); + } + let notify_route_change_handle = NotifyChangeHandle(handle_ptr); + + let mut handle_ptr = 0; + // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle + // has not been dropped. + let status = unsafe { + NotifyIpInterfaceChange( + family, + Some(interface_change_callback), + context_ptr as *const _, + WIN_FALSE, + &mut handle_ptr, + ) + }; + if NO_ERROR as i32 != status { + return Err(Error::RegisterNotifyIpInterfaceCallback( + io::Error::from_raw_os_error(status), + )); + } + let notify_interface_change_handle = NotifyChangeHandle(handle_ptr); + + let mut handle_ptr = 0; + // SAFETY: No clear safety specifications, context_ptr must be valid for as long as handle + // has not been dropped. + let status = unsafe { + NotifyUnicastIpAddressChange( + family, + Some(ip_address_change_callback), + context_ptr as *const _, + WIN_FALSE, + &mut handle_ptr, + ) + }; + if NO_ERROR as i32 != status { + return Err(Error::RegisterNotifyUnicastIpAddressCallback( + io::Error::from_raw_os_error(status), + )); + } + let notify_address_change_handle = NotifyChangeHandle(handle_ptr); + + Ok(( + notify_route_change_handle, + notify_interface_change_handle, + notify_address_change_handle, + )) + } +} + +// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference. +// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed +// to not be called again. +unsafe extern "system" fn route_change_callback( + context: *const c_void, + row: *const MIB_IPFORWARD_ROW2, + _notification_type: MIB_NOTIFICATION_TYPE, +) { + // SAFETY: We assume Windows provides this pointer correctly + let row = &*row; + + if row.DestinationPrefix.PrefixLength != 0 || !route_has_gateway(row) { + return; + } + + // SAFETY: context must not be dropped or modified until this callback has been cancelled. + let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard); + let mut context = context_and_burst.context.lock().unwrap(); + + context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex); + context_and_burst.burst_guard.lock().unwrap().trigger(); +} + +// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference. +// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed +// to not be called again. +unsafe extern "system" fn interface_change_callback( + context: *const c_void, + row: *const MIB_IPINTERFACE_ROW, + _notification_type: MIB_NOTIFICATION_TYPE, +) { + // SAFETY: We assume Windows provides this pointer correctly + let row = &*row; + + // SAFETY: context must not be dropped or modified until this callback has been cancelled. + let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard); + let mut context = context_and_burst.context.lock().unwrap(); + + context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex); + context_and_burst.burst_guard.lock().unwrap().trigger(); +} + +// SAFETY: `context` is a Box::into_raw() pointer which may only be used as a non-mutable reference. +// It is guaranteed by the DefaultRouteMonitor to not be dropped before this function is guaranteed +// to not be called again. +unsafe extern "system" fn ip_address_change_callback( + context: *const c_void, + row: *const MIB_UNICASTIPADDRESS_ROW, + _notification_type: MIB_NOTIFICATION_TYPE, +) { + // SAFETY: We assume Windows provides this pointer correctly + let row = &*row; + + // SAFETY: context must not be dropped or modified until this callback has been cancelled. + let context_and_burst: &ContextAndBurstGuard = &*(context as *const ContextAndBurstGuard); + let mut context = context_and_burst.context.lock().unwrap(); + + context.update_refresh_flag(&row.InterfaceLuid, row.InterfaceIndex); + context_and_burst.burst_guard.lock().unwrap().trigger(); +} + +/// BurstGuard is a wrapper for a function that protects that function from being called too many +/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point +/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it +/// will call the wrapped function. If another call to `trigger` is made during this wait then it +/// will wait another `buffer_period`, this happens over and over until either +/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in +/// `buffer_period`. At which point the wrapped function will be called. +struct BurstGuard { + sender: Sender<BurstGuardEvent>, +} + +enum BurstGuardEvent { + Trigger, + Shutdown(Sender<()>), +} + +impl BurstGuard { + fn new<F: Fn() + Send + 'static>(callback: F) -> Self { + /// This is the period of time the `BurstGuard` will wait for a new trigger to be sent + /// before it calls the callback. + const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200); + /// This is the longest period that the `BurstGuard` will wait from the first trigger till + /// it calls the callback. + const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2); + + let (sender, listener) = channel(); + std::thread::spawn(move || { + // The `stop` implementation assumes that this thread will not call `callback` again + // if the listener has been dropped. + while let Ok(message) = listener.recv() { + match message { + BurstGuardEvent::Trigger => { + let start = Instant::now(); + loop { + match listener.recv_timeout(BURST_BUFFER_PERIOD) { + Ok(BurstGuardEvent::Trigger) => { + if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD { + callback(); + break; + } + } + Ok(BurstGuardEvent::Shutdown(tx)) => { + let _ = tx.send(()); + return; + } + Err(RecvTimeoutError::Timeout) => { + callback(); + break; + } + Err(RecvTimeoutError::Disconnected) => { + break; + } + } + } + } + BurstGuardEvent::Shutdown(tx) => { + let _ = tx.send(()); + return; + } + } + } + }); + Self { sender } + } + + /// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further + /// calls to `callback`. + fn stop(&self) { + let (sender, listener) = channel(); + // If we could not send then it means the thread has already shut down and we can return + if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() { + // We do not care what the result is, if it is OK it means the thread shut down, if + // it is Err it also means it shut down. + let _ = listener.recv(); + } + } + + /// Non-blocking + fn trigger(&self) { + self.sender.send(BurstGuardEvent::Trigger).unwrap(); + } +} diff --git a/talpid-core/src/routing/windows/get_best_default_route.rs b/talpid-core/src/routing/windows/get_best_default_route.rs new file mode 100644 index 0000000000..4ec7395fff --- /dev/null +++ b/talpid-core/src/routing/windows/get_best_default_route.rs @@ -0,0 +1,190 @@ +use super::{Error, Result}; +use crate::windows::{get_ip_interface_entry, try_socketaddr_from_inet_sockaddr, AddressFamily}; +use std::{convert::TryInto, io, net::SocketAddr}; +use widestring::{widecstr, WideCStr}; +use windows_sys::Win32::{ + Foundation::NO_ERROR, + NetworkManagement::{ + IpHelper::{ + FreeMibTable, GetIfEntry2, GetIpForwardTable2, IF_TYPE_SOFTWARE_LOOPBACK, + IF_TYPE_TUNNEL, MIB_IF_ROW2, MIB_IPFORWARD_ROW2, + }, + Ndis::NET_LUID_LH, + }, +}; + +// Interface description substrings found for virtual adapters. +const TUNNEL_INTERFACE_DESCS: [&WideCStr; 3] = [ + widecstr!("WireGuard"), + widecstr!("Wintun"), + widecstr!("Tunnel"), +]; + +fn get_ipforward_rows(family: AddressFamily) -> Result<Vec<MIB_IPFORWARD_ROW2>> { + let family = family.to_af_family(); + let mut table_ptr = std::ptr::null_mut(); + + // SAFETY: GetIpForwardTable2 does not have clear safety specifications however what it does is + // heap allocate a IpForwardTable2 and then change table_ptr to point to that allocation. + let status = unsafe { GetIpForwardTable2(family, &mut table_ptr) }; + if NO_ERROR as i32 != status { + return Err(Error::GetIpForwardTableFailed( + io::Error::from_raw_os_error(status), + )); + } + + // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error + let num_entries = unsafe { *table_ptr }.NumEntries; + let mut vec = Vec::with_capacity(num_entries.try_into().unwrap_or_default()); + + for i in 0..num_entries { + assert!( + usize::try_from(i).unwrap() * std::mem::size_of::<MIB_IPFORWARD_ROW2>() + < usize::try_from(isize::MAX).unwrap() + ); + + // SAFETY: table_ptr is valid since GetIpForwardTable2 did not return an error nor have we + // or will we modify the table + let ptr: *const MIB_IPFORWARD_ROW2 = unsafe { (*table_ptr).Table.as_ptr() }; + + // SAFETY: The assert guarantees that the amount of bytes we are jumping is not larger than + // isize::MAX. Win32 guarantees that the resulting pointer is aligned, non-null, + // init. + let row: &MIB_IPFORWARD_ROW2 = + unsafe { ptr.offset(i.try_into().unwrap()).as_ref() }.unwrap(); + vec.push(row.clone()); + } + // SAFETY: FreeMibTable does not have clear safety rules but it deallocates the + // MIB_IPFORWARD_TABLE2 This pointer is ONLY deallocated here so it is guaranteed to not + // have been already deallocated. We have cloned all MIB_IPFORWARD_ROW2s and the rows do not + // contain pointers to the table so they will not be dangling after this free. + unsafe { FreeMibTable(table_ptr as *const _) } + Ok(vec) +} + +/// General type for passing interface and gateway +pub struct InterfaceAndGateway { + /// Interface + pub iface: NET_LUID_LH, + /// Gateway + pub gateway: SocketAddr, +} + +impl PartialEq for InterfaceAndGateway { + fn eq(&self, other: &InterfaceAndGateway) -> bool { + // SAFETY: Accessing Value is always valid in this union as both fields are the same type + (unsafe { self.iface.Value == other.iface.Value } && self.gateway == other.gateway) + } +} + +/// Get the best default route for the given address family or None if none exists. +pub fn get_best_default_route(family: AddressFamily) -> Result<Option<InterfaceAndGateway>> { + let table = get_ipforward_rows(family)?; + + // Remove all candidates without a gateway and which are not on a physical interface. + // Then get the annotated routes which are active. + let mut annotated: Vec<AnnotatedRoute<'_>> = table + .iter() + .filter(|row| { + 0 == row.DestinationPrefix.PrefixLength + && route_has_gateway(row) + && is_route_on_physical_interface(row).unwrap_or(false) + }) + .filter_map(|row| annotate_route(row)) + .collect(); + + if annotated.is_empty() { + return Ok(None); + } + + // We previously filtered out all inactive routes so we only need to sort by acending + // effective_metric + annotated.sort_by(|lhs, rhs| lhs.effective_metric.cmp(&rhs.effective_metric)); + + Ok(Some(InterfaceAndGateway { + iface: annotated[0].route.InterfaceLuid, + gateway: try_socketaddr_from_inet_sockaddr(annotated[0].route.NextHop) + .map_err(|_| Error::InvalidSiFamily)?, + })) +} + +pub fn route_has_gateway(route: &MIB_IPFORWARD_ROW2) -> bool { + match try_socketaddr_from_inet_sockaddr(route.NextHop) { + Ok(sock) => !sock.ip().is_unspecified(), + Err(_) => false, + } +} + +// TODO(Jon): It would be more correct to filter for devices that match the known LUID of the tunnel +// interface +fn is_route_on_physical_interface(route: &MIB_IPFORWARD_ROW2) -> Result<bool> { + // The last 16 bits of _bitfield represent the interface type. For that reason we mask it with + // 0xFFFF. SAFETY: route.InterfaceLuid is a union. Both variants of this union are always + // valid since one is a u64 and the other is a wrapped u64. Access to the _bitfield as such + // is safe since it does not reinterpret the u64 as anything it is not. + let if_type = u32::try_from(unsafe { route.InterfaceLuid.Info._bitfield } & 0xFFFF).unwrap(); + if if_type == IF_TYPE_SOFTWARE_LOOPBACK || if_type == IF_TYPE_TUNNEL { + return Ok(false); + } + + // OpenVPN uses interface type IF_TYPE_PROP_VIRTUAL, + // but tethering etc. may rely on virtual adapters too, + // so we have to filter out the TAP adapter specifically. + + // SAFETY: We are allowed to initialize MIB_IF_ROW2 with zeroed because it is made up entirely + // of types for which the zero pattern (all zeros) is valid. + let mut row: MIB_IF_ROW2 = unsafe { std::mem::zeroed() }; + row.InterfaceLuid = route.InterfaceLuid; + row.InterfaceIndex = route.InterfaceIndex; + + // SAFETY: GetIfEntry2 does not have clear safety rules however it will read the + // row.InterfaceLuid or row.InterfaceIndex and use that information to populate the struct. + // We guarantee here that these fields are valid since they are set. + let status = unsafe { GetIfEntry2(&mut row) }; + if NO_ERROR as i32 != status { + return Err(Error::GetIfEntryFailed(io::Error::from_raw_os_error( + status, + ))); + } + + let row_description = WideCStr::from_slice_truncate(&row.Description) + .expect("Windows provided incorrectly formatted utf16 string"); + + for tunnel_interface_desc in TUNNEL_INTERFACE_DESCS { + if contains_subslice(row_description.as_slice(), tunnel_interface_desc.as_slice()) { + return Ok(false); + } + } + + return Ok(true); +} + +fn contains_subslice<T: PartialEq>(slice: &[T], subslice: &[T]) -> bool { + slice + .windows(subslice.len()) + .any(|window| window == subslice) +} + +struct AnnotatedRoute<'a> { + route: &'a MIB_IPFORWARD_ROW2, + effective_metric: u32, +} + +fn annotate_route<'a>(route: &'a MIB_IPFORWARD_ROW2) -> Option<AnnotatedRoute<'a>> { + // SAFETY: `si_family` is valid in both `Ipv4` and `Ipv6` so we can safely access `si_family`. + let iface = get_ip_interface_entry( + AddressFamily::try_from_af_family(unsafe { route.DestinationPrefix.Prefix.si_family }) + .ok()?, + &route.InterfaceLuid, + ) + .ok()?; + + if iface.Connected == 0 { + None + } else { + Some(AnnotatedRoute { + route, + effective_metric: route.Metric + iface.Metric, + }) + } +} diff --git a/talpid-core/src/routing/windows/mod.rs b/talpid-core/src/routing/windows/mod.rs new file mode 100644 index 0000000000..06d23368ca --- /dev/null +++ b/talpid-core/src/routing/windows/mod.rs @@ -0,0 +1,303 @@ +use crate::{routing::RequiredRoute, windows::AddressFamily}; +use futures::channel::oneshot; +use std::{collections::HashSet, io, net::IpAddr}; +use talpid_types::ErrorExt; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; + +pub use default_route_monitor::EventType; +pub use get_best_default_route::{get_best_default_route, route_has_gateway, InterfaceAndGateway}; +pub use route_manager::{Callback, CallbackHandle, Route, RouteManagerInternal}; + +mod default_route_monitor; +mod get_best_default_route; +mod route_manager; + +/// Windows routing errors. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The sender was dropped unexpectedly -- possible panic + #[error(display = "The channel sender was dropped")] + ManagerChannelDown, + /// Failure to initialize route manager + #[error(display = "Failed to start route manager")] + FailedToStartManager, + /// Attempt to use route manager that has been dropped + #[error(display = "Cannot send message to route manager since it is down")] + RouteManagerDown, + /// Low level error caused by a failure to add to route table + #[error(display = "Could not add route to route table")] + AddToRouteTable(io::Error), + /// Low level error caused by failure to delete route from route table + #[error(display = "Failed to delete applied routes")] + DeleteFromRouteTable(io::Error), + /// GetIpForwardTable2 windows API call failed + #[error(display = "Failed to retrieve the routing table")] + GetIpForwardTableFailed(io::Error), + /// GetIfEntry2 windows API call failed + #[error(display = "Failed to retrieve network interface entry")] + GetIfEntryFailed(io::Error), + /// Low level error caused by failing to register the route callback + #[error(display = "Attempt to register notify route change callback failed")] + RegisterNotifyRouteCallback(io::Error), + /// Low level error caused by failing to register the ip interface callback + #[error(display = "Attempt to register notify ip interface change callback failed")] + RegisterNotifyIpInterfaceCallback(io::Error), + /// Low level error caused by failing to register the unicast ip address callback + #[error(display = "Attempt to register notify unicast ip address change callback failed")] + RegisterNotifyUnicastIpAddressCallback(io::Error), + /// Low level error caused by windows Adapters API + #[error(display = "Windows adapter error")] + Adapter(io::Error), + /// High level error caused by a failure to clear the routes in the route manager. + /// Contains the lower error + #[error(display = "Failed to clear applied routes")] + ClearRoutesFailed(Box<Error>), + /// High level error caused by a failure to add routes in the route manager. + /// Contains the lower error + #[error(display = "Failed to add routes")] + AddRoutesFailed(Box<Error>), + /// Something went wrong when getting the mtu of the interface + #[error(display = "Could not get the mtu of the interface")] + GetMtu, + /// The SI family was of an unexpected value + #[error(display = "The SI family was of an unexpected value")] + InvalidSiFamily, + /// Device name not found + #[error(display = "The device name was not found")] + DeviceNameNotFound, + /// No default route + #[error(display = "No default route found")] + NoDefaultRoute, + /// Conversion error between types + #[error(display = "Conversion error")] + Conversion, + /// Could not find device gateway + #[error(display = "Could not find device gateway")] + DeviceGatewayNotFound, + /// Could not get default route + #[error(display = "Could not get default route")] + GetDefaultRoute, + /// Could not find device by name + #[error(display = "Could not find device by name")] + GetDeviceByName, + /// Could not find device by gateway + #[error(display = "Could not find device by gateway")] + GetDeviceByGateway, +} + +pub type Result<T> = std::result::Result<T, Error>; + +/// Manages routes by calling into WinNet +pub struct RouteManager { + manage_tx: Option<UnboundedSender<RouteManagerCommand>>, +} + +/// Handle to a route manager. +#[derive(Clone)] +pub struct RouteManagerHandle { + tx: UnboundedSender<RouteManagerCommand>, +} + +impl RouteManagerHandle { + /// Add a callback which will be called if the default route changes. + pub async fn add_default_route_change_callback( + &self, + callback: Callback, + ) -> Result<CallbackHandle> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .send(RouteManagerCommand::RegisterDefaultRouteChangeCallback( + callback, + response_tx, + )) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown)? + } + + /// Applies the given routes while the route manager is running. + pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .send(RouteManagerCommand::AddRoutes(routes, response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown)? + } + + /// Applies the given routes while the route manager is running. + pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .send(RouteManagerCommand::GetMtuForRoute(ip, response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown)? + } +} + +pub enum RouteManagerCommand { + AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), + GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>), + ClearRoutes, + RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<Result<CallbackHandle>>), + Shutdown, +} + +impl RouteManager { + /// Creates a new route manager that will apply the provided routes and ensure they exist until + /// it's stopped. + pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { + let internal = match RouteManagerInternal::new() { + Ok(internal) => internal, + Err(_) => return Err(Error::FailedToStartManager), + }; + let (manage_tx, manage_rx) = mpsc::unbounded_channel(); + let manager = Self { + manage_tx: Some(manage_tx), + }; + tokio::spawn(RouteManager::listen(manage_rx, internal)); + manager.add_routes(required_routes).await?; + + Ok(manager) + } + + /// Add a callback which will be called if the default route changes. + pub async fn add_default_route_change_callback( + &self, + callback: Callback, + ) -> Result<CallbackHandle> { + if let Some(tx) = &self.manage_tx { + let (result_tx, result_rx) = oneshot::channel(); + if tx + .send(RouteManagerCommand::RegisterDefaultRouteChangeCallback( + callback, result_tx, + )) + .is_err() + { + return Err(Error::RouteManagerDown); + } + result_rx.await.map_err(|_| Error::ManagerChannelDown)? + } else { + Err(Error::RouteManagerDown) + } + } + + /// Retrieve a sender directly to the command channel. + pub fn handle(&self) -> Result<RouteManagerHandle> { + if let Some(tx) = &self.manage_tx { + Ok(RouteManagerHandle { tx: tx.clone() }) + } else { + Err(Error::RouteManagerDown) + } + } + + async fn listen( + mut manage_rx: UnboundedReceiver<RouteManagerCommand>, + mut internal: RouteManagerInternal, + ) { + while let Some(command) = manage_rx.recv().await { + match command { + RouteManagerCommand::AddRoutes(routes, tx) => { + let routes: Vec<_> = routes + .into_iter() + .map(|route| Route { + network: route.prefix, + node: route.node, + }) + .collect(); + + let _ = tx.send( + internal + .add_routes(routes) + .map_err(|e| Error::AddRoutesFailed(Box::new(e))), + ); + } + RouteManagerCommand::GetMtuForRoute(ip, tx) => { + let addr_family = if ip.is_ipv4() { + AddressFamily::Ipv4 + } else { + AddressFamily::Ipv6 + }; + let res = match get_mtu_for_route(addr_family) { + Ok(Some(mtu)) => Ok(mtu), + Ok(None) => Err(Error::GetMtu), + Err(e) => Err(e), + }; + let _ = tx.send(res); + } + RouteManagerCommand::ClearRoutes => { + if let Err(e) = internal.delete_applied_routes() { + log::error!("{}", e.display_chain_with_msg("Could not clear routes")); + } + } + RouteManagerCommand::RegisterDefaultRouteChangeCallback(callback, tx) => { + let _ = tx.send(internal.register_default_route_changed_callback(callback)); + } + RouteManagerCommand::Shutdown => { + break; + } + } + } + } + + /// Stops the routing manager and invalidates the route manager - no new default route callbacks + /// can be added + pub fn stop(&mut self) { + if let Some(tx) = self.manage_tx.take() { + if tx.send(RouteManagerCommand::Shutdown).is_err() { + log::error!("RouteManager channel already down or thread panicked"); + } + } + } + + /// Applies the given routes until [`RouteManager::stop`] is called. + pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + if let Some(tx) = &self.manage_tx { + let (result_tx, result_rx) = oneshot::channel(); + if tx + .send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .is_err() + { + return Err(Error::RouteManagerDown); + } + result_rx.await.map_err(|_| Error::ManagerChannelDown)? + } else { + Err(Error::RouteManagerDown) + } + } + + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_routes(&self) -> Result<()> { + if let Some(tx) = &self.manage_tx { + tx.send(RouteManagerCommand::ClearRoutes) + .map_err(|_| Error::RouteManagerDown) + } else { + Err(Error::RouteManagerDown) + } + } +} + +fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> { + match get_best_default_route(addr_family) { + Ok(Some(route)) => { + let interface_row = crate::windows::get_ip_interface_entry(addr_family, &route.iface) + .map_err(|e| { + log::error!("Could not get ip interface entry: {}", e); + Error::GetMtu + })?; + let mtu = interface_row.NlMtu; + let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?; + Ok(Some(mtu)) + } + Ok(None) => Ok(None), + Err(e) => { + log::error!("Could not get best default route: {}", e); + Err(Error::GetMtu) + } + } +} + +impl Drop for RouteManager { + fn drop(&mut self) { + self.stop(); + } +} diff --git a/talpid-core/src/routing/windows/route_manager.rs b/talpid-core/src/routing/windows/route_manager.rs new file mode 100644 index 0000000000..f1d878dd28 --- /dev/null +++ b/talpid-core/src/routing/windows/route_manager.rs @@ -0,0 +1,885 @@ +use super::{ + default_route_monitor::{DefaultRouteMonitor, EventType as RouteMonitorEventType}, + get_best_default_route, Error, InterfaceAndGateway, Result, +}; +use crate::{ + routing::NetNode, + windows::{inet_sockaddr_from_socketaddr, try_socketaddr_from_inet_sockaddr, AddressFamily}, +}; +use ipnetwork::IpNetwork; +use std::{ + collections::HashMap, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::{Arc, Mutex}, +}; +use widestring::{WideCStr, WideCString}; +use windows_sys::Win32::{ + Foundation::{ + ERROR_BUFFER_OVERFLOW, ERROR_NOT_FOUND, ERROR_NO_DATA, ERROR_OBJECT_ALREADY_EXISTS, + ERROR_SUCCESS, NO_ERROR, + }, + NetworkManagement::{ + IpHelper::{ + ConvertInterfaceAliasToLuid, CreateIpForwardEntry2, DeleteIpForwardEntry2, + GetAdaptersAddresses, InitializeIpForwardEntry, SetIpForwardEntry2, + GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER, + GAA_FLAG_SKIP_FRIENDLY_NAME, GAA_FLAG_SKIP_MULTICAST, GET_ADAPTERS_ADDRESSES_FLAGS, + IP_ADAPTER_ADDRESSES_LH, IP_ADAPTER_GATEWAY_ADDRESS_LH, IP_ADAPTER_IPV4_ENABLED, + IP_ADAPTER_IPV6_ENABLED, IP_ADDRESS_PREFIX, MIB_IPFORWARD_ROW2, + }, + Ndis::NET_LUID_LH, + }, + Networking::WinSock::{ + NlroManual, ADDRESS_FAMILY, AF_INET, AF_INET6, MIB_IPPROTO_NETMGMT, SOCKADDR_IN, + SOCKADDR_IN6, SOCKADDR_INET, SOCKET_ADDRESS, + }, +}; + +type Network = IpNetwork; +type NodeAddress = SOCKADDR_INET; + +/// Callback handle for the default route changed callback. Produced by the RouteManager. +pub struct CallbackHandle { + nonce: i32, + callbacks: Arc<Mutex<(i32, HashMap<i32, Callback>)>>, +} + +impl Drop for CallbackHandle { + fn drop(&mut self) { + let (_, callbacks) = &mut *self.callbacks.lock().unwrap(); + match callbacks.remove(&self.nonce) { + Some(_) => (), + None => { + log::warn!("Could not un-register route manager callback due to it already being de-registered"); + } + } + } +} + +#[derive(Clone)] +struct RegisteredRoute { + network: Network, + luid: NET_LUID_LH, + next_hop: SocketAddr, +} + +impl std::fmt::Display for RegisteredRoute { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // SAFETY: luid.Value is always valid as the underlying type of both union fields is an u64 + formatter.write_fmt(format_args!("RegisteredRoute {{ luid: {} }}", unsafe { + self.luid.Value + })) + } +} + +impl PartialEq for RegisteredRoute { + fn eq(&self, other: &Self) -> bool { + // SAFETY: luid.Value is always valid as the underlying type of both union fields is an u64 + (unsafe { self.luid.Value == other.luid.Value }) + && (self.next_hop == other.next_hop) + && (self.network == other.network) + } +} + +#[derive(Clone)] +pub struct Node { + pub device_name: Option<widestring::U16CString>, + pub gateway: Option<NodeAddress>, +} + +#[derive(Clone)] +pub struct Route { + pub network: Network, + pub node: NetNode, +} + +#[derive(Clone)] +struct RouteRecord { + route: Route, + registered_route: RegisteredRoute, +} + +struct EventEntry { + record: RouteRecord, + event_type: RecordEventType, +} + +enum RecordEventType { + AddRoute, + DeleteRoute, +} + +pub type Callback = Box<dyn for<'a> Fn(RouteMonitorEventType<'a>, AddressFamily) + Send>; + +pub struct RouteManagerInternal { + route_monitor_v4: Option<DefaultRouteMonitor>, + route_monitor_v6: Option<DefaultRouteMonitor>, + routes: Arc<Mutex<Vec<RouteRecord>>>, + /// Lock for a nonce and a HashMap of callbacks and their id which is used as a handle to + /// unregister them. The nonce is used to create new ids and then incrementing. + callbacks: Arc<Mutex<(i32, HashMap<i32, Callback>)>>, +} + +impl RouteManagerInternal { + pub fn new() -> Result<Self> { + let routes = Arc::new(Mutex::new(Vec::new())); + let callbacks = Arc::new(Mutex::new((0, HashMap::new()))); + + let callbacks_ipv4 = callbacks.clone(); + let routes_ipv4 = routes.clone(); + let callbacks_ipv6 = callbacks.clone(); + let routes_ipv6 = routes.clone(); + + Ok(Self { + route_monitor_v4: Some(DefaultRouteMonitor::new( + AddressFamily::Ipv4, + move |event_type| { + Self::default_route_change(&callbacks_ipv4, &routes_ipv4, AF_INET, event_type); + }, + )?), + route_monitor_v6: Some(DefaultRouteMonitor::new( + AddressFamily::Ipv6, + move |event_type| { + Self::default_route_change(&callbacks_ipv6, &routes_ipv6, AF_INET6, event_type); + }, + )?), + routes, + callbacks, + }) + } + + pub fn add_routes(&self, new_routes: Vec<Route>) -> Result<()> { + let mut route_manager_routes = self.routes.lock().unwrap(); + + let mut event_log = vec![]; + + for route in new_routes { + let registered_route = Self::add_into_routing_table(&route).map_err(|error| { + if let Err(error) = Self::undo_events(&event_log, &mut route_manager_routes) { + error + } else { + error + } + })?; + + let new_record = RouteRecord { + route, + registered_route, + }; + + event_log.push(EventEntry { + event_type: RecordEventType::AddRoute, + record: new_record.clone(), + }); + + let existing_record_idx = + Self::find_route_record(&mut route_manager_routes, &new_record.registered_route); + + match existing_record_idx { + None => route_manager_routes.push(new_record), + Some(idx) => route_manager_routes[idx] = new_record, + } + } + Ok(()) + } + + fn add_into_routing_table(route: &Route) -> Result<RegisteredRoute> { + let node = Self::resolve_node(ipnetwork_to_address_family(route.network), &route.node)?; + + // SAFETY: MIB_IPFORWARD_ROW2 contains no references or pointers only number primitives and + // as such it is safe to zero it. + let mut spec: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() }; + + // SAFETY: This function must be used to initialize MIB_IPFORWARD_ROW2 structs if it is to + // be used later by CreateIpForwardEntry2. + unsafe { InitializeIpForwardEntry(&mut spec) }; + + spec.InterfaceLuid = node.iface; + spec.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network); + spec.NextHop = inet_sockaddr_from_socketaddr(node.gateway); + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + // SAFETY: DestinationPrefix must be initialized to a valid prefix. NextHop must have a + // valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be set + // to the interface. + let mut status = unsafe { CreateIpForwardEntry2(&spec) }; + + // The return code ERROR_OBJECT_ALREADY_EXISTS means there is already an existing route + // on the same interface, with the same DestinationPrefix and NextHop. + // + // However, all the other properties of the route may be different. And the properties may + // not have the exact same values as when the route was registered, because windows + // will adjust route properties at time of route insertion as well as later. + // + // The simplest thing in this case is to just overwrite the route. + // + + if ERROR_OBJECT_ALREADY_EXISTS as i32 == status { + // SAFETY: DestinationPrefix must be initialzed to a valid prefix. NextHop must have + // a valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must + // be set to the interface. + status = unsafe { SetIpForwardEntry2(&spec) }; + } + + if NO_ERROR as i32 != status { + log::error!("Could not register route in routing table"); + return Err(Error::AddToRouteTable(io::Error::from_raw_os_error(status))); + } + + Ok(RegisteredRoute { + network: route.network, + luid: node.iface, + next_hop: node.gateway, + }) + } + + fn resolve_node(family: AddressFamily, optional_node: &NetNode) -> Result<InterfaceAndGateway> { + // There are four cases: + // + // Unspecified node (use interface and gateway of default route). + // Node is specified by name. + // Node is specified by name and gateway. + // Node is specified by gateway. + // + + match optional_node { + NetNode::DefaultNode => { + let default_route = get_best_default_route(family)?; + match default_route { + None => { + log::error!("Unable to determine details of default route"); + return Err(Error::NoDefaultRoute); + } + Some(default_route) => return Ok(default_route), + } + } + NetNode::RealNode(node) => { + if let Some(device_name) = &node.get_device() { + let device_name = WideCString::from_str(device_name) + .expect("Failed to convert UTF-8 string to null terminated UCS string"); + let luid = match Self::parse_string_encoded_luid(device_name.as_ucstr())? { + None => { + let mut luid = NET_LUID_LH { Value: 0 }; + // SAFETY: No specific safety requirement + if NO_ERROR as i32 + != unsafe { + ConvertInterfaceAliasToLuid(device_name.as_ptr(), &mut luid) + } + { + log::error!( + "Unable to derive interface LUID from interface alias: {:?}", + device_name + ); + return Err(Error::DeviceNameNotFound); + } else { + luid + } + } + Some(luid) => luid, + }; + + return Ok(InterfaceAndGateway { + iface: luid, + gateway: match node.get_address() { + Some(ip) => SocketAddr::new(ip, 0), + None => match family { + AddressFamily::Ipv4 => { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } + AddressFamily::Ipv6 => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }, + }, + }); + } + + // The node is specified only by gateway. + // + + // Unwrapping is fine because the node must have an address since no device name was + // found. + let gateway = node.get_address().map(inet_sockaddr_from_ipaddr).unwrap(); + Ok(InterfaceAndGateway { + iface: interface_luid_from_gateway(&gateway)?, + gateway: try_socketaddr_from_inet_sockaddr(gateway) + .map_err(|_| Error::InvalidSiFamily)?, + }) + } + } + } + + fn find_route_record(records: &mut Vec<RouteRecord>, route: &RegisteredRoute) -> Option<usize> { + records + .iter() + .position(|record| route == &record.registered_route) + } + + fn undo_events(event_log: &Vec<EventEntry>, records: &mut Vec<RouteRecord>) -> Result<()> { + // Rewind state by processing events in the reverse order. + // + + let mut result = Ok(()); + + for event in event_log.iter().rev() { + match event.event_type { + RecordEventType::AddRoute => { + let record_idx = Self::find_route_record(records, &event.record.registered_route) + .expect("Internal state inconsistency in route manager, could not find route record"); + let record = records.get(record_idx) + .expect("Internal state inconsistency in route manager, route record index pointing at nothing"); + + if let Err(e) = Self::delete_from_routing_table(&record.registered_route) { + result = result.and(Err(e)); + continue; + } + records.remove(record_idx); + } + RecordEventType::DeleteRoute => { + if let Err(e) = Self::restore_into_routing_table(&event.record.registered_route) + { + result = result.and(Err(e)); + continue; + } + records.push(event.record.clone()); + } + } + } + + result + } + + fn delete_from_routing_table(route: &RegisteredRoute) -> Result<()> { + // SAFETY: There are no pointers or references inside of MIB_IPFORWARD_ROW2, only primitive + // numbers as such it is safe to zero it. + let mut r: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() }; + + r.InterfaceLuid = route.luid; + r.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network); + r.NextHop = inet_sockaddr_from_socketaddr(route.next_hop); + + // SAFETY: DestinationPrefix must be initialzed to a valid prefix. NextHop must have + // a valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be + // set to the interface. + let status = unsafe { DeleteIpForwardEntry2(&r) }; + + match u32::try_from(status) { + Ok(ERROR_NOT_FOUND) => { + log::warn!("Attempting to delete route which was not present in routing table, ignoring and proceeding. Route: {}", route); + } + Ok(NO_ERROR) => (), + _ => { + log::error!( + "Failed to delete route in routing table. Route: {}, Status: {}", + route, + status + ); + return Err(Error::DeleteFromRouteTable(io::Error::from_raw_os_error( + status, + ))); + } + } + + Ok(()) + } + + fn restore_into_routing_table(route: &RegisteredRoute) -> Result<()> { + // SAFETY: There are no pointers or references inside of MIB_IPFORWARD_ROW2, only primitive + // numbers as such it is safe to zero it. + let mut spec: MIB_IPFORWARD_ROW2 = unsafe { std::mem::zeroed() }; + + // SAFETY: This function must be used to initialize MIB_IPFORWARD_ROW2 structs if it is to + // be used later by CreateIpForwardEntry2. + unsafe { InitializeIpForwardEntry(&mut spec) }; + + spec.InterfaceLuid = route.luid; + spec.DestinationPrefix = win_ip_address_prefix_from_ipnetwork_port_zero(route.network); + spec.NextHop = inet_sockaddr_from_socketaddr(route.next_hop); + spec.Metric = 0; + spec.Protocol = MIB_IPPROTO_NETMGMT; + spec.Origin = NlroManual; + + // SAFETY: DestinationPrefix must be initialized to a valid prefix. NextHop must have a + // valid IP address and family. At least one of InterfaceLuid and InterfaceIndex must be set + // to the interface. + let status = unsafe { CreateIpForwardEntry2(&spec) }; + + if NO_ERROR as i32 != status { + log::error!( + "Could not register route in routing table. Route: {}, Status: {}", + route, + status + ); + return Err(Error::AddToRouteTable(io::Error::from_raw_os_error(status))); + } + Ok(()) + } + + fn parse_string_encoded_luid(encoded_luid: &WideCStr) -> Result<Option<NET_LUID_LH>> { + // The `#` is a valid character in adapter names so we use `?` instead. + // The LUID is thus prefixed with `?` and hex encoded and left-padded with zeroes. + // E.g. `?deadbeefcafebabe` or `?000dbeefcafebabe`. + // + + const STRING_ENCODED_LUID_LENGTH: usize = 17; + + if encoded_luid.len() != STRING_ENCODED_LUID_LENGTH + || Some(Ok('?')) != encoded_luid.chars().next() + { + return Ok(None); + } + + let luid = NET_LUID_LH { + Value: u64::from_str_radix( + &encoded_luid.to_string().map_err(|_| { + log::error!("Failed to parse string encoded LUID: {:?}", encoded_luid); + Error::Conversion + })?[1..], + 16, + ) + .map_err(|_| { + log::error!("Failed to parse string encoded LUID: {:?}", encoded_luid); + Error::Conversion + })?, + }; + + return Ok(Some(luid)); + } + + pub fn delete_applied_routes(&mut self) -> Result<()> { + let mut routes = self.routes.lock().unwrap(); + // Delete all routes owned by us. + // + + for record in (*routes).iter() { + if let Err(_) = Self::delete_from_routing_table(&record.registered_route) { + log::error!( + "Failed to delete route while clearing applied routes, {}", + record.registered_route + ); + } + } + + routes.clear(); + Ok(()) + } + + pub fn register_default_route_changed_callback( + &self, + callback: Callback, + ) -> Result<CallbackHandle> { + let (nonce, callbacks) = &mut *self.callbacks.lock().unwrap(); + let old_nonce = *nonce; + callbacks.insert(old_nonce, callback); + *nonce = nonce.wrapping_add(1); + Ok(CallbackHandle { + nonce: old_nonce, + callbacks: self.callbacks.clone(), + }) + } + + fn default_route_change<'a>( + callbacks: &Arc<Mutex<(i32, HashMap<i32, Callback>)>>, + records: &Arc<Mutex<Vec<RouteRecord>>>, + family: ADDRESS_FAMILY, + event_type: RouteMonitorEventType<'a>, + ) { + // Forward event to all registered listeners. + // + + { + let (_, callbacks) = &mut *callbacks.lock().unwrap(); + for callback in callbacks.values() { + let family = + AddressFamily::try_from_af_family(u16::try_from(family).unwrap()).unwrap(); + callback(event_type, family); + } + } + + // Examine event to determine if best default route has changed. + // + + let route = if let RouteMonitorEventType::Updated(route) = event_type { + route + } else { + return; + }; + + // Examine our routes to see if any of them are policy bound to the best default route. + // + + let mut records = records.lock().unwrap(); + let mut affected_routes: Vec<&mut RouteRecord> = vec![]; + + for record in (*records).iter_mut() { + if matches!(record.route.node, NetNode::DefaultNode) + && family + == u32::from(ipnetwork_to_address_family(record.route.network).to_af_family()) + { + affected_routes.push(record); + } + } + + if affected_routes.is_empty() { + return; + } + + // Update all affected routes. + // + + log::info!("Best default route has changed. Refreshing dependent routes"); + + for affected_route in affected_routes { + // We can't update the existing route because defining characteristics are being + // changed. So removing and adding again is the only option. + // + + match Self::delete_from_routing_table(&affected_route.registered_route) { + Ok(()) => (), + Err(e) => { + log::error!( + "Failed to delete route when refreshing existing routes: {}", + e + ); + continue; + } + } + + affected_route.registered_route.luid = route.iface; + affected_route.registered_route.next_hop = route.gateway; + + match Self::restore_into_routing_table(&affected_route.registered_route) { + Ok(()) => (), + Err(e) => { + log::error!("Failed to add route when refreshing existing routes: {}", e); + continue; + } + } + } + } +} + +impl Drop for RouteManagerInternal { + fn drop(&mut self) { + drop(self.route_monitor_v4.take()); + drop(self.route_monitor_v6.take()); + + match self.delete_applied_routes() { + Ok(()) => (), + Err(e) => { + log::error!("Failed to correctly drop RouteManagerInternal {}", e) + } + } + } +} + +fn interface_luid_from_gateway(gateway: &SOCKADDR_INET) -> Result<NET_LUID_LH> { + const ADAPTER_FLAGS: GET_ADAPTERS_ADDRESSES_FLAGS = GAA_FLAG_SKIP_ANYCAST + | GAA_FLAG_SKIP_MULTICAST + | GAA_FLAG_SKIP_DNS_SERVER + | GAA_FLAG_SKIP_FRIENDLY_NAME + | GAA_FLAG_INCLUDE_GATEWAYS; + + // SAFETY: The si_family field is always valid to access. + let family: u32 = u32::from(unsafe { gateway.si_family }); + let adapters = Adapters::new(family, ADAPTER_FLAGS)?; + + // Process adapters to find matching ones. + // + + let mut matches: Vec<_> = adapters + // SAFETY: We are not allowed to dereference adapter.Head if it has been aquired in a previous iteration of the iterator + // we ensure this is upheld by not saving any references to adapter.Head between iterations. + .iter() + .filter(|adapter| { + if !adapter_interface_enabled(adapter, family).unwrap_or(false) { + return false; + } + let gateways = if adapter.FirstGatewayAddress.is_null() { + vec![] + } else { + // SAFETY: adapter.FirstGatewayAddress is not null and all elements in the linked list live + // in the same buffer and as such have the same lifetime. + unsafe { isolate_gateway_address(get_first_gateway_address_reference(adapter), family) } + }; + + address_present(gateways, &gateway).unwrap_or(false) + }) + .collect(); + + if matches.is_empty() { + log::error!("Unable to find network adapter with specified gateway"); + return Err(Error::DeviceGatewayNotFound); + } + + // Sort matching interfaces ascending by metric. + // + + let target_v4 = AF_INET == family; + + matches.sort_by(|lhs, rhs| { + if target_v4 { + lhs.Ipv4Metric.cmp(&rhs.Ipv4Metric) + } else { + lhs.Ipv6Metric.cmp(&rhs.Ipv6Metric) + } + }); + + // Select the interface with the best (lowest) metric. + // + + Ok(matches[0].Luid) +} + +/// SAFETY: adapter.FirstGatewayAddress must be dereferencable and must live as long as adapter +unsafe fn get_first_gateway_address_reference( + adapter: &IP_ADAPTER_ADDRESSES_LH, +) -> &IP_ADAPTER_GATEWAY_ADDRESS_LH { + &*adapter.FirstGatewayAddress +} + +fn adapter_interface_enabled( + adapter: &IP_ADAPTER_ADDRESSES_LH, + family: ADDRESS_FAMILY, +) -> Result<bool> { + match family { + // SAFETY: All fields in the Anonymous2 union are at represented by a u32 so dereferencing + // them is safe + AF_INET => Ok(0 != unsafe { adapter.Anonymous2.Flags } & IP_ADAPTER_IPV4_ENABLED), + AF_INET6 => Ok(0 != unsafe { adapter.Anonymous2.Flags } & IP_ADAPTER_IPV6_ENABLED), + _ => Err(Error::InvalidSiFamily), + } +} + +/// SAFETY: `head` must be a linked list where each `head.Next` is either null or +/// the it and all of its fields has lifetime 'a and are dereferencable. +unsafe fn isolate_gateway_address<'a>( + head: &'a IP_ADAPTER_GATEWAY_ADDRESS_LH, + family: ADDRESS_FAMILY, +) -> Vec<&'a SOCKET_ADDRESS> { + let mut matches = vec![]; + + let mut gateway = head; + loop { + // SAFETY: The contract states that Address.lpSockaddr is dereferencable if the element is + // non-null + if family == u32::from((*gateway.Address.lpSockaddr).sa_family) { + // SAFETY: The contract states that this field must have lifetime 'a + matches.push(&gateway.Address); + } + + if gateway.Next.is_null() { + break; + } + + // SAFETY: Gateway.Next is not null here and the contract states it must be dereferencable + // if non-null + gateway = &*gateway.Next; + } + + matches +} + +fn address_present(hay: Vec<&'_ SOCKET_ADDRESS>, needle: &'_ SOCKADDR_INET) -> Result<bool> { + for candidate in hay { + // SAFETY: Contract states that needle is dereferencable + if equal_address(needle, candidate)? { + return Ok(true); + } + } + + Ok(false) +} + +fn equal_address(lhs: &'_ SOCKADDR_INET, rhs: &'_ SOCKET_ADDRESS) -> Result<bool> { + let rhs = &*rhs; + // SAFETY: The si_family field is always valid + if unsafe { lhs.si_family != (*rhs.lpSockaddr).sa_family } { + return Ok(false); + } + + match unsafe { lhs.si_family } as u32 { + AF_INET => { + let typed_rhs = rhs.lpSockaddr as *mut SOCKADDR_IN; + // SAFETY: If rhs.lpSockaddr.sa_family is IPv4 then lpSockaddr is a SOCKADDR_IN + Ok(unsafe { lhs.Ipv4.sin_addr.S_un.S_addr == (*typed_rhs).sin_addr.S_un.S_addr }) + } + AF_INET6 => { + let typed_rhs = rhs.lpSockaddr as *mut SOCKADDR_IN6; + // SAFETY: If rhs.lpSockaddr.sa_family is IPv6 then lpSockaddr is a SOCKADDR_IN6 + Ok(unsafe { lhs.Ipv6.sin6_addr.u.Byte == (*typed_rhs).sin6_addr.u.Byte }) + } + _ => { + log::error!("Missing case handler in match"); + Err(Error::InvalidSiFamily) + } + } +} + +/// Linked list containing `IP_ADAPTER_ADDRESSES_LH` queried from the windows API. +/// Consume by using the iterator produced by `iter_mut()` +struct Adapters { + // SAFETY: This vector is not allowed to be resized since all of the data inside of it would be + // dangling + buffer: Vec<u8>, +} + +impl Adapters { + /// Create a new linked list of adapters from the windows API + fn new(family: ADDRESS_FAMILY, flags: GET_ADAPTERS_ADDRESSES_FLAGS) -> Result<Self> { + const MSDN_RECOMMENDED_STARTING_BUFFER_SIZE: usize = 1024 * 15; + let mut buffer: Vec<u8> = Vec::with_capacity(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE); + buffer.resize(MSDN_RECOMMENDED_STARTING_BUFFER_SIZE, 0); + + let mut buffer_size = u32::try_from(buffer.len()).unwrap(); + let mut buffer_pointer = buffer.as_mut_ptr(); + + // Acquire interfaces. + // + + loop { + // SAFETY: buffer_size must point to the correct amount of bytes in the buffer which it + // does. buffer_pointer must point to the start of a mutable buffer which it + // does. After this call buffer_size might have changed and as such the + // buffer must be resized to reflect this if this function is going to be + // called again. + let status = unsafe { + GetAdaptersAddresses( + family, + flags, + std::ptr::null_mut() as *mut _, + buffer_pointer as *mut IP_ADAPTER_ADDRESSES_LH, + &mut buffer_size, + ) + }; + + if ERROR_SUCCESS == status { + // SAFETY: We truncate the buffer to avoid having a bunch of zero:ed objects at the + // end of it truncate will not change capacity and will therefore + // never reallocate the vector which means it can not cause the + // pointers in the buffer to dangle. + buffer.truncate(usize::try_from(buffer_size).unwrap()); + break; + } + + if ERROR_NO_DATA == status { + return Ok(Self { buffer: Vec::new() }); + } + + if ERROR_BUFFER_OVERFLOW != status { + log::error!("Probe required buffer size for GetAdaptersAddresses"); + return Err(Error::Adapter(io::Error::from_raw_os_error( + i32::try_from(status).unwrap(), + ))); + } + + // The needed length is returned in the buffer_size pointer + buffer.resize(usize::try_from(buffer_size).unwrap(), 0); + buffer_pointer = buffer.as_mut_ptr(); + } + + // Verify structure compatibility. + // The structure has been extended many times. + // + + // Unwrapping is fine because we previously would return if we got a ERROR_NO_DATA status. + // As such the buffer is not empty. SAFETY: Casting the buffers first element to an + // IP_ADAPTER_ADDRESSES_LH is safe as that is the underlying data structure. SAFETY: + // This union field is always valid to read from + let system_size = unsafe { + (*(buffer.get(0).unwrap() as *const u8 as *const IP_ADAPTER_ADDRESSES_LH)) + .Anonymous1 + .Anonymous + .Length + }; + let code_size = u32::try_from(std::mem::size_of::<IP_ADAPTER_ADDRESSES_LH>()).unwrap(); + + if system_size < code_size { + log::error!("Expecting IP_ADAPTER_ADDRESSES to have size {code_size} bytes. Found structure with size {system_size} bytes."); + return Err(Error::Adapter(io::Error::new(io::ErrorKind::Other, + format!("Expecting IP_ADAPTER_ADDRESSES to have size {code_size} bytes. Found structure with size {system_size} bytes.")))); + } + + // Initialize members. + // + + Ok(Self { buffer }) + } + + /// Produces a iterator for the linked list in `Adapters` see + /// [AdaptersIterator](struct.AdaptersIterator.html) SAFETY: See the documentation on + /// `AdaptersIterator` + fn iter<'a>(&'a self) -> AdaptersIterator<'a> { + let cur = if self.buffer.is_empty() { + std::ptr::null() + } else { + &self.buffer[0] as *const u8 as *const IP_ADAPTER_ADDRESSES_LH + }; + AdaptersIterator { + _adapters: self, + cur, + } + } +} + +/// SAFETY: You are only allowed to dereference `IP_ADAPTER_ADDRESSES_LH.Next` or any following +/// `Next` items in the linked list if they were produced by the latest call to `next()`. Any raw +/// pointers that were aquired before the call to `next()` are not valid to dereference. +struct AdaptersIterator<'a> { + _adapters: &'a Adapters, + cur: *const IP_ADAPTER_ADDRESSES_LH, +} + +impl<'a> Iterator for AdaptersIterator<'a> { + type Item = &'a IP_ADAPTER_ADDRESSES_LH; + fn next(&mut self) -> Option<Self::Item> { + if self.cur.is_null() { + None + } else { + let ret = self.cur; + // SAFETY: self.cur is guaranteed to not be null, we are also holding a &Adapters which + // guarantees no other reference of self could be held right now which has + // mutably dereferenced the same address that self.cur is pointing to. + // + // It is possible that someone has copied the previous returned items `Next` pointer + // which points to the same as address as self.cur, however dereferencing + // that is unsafe and that code is responsible for not dereferencing + // `Next` on a reference returned by this function after that reference has been + // dropped. + self.cur = unsafe { (*self.cur).Next }; + // SAFETY: ret is guaranteed to be non-null and valid since self.adapters owns the + // memory. + Some(unsafe { &*ret }) + } + } +} + +/// Convert to a windows defined `IP_ADDRESS_PREFIX` from a `ipnetwork::IpNetwork` but set the port +/// to 0 +pub fn win_ip_address_prefix_from_ipnetwork_port_zero(from: IpNetwork) -> IP_ADDRESS_PREFIX { + // Port should not matter so we set it to 0 + let prefix = + crate::windows::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from.ip(), 0)); + IP_ADDRESS_PREFIX { + Prefix: prefix, + PrefixLength: from.prefix(), + } +} + +/// Convert to a windows defined `SOCKADDR_INET` from a `IpAddr` but set the port to 0 +pub fn inet_sockaddr_from_ipaddr(from: IpAddr) -> SOCKADDR_INET { + // Port should not matter so we set it to 0 + crate::windows::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from, 0)) +} + +/// Convert to a `AddressFamily` from a `ipnetwork::IpNetwork` +pub fn ipnetwork_to_address_family(from: IpNetwork) -> AddressFamily { + if from.is_ipv4() { + AddressFamily::Ipv4 + } else { + AddressFamily::Ipv6 + } +} diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 0bde6ac435..49028319a0 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -5,6 +5,7 @@ mod volume_monitor; mod windows; use crate::{ + routing::{self, get_best_default_route, CallbackHandle, EventType, RouteManagerHandle}, tunnel::TunnelMetadata, tunnel_state_machine::TunnelCommand, windows::{ @@ -12,7 +13,6 @@ use crate::{ window::{PowerManagementEvent, PowerManagementListener}, AddressFamily, }, - winnet::{self, get_best_default_route, WinNetAddrFamily, WinNetCallbackHandle}, }; use futures::channel::{mpsc, oneshot}; use std::{ @@ -29,9 +29,7 @@ use std::{ time::Duration, }; use talpid_types::{tunnel::ErrorStateCause, ErrorExt}; -use windows_sys::Win32::{ - Foundation::ERROR_OPERATION_ABORTED, NetworkManagement::Ndis::NET_LUID_LH, -}; +use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); @@ -74,7 +72,7 @@ pub enum Error { /// Failed to obtain default route #[error(display = "Failed to obtain the default route")] - ObtainDefaultRoute(#[error(source)] winnet::Error), + ObtainDefaultRoute(#[error(source)] routing::Error), /// Failed to obtain an IP address given a network interface LUID #[error(display = "Failed to obtain IP address for interface LUID")] @@ -116,10 +114,11 @@ pub struct SplitTunnel { event_thread: Option<std::thread::JoinHandle<()>>, quit_event: Arc<windows::Event>, excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, - _route_change_callback: Option<WinNetCallbackHandle>, + _route_change_callback: Option<CallbackHandle>, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, async_path_update_in_progress: Arc<AtomicBool>, power_mgmt_handle: tokio::task::JoinHandle<()>, + route_manager: RouteManagerHandle, } enum Request { @@ -187,6 +186,7 @@ impl SplitTunnel { daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, volume_update_rx: mpsc::UnboundedReceiver<()>, power_mgmt_rx: PowerManagementListener, + route_manager: RouteManagerHandle, ) -> Result<Self, Error> { let excluded_processes = Arc::new(RwLock::new(HashMap::new())); @@ -209,6 +209,7 @@ impl SplitTunnel { async_path_update_in_progress: Arc::new(AtomicBool::new(false)), excluded_processes, power_mgmt_handle, + route_manager, }) } @@ -715,13 +716,22 @@ impl SplitTunnel { )); self._route_change_callback = None; + let moved_context_mutex = context_mutex.clone(); let mut context = context_mutex.lock().unwrap(); - let callback = winnet::add_default_route_change_callback( - Some(split_tunnel_default_route_change_handler), - context_mutex.clone(), - ) - .map(Some) - .map_err(|_| Error::RegisterRouteChangeCallback)?; + let callback = self + .runtime + .block_on( + self.route_manager + .add_default_route_change_callback(Box::new(move |event, addr_family| { + split_tunnel_default_route_change_handler( + event, + addr_family, + &moved_context_mutex, + ) + })), + ) + .map(Some) + .map_err(|_| Error::RegisterRouteChangeCallback)?; context.initialize_internet_addresses()?; context.register_ips()?; @@ -801,16 +811,10 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { pub fn initialize_internet_addresses(&mut self) -> Result<(), Error> { // Identify IP address that gives us Internet access - let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4) + let internet_ipv4 = get_best_default_route(AddressFamily::Ipv4) .map_err(Error::ObtainDefaultRoute)? .map(|route| { - get_ip_address_for_interface( - AddressFamily::Ipv4, - NET_LUID_LH { - Value: route.interface_luid, - }, - ) - .map(|ip| match ip { + get_ip_address_for_interface(AddressFamily::Ipv4, route.iface).map(|ip| match ip { Some(IpAddr::V4(addr)) => Some(addr), Some(_) => unreachable!("wrong address family (expected IPv4)"), None => { @@ -822,16 +826,10 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { .transpose() .map_err(Error::LuidToIp)? .flatten(); - let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6) + let internet_ipv6 = get_best_default_route(AddressFamily::Ipv6) .map_err(Error::ObtainDefaultRoute)? .map(|route| { - get_ip_address_for_interface( - AddressFamily::Ipv6, - NET_LUID_LH { - Value: route.interface_luid, - }, - ) - .map(|ip| match ip { + get_ip_address_for_interface(AddressFamily::Ipv6, route.iface).map(|ip| match ip { Some(IpAddr::V6(addr)) => Some(addr), Some(_) => unreachable!("wrong address family (expected IPv6)"), None => { @@ -851,16 +849,14 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { } } -unsafe extern "system" fn split_tunnel_default_route_change_handler( - event_type: winnet::WinNetDefaultRouteChangeEventType, - address_family: WinNetAddrFamily, - default_route: winnet::WinNetDefaultRoute, - ctx: *mut libc::c_void, +fn split_tunnel_default_route_change_handler<'a>( + event_type: EventType<'a>, + address_family: AddressFamily, + ctx_mutex: &Arc<Mutex<SplitTunnelDefaultRouteChangeHandlerContext>>, ) { - use winnet::WinNetDefaultRouteChangeEventType::*; + use crate::routing::EventType::*; // Update the "internet interface" IP when best default route changes - let ctx_mutex = &mut *(ctx as *mut Arc<Mutex<SplitTunnelDefaultRouteChangeHandlerContext>>); let mut ctx = ctx_mutex.lock().expect("ST route handler mutex poisoned"); let daemon_tx = ctx.daemon_tx.upgrade(); @@ -870,16 +866,9 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( } }; - let translated_family = winnet_to_talpid_family(address_family); - let result = match event_type { - DefaultRouteChanged | DefaultRouteUpdatedDetails => { - match get_ip_address_for_interface( - translated_family, - NET_LUID_LH { - Value: default_route.interface_luid, - }, - ) { + Updated(default_route) | UpdatedDetails(default_route) => { + match get_ip_address_for_interface(address_family, default_route.iface) { Ok(Some(ip)) => match IpAddr::from(ip) { IpAddr::V4(addr) => ctx.addresses.internet_ipv4 = Some(addr), IpAddr::V6(addr) => ctx.addresses.internet_ipv6 = Some(addr), @@ -887,10 +876,10 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( Ok(None) => { log::warn!("Failed to obtain default route interface address"); match address_family { - WinNetAddrFamily::IPV4 => { + AddressFamily::Ipv4 => { ctx.addresses.internet_ipv4 = None; } - WinNetAddrFamily::IPV6 => { + AddressFamily::Ipv6 => { ctx.addresses.internet_ipv6 = None; } } @@ -910,12 +899,12 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( ctx.register_ips() } // no default route - DefaultRouteRemoved => { + Removed => { match address_family { - WinNetAddrFamily::IPV4 => { + AddressFamily::Ipv4 => { ctx.addresses.internet_ipv4 = None; } - WinNetAddrFamily::IPV6 => { + AddressFamily::Ipv6 => { ctx.addresses.internet_ipv6 = None; } } @@ -931,10 +920,3 @@ unsafe extern "system" fn split_tunnel_default_route_change_handler( maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); } } - -fn winnet_to_talpid_family(address_family: WinNetAddrFamily) -> AddressFamily { - match address_family { - WinNetAddrFamily::IPV4 => AddressFamily::Ipv4, - WinNetAddrFamily::IPV6 => AddressFamily::Ipv6, - } -} diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 302c8003c9..4425c3c69d 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -42,7 +42,7 @@ pub enum Error { /// Failure in Windows syscall. #[cfg(windows)] #[error(display = "Failure in Windows syscall")] - WinnetError(#[error(source)] crate::winnet::Error), + WinnetError(#[error(source)] crate::routing::Error), /// Running on an operating system which is not supported yet. #[error(display = "Tunnel type not supported on this operating system")] diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 5e8c0ede49..b982dc148d 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -226,6 +226,8 @@ impl WireguardMonitor { args.resource_dir, args.tun_provider, #[cfg(target_os = "windows")] + args.route_manager.clone(), + #[cfg(target_os = "windows")] setup_done_tx, )?; let iface_name = tunnel.get_interface_name(); @@ -507,6 +509,7 @@ impl WireguardMonitor { log_path: Option<&Path>, resource_dir: &Path, tun_provider: Arc<Mutex<TunProvider>>, + #[cfg(windows)] route_manager_handle: crate::routing::RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender<std::result::Result<(), BoxedError>>, ) -> Result<Box<dyn Tunnel>> { #[cfg(target_os = "linux")] @@ -576,7 +579,11 @@ impl WireguardMonitor { #[cfg(not(windows))] Self::get_tunnel_destinations(config).flat_map(Self::replace_default_prefixes), #[cfg(windows)] + route_manager_handle, + #[cfg(windows)] setup_done_tx, + #[cfg(windows)] + &runtime, ) .map_err(Error::TunnelError)?, )) diff --git a/talpid-core/src/tunnel/wireguard/wireguard_go.rs b/talpid-core/src/tunnel/wireguard/wireguard_go.rs index 60705d324f..a1ca8be6ba 100644 --- a/talpid-core/src/tunnel/wireguard/wireguard_go.rs +++ b/talpid-core/src/tunnel/wireguard/wireguard_go.rs @@ -39,9 +39,6 @@ use { type Result<T> = std::result::Result<T, TunnelError>; -#[cfg(target_os = "windows")] -use crate::winnet; - #[cfg(not(target_os = "windows"))] use std::sync::{Arc, Mutex}; @@ -66,7 +63,7 @@ pub struct WgGoTunnel { // context that maps to fs::File instance, used with logging callback _logging_context: LoggingContext, #[cfg(target_os = "windows")] - _route_callback_handle: Option<crate::winnet::WinNetCallbackHandle>, + _route_callback_handle: Option<crate::routing::CallbackHandle>, #[cfg(target_os = "windows")] setup_handle: tokio::task::JoinHandle<()>, } @@ -117,15 +114,19 @@ impl WgGoTunnel { pub fn start_tunnel( config: &Config, log_path: Option<&Path>, + route_manager_handle: crate::tunnel::RouteManagerHandle, mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>, + runtime: &tokio::runtime::Handle, ) -> Result<Self> { use talpid_types::ErrorExt; - let route_callback_handle = winnet::add_default_route_change_callback( - Some(WgGoTunnel::default_route_changed_callback), - (), - ) - .ok(); + let route_callback_handle = runtime + .block_on( + route_manager_handle.add_default_route_change_callback(Box::new( + WgGoTunnel::default_route_changed_callback, + )), + ) + .ok(); if route_callback_handle.is_none() { log::warn!("Failed to register default route callback"); } @@ -208,25 +209,21 @@ impl WgGoTunnel { // Callback to be used to rebind the tunnel sockets when the default route changes #[cfg(target_os = "windows")] - pub unsafe extern "system" fn default_route_changed_callback( - event_type: winnet::WinNetDefaultRouteChangeEventType, - address_family: winnet::WinNetAddrFamily, - default_route: winnet::WinNetDefaultRoute, - _ctx: *mut libc::c_void, + pub fn default_route_changed_callback<'a>( + event_type: crate::routing::EventType<'a>, + address_family: crate::windows::AddressFamily, ) { - use windows_sys::Win32::NetworkManagement::{ - IpHelper::ConvertInterfaceLuidToIndex, Ndis::NET_LUID_LH, - }; - use winnet::WinNetDefaultRouteChangeEventType::*; + use crate::routing::EventType::*; + use windows_sys::Win32::NetworkManagement::IpHelper::ConvertInterfaceLuidToIndex; let iface_idx: u32 = match event_type { - DefaultRouteChanged => { + Updated(default_route) => { let mut iface_idx = 0u32; - let iface_luid = NET_LUID_LH { - Value: default_route.interface_luid, + // TODO: Make sure unwrap is fine + let iface_luid = default_route.iface; + let status = unsafe { + ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _) }; - let status = - ConvertInterfaceLuidToIndex(&iface_luid as *const _, &mut iface_idx as *mut _); if status != 0 { log::error!( "Failed to convert interface LUID to interface index: {}: {}", @@ -238,12 +235,12 @@ impl WgGoTunnel { iface_idx } // if there is no new default route, specify 0 as the interface index - DefaultRouteRemoved => 0, + Removed => 0, // ignore interface updates that don't affect the interface to use - DefaultRouteUpdatedDetails => return, + UpdatedDetails(_) => return, }; - wgRebindTunnelSocket(address_family.to_windows_proto_enum(), iface_idx); + unsafe { wgRebindTunnelSocket(address_family.to_af_family(), iface_idx) }; } #[cfg(not(target_os = "windows"))] diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 5a83bd6b76..964963d46e 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -29,7 +29,7 @@ use talpid_types::{ }; #[cfg(windows)] -use crate::{routing, winnet}; +use crate::routing; #[cfg(target_os = "android")] use crate::tunnel::tun_provider; @@ -524,12 +524,7 @@ fn should_retry(error: &tunnel::Error, retry_attempt: u32) -> bool { #[cfg(windows)] fn is_recoverable_routing_error(error: &crate::routing::Error) -> bool { match error { - routing::Error::AddRoutesFailed(route_error) => match route_error { - winnet::Error::GetDefaultRoute - | winnet::Error::GetDeviceByName - | winnet::Error::GetDeviceByGateway => true, - _ => false, - }, + routing::Error::AddRoutesFailed(_) => true, _ => false, } } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index c1b52278f0..5d13b7d1f2 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -258,6 +258,10 @@ impl TunnelStateMachine { #[cfg(target_os = "windows")] let power_mgmt_rx = crate::windows::window::PowerManagementListener::new(); + let route_manager = RouteManager::new(HashSet::new()) + .await + .map_err(Error::InitRouteManagerError)?; + #[cfg(windows)] let split_tunnel = split_tunnel::SplitTunnel::new( runtime.clone(), @@ -265,6 +269,9 @@ impl TunnelStateMachine { args.command_tx.clone(), volume_update_rx, power_mgmt_rx.clone(), + route_manager + .handle() + .map_err(Error::InitRouteManagerError)?, ) .map_err(Error::InitSplitTunneling)?; @@ -279,9 +286,6 @@ impl TunnelStateMachine { }; let firewall = Firewall::from_args(fw_args).map_err(Error::InitFirewallError)?; - let route_manager = RouteManager::new(HashSet::new()) - .await - .map_err(Error::InitRouteManagerError)?; let dns_monitor = DnsMonitor::new( #[cfg(target_os = "linux")] runtime.clone(), @@ -315,6 +319,8 @@ impl TunnelStateMachine { #[cfg(target_os = "android")] android_context, #[cfg(target_os = "windows")] + route_manager.handle()?, + #[cfg(target_os = "windows")] power_mgmt_rx, ) .await diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs index d853991707..5504e11d93 100644 --- a/talpid-core/src/windows/mod.rs +++ b/talpid-core/src/windows/mod.rs @@ -109,7 +109,7 @@ impl fmt::Display for AddressFamily { } impl AddressFamily { - /// Convert an [`AddressFamily`] to one of the `AF_*` constants. + /// Convert one of the `AF_*` constants to an [`AddressFamily`]. pub fn try_from_af_family(family: u16) -> Result<AddressFamily> { match u32::from(family) { AF_INET => Ok(AddressFamily::Ipv4), @@ -117,6 +117,15 @@ impl AddressFamily { family => Err(Error::UnknownAddressFamily(family)), } } + + /// Convert an [`AddressFamily`] to one of the `AF_*` constants. + pub fn to_af_family(&self) -> u16 { + match self { + // These values are both small enough to fit in a u16 + Self::Ipv4 => u16::try_from(AF_INET).unwrap(), + Self::Ipv6 => u16::try_from(AF_INET6).unwrap(), + } + } } /// Context for [`notify_ip_interface_change`]. When it is dropped, diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs deleted file mode 100644 index 9843d873aa..0000000000 --- a/talpid-core/src/winnet.rs +++ /dev/null @@ -1,416 +0,0 @@ -use self::api::*; -use crate::{logging::windows::log_sink, routing::Node}; -use ipnetwork::IpNetwork; -use libc::c_void; -use std::{ - convert::TryFrom, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, - ptr, -}; -use widestring::WideCString; - -/// Errors that this module may produce. -#[derive(err_derive::Error, Debug)] -pub enum Error { - /// Supplied interface alias is invalid. - #[error(display = "Supplied interface alias is invalid")] - InvalidInterfaceAlias(#[error(source)] widestring::NulError<u16>), - - /// Failed to enable IPv6 on the network interface. - #[error(display = "Failed to enable IPv6 on the network interface")] - EnableIpv6, - - /// Failed to get the current default route. - #[error(display = "Failed to obtain default route")] - GetDefaultRoute, - - /// Failed to get a network device. - #[error(display = "Failed to obtain network interface by name")] - GetDeviceByName, - - /// Failed to get a network device. - #[error(display = "Failed to obtain network interface by gateway")] - GetDeviceByGateway, - - /// Unexpected error while adding routes - #[error(display = "Winnet returned an error while adding routes")] - GeneralAddRoutesError, - - /// Failed to obtain an IP address given a LUID. - #[error(display = "Failed to obtain IP address for the given interface")] - GetIpAddressFromLuid, - - /// Failed to read IPv6 status on the TAP network interface. - #[error(display = "Failed to read IPv6 status on the TAP network interface")] - GetIpv6Status, -} - -fn logging_context() -> *const u8 { - b"WinNet\0".as_ptr() -} - -#[derive(Debug, Default, Clone, Copy)] -#[allow(dead_code)] -#[repr(u32)] -pub enum WinNetAddrFamily { - #[default] - IPV4 = 0, - IPV6 = 1, -} - -impl WinNetAddrFamily { - pub fn to_windows_proto_enum(&self) -> u16 { - match self { - Self::IPV4 => 2, - Self::IPV6 => 23, - } - } -} - -#[repr(C)] -#[derive(Default)] -pub struct WinNetIp { - pub addr_family: WinNetAddrFamily, - pub ip_bytes: [u8; 16], -} - -#[repr(C)] -#[derive(Default)] -pub struct WinNetDefaultRoute { - pub interface_luid: u64, - pub gateway: WinNetIp, -} - -#[derive(Debug)] -pub struct WrongIpFamilyError; - -impl TryFrom<WinNetIp> for Ipv4Addr { - type Error = WrongIpFamilyError; - - fn try_from(addr: WinNetIp) -> Result<Ipv4Addr, WrongIpFamilyError> { - match addr.addr_family { - WinNetAddrFamily::IPV4 => { - let mut bytes: [u8; 4] = Default::default(); - bytes.clone_from_slice(&addr.ip_bytes[..4]); - Ok(Ipv4Addr::from(bytes)) - } - WinNetAddrFamily::IPV6 => Err(WrongIpFamilyError), - } - } -} - -impl TryFrom<WinNetIp> for Ipv6Addr { - type Error = WrongIpFamilyError; - - fn try_from(addr: WinNetIp) -> Result<Ipv6Addr, WrongIpFamilyError> { - match addr.addr_family { - WinNetAddrFamily::IPV4 => Err(WrongIpFamilyError), - WinNetAddrFamily::IPV6 => Ok(Ipv6Addr::from(addr.ip_bytes)), - } - } -} - -impl From<WinNetIp> for IpAddr { - fn from(addr: WinNetIp) -> IpAddr { - match addr.addr_family { - WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::try_from(addr).unwrap()), - WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::try_from(addr).unwrap()), - } - } -} - -impl From<IpAddr> for WinNetIp { - fn from(addr: IpAddr) -> WinNetIp { - let mut bytes = [0u8; 16]; - match addr { - IpAddr::V4(v4_addr) => { - bytes[..4].copy_from_slice(&v4_addr.octets()); - WinNetIp { - addr_family: WinNetAddrFamily::IPV4, - ip_bytes: bytes, - } - } - IpAddr::V6(v6_addr) => { - bytes.copy_from_slice(&v6_addr.octets()); - - WinNetIp { - addr_family: WinNetAddrFamily::IPV6, - ip_bytes: bytes, - } - } - } - } -} - -#[repr(C)] -pub struct WinNetIpNetwork { - prefix: u8, - ip: WinNetIp, -} - -impl From<IpNetwork> for WinNetIpNetwork { - fn from(network: IpNetwork) -> WinNetIpNetwork { - WinNetIpNetwork { - prefix: network.prefix(), - ip: WinNetIp::from(network.ip()), - } - } -} - -#[repr(C)] -pub struct WinNetNode { - gateway: *mut WinNetIp, - device_name: *mut u16, -} - -impl WinNetNode { - fn new(name: &str, ip: WinNetIp) -> Self { - let device_name = WideCString::from_str(name) - .expect("Failed to convert UTF-8 string to null terminated UCS string") - .into_raw(); - let gateway = Box::into_raw(Box::new(ip)); - Self { - gateway, - device_name, - } - } - - fn from_gateway(ip: WinNetIp) -> Self { - let gateway = Box::into_raw(Box::new(ip)); - Self { - gateway, - device_name: ptr::null_mut(), - } - } - - fn from_device(name: &str) -> Self { - let device_name = WideCString::from_str(name) - .expect("Failed to convert UTF-8 string to null terminated UCS string") - .into_raw(); - Self { - gateway: ptr::null_mut(), - device_name, - } - } -} - -impl From<&Node> for WinNetNode { - fn from(node: &Node) -> Self { - match (node.get_address(), node.get_device()) { - (Some(gateway), None) => WinNetNode::from_gateway(gateway.into()), - (None, Some(device)) => WinNetNode::from_device(device), - (Some(gateway), Some(device)) => WinNetNode::new(device, gateway.into()), - _ => unreachable!(), - } - } -} - -impl Drop for WinNetNode { - fn drop(&mut self) { - if !self.gateway.is_null() { - unsafe { - let _ = Box::from_raw(self.gateway); - } - } - if !self.device_name.is_null() { - unsafe { - let _ = WideCString::from_ptr_str(self.device_name); - } - } - } -} - -#[repr(C)] -pub struct WinNetRoute { - gateway: WinNetIpNetwork, - node: *mut WinNetNode, -} - -impl WinNetRoute { - pub fn through_default_node(gateway: WinNetIpNetwork) -> Self { - Self { - gateway, - node: ptr::null_mut(), - } - } - - pub fn new(node: WinNetNode, gateway: WinNetIpNetwork) -> Self { - let node = Box::into_raw(Box::new(node)); - Self { gateway, node } - } -} - -impl Drop for WinNetRoute { - fn drop(&mut self) { - if !self.node.is_null() { - unsafe { - let _ = Box::from_raw(self.node); - } - self.node = ptr::null_mut(); - } - } -} - -pub fn activate_routing_manager() -> bool { - unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } -} - -pub struct WinNetCallbackHandle { - handle: *mut libc::c_void, - // Allows us to keep the context pointer alive. - _context: Box<dyn std::any::Any>, -} - -unsafe impl Send for WinNetCallbackHandle {} - -impl Drop for WinNetCallbackHandle { - fn drop(&mut self) { - unsafe { WinNet_UnregisterDefaultRouteChangedCallback(self.handle) }; - } -} - -#[derive(Debug, Clone, Copy, PartialEq)] -#[allow(dead_code)] -#[repr(u16)] -pub enum WinNetDefaultRouteChangeEventType { - DefaultRouteChanged = 0, - DefaultRouteUpdatedDetails = 1, - DefaultRouteRemoved = 2, -} - -pub type DefaultRouteChangedCallback = unsafe extern "system" fn( - event_type: WinNetDefaultRouteChangeEventType, - family: WinNetAddrFamily, - default_route: WinNetDefaultRoute, - ctx: *mut c_void, -); - -#[derive(err_derive::Error, Debug)] -#[error(display = "Failed to set callback for default route")] -pub struct DefaultRouteCallbackError; - -pub fn add_default_route_change_callback<T: 'static>( - callback: Option<DefaultRouteChangedCallback>, - context: T, -) -> std::result::Result<WinNetCallbackHandle, DefaultRouteCallbackError> { - let mut handle_ptr = ptr::null_mut(); - let mut context = Box::new(context); - let ctx_ptr = &mut *context as *mut T as *mut libc::c_void; - unsafe { - if !WinNet_RegisterDefaultRouteChangedCallback(callback, ctx_ptr, &mut handle_ptr as *mut _) - { - return Err(DefaultRouteCallbackError); - } - - Ok(WinNetCallbackHandle { - handle: handle_ptr, - _context: context, - }) - } -} - -pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> Result<(), Error> { - let ptr = routes.as_ptr(); - let length: u32 = routes.len() as u32; - match unsafe { WinNet_AddRoutes(ptr, length) } { - WinNetAddRouteStatus::Success => Ok(()), - WinNetAddRouteStatus::GeneralError => Err(Error::GeneralAddRoutesError), - WinNetAddRouteStatus::NoDefaultRoute => Err(Error::GetDefaultRoute), - WinNetAddRouteStatus::NameNotFound => Err(Error::GetDeviceByName), - WinNetAddRouteStatus::GatewayNotFound => Err(Error::GetDeviceByGateway), - } -} - -pub fn routing_manager_delete_applied_routes() -> bool { - unsafe { WinNet_DeleteAppliedRoutes() } -} - -pub fn deactivate_routing_manager() { - unsafe { WinNet_DeactivateRouteManager() } -} - -pub fn get_best_default_route( - family: WinNetAddrFamily, -) -> Result<Option<WinNetDefaultRoute>, Error> { - let mut default_route = WinNetDefaultRoute::default(); - match unsafe { - WinNet_GetBestDefaultRoute( - family, - &mut default_route as *mut _, - Some(log_sink), - logging_context(), - ) - } { - WinNetStatus::Success => Ok(Some(default_route)), - WinNetStatus::NotFound => Ok(None), - WinNetStatus::Failure => Err(Error::GetDefaultRoute), - } -} - -#[allow(non_snake_case)] -mod api { - use super::DefaultRouteChangedCallback; - use crate::logging::windows::LogSink; - - #[allow(dead_code)] - #[repr(u32)] - pub enum WinNetStatus { - Success = 0, - NotFound = 1, - Failure = 2, - } - - #[allow(dead_code)] - #[repr(u32)] - pub enum WinNetAddRouteStatus { - Success = 0, - GeneralError = 1, - NoDefaultRoute = 2, - NameNotFound = 3, - GatewayNotFound = 4, - } - - extern "system" { - #[link_name = "WinNet_ActivateRouteManager"] - pub fn WinNet_ActivateRouteManager(sink: Option<LogSink>, sink_context: *const u8) -> bool; - - #[link_name = "WinNet_AddRoutes"] - pub fn WinNet_AddRoutes( - routes: *const super::WinNetRoute, - num_routes: u32, - ) -> WinNetAddRouteStatus; - - // #[link_name = "WinNet_AddRoute"] - // pub fn WinNet_AddRoute(route: *const super::WinNetRoute) -> WinNetAddRouteStatus; - - // #[link_name = "WinNet_DeleteRoutes"] - // pub fn WinNet_DeleteRoutes(routes: *const super::WinNetRoute, num_routes: u32) -> bool; - - // #[link_name = "WinNet_DeleteRoute"] - // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool; - - #[link_name = "WinNet_DeleteAppliedRoutes"] - pub fn WinNet_DeleteAppliedRoutes() -> bool; - - #[link_name = "WinNet_DeactivateRouteManager"] - pub fn WinNet_DeactivateRouteManager(); - - #[link_name = "WinNet_GetBestDefaultRoute"] - pub fn WinNet_GetBestDefaultRoute( - family: super::WinNetAddrFamily, - default_route: *mut super::WinNetDefaultRoute, - sink: Option<LogSink>, - sink_context: *const u8, - ) -> WinNetStatus; - - #[link_name = "WinNet_RegisterDefaultRouteChangedCallback"] - pub fn WinNet_RegisterDefaultRouteChangedCallback( - callback: Option<DefaultRouteChangedCallback>, - callbackContext: *mut libc::c_void, - registrationHandle: *mut *mut libc::c_void, - ) -> bool; - - #[link_name = "WinNet_UnregisterDefaultRouteChangedCallback"] - pub fn WinNet_UnregisterDefaultRouteChangedCallback(registrationHandle: *mut libc::c_void); - } -} |
