diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-06-05 20:24:24 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-06-05 20:24:24 +0200 |
| commit | 4f60f2368fba733c1aaa968c34d0a33d43e4e90b (patch) | |
| tree | ebbae33de7e464d85bc794b344df14076d8a9923 | |
| parent | 2adf93174484af166c2104627530874698498b7d (diff) | |
| parent | 38d7c4b0a02c293a84ccec4a4cc632934a36cdfb (diff) | |
| download | mullvadvpn-4f60f2368fba733c1aaa968c34d0a33d43e4e90b.tar.xz mullvadvpn-4f60f2368fba733c1aaa968c34d0a33d43e4e90b.zip | |
Merge branch 'macos-routing-rework-2'
| -rw-r--r-- | CHANGELOG.md | 3 | ||||
| -rw-r--r-- | Cargo.lock | 30 | ||||
| -rw-r--r-- | deny.toml | 3 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 2 | ||||
| -rw-r--r-- | talpid-core/src/offline/macos.rs | 213 | ||||
| -rw-r--r-- | talpid-core/src/offline/mod.rs | 6 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 10 | ||||
| -rw-r--r-- | talpid-routing/Cargo.toml | 9 | ||||
| -rw-r--r-- | talpid-routing/src/lib.rs | 2 | ||||
| -rw-r--r-- | talpid-routing/src/unix/android.rs | 5 | ||||
| -rw-r--r-- | talpid-routing/src/unix/linux.rs | 12 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos.rs | 362 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/data.rs | 1147 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/interface.rs | 78 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/mod.rs | 533 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/routing_socket.rs | 185 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos/watch.rs | 147 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 65 | ||||
| -rw-r--r-- | talpid-routing/src/windows/mod.rs | 10 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 5 |
20 files changed, 2321 insertions, 506 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 19d4f60383..042f059b1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,9 @@ Line wrap the file at 100 chars. Th - Fix adaptive app icon which previously had a displaced nose and some other oddities. - Fix app version sometimes missing in the settings menu. +#### macOS +- Fix inability to sync iCloud and Safari bookmarks while connected to the VPN. + ## [2023.4-beta1] - 2023-05-02 ### Added diff --git a/Cargo.lock b/Cargo.lock index 17847ae5cd..61c6c83af4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1851,6 +1851,15 @@ dependencies = [ ] [[package]] +name = "memoffset" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +dependencies = [ + "autocfg", +] + +[[package]] name = "mime" version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2304,6 +2313,19 @@ dependencies = [ [[package]] name = "nix" +version = "0.26.1" +source = "git+https://github.com/nix-rust/nix?rev=b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28#b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "memoffset 0.8.0", + "pin-utils", + "static_assertions", +] + +[[package]] +name = "nix" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" @@ -3511,9 +3533,9 @@ dependencies = [ [[package]] name = "system-configuration" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75182f12f490e953596550b65ee31bda7c8e043d9386174b353bda50838c3fd" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags", "core-foundation", @@ -3678,6 +3700,7 @@ dependencies = [ name = "talpid-routing" version = "0.0.0" dependencies = [ + "bitflags", "err-derive", "futures", "ipnetwork", @@ -3686,12 +3709,13 @@ dependencies = [ "log", "netlink-packet-route", "netlink-sys", + "nix 0.26.1", "rtnetlink", "socket2", + "system-configuration", "talpid-types", "talpid-windows-net", "tokio", - "tokio-stream", "widestring 1.0.2", "windows-sys 0.45.0", ] @@ -96,7 +96,8 @@ skip-tree = [] unknown-registry = "deny" unknown-git = "deny" allow-registry = ["https://github.com/rust-lang/crates.io-index"] -allow-git = [] +# TODO: The PF socket type isn't released yet +allow-git = ["https://github.com/nix-rust/nix"] [sources.allow-org] # 1 or more github.com organizations to allow git sources for diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 9a833353d3..f16aa7307c 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -67,7 +67,7 @@ talpid-dbus = { path = "../talpid-dbus" } [target.'cfg(target_os = "macos")'.dependencies] pfctl = "0.4.4" -system-configuration = "0.5" +system-configuration = "0.5.1" trust-dns-server = { version = "0.22.0", features = ["resolver"] } tun = "0.5.1" subslice = "0.2" diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs index baae3c40f1..7263f9ab0f 100644 --- a/talpid-core/src/offline/macos.rs +++ b/talpid-core/src/offline/macos.rs @@ -2,119 +2,156 @@ //! that the app gets stuck in an offline state, blocking all internet access and preventing the //! user from connecting to a relay. //! -//! Currently, this functionality is implemented by using `route monitor -n` to observe routing -//! table changes and then use the CLI once more to query if there exists a default route. -//! Generally, it is assumed that a machine is online if there exists a route to a public IP -//! address that isn't using a tunnel adapter. On macOS, there were various ways of deducing this: -//! - watching the `State:/Network/Global/IPv4` key in SystemConfiguration via -//! `system-configuration-rs`, relying on a CoreFoundation runloop to drive callbacks. -//! The issue with this is that sometimes during early boot or after a re-install, the callbacks -//! won't be called, often leaving the daemon stuck in an offline state. -//! - setting a callback via [`SCNetworkReachability`]. The callback should be called whenever the -//! reachability of a remote host changes, but sometimes the callbacks just don't get called. -//! - [`NWPathMonitor`] is a macOS native interface to watch changes in the routing table. It works -//! great, but it seems to deliver updates before they actually get added to the routing table, -//! effectively calling our callbacks with routes that aren't yet usable, so starting tunnels -//! would fail anyway. This would be the API to use if we were able to bind the sockets our tunnel -//! implementations would use, but that is far too much complexity. -//! -//! [`SCNetworkReachability`]: https://developer.apple.com/documentation/systemconfiguration/scnetworkreachability-g7d -//! [`NWPathMonitor`]: https://developer.apple.com/documentation/network/nwpathmonitor -use futures::{channel::mpsc::UnboundedSender, Future, StreamExt}; -use std::sync::{Arc, Weak}; -use talpid_types::ErrorExt; +//! Currently, this functionality is implemented by watching for changes to the default route +//! in [`RouteManager`] using a `PF_ROUTE` socket. If there is no default route for neither IPv4 nor +//! IPv6, the host is considered to be offline. +use futures::{channel::mpsc::UnboundedSender, StreamExt}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; +use talpid_routing::{DefaultRouteEvent, RouteManagerHandle}; + +/// How long to wait before announcing changes to the offline state +const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(2); #[derive(err_derive::Error, Debug)] pub enum Error { #[error(display = "Failed to initialize route monitor")] - StartMonitorError(#[error(source)] talpid_routing::PlatformError), + StartMonitorError(#[error(source)] talpid_routing::Error), } pub struct MonitorHandle { + state: Arc<Mutex<ConnectivityState>>, _notify_tx: Arc<UnboundedSender<bool>>, } -impl MonitorHandle { - /// Host is considered to be offline if the IPv4 internet is considered to be unreachable by the - /// given reachability flags *or* there are no active physical interfaces. - pub async fn host_is_offline(&self) -> bool { - !exists_non_tunnel_default_route().await +struct ConnectivityState { + v4_connectivity: bool, + v6_connectivity: bool, +} + +impl ConnectivityState { + fn get_connectivity(&self) -> bool { + self.v4_connectivity || self.v6_connectivity } } -async fn exists_non_tunnel_default_route() -> bool { - match talpid_routing::get_default_routes().await { - Ok((Some(node), _)) | Ok((None, Some(node))) => { - let route_exists = node - .get_device() - .map(|iface_name| !iface_name.contains("tun")) - .unwrap_or(true); - log::debug!("Assuming non-tunnel default route exists due to {:?}", node); - route_exists - } - Ok((None, None)) => { - log::debug!("No default routes exist, assuming machine is offline"); - false - } - Err(err) => { - log::error!( - "{}", - err.display_chain_with_msg( - "Failed to obtain default routes, assuming machine is online." - ) - ); - true - } +impl MonitorHandle { + /// Host is considered to be offline if macOS doesn't assign a non-tunnel default route + #[allow(clippy::unused_async)] + pub async fn host_is_offline(&self) -> bool { + let state = self.state.lock().unwrap(); + !state.get_connectivity() } } -pub async fn spawn_monitor(notify_tx: UnboundedSender<bool>) -> Result<MonitorHandle, Error> { + +pub async fn spawn_monitor( + notify_tx: UnboundedSender<bool>, + route_manager_handle: RouteManagerHandle, +) -> Result<MonitorHandle, Error> { let notify_tx = Arc::new(notify_tx); - let context = OfflineStateContext { - sender: Arc::downgrade(¬ify_tx), - is_offline: !exists_non_tunnel_default_route().await, + let (v4_connectivity, v6_connectivity) = match route_manager_handle.get_default_routes().await { + Ok((v4_route, v6_route)) => (v4_route.is_some(), v6_route.is_some()), + Err(error) => { + log::warn!("Failed to initialize offline monitor: {error}"); + // Fail open: Assume that we have connectivity if we cannot determine the existence of + // a default route, since we don't want to block the user from connecting + (true, true) + } }; - let route_monitor = watch_route_monitor(context)?; - tokio::spawn(route_monitor); - Ok(MonitorHandle { - _notify_tx: notify_tx, - }) -} + let state = ConnectivityState { + v4_connectivity, + v6_connectivity, + }; + let initial_connectivity = state.get_connectivity(); + let state = Arc::new(Mutex::new(state)); + + let mut route_listener = route_manager_handle.default_route_listener().await?; + let weak_state = Arc::downgrade(&state); + let weak_notify_tx = Arc::downgrade(¬ify_tx); + + // Detect changes to the default route + tokio::spawn(async move { + let mut state_update_handle: Option<tokio::task::JoinHandle<()>> = None; + let prev_notified_state = Arc::new(AtomicBool::new(initial_connectivity)); -fn watch_route_monitor( - mut context: OfflineStateContext, -) -> Result<impl Future<Output = ()>, Error> { - let mut monitor = talpid_routing::listen_for_default_route_changes()?; + while let Some(event) = route_listener.next().await { + let state = match weak_state.upgrade() { + Some(state) => state, + None => break, + }; - Ok(async move { - while let Some(_route_change) = monitor.next().await { - context.new_state(!exists_non_tunnel_default_route().await); - if context.should_shut_down() { - break; + let mut state = state.lock().unwrap(); + + log::trace!("Default route event: {event:?}"); + + let previous_connectivity = state.get_connectivity(); + + match event { + DefaultRouteEvent::AddedOrChangedV4 => { + state.v4_connectivity = true; + } + DefaultRouteEvent::AddedOrChangedV6 => { + state.v6_connectivity = true; + } + DefaultRouteEvent::RemovedV4 => { + state.v4_connectivity = false; + } + DefaultRouteEvent::RemovedV6 => { + state.v6_connectivity = false; + } } - } - log::debug!("Stopping offline monitor"); - }) -} -#[derive(Clone)] -struct OfflineStateContext { - sender: Weak<UnboundedSender<bool>>, - is_offline: bool, -} + let new_connectivity = state.get_connectivity(); + if previous_connectivity != new_connectivity { + if let Some(update_state) = state_update_handle.take() { + update_state.abort(); + } -impl OfflineStateContext { - fn should_shut_down(&self) -> bool { - self.sender.upgrade().is_none() - } + let prev_notified = prev_notified_state.clone(); + + let notify_copy = weak_notify_tx.clone(); + let update_task = tokio::spawn(async move { + let notify_tx = match notify_copy.upgrade() { + Some(tx) => tx, + None => return, + }; + + // Debounce event updates + tokio::time::sleep(DEBOUNCE_INTERVAL).await; - fn new_state(&mut self, is_offline: bool) { - if self.is_offline != is_offline { - self.is_offline = is_offline; - if let Some(sender) = self.sender.upgrade() { - let _ = sender.unbounded_send(is_offline); + if prev_notified.swap(new_connectivity, Ordering::AcqRel) == new_connectivity { + // We don't care about network changes here + return; + } + + log::info!( + "Connectivity changed: {}", + if new_connectivity { + "Connected" + } else { + "Offline" + } + ); + + let _ = notify_tx.unbounded_send(!new_connectivity); + }); + + state_update_handle = Some(update_task); } } - } + + log::trace!("Offline monitor exiting"); + }); + + Ok(MonitorHandle { + state, + _notify_tx: notify_tx, + }) } diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs index e2df51e4e1..76f9b1c315 100644 --- a/talpid-core/src/offline/mod.rs +++ b/talpid-core/src/offline/mod.rs @@ -1,5 +1,5 @@ use futures::channel::mpsc::UnboundedSender; -#[cfg(any(target_os = "linux", target_os = "windows"))] +#[cfg(not(target_os = "android"))] use talpid_routing::RouteManagerHandle; #[cfg(target_os = "android")] use talpid_types::android::AndroidContext; @@ -42,7 +42,7 @@ impl MonitorHandle { pub async fn spawn_monitor( sender: UnboundedSender<bool>, - #[cfg(any(target_os = "linux", target_os = "windows"))] route_manager: RouteManagerHandle, + #[cfg(not(target_os = "android"))] route_manager: RouteManagerHandle, #[cfg(target_os = "linux")] fwmark: Option<u32>, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<MonitorHandle, Error> { @@ -50,7 +50,7 @@ pub async fn spawn_monitor( Some( imp::spawn_monitor( sender, - #[cfg(any(target_os = "windows", target_os = "linux"))] + #[cfg(not(target_os = "android"))] route_manager, #[cfg(target_os = "linux")] fwmark, diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 86ad1c6dd2..f4c58b849c 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -31,7 +31,6 @@ use futures::{ #[cfg(target_os = "android")] use std::os::unix::io::RawFd; use std::{ - collections::HashSet, future::Future, io, net::IpAddr, @@ -271,7 +270,6 @@ impl TunnelStateMachine { let filtering_resolver = crate::resolver::start_resolver().await?; let route_manager = RouteManager::new( - HashSet::new(), #[cfg(target_os = "linux")] args.linux_ids.fwmark, #[cfg(target_os = "linux")] @@ -332,16 +330,12 @@ impl TunnelStateMachine { }); let offline_monitor = offline::spawn_monitor( offline_tx, - #[cfg(target_os = "linux")] - route_manager - .handle() - .map_err(Error::InitRouteManagerError)?, + #[cfg(not(target_os = "android"))] + route_manager.handle()?, #[cfg(target_os = "linux")] Some(args.linux_ids.fwmark), #[cfg(target_os = "android")] android_context, - #[cfg(target_os = "windows")] - route_manager.handle()?, ) .await .map_err(Error::OfflineMonitorError)?; diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml index a2f7e780be..27ffccce2b 100644 --- a/talpid-routing/Cargo.toml +++ b/talpid-routing/Cargo.toml @@ -13,7 +13,8 @@ err-derive = "0.3.1" futures = "0.3.15" ipnetwork = "0.16" log = "0.4" -tokio = { version = "1.8", features = ["process", "rt-multi-thread"] } +talpid-types = { path = "../talpid-types" } +tokio = { version = "1.8", features = ["process", "rt-multi-thread", "net"] } [target.'cfg(not(target_os="android"))'.dependencies] talpid-types = { path = "../talpid-types" } @@ -26,7 +27,11 @@ netlink-packet-route = "0.13" netlink-sys = "0.8.3" [target.'cfg(target_os = "macos")'.dependencies] -tokio-stream = { version = "0.1", features = ["io-util"] } +# TODO: The PF socket type isn't released yet +nix = { git = "https://github.com/nix-rust/nix", rev = "b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28", features = ["socket"] } +libc = "0.2" +bitflags = "1.2" +system-configuration = "0.5.1" [target.'cfg(windows)'.dependencies] diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index d8c65e80da..8985b4e394 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -21,7 +21,7 @@ mod imp; use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; #[cfg(target_os = "macos")] -pub use imp::{get_default_routes, listen_for_default_route_changes, PlatformError}; +pub use imp::{DefaultRouteEvent, PlatformError}; pub use imp::{Error, RouteManager}; diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs index 953f0901ca..b8ee26e480 100644 --- a/talpid-routing/src/unix/android.rs +++ b/talpid-routing/src/unix/android.rs @@ -1,6 +1,5 @@ -use crate::{imp::RouteManagerCommand, RequiredRoute}; +use crate::imp::RouteManagerCommand; use futures::{channel::mpsc, stream::StreamExt}; -use std::collections::HashSet; /// Stub error type for routing errors on Android. #[derive(Debug, err_derive::Error)] @@ -12,7 +11,7 @@ pub struct RouteManagerImpl {} impl RouteManagerImpl { #[allow(clippy::unused_async)] - pub async fn new(_required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> { + pub async fn new() -> Result<Self, Error> { Ok(RouteManagerImpl {}) } diff --git a/talpid-routing/src/unix/linux.rs b/talpid-routing/src/unix/linux.rs index a642bf5e3d..d2f98b0701 100644 --- a/talpid-routing/src/unix/linux.rs +++ b/talpid-routing/src/unix/linux.rs @@ -148,11 +148,7 @@ pub struct RouteManagerImpl { } impl RouteManagerImpl { - pub async fn new( - required_routes: HashSet<RequiredRoute>, - table_id: u32, - fwmark: u32, - ) -> Result<Self> { + pub async fn new(table_id: u32, fwmark: u32) -> Result<Self> { let (mut connection, handle, messages) = rtnetlink::new_connection().map_err(Error::Connect)?; @@ -179,7 +175,6 @@ impl RouteManagerImpl { }; monitor.clear_routing_rules().await?; - monitor.add_required_routes(required_routes).await?; Ok(monitor) } @@ -903,14 +898,13 @@ impl NetworkInterface { #[cfg(test)] mod test { use super::*; - use std::collections::HashSet; /// Tests if dropping inside a tokio runtime panics #[test] fn test_drop_in_executor() { let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); runtime.block_on(async { - let manager = RouteManagerImpl::new(HashSet::new(), 0, 0) + let manager = RouteManagerImpl::new(0, 0) .await .expect("Failed to initialize route manager"); std::mem::drop(manager); @@ -922,7 +916,7 @@ mod test { fn test_drop() { let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); let manager = runtime.block_on(async { - RouteManagerImpl::new(HashSet::new(), 1000, 1000) + RouteManagerImpl::new(1000, 1000) .await .expect("Failed to initialize route manager") }); diff --git a/talpid-routing/src/unix/macos.rs b/talpid-routing/src/unix/macos.rs deleted file mode 100644 index 893416a699..0000000000 --- a/talpid-routing/src/unix/macos.rs +++ /dev/null @@ -1,362 +0,0 @@ -use super::RouteManagerCommand; -use crate::{NetNode, Node, RequiredRoute, Route}; - -use futures::{ - channel::mpsc, - future, - stream::{FusedStream, Stream, StreamExt, TryStreamExt}, -}; -use ipnetwork::IpNetwork; -use std::{ - collections::HashSet, - io, - net::IpAddr, - process::{ExitStatus, Stdio}, -}; -use talpid_types::net::IpVersion; -use tokio::{io::AsyncBufReadExt, process::Command}; -use tokio_stream::wrappers::LinesStream; - -pub type Result<T> = std::result::Result<T, Error>; - -/// Errors that can happen in the macOS routing integration. -#[derive(err_derive::Error, Debug)] -#[error(no_from)] -pub enum Error { - /// Failed to add route. - #[error(display = "Failed to add route")] - FailedToAddRoute(#[error(source)] io::Error), - - /// Failed to remove route. - #[error(display = "Failed to remove route")] - FailedToRemoveRoute(#[error(source)] io::Error), - - /// Error while running "ip route". - #[error(display = "Error while running \"route get\"")] - FailedToRunRoute(#[error(source)] io::Error), - - /// Error while monitoring routes with `route -nv monitor` - #[error(display = "Error while running \"route -nv monitor\"")] - FailedToMonitorRoutes(#[error(source)] io::Error), - - /// Unexpected output from netstat - #[error(display = "Unexpected output from netstat")] - BadOutputFromNetstat, -} - -/// Route manager can be in 1 of 4 states - -/// - waiting for a route to be added or removed from the route table -/// - obtaining default routes -/// - applying changes to the route table -/// - shutting down -/// -/// Only the _shutting down_ state can be reached from all other states, but during normal -/// operation, the route manager will add all the required routes during startup and will start -/// waiting for changes to the route table. If any change is detected, it will stop listening for -/// new changes, obtain new default routes and reapply routes that should be routed through the -/// default nodes. Once the routes are reapplied, the route table changes are monitored again. -pub struct RouteManagerImpl { - default_destinations: HashSet<IpNetwork>, - applied_routes: HashSet<Route>, - v4_gateway: Option<Node>, - v6_gateway: Option<Node>, - connectivity_change: - Option<Box<dyn FusedStream<Item = std::io::Result<()>> + Unpin + Send + Sync>>, -} - -impl RouteManagerImpl { - pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { - let v4_gateway = Self::get_default_node(IpVersion::V4).await?; - let v6_gateway = Self::get_default_node(IpVersion::V6).await?; - - let monitor = listen_for_default_route_changes()?; - - let mut manager = Self { - default_destinations: HashSet::new(), - applied_routes: HashSet::new(), - connectivity_change: Some(Box::new(monitor.fuse())), - v4_gateway, - v6_gateway, - }; - - manager.add_required_routes(required_routes).await?; - - Ok(manager) - } - - pub(crate) async fn run(mut self, manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>) { - let mut manage_rx = manage_rx.fuse(); - let mut connectivity_change = self.connectivity_change.take().unwrap(); - - loop { - futures::select! { - command = manage_rx.next() => { - match command { - Some(RouteManagerCommand::Shutdown(tx)) => { - self.cleanup_routes().await; - let _ = tx.send(()); - return; - }, - - Some(RouteManagerCommand::AddRoutes(routes, result_tx)) => { - let result = self.add_required_routes(routes).await; - let _ = result_tx.send(result); - }, - Some(RouteManagerCommand::ClearRoutes) => { - self.cleanup_routes().await; - }, - None => { - break; - } - } - }, - - _result = connectivity_change.select_next_some() => { - let v4_gateway = Self::get_default_node(IpVersion::V4).await.unwrap_or(None); - let v6_gateway = Self::get_default_node(IpVersion::V6).await.unwrap_or(None); - - if v4_gateway != self.v4_gateway { - self.v4_gateway = v4_gateway; - self.apply_new_default_route(&self.v4_gateway, true).await; - } - - if v6_gateway != self.v6_gateway { - self.v6_gateway = v6_gateway; - self.apply_new_default_route(&self.v6_gateway, false).await; - } - }, - complete => { - break; - } - }; - } - self.cleanup_routes().await; - } - - async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { - let mut routes_to_apply = vec![]; - let mut default_destinations = HashSet::new(); - - for route in required_routes { - match route.node { - NetNode::DefaultNode => { - default_destinations.insert(route.prefix); - } - - NetNode::RealNode(node) => routes_to_apply.push(Route::new(node, route.prefix)), - } - } - - for route in routes_to_apply { - Self::add_route(&route).await?; - self.applied_routes.insert(route); - } - - for destination in default_destinations.iter() { - match (&self.v4_gateway, &self.v6_gateway, destination.is_ipv4()) { - (Some(gateway), _, true) | (_, Some(gateway), false) => { - let route = Route::new(gateway.clone(), *destination); - Self::add_route(&route).await?; - self.applied_routes.insert(route); - } - _ => (), - }; - } - - self.default_destinations = default_destinations; - - Ok(()) - } - - // Retrieves the node that's currently used to reach 0.0.0.0/0 - pub(crate) async fn get_default_node(ip_version: IpVersion) -> Result<Option<Node>> { - let ip_version_arg = match ip_version { - IpVersion::V4 => "-inet", - IpVersion::V6 => "-inet6", - }; - let mut cmd = Command::new("route"); - cmd.arg("-n").arg("get").arg(ip_version_arg).arg("default"); - - let output = cmd.output().await.map_err(Error::FailedToRunRoute)?; - let output = String::from_utf8(output.stdout).map_err(|e| { - log::error!("Failed to parse utf-8 bytes from output of netstat: {}", e); - Error::BadOutputFromNetstat - })?; - Ok(Self::parse_route(&output)) - } - - fn parse_route(route_output: &str) -> Option<Node> { - let mut address: Option<IpAddr> = None; - let mut device = None; - for line in route_output.lines() { - // we're looking for just 2 different lines: - // interface: utun0 - // gateway: 192.168.3.1 - let tokens: Vec<_> = line.split_whitespace().collect(); - if tokens.len() == 2 { - match tokens[0].trim() { - "interface:" => { - device = Some(tokens[1].to_string()); - } - "gateway:" => { - address = Self::parse_gateway_line(tokens[1]); - } - _ => continue, - } - } - } - - match (address, device) { - (Some(address), Some(device)) => Some(Node::new(address, device)), - (Some(address), None) => Some(Node::address(address)), - (None, Some(device)) => Some(Node::device(device)), - _ => None, - } - } - - fn parse_gateway_line(line: &str) -> Option<IpAddr> { - // IPv6 addresses may contain interfaces - // if line contains '%' it should be split off - line.split('%') - .next() - .and_then(|ip_str| ip_str.parse().ok()) - } - - async fn delete_route(destination: IpNetwork) -> Result<ExitStatus> { - let mut cmd = Command::new("route"); - cmd.arg("-q") - .arg("-n") - .arg("delete") - .arg(ip_vers(destination)) - .arg(destination.to_string()) - .stderr(Stdio::null()); - - cmd.status().await.map_err(Error::FailedToRemoveRoute) - } - - async fn add_route(route: &Route) -> Result<ExitStatus> { - let mut cmd = Command::new("route"); - cmd.arg("-q") - .arg("-n") - .arg("add") - .arg(ip_vers(route.prefix)) - .arg(route.prefix.to_string()); - - if let Some(addr) = route.node.get_address() { - cmd.arg("-gateway").arg(addr.to_string()); - } else if let Some(device) = route.node.get_device() { - cmd.arg("-interface").arg(device); - } - - cmd.status().await.map_err(Error::FailedToAddRoute) - } - - async fn cleanup_routes(&self) { - let destinations_to_remove = self - .applied_routes - .iter() - .map(|route| &route.prefix) - .chain(self.default_destinations.iter()); - - for destination in destinations_to_remove { - match Self::delete_route(*destination).await { - Ok(status) => { - if !status.success() { - log::debug!("Failed to remove route during shutdown"); - } - } - Err(e) => log::error!("Failed to remove route during shutdown: {}", e), - }; - } - } - - async fn apply_new_default_route(&self, new_node: &Option<Node>, v4: bool) { - for destination in self.default_destinations.iter() { - if destination.is_ipv4() == v4 { - let _ = Self::delete_route(*destination).await; - - if let Some(node) = new_node { - log::error!("Resetting default route for {}", destination); - match Self::add_route(&Route::new(node.clone(), *destination)).await { - Ok(status) => { - if !status.success() { - log::error!("Failed to reapply route"); - } - } - Err(e) => log::error!("Failed to reset route: {}", e), - } - } - } - } - } -} - -fn ip_vers(prefix: IpNetwork) -> &'static str { - if prefix.is_ipv4() { - "-inet" - } else { - "-inet6" - } -} - -/// Returns a tuple containing a IPv4 and IPv6 default route nodes. -pub async fn get_default_routes() -> Result<(Option<Node>, Option<Node>)> { - futures::try_join!( - RouteManagerImpl::get_default_node(IpVersion::V4), - RouteManagerImpl::get_default_node(IpVersion::V6) - ) -} - -/// Returns a stream that produces an item whenever a default route is either added or deleted from -/// the routing table. -pub fn listen_for_default_route_changes() -> Result<impl Stream<Item = std::io::Result<()>>> { - let mut cmd = Command::new("route"); - cmd.arg("-n") - .arg("monitor") - .arg("-") - .stderr(Stdio::null()) - .stdout(Stdio::piped()) - .stdin(Stdio::null()); - - let mut process = cmd.spawn().map_err(Error::FailedToMonitorRoutes)?; - let reader = tokio::io::BufReader::new(process.stdout.take().unwrap()); - let lines = reader.lines(); - - // route -n monitor will produce netlink messages in the following format - // ``` - // got message of size 176 on Thu Jun 4 10:08:05 2020 - // RTM_DELETE: Delete Route: len 176, pid: 109, seq 1151, errno 3, ifscope 23, - // flags:<UP,GATEWAY,STATIC,IFSCOPE> - // locks: inits: - // sockaddrs: <DST,GATEWAY,NETMASK,IFP,IFA> - // default 192.168.44.1 default 192.168.44.90 - // ``` - // On the second line of the message, the message type is specified. Only messages with the - // type 'RTM_ADD' or 'RTM_DELETE' are considered. On the 6th line, message attribute values are - // shown. To detect a change for a default route in the routing table, check whether this line - // contains 'default'. Whenever an empty line is encountered, the message has been sent, so - // the state can be reset. - - let mut add_or_delete_message = false; - let mut contains_default = false; - - let monitor = LinesStream::new(lines).try_filter_map(move |line| { - if add_or_delete_message { - if line.contains("default") { - contains_default = true; - } - if line.trim().is_empty() { - add_or_delete_message = false; - if contains_default { - contains_default = false; - return future::ready(Ok(Some(()))); - } - } - } else { - add_or_delete_message = line.starts_with("RTM_ADD:") || line.starts_with("RTM_DELETE:"); - } - future::ready(Ok(None)) - }); - - Ok(monitor) -} diff --git a/talpid-routing/src/unix/macos/data.rs b/talpid-routing/src/unix/macos/data.rs new file mode 100644 index 0000000000..1e955c5fa0 --- /dev/null +++ b/talpid-routing/src/unix/macos/data.rs @@ -0,0 +1,1147 @@ +use ipnetwork::IpNetwork; +use nix::{ + ifaddrs::InterfaceAddress, + sys::socket::{SockaddrLike, SockaddrStorage}, +}; +use std::{ + collections::BTreeMap, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, +}; + +/// Message that describes a route - either an added, removed, changed or plainly retrieved route. +#[derive(Debug, Clone, PartialEq)] +pub struct RouteMessage { + sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>, + route_flags: RouteFlag, + interface_index: u16, + errno: i32, +} + +impl RouteMessage { + pub fn new_route(destination: Destination) -> Self { + let mut route_flags = RouteFlag::RTF_STATIC | RouteFlag::RTF_DONE | RouteFlag::RTF_UP; + let mut sockaddrs = BTreeMap::new(); + match destination { + Destination::Network(net) => { + let dest_addr = SockaddrStorage::from(SocketAddr::from((net.ip(), 0))); + let destination = RouteSocketAddress::Destination(Some(dest_addr)); + let netmask = + RouteSocketAddress::Netmask(Some(SocketAddr::from((net.mask(), 0)).into())); + sockaddrs.insert(destination.address_flag(), destination); + sockaddrs.insert(netmask.address_flag(), netmask); + } + Destination::Host(addr) => { + let destination = + RouteSocketAddress::Destination(Some(SocketAddr::from((addr, 0)).into())); + route_flags |= RouteFlag::RTF_HOST; + sockaddrs.insert(destination.address_flag(), destination); + } + }; + + Self { + sockaddrs, + route_flags, + interface_index: 0, + errno: 0, + } + } + + pub fn route_addrs(&self) -> impl Iterator<Item = &RouteSocketAddress> { + self.sockaddrs.values() + } + + fn socketaddress_to_ip(sockaddr: &SockaddrStorage) -> Option<IpAddr> { + saddr_to_ipv4(sockaddr) + .map(Into::into) + .or_else(|| saddr_to_ipv6(sockaddr).map(Into::into)) + } + + pub fn netmask(&self) -> Option<IpAddr> { + self.route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Netmask(netmask) => Some(netmask), + _ => None, + })? + .as_ref() + .and_then(Self::socketaddress_to_ip) + } + + pub fn is_default(&self) -> Result<bool> { + Ok(self.is_default_v4()? || self.is_default_v6()?) + } + + pub fn is_default_v4(&self) -> Result<bool> { + let destination_is_default = self + .destination_v4()? + .map(|addr| addr == Ipv4Addr::UNSPECIFIED) + .unwrap_or(false); + let netmask = self.route_addrs().find_map(|saddr| match saddr { + RouteSocketAddress::Netmask(addr) => Some(addr), + _ => None, + }); + + // TODO: This might be superfluous + let netmask_is_default = match netmask { + // empty socket address implies that it is a 'default' netmask + Some(None) => true, + Some(Some(addr)) => { + if let Some(netmask_addr) = saddr_to_ipv4(addr) { + netmask_addr.is_unspecified() + } else if let Some(netmask_addr) = saddr_to_ipv6(addr) { + netmask_addr.is_unspecified() + } else { + // if the route socket address describing the netmask isn't a sockaddr_in or a + // sockaddr_in6, it can't possibly be a default route for IP + false + } + } + // absence of a netmask socket address implies that it is a host route + None => false, + }; + + Ok(destination_is_default && netmask_is_default) + } + + pub fn is_default_v6(&self) -> Result<bool> { + Ok(self + .destination_v6()? + .map(|addr| addr == Ipv6Addr::UNSPECIFIED) + .unwrap_or(false)) + } + + fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + let header: rt_msghdr = rt_msghdr::from_bytes(buffer)?; + + let msg_len = usize::from(header.rtm_msglen); + if msg_len > buffer.len() { + return Err(Error::BufferTooSmall( + "Message is shorter than it's msg_len indicates", + msg_len, + buffer.len(), + )); + } + + let payload = &buffer[ROUTE_MESSAGE_HEADER_SIZE..std::cmp::min(msg_len, buffer.len())]; + + let route_flags = RouteFlag::from_bits(header.rtm_flags) + .ok_or(Error::UnknownRouteFlag(header.rtm_flags))?; + let address_flags = AddressFlag::from_bits(header.rtm_addrs) + .ok_or(Error::UnknownAddressFlag(header.rtm_addrs))?; + if !address_flags.contains(AddressFlag::RTA_DST) { + return Err(Error::NoDestination); + } + let sockaddrs = RouteSockAddrIterator::new(payload, address_flags) + .map(|addr| addr.map(|a| (a.address_flag(), a))) + .collect::<Result<BTreeMap<_, _>>>()?; + let interface_index = header.rtm_index; + + Ok(Self { + route_flags, + sockaddrs, + interface_index, + errno: header.rtm_errno, + }) + } + + fn insert_sockaddr(&mut self, saddr: RouteSocketAddress) { + self.sockaddrs.insert(saddr.address_flag(), saddr); + } + + pub fn set_destination(mut self, destination: Destination) -> Self { + match destination { + Destination::Network(net) => { + let sockaddr: SocketAddr = (net.ip(), 0).into(); + let netmask: SocketAddr = (net.mask(), 0).into(); + self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into()))); + self.insert_sockaddr(RouteSocketAddress::Netmask(Some(netmask.into()))); + self.route_flags.remove(RouteFlag::RTF_HOST); + } + Destination::Host(addr) => { + self.route_flags.insert(RouteFlag::RTF_HOST); + let sockaddr: SocketAddr = (addr, 0).into(); + self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into()))); + } + }; + + self + } + + pub fn set_interface_addr(mut self, link: &InterfaceAddress) -> Self { + self.insert_sockaddr(RouteSocketAddress::Gateway(link.address)); + self.route_flags |= RouteFlag::RTF_GATEWAY; + self + } + + pub fn set_gateway_sockaddr(mut self, sockaddr: SockaddrStorage) -> Self { + self.insert_sockaddr(RouteSocketAddress::Gateway(Some(sockaddr))); + self.route_flags |= RouteFlag::RTF_GATEWAY; + self + } + + pub fn set_gateway_addr(mut self, addr: IpAddr) -> Self { + let gateway: SocketAddr = (addr, 0).into(); + self.insert_sockaddr(RouteSocketAddress::Gateway(Some(gateway.into()))); + self.route_flags |= RouteFlag::RTF_GATEWAY; + + self + } + + pub fn set_gateway_route(mut self, is_gateway_route: bool) -> Self { + if is_gateway_route { + self.route_flags.insert(RouteFlag::RTF_GATEWAY); + } else { + self.route_flags.remove(RouteFlag::RTF_GATEWAY); + } + self + } + + pub fn route_flag(mut self, route_flags: RouteFlag) -> Self { + self.route_flags = route_flags; + self + } + + pub fn gateway(&self) -> Option<&SockaddrStorage> { + self.route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Gateway(gateway) => Some(gateway), + _ => None, + })? + .as_ref() + } + + pub fn gateway_ip(&self) -> Option<IpAddr> { + self.gateway_v4() + .map(IpAddr::V4) + .or(self.gateway_v6().map(IpAddr::V6)) + } + + pub fn gateway_v4(&self) -> Option<Ipv4Addr> { + saddr_to_ipv4(self.gateway()?) + } + + pub fn gateway_v6(&self) -> Option<Ipv6Addr> { + saddr_to_ipv6(self.gateway()?) + } + + pub fn destination_ip(&self) -> Result<IpNetwork> { + if let Some(saddr) = self.destination()? { + if let Some(v4) = saddr.as_sockaddr_in() { + let ip_addr = *SocketAddrV4::from(*v4).ip(); + let netmask = self.netmask().unwrap_or(Ipv4Addr::UNSPECIFIED.into()); + let destination = IpNetwork::with_netmask(ip_addr.into(), netmask) + .map_err(Error::InvalidNetmask)?; + return Ok(destination); + } + + if let Some(v6) = saddr.as_sockaddr_in6() { + let ip_addr = *SocketAddrV6::from(*v6).ip(); + let netmask = self.netmask().unwrap_or(Ipv6Addr::UNSPECIFIED.into()); + let destination = IpNetwork::with_netmask(ip_addr.into(), netmask) + .map_err(Error::InvalidNetmask)?; + return Ok(destination); + } + + return Err(Error::MismatchedSocketAddress( + AddressFlag::RTA_DST, + Box::new(*saddr), + )); + } + Err(Error::NoDestination) + } + + pub fn destination(&self) -> Result<Option<&SockaddrStorage>> { + Ok(self + .route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Destination(destination) => Some(destination), + _ => None, + }) + .ok_or(Error::NoDestination)? + .as_ref()) + } + + pub fn destination_v4(&self) -> Result<Option<Ipv4Addr>> { + Ok(self.destination()?.and_then(saddr_to_ipv4)) + } + + pub fn destination_v6(&self) -> Result<Option<Ipv6Addr>> { + Ok(self.destination()?.and_then(saddr_to_ipv6)) + } + + pub fn flags(&self) -> &RouteFlag { + &self.route_flags + } + + pub fn payload( + &self, + message_type: MessageType, + sequence: i32, + pid: i32, + ) -> (rt_msghdr, Vec<Vec<u8>>) { + let address_flags = self.route_addrs().fold(AddressFlag::empty(), |flag, addr| { + flag | addr.address_flag() + }); + + // The sockaddrs should be ordered by their address flag in the payload, + // because the payload does not contain their flags. Flags are only specified + // in the header. + let mut sockaddrs = self.route_addrs().collect::<Vec<_>>(); + sockaddrs.sort_by_key(|saddr| saddr.address_flag()); + let payload_bytes = sockaddrs + .into_iter() + .map(RouteSocketAddress::to_bytes) + .collect::<Vec<_>>(); + + let payload_len: usize = payload_bytes.iter().map(Vec::len).sum(); + + let rtm_msglen = (payload_len + ROUTE_MESSAGE_HEADER_SIZE) + .try_into() + .expect("route message buffer size cannot fit in 32 bits"); + + let header = super::data::rt_msghdr { + rtm_msglen, + rtm_version: libc::RTM_VERSION.try_into().unwrap(), + rtm_type: message_type.bits(), + rtm_index: self.interface_index, + rtm_flags: self.route_flags.bits(), + rtm_addrs: address_flags.bits(), + rtm_pid: pid, + rtm_seq: sequence, + rtm_errno: 0, + rtm_use: 0, + rtm_inits: 0, + rtm_rmx: Default::default(), + }; + + (header, payload_bytes) + } + + pub fn interface_index(&self) -> u16 { + self.interface_index + } + + pub fn interface_address(&self) -> Option<IpAddr> { + self.get_address(&AddressFlag::RTA_IFA) + } + + fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> { + let addr = self.sockaddrs.get(address_flag)?; + saddr_to_ipv4(addr.inner()?) + .map(IpAddr::from) + .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from)) + } + + pub fn interface_sockaddr_index(&self) -> Option<u16> { + self.sockaddrs + .values() + .find_map(|addr| addr.interface_index()) + } + + pub fn errno(&self) -> i32 { + self.errno + } + + pub fn is_ipv4(&self) -> bool { + self.destination_v4() + .map(|addr| addr.is_some()) + .unwrap_or(false) + } + + pub fn is_ipv6(&self) -> bool { + self.destination_v6() + .map(|addr| addr.is_some()) + .unwrap_or(false) + } + + pub fn is_ifscope(&self) -> bool { + self.route_flags.contains(RouteFlag::RTF_IFSCOPE) + } + + pub fn ifscope(&self) -> Option<u16> { + if self.is_ifscope() { + Some(self.interface_index) + } else { + None + } + } + + pub fn unset_ifscope(mut self) -> Self { + self.route_flags.remove(RouteFlag::RTF_IFSCOPE); + self + } + + pub fn set_ifscope(mut self, iface_index: u16) -> Self { + if iface_index > 0 { + self.interface_index = iface_index; + self.route_flags.insert(RouteFlag::RTF_IFSCOPE); + } else { + self.route_flags.remove(RouteFlag::RTF_IFSCOPE); + } + + self + } +} + +#[derive(Debug)] +#[repr(C)] +struct ifa_msghdr { + ifam_msglen: libc::c_ushort, + ifam_version: libc::c_uchar, + ifam_type: libc::c_uchar, + ifam_addrs: libc::c_int, + ifam_flags: libc::c_int, + ifam_index: libc::c_ushort, + ifam_metric: libc::c_int, +} + +#[derive(Debug)] +pub struct AddressMessage { + sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>, + interface_index: u16, +} + +impl AddressMessage { + pub fn index(&self) -> u16 { + self.interface_index + } + + pub fn address(&self) -> Result<IpAddr> { + self.get_address(&AddressFlag::RTA_IFP) + .or_else(|| self.get_address(&AddressFlag::RTA_IFA)) + .ok_or(Error::NoInterfaceAddress) + } + + fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> { + let addr = self.sockaddrs.get(address_flag)?; + saddr_to_ipv4(addr.inner()?) + .map(IpAddr::from) + .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from)) + } + + pub fn netmask(&self) -> Result<IpAddr> { + self.get_address(&AddressFlag::RTA_NETMASK) + .ok_or(Error::NoNetmaskAddress) + } + + pub fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + const HEADER_SIZE: usize = std::mem::size_of::<ifa_msghdr>(); + if HEADER_SIZE > buffer.len() { + return Err(Error::BufferTooSmall( + "ifa_msghdr", + buffer.len(), + HEADER_SIZE, + )); + } + + // SAFETY: buffer is pointing to enough memory to contain a valid value for ifa_msghdr + let header: ifa_msghdr = unsafe { std::ptr::read_unaligned(buffer.as_ptr() as *const _) }; + + let msg_len = usize::from(header.ifam_msglen); + if msg_len > buffer.len() { + return Err(Error::BufferTooSmall( + "Message is shorter than it's msg_len indicates", + msg_len, + buffer.len(), + )); + } + + let payload = &buffer[HEADER_SIZE..std::cmp::min(msg_len, buffer.len())]; + + let address_flags = AddressFlag::from_bits(header.ifam_addrs) + .ok_or(Error::UnknownAddressFlag(header.ifam_addrs))?; + + let sockaddrs = RouteSockAddrIterator::new(payload, address_flags) + .map(|addr| addr.map(|addr| (addr.address_flag(), addr))) + .collect::<Result<BTreeMap<_, _>>>()?; + + Ok(Self { + sockaddrs, + interface_index: header.ifam_index, + }) + } +} + +#[derive(Debug)] +pub enum RouteSocketMessage { + AddRoute(RouteMessage), + DeleteRoute(RouteMessage), + ChangeRoute(RouteMessage), + GetRoute(RouteMessage), + Interface(Interface), + AddAddress(AddressMessage), + DeleteAddress(AddressMessage), + Other { + header: rt_msghdr_short, + payload: Vec<u8>, + }, + Error { + header: rt_msghdr_short, + payload: Vec<u8>, + }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Destination { + Host(IpAddr), + Network(IpNetwork), +} + +impl Destination { + pub fn is_network(&self) -> bool { + matches!(self, Self::Network(_)) + } + + pub fn default_v4() -> Self { + Destination::Network(IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap()) + } + + pub fn default_v6() -> Self { + Destination::Network(IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap()) + } +} + +impl From<IpAddr> for Destination { + fn from(addr: IpAddr) -> Self { + Self::Host(addr) + } +} + +impl From<IpNetwork> for Destination { + fn from(net: IpNetwork) -> Self { + if net.prefix() == 32 && net.is_ipv4() { + return Self::Host(net.ip()); + } + + Self::Network(net) + } +} + +#[derive(Debug)] +pub enum Error { + /// Payload buffer didn't match the reported message size in header + InvalidBuffer(Vec<u8>, AddressFlag), + /// Buffer too small for specific message type + BufferTooSmall(&'static str, usize, usize), + /// Unknown route flag + UnknownRouteFlag(i32), + /// Socket address is empty for the given address flag + EmptySockaddr(AddressFlag), + /// Unrecognized message + UnknownMessageType(u8), + /// Unrecognized address flag + UnknownAddressFlag(libc::c_int), + /// Mismatched socket address type + MismatchedSocketAddress(AddressFlag, Box<SockaddrStorage>), + /// Link socket address contains no identifier + NoLinkIdentifier(nix::libc::sockaddr_dl), + /// Failed to resolve an interface name to an index + InterfaceIndex(nix::Error), + /// Invalid netmask + InvalidNetmask(ipnetwork::IpNetworkError), + /// Route contains no netmask socket address + NoDestination, + /// Address message does not contain an interface address + NoInterfaceAddress, + /// Address message does not contain an interface address + NoNetmaskAddress, +} + +type Result<T> = std::result::Result<T, Error>; + +impl RouteSocketMessage { + pub fn parse_message(buffer: &[u8]) -> Result<Self> { + let route_message = |route_constructor: fn(RouteMessage) -> RouteSocketMessage, buffer| { + let route = RouteMessage::from_byte_buffer(buffer)?; + Ok(route_constructor(route)) + }; + + match rt_msghdr_short::from_bytes(buffer) { + Some(header) if header.is_type(libc::RTM_ADD) => route_message(Self::AddRoute, buffer), + + Some(header) if header.is_type(libc::RTM_CHANGE) => { + route_message(Self::ChangeRoute, buffer) + } + + Some(header) if header.is_type(libc::RTM_DELETE) => { + route_message(Self::DeleteRoute, buffer) + } + + Some(header) if header.is_type(libc::RTM_GET) => route_message(Self::GetRoute, buffer), + + Some(header) if header.is_type(libc::RTM_IFINFO) => Ok(RouteSocketMessage::Interface( + Interface::from_byte_buffer(buffer)?, + )), + + Some(header) if header.is_type(libc::RTM_NEWADDR) => Ok( + RouteSocketMessage::AddAddress(AddressMessage::from_byte_buffer(buffer)?), + ), + Some(header) if header.is_type(libc::RTM_DELADDR) => Ok( + RouteSocketMessage::DeleteAddress(AddressMessage::from_byte_buffer(buffer)?), + ), + Some(header) => Ok(Self::Other { + header, + payload: buffer.to_vec(), + }), + None => Err(Error::BufferTooSmall( + "rt_msghdr_short", + buffer.len(), + ROUTE_MESSAGE_HEADER_SHORT_SIZE, + )), + } + } +} + +#[derive(Debug)] +pub struct Interface { + header: libc::if_msghdr, +} + +impl Interface { + pub fn is_up(&self) -> bool { + self.header.ifm_flags & nix::libc::IFF_UP != 0 + } + + pub fn index(&self) -> u16 { + self.header.ifm_index + } + + fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + const INTERFACE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<libc::if_msghdr>(); + if INTERFACE_MESSAGE_HEADER_SIZE > buffer.len() { + return Err(Error::BufferTooSmall( + "if_msghdr", + buffer.len(), + INTERFACE_MESSAGE_HEADER_SIZE, + )); + } + let header: libc::if_msghdr = unsafe { std::ptr::read(buffer.as_ptr() as *const _) }; + // let payload = buffer[INTERFACE_MESSAGE_HEADER_SIZE..header.ifm_msglen.into()].to_vec(); + Ok(Self { header }) + } +} + +// #define RTA_DST 0x1 /* destination sockaddr present */ +// #define RTA_GATEWAY 0x2 /* gateway sockaddr present */ +// #define RTA_NETMASK 0x4 /* netmask sockaddr present */ +// #define RTA_GENMASK 0x8 /* cloning mask sockaddr present */ +// #define RTA_IFP 0x10 /* interface name sockaddr present */ +// #define RTA_IFA 0x20 /* interface addr sockaddr present */ +// #define RTA_AUTHOR 0x40 /* sockaddr for author of redirect */ +// #define RTA_BRD 0x80 /* for NEWADDR, broadcast or p-p dest addr */ +bitflags::bitflags! { + /// All enum values of address flags can be iterated via `flag <<= 1`, starting from 1. + /// See https://www.manpagez.com/man/4/route/. + pub struct AddressFlag: i32 { + /// Destination socket address + const RTA_DST = 0x1; + /// Gateway socket address + const RTA_GATEWAY = 0x2; + /// Netmask socket address + const RTA_NETMASK = 0x4; + /// Cloning mask socket address + const RTA_GENMASK = 0x8; + /// Interface name socket address + const RTA_IFP = 0x10; + /// Interface address socket address + const RTA_IFA = 0x20; + /// Socket address for author of redirect + const RTA_AUTHOR = 0x40; + /// Socket address for `NEWADDR`, broadcast or point-to-point destination address + const RTA_BRD = 0x80; + } +} + +bitflags::bitflags! { + /// Types of routing messages + /// See https://www.manpagez.com/man/4/route/. + pub struct MessageType: u8 { + /// Add Route + const RTM_ADD = 0x1; + /// Delete Route + const RTM_DELETE = 0x2; + /// Change Metrics or flags + const RTM_CHANGE = 0x3; + /// Report Metrics + const RTM_GET = 0x4; + /// RTM_LOSING is no longer generated by and is deprecated + const RTM_LOSING = 0x5; + /// Told to use different route + const RTM_REDIRECT = 0x6; + /// Lookup failed on this address + const RTM_MISS = 0x7; + /// fix specified metrics + const RTM_LOCK = 0x8; + /// caused by SIOCADDRT + const RTM_OLDADD = 0x9; + /// caused by SIOCDELRT + const RTM_OLDDEL = 0xa; + /// req to resolve dst to LL addr + const RTM_RESOLVE = 0xb; + /// address being added to iface + const RTM_NEWADDR = 0xc; + /// address being removed from iface + const RTM_DELADDR = 0xd; + /// iface going up/down etc. + const RTM_IFINFO = 0xe; + /// mcast group membership being added to if + const RTM_NEWMADDR = 0xf; + /// mcast group membership being deleted + const RTM_DELMADDR = 0x10; + } + + /// Routing message flags + /// See https://www.manpagez.com/man/4/route/. + pub struct RouteFlag: i32 { + /// route usable + const RTF_UP = 0x1; + /// destination is a gateway + const RTF_GATEWAY = 0x2; + /// host entry (net otherwise) + const RTF_HOST = 0x4; + /// host or net unreachable + const RTF_REJECT = 0x8; + /// created dynamically (by redirect) + const RTF_DYNAMIC = 0x10; + /// modified dynamically (by redirect) + const RTF_MODIFIED = 0x20; + /// message confirmed + const RTF_DONE = 0x40; + /// delete cloned route + const RTF_DELCLONE = 0x80; + /// generate new routes on use + const RTF_CLONING = 0x100; + /// external daemon resolves name + const RTF_XRESOLVE = 0x200; + /// DEPRECATED - exists ONLY for backwards compatibility + const RTF_LLINFO = 0x400; + /// used by apps to add/del L2 entries + const RTF_LLDATA = 0x400; + /// manually added + const RTF_STATIC = 0x800; + /// just discard pkts (during updates) + const RTF_BLACKHOLE = 0x1000; + /// not eligible for RTF_IFREF + const RTF_NOIFREF = 0x2000; + /// protocol specific routing flag + const RTF_PROTO2 = 0x4000; + /// protocol specific routing flag + const RTF_PROTO1 = 0x8000; + /// protocol requires cloning + const RTF_PRCLONING = 0x10000; + /// route generated through cloning + const RTF_WASCLONED = 0x20000; + /// protocol specific routing flag + const RTF_PROTO3 = 0x40000; + /// future use + const RTF_PINNED = 0x100000; + /// route represents a local address + const RTF_LOCAL = 0x200000; + /// route represents a bcast address + const RTF_BROADCAST = 0x400000; + /// route represents a mcast address + const RTF_MULTICAST = 0x800000; + /// has valid interface scope + const RTF_IFSCOPE = 0x1000000; + /// defunct; no longer modifiable + const RTF_CONDEMNED = 0x2000000; + /// route holds a ref to interface + const RTF_IFREF = 0x4000000; + /// proxying, no interface scope + const RTF_PROXY = 0x8000000; + /// host is a router + const RTF_ROUTER = 0x10000000; + /// Route entry is being freed + const RTF_DEAD = 0x20000000; + /// route to destination of the global internet + const RTF_GLOBAL = 0x40000000; + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RouteSocketAddress { + /// Corresponds to RTA_DST + Destination(Option<SockaddrStorage>), + /// RTA_GATEWAY + Gateway(Option<SockaddrStorage>), + /// RTA_NETMASK + Netmask(Option<SockaddrStorage>), + /// RTA_GENMASK + CloningMask(Option<SockaddrStorage>), + /// RTA_IFP + IfName(Option<SockaddrStorage>), + /// RTA_IFA + IfSockaddr(Option<SockaddrStorage>), + /// RTA_AUTHOR + RedirectAuthor(Option<SockaddrStorage>), + /// RTA_BRD + Broadcast(Option<SockaddrStorage>), +} + +impl RouteSocketAddress { + // Returns a new route socket address and number of bytes read from the buffer + pub fn new(flag: AddressFlag, buf: &[u8]) -> Result<(Self, u8)> { + // If buffer is empty, then the socket address is empty too, the backing buffer shouldn't + // be advanced. + if buf.is_empty() { + return Ok((Self::with_sockaddr(flag, None)?, 0)); + } + + // to get the length and type of + if buf.len() < std::mem::size_of::<sockaddr_hdr>() { + return Err(Error::BufferTooSmall( + "sockaddr buffer too small", + buf.len(), + std::mem::size_of::<sockaddr_hdr>(), + )); + } + + let addr_header_ptr = buf.as_ptr() as *const sockaddr_hdr; + // SAFETY: Since `buf` is at least as long as a `sockaddr_hdr`, it's perfectly valid to + // read from. + let addr_header = unsafe { std::ptr::read(addr_header_ptr) }; + let saddr_len = addr_header.sa_len; + if saddr_len == 0 { + return Ok((Self::with_sockaddr(flag, None)?, 4)); + } + + if Into::<usize>::into(saddr_len) > buf.len() { + return Err(Error::InvalidBuffer(buf.to_vec(), flag)); + } + + // SAFETY: the buffer is big enough for the sockaddr struct inside it, so accessing as a + // `sockaddr` is valid. + let saddr = unsafe { + SockaddrStorage::from_raw( + addr_header_ptr as *const nix::libc::sockaddr, + Some(saddr_len.into()), + ) + }; + + Ok((Self::with_sockaddr(flag, saddr)?, saddr_len)) + } + + pub fn to_bytes(&self) -> Vec<u8> { + match self.inner() { + None => vec![0u8; 4], + Some(addr) => { + let len = usize::try_from(addr.len()).unwrap(); + assert!(len >= 4); + + // The "serialized" socket addresses must be padded to be aligned to 4 bytes, with + // the smallest size being 4 bytes. + let buffer_size = len + len % 4; + let mut buffer = vec![0u8; buffer_size]; + unsafe { + // SAFETY: copying conents of addr into buffer is safe, as long as addr.len() + // returns a correct size for the socket address pointer. + std::ptr::copy_nonoverlapping( + addr.as_ptr() as *const _, + buffer.as_mut_ptr(), + len, + ); + } + buffer + } + } + } + + pub fn address_flag(&self) -> AddressFlag { + match &self { + Self::Destination(_) => AddressFlag::RTA_DST, + Self::Gateway(_) => AddressFlag::RTA_GATEWAY, + Self::Netmask(_) => AddressFlag::RTA_NETMASK, + Self::CloningMask(_) => AddressFlag::RTA_GENMASK, + Self::IfName(_) => AddressFlag::RTA_IFP, + Self::IfSockaddr(_) => AddressFlag::RTA_IFA, + Self::RedirectAuthor(_) => AddressFlag::RTA_AUTHOR, + Self::Broadcast(_) => AddressFlag::RTA_BRD, + } + } + + pub fn inner(&self) -> Option<&SockaddrStorage> { + match &self { + Self::Gateway(addr) + | Self::Destination(addr) + | Self::Netmask(addr) + | Self::CloningMask(addr) + | Self::IfName(addr) + | Self::IfSockaddr(addr) + | Self::RedirectAuthor(addr) + | Self::Broadcast(addr) => addr.as_ref(), + } + } + + fn with_sockaddr(flag: AddressFlag, sockaddr: Option<SockaddrStorage>) -> Result<Self> { + let constructor = match flag { + AddressFlag::RTA_GATEWAY => Self::Gateway, + AddressFlag::RTA_DST => Self::Destination, + AddressFlag::RTA_NETMASK => Self::Netmask, + AddressFlag::RTA_GENMASK => Self::CloningMask, + AddressFlag::RTA_IFP => Self::IfName, + AddressFlag::RTA_IFA => Self::IfSockaddr, + AddressFlag::RTA_AUTHOR => Self::RedirectAuthor, + AddressFlag::RTA_BRD => Self::Broadcast, + unknown => return Err(Error::UnknownAddressFlag(unknown.bits())), + }; + + Ok(constructor(sockaddr)) + } + + pub fn interface_index(&self) -> Option<u16> { + match self { + Self::IfName(Some(iface)) => { + let index = iface.as_link_addr()?.ifindex(); + Some( + u16::try_from(index) + .expect("interface indexes actually are u16s, nix is just *interesting*"), + ) + } + _ => None, + } + } +} + +/// Route socket addreses should be ordered by their corresponding address flag when a route +/// message is constructed +impl std::cmp::PartialOrd for RouteSocketAddress { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + self.address_flag().partial_cmp(&other.address_flag()) + } +} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +struct sockaddr_hdr { + sa_len: u8, + sa_family: libc::sa_family_t, + padding: u16, +} + +/// An iterator to consume a byte buffer containing socket address structures originating from a +/// routing socket message. +pub struct RouteSockAddrIterator<'a> { + buffer: &'a [u8], + flags: AddressFlag, + // Cursor used to iterate through address flags + flag_cursor: i32, +} + +impl<'a> RouteSockAddrIterator<'a> { + fn new(buffer: &'a [u8], flags: AddressFlag) -> Self { + Self { + buffer, + flags, + flag_cursor: AddressFlag::RTA_DST.bits(), + } + } + + /// Advances internal byte buffer by given amount. The byte amount will be padded to be + /// aligned to 4 bytes if there's more data in the buffer. + fn advance_buffer(&mut self, saddr_len: u8) { + let saddr_len = usize::from(saddr_len); + + // if consumed as many bytes as are left in the buffer, the buffer can be cleared + if saddr_len == self.buffer.len() { + self.buffer = &[]; + return; + } + + let padded_saddr_len = if saddr_len % 4 != 0 { + saddr_len + (4 - saddr_len % 4) + } else { + saddr_len + }; + + // if offset is larger than current buffer, ensure slice gets truncated + // since the socket address should've already be read from the buffer at this point, this + // probably should be an invariant? + self.buffer = &self.buffer[padded_saddr_len..]; + } +} + +impl<'a> Iterator for RouteSockAddrIterator<'a> { + type Item = Result<RouteSocketAddress>; + + fn next(&mut self) -> Option<Self::Item> { + loop { + // If address flags don't contain the current one, try the next one. + // Will return None if it runs out of valid flags. + let current_flag = AddressFlag::from_bits(self.flag_cursor)?; + self.flag_cursor <<= 1; + + if !self.flags.contains(current_flag) { + continue; + } + return match RouteSocketAddress::new(current_flag, self.buffer) { + Ok((next_addr, addr_len)) => { + self.advance_buffer(addr_len); + Some(Ok(next_addr)) + } + Err(err) => { + self.buffer = &[]; + Some(Err(err)) + } + }; + } + } +} + +// struct rt_msghdr { +// u_short rtm_msglen; /* to skip over non-understood messages */ +// u_char rtm_version; /* future binary compatibility */ +// u_char rtm_type; /* message type */ +// u_short rtm_index; /* index for associated ifp */ +// int rtm_flags; /* flags, incl. kern & message, e.g. DONE */ +// int rtm_addrs; /* bitmask identifying sockaddrs in msg */ +// pid_t rtm_pid; /* identify sender */ +// int rtm_seq; /* for sender to identify action */ +// int rtm_errno; /* why failed */ +// int rtm_use; /* from rtentry */ +// u_int32_t rtm_inits; /* which metrics we are initializing */ +// struct rt_metrics rtm_rmx; /* metrics themselves */ +// }; +#[derive(Debug, Clone)] +#[repr(C)] +pub struct rt_msghdr { + pub rtm_msglen: libc::c_ushort, + pub rtm_version: libc::c_uchar, + pub rtm_type: libc::c_uchar, + pub rtm_index: libc::c_ushort, + pub rtm_flags: libc::c_int, + pub rtm_addrs: libc::c_int, + pub rtm_pid: libc::pid_t, + pub rtm_seq: libc::c_int, + pub rtm_errno: libc::c_int, + pub rtm_use: libc::c_int, + pub rtm_inits: u32, + pub rtm_rmx: rt_metrics, +} +const ROUTE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<rt_msghdr>(); + +fn saddr_to_ipv4(saddr: &SockaddrStorage) -> Option<Ipv4Addr> { + let addr = saddr.as_sockaddr_in()?; + Some(*SocketAddrV4::from(*addr).ip()) +} + +fn saddr_to_ipv6(saddr: &SockaddrStorage) -> Option<Ipv6Addr> { + let addr = saddr.as_sockaddr_in6()?; + Some(*SocketAddrV6::from(*addr).ip()) +} + +impl rt_msghdr { + pub fn from_bytes(buf: &[u8]) -> Result<Self> { + if buf.len() >= ROUTE_MESSAGE_HEADER_SIZE { + let ptr = buf.as_ptr(); + // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr value and it's + // readable. rt_msghdr doesn't contain any pointers so any values are valid. + Ok(unsafe { std::ptr::read(ptr as *const _) }) + } else { + Err(Error::BufferTooSmall( + "if_msghdr", + buf.len(), + ROUTE_MESSAGE_HEADER_SIZE, + )) + } + } +} + +/// Shorter rt_msghdr version that matches all routing messages +#[derive(Debug)] +#[repr(C)] +pub struct rt_msghdr_short { + pub rtm_msglen: libc::c_ushort, + pub rtm_version: libc::c_uchar, + pub rtm_type: libc::c_uchar, + pub rtm_index: libc::c_ushort, + pub rtm_flags: libc::c_int, + pub rtm_addrs: libc::c_int, + pub rtm_pid: libc::pid_t, + pub rtm_seq: libc::c_int, + pub rtm_errno: libc::c_int, +} +const ROUTE_MESSAGE_HEADER_SHORT_SIZE: usize = std::mem::size_of::<rt_msghdr_short>(); + +impl rt_msghdr_short { + fn is_type(&self, expected_type: i32) -> bool { + u8::try_from(expected_type) + .map(|expected| self.rtm_type == expected) + .unwrap_or(false) + } + + pub fn from_bytes(buf: &[u8]) -> Option<Self> { + if buf.len() >= ROUTE_MESSAGE_HEADER_SHORT_SIZE { + let ptr = buf.as_ptr(); + // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr_short value and + // is readable. `rt_msghdr_short` doesn't contain any pointers so any values are valid. + Some(unsafe { std::ptr::read(ptr as *const rt_msghdr_short) }) + } else { + None + } + } +} + +#[derive(PartialEq, PartialOrd, Ord, Eq, Clone)] +pub struct RouteDestination { + pub network: IpNetwork, + pub interface: Option<u16>, + pub gateway: Option<IpAddr>, +} + +impl TryFrom<&RouteMessage> for RouteDestination { + type Error = Error; + + fn try_from(msg: &RouteMessage) -> std::result::Result<Self, Self::Error> { + let network = msg.destination_ip()?; + let interface = msg.ifscope(); + let gateway = msg.gateway_ip(); + Ok(Self { + network, + interface, + gateway, + }) + } +} + +// Struct containing metrics of various metrics for a specific route +// struct rt_metrics { +// u_int32_t rmx_locks; /* Kernel leaves these values alone */ +// u_int32_t rmx_mtu; /* MTU for this path */ +// u_int32_t rmx_hopcount; /* max hops expected */ +// int32_t rmx_expire; /* lifetime for route, e.g. redirect */ +// u_int32_t rmx_recvpipe; /* inbound delay-bandwidth product */ +// u_int32_t rmx_sendpipe; /* outbound delay-bandwidth product */ +// u_int32_t rmx_ssthresh; /* outbound gateway buffer limit */ +// u_int32_t rmx_rtt; /* estimated round trip time */ +// u_int32_t rmx_rttvar; /* estimated rtt variance */ +// u_int32_t rmx_pksent; /* packets sent using this route */ +// u_int32_t rmx_state; /* route state */ +// u_int32_t rmx_filler[3]; /* will be used for TCP's peer-MSS cache */ +// }; +#[derive(Debug, Default, Clone)] +#[repr(C)] +pub struct rt_metrics { + pub rmx_locks: u32, + pub rmx_mtu: u32, + pub rmx_hopcount: u32, + pub rmx_expire: i32, + pub rmx_recvpipe: u32, + pub rmx_sendpipe: u32, + pub rmx_ssthresh: u32, + pub rmx_rtt: u32, + pub rmx_rttvar: u32, + pub rmx_pksent: u32, + pub rmx_state: u32, + pub rmx_filler: [u32; 3], +} + +#[test] +fn test_failing_rtmsg() { + let bytes = [ + 135, 0, 5, 1, 11, 0, 0, 0, 1, 1, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 16, 2, 0, 0, 192, 168, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 18, 11, 0, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 255, 255, 255, 255, 255, 255, + ]; + let _ = RouteSocketMessage::parse_message(&bytes).unwrap(); +} diff --git a/talpid-routing/src/unix/macos/interface.rs b/talpid-routing/src/unix/macos/interface.rs new file mode 100644 index 0000000000..8ec4863ab7 --- /dev/null +++ b/talpid-routing/src/unix/macos/interface.rs @@ -0,0 +1,78 @@ +use nix::net::if_::if_nametoindex; +use std::ffi::CString; +use system_configuration::{ + core_foundation::string::CFString, + network_configuration::{SCNetworkService, SCNetworkSet}, + preferences::SCPreferences, +}; + +use super::{ + data::{Destination, RouteMessage}, + watch::RoutingTable, +}; + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum Family { + V4, + V6, +} + +/// Attempt to retrieve the best current default route. +/// Note: The tunnel interface is not even listed in the service order, so it will be skipped. +pub async fn get_best_default_route( + routing_table: &mut RoutingTable, + family: Family, +) -> Option<RouteMessage> { + let destination = match family { + Family::V4 => super::v4_default(), + Family::V6 => super::v6_default(), + }; + + let mut msg = RouteMessage::new_route(Destination::Network(destination)); + msg = msg.set_gateway_route(true); + + for iface in network_service_order() { + let iface_bytes = match CString::new(iface.as_bytes()) { + Ok(name) => name, + Err(error) => { + log::error!("Invalid interface name: {iface}, {error}"); + continue; + } + }; + + // Get interface ID + let index = match if_nametoindex(iface_bytes.as_c_str()) { + Ok(index) => index, + Err(error) => { + log::error!("Failed to get index of network interface: {error}"); + continue; + } + }; + + // Request ifscoped route for this interface + let route_msg = msg.clone().set_ifscope(u16::try_from(index).unwrap()); + if let Ok(Some(route)) = routing_table.get_route(&route_msg).await { + return Some(route); + } + } + + None +} + +fn network_service_order() -> Vec<String> { + let prefs = SCPreferences::default(&CFString::new("talpid-routing")); + let services = SCNetworkService::get_services(&prefs); + let set = SCNetworkSet::new(&prefs); + let service_order = set.service_order(); + + service_order + .iter() + .filter_map(|service_id| { + services + .iter() + .find(|service| service.id().as_ref() == Some(&*service_id)) + .and_then(|service| service.network_interface()?.bsd_name()) + .map(|cf_name| cf_name.to_string()) + }) + .collect::<Vec<_>>() +} diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs new file mode 100644 index 0000000000..3e5d8b0aed --- /dev/null +++ b/talpid-routing/src/unix/macos/mod.rs @@ -0,0 +1,533 @@ +use crate::{NetNode, Node, RequiredRoute, Route}; + +use futures::{channel::mpsc, future::FutureExt, stream::StreamExt}; +use ipnetwork::IpNetwork; +use nix::sys::socket::{AddressFamily, SockaddrLike, SockaddrStorage}; +use std::{ + collections::{BTreeMap, HashSet}, + net::{Ipv4Addr, Ipv6Addr}, +}; +use talpid_types::ErrorExt; +use watch::RoutingTable; + +use super::{DefaultRouteEvent, RouteManagerCommand}; +use data::{Destination, RouteDestination, RouteMessage, RouteSocketMessage}; + +mod data; +mod interface; +mod routing_socket; +mod watch; + +pub type Result<T> = std::result::Result<T, Error>; + +/// Errors that can happen in the macOS routing integration. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Encountered an error when interacting with the routing socket + #[error(display = "Error occurred when interfacing with the routing table")] + RoutingTable(#[error(source)] watch::Error), + + /// Failed to remvoe route + #[error(display = "Error occurred when deleting a route")] + DeleteRoute(#[error(source)] watch::Error), + + /// Failed to add route + #[error(display = "Error occurred when adding a route")] + AddRoute(#[error(source)] watch::Error), + + /// Failed to fetch link addresses + #[error(display = "Failed to fetch link addresses")] + FetchLinkAddresses(nix::Error), + + /// Received message isn't valid + #[error(display = "Invalid data")] + InvalidData(data::Error), +} + +/// Route manager can be in 1 of 4 states - +/// - waiting for a route to be added or removed from the route table +/// - obtaining default routes +/// - applying changes to the route table +/// - shutting down +/// +/// Only the _shutting down_ state can be reached from all other states, but during normal +/// operation, the route manager will add all the required routes during startup and will start +/// waiting for changes to the route table. If any change is detected, it will stop listening for +/// new changes, obtain new default routes and reapply routes that should be routed through the +/// default nodes. Once the routes are reapplied, the route table changes are monitored again. +pub struct RouteManagerImpl { + routing_table: RoutingTable, + // Routes that use the default non-tunnel interface + non_tunnel_routes: HashSet<IpNetwork>, + v4_tunnel_default_route: Option<data::RouteMessage>, + v6_tunnel_default_route: Option<data::RouteMessage>, + applied_routes: BTreeMap<RouteDestination, RouteMessage>, + v4_default_route: Option<data::RouteMessage>, + v6_default_route: Option<data::RouteMessage>, + default_route_listeners: Vec<mpsc::UnboundedSender<DefaultRouteEvent>>, +} + +impl RouteManagerImpl { + /// Create new route manager + #[allow(clippy::unused_async)] + pub async fn new() -> Result<Self> { + let routing_table = RoutingTable::new().map_err(Error::RoutingTable)?; + Ok(Self { + routing_table, + non_tunnel_routes: HashSet::new(), + v4_tunnel_default_route: None, + v6_tunnel_default_route: None, + applied_routes: BTreeMap::new(), + v4_default_route: None, + v6_default_route: None, + default_route_listeners: vec![], + }) + } + + pub(crate) async fn run(mut self, manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>) { + let mut manage_rx = manage_rx.fuse(); + + // Initialize default routes + // NOTE: This isn't race-free, as we're not listening for route changes before initializing + self.update_best_default_route(interface::Family::V4) + .await + .unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to get initial default v4 route") + ); + }); + self.update_best_default_route(interface::Family::V6) + .await + .unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to get initial default v6 route") + ); + }); + + loop { + futures::select_biased! { + route_message = self.routing_table.next_message().fuse() => { + self.handle_route_message(route_message).await; + } + + command = manage_rx.next() => { + match command { + Some(RouteManagerCommand::Shutdown(tx)) => { + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up routes: {err}"); + } + let _ = tx.send(()); + return; + }, + + Some(RouteManagerCommand::NewDefaultRouteListener(tx)) => { + let (events_tx, events_rx) = mpsc::unbounded(); + self.default_route_listeners.push(events_tx); + let _ = tx.send(events_rx); + } + Some(RouteManagerCommand::GetDefaultRoutes(tx)) => { + // NOTE: The device name isn't really relevant here, + // as we only care about routes with a gateway IP. + let v4_route = self.v4_default_route.as_ref().map(|route| { + Route { + node: Node { + device: None, + ip: route.gateway_ip(), + }, + prefix: v4_default(), + metric: None, + } + }); + let v6_route = self.v6_default_route.as_ref().map(|route| { + Route { + node: Node { + device: None, + ip: route.gateway_ip(), + }, + prefix: v6_default(), + metric: None, + } + }); + + let _ = tx.send((v4_route, v6_route)); + } + + Some(RouteManagerCommand::AddRoutes(routes, tx)) => { + log::debug!("Adding routes: {routes:?}"); + let _ = tx.send(self.add_required_routes(routes).await); + } + Some(RouteManagerCommand::ClearRoutes) => { + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up rotues: {err}"); + } + }, + None => { + break; + } + } + }, + }; + } + + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up routing table when shutting down: {err}"); + } + } + + async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { + let mut routes_to_apply = vec![]; + + for route in required_routes { + match route.node { + NetNode::DefaultNode => { + self.non_tunnel_routes.insert(route.prefix); + } + + NetNode::RealNode(node) => routes_to_apply.push(Route::new(node, route.prefix)), + } + } + + // Map all interfaces to their link addresses + let interface_link_addrs = get_interface_link_addresses()?; + + // Add routes not using the default interface + for route in routes_to_apply { + let message = if let Some(ref device) = route.node.device { + // If we specify route by interface name, use the link address of the given + // interface + match interface_link_addrs.get(device) { + Some(link_addr) => RouteMessage::new_route(Destination::from(route.prefix)) + .set_gateway_sockaddr(*link_addr), + None => { + log::error!("Route with unknown device: {route:?}, {device}"); + continue; + } + } + } else { + log::error!("Specifying gateway by IP rather than device is unimplemented"); + continue; + }; + + // Default routes are a special case: We must apply it after replacing the current + // default route with an ifscope route. + if route.prefix.prefix() == 0 { + if route.prefix.is_ipv4() { + self.v4_tunnel_default_route = Some(message); + } else { + self.v6_tunnel_default_route = Some(message); + } + continue; + } + + // Add route + self.add_route_with_record(message).await?; + } + + self.apply_tunnel_default_route().await?; + + // Add routes that use the default interface + if let Err(error) = self.apply_non_tunnel_routes().await { + self.non_tunnel_routes.clear(); + return Err(error); + } + + Ok(()) + } + + async fn handle_route_message( + &mut self, + message: std::result::Result<RouteSocketMessage, watch::Error>, + ) { + match message { + Ok(RouteSocketMessage::DeleteRoute(route)) => { + // Forget about applied route, if relevant. This is simply prevent ourselves from + // deleting it later. + match RouteDestination::try_from(&route).map_err(Error::InvalidData) { + Ok(destination) => { + self.applied_routes.remove(&destination); + } + Err(err) => { + log::error!("Failed to process deleted route: {err}"); + } + } + + if let Err(error) = self.handle_route_change(route).await { + log::error!("Failed to process route change: {error}"); + } + } + Ok(RouteSocketMessage::AddRoute(route)) + | Ok(RouteSocketMessage::ChangeRoute(route)) => { + // Refresh routes that are using the default interface + if let Err(error) = self.handle_route_change(route).await { + log::error!("Failed to process route change: {error}"); + } + } + // ignore all other message types + Ok(_) => {} + Err(err) => { + log::error!("Failed to receive a message from the routing table: {err}"); + } + } + } + + /// Update routes that use the non-tunnel default interface + async fn handle_route_change(&mut self, route: data::RouteMessage) -> Result<()> { + // Ignore routes that aren't default routes + if !route.is_default().map_err(Error::InvalidData)? { + return Ok(()); + } + + let new_gateway_link_addr = route.gateway().and_then(|addr| addr.as_link_addr()); + + // Ignore the new route if it is our tunnel route, lest we create a loop + for tunnel_default_route in [&self.v4_tunnel_default_route, &self.v6_tunnel_default_route] { + if let Some(tunnel_route) = tunnel_default_route.clone() { + let tun_gateway_link_addr = + tunnel_route.gateway().and_then(|addr| addr.as_link_addr()); + + if new_gateway_link_addr == tun_gateway_link_addr { + return Ok(()); + } + } + } + + let ip_version = if route.is_ipv4() { + interface::Family::V4 + } else { + interface::Family::V6 + }; + self.update_best_default_route(ip_version).await + } + + async fn update_best_default_route(&mut self, family: interface::Family) -> Result<()> { + let best_route = interface::get_best_default_route(&mut self.routing_table, family).await; + log::trace!("Best route: {best_route:?}"); + + let default_route = match family { + interface::Family::V4 => &mut self.v4_default_route, + interface::Family::V6 => &mut self.v6_default_route, + }; + + if default_route == &best_route { + log::trace!("Default route is unchanged"); + return Ok(()); + } + + let old_route = std::mem::replace(default_route, best_route); + + log::debug!("New default route: {old_route:?} -> {default_route:?}"); + + // Notify default route listeners + let event = match (family, default_route.is_some()) { + (interface::Family::V4, true) => DefaultRouteEvent::AddedOrChangedV4, + (interface::Family::V6, true) => DefaultRouteEvent::AddedOrChangedV6, + (interface::Family::V4, false) => DefaultRouteEvent::RemovedV4, + (interface::Family::V6, false) => DefaultRouteEvent::RemovedV6, + }; + self.default_route_listeners + .retain(|tx| tx.unbounded_send(event).is_ok()); + + // Substitute route with a tunnel route + self.apply_tunnel_default_route().await?; + + // Update routes using default interface + self.apply_non_tunnel_routes().await?; + + Ok(()) + } + + /// Replace the default routes with an ifscope route, and + /// add a new default tunnel route. + async fn apply_tunnel_default_route(&mut self) -> Result<()> { + // As long as the relay route has a way of reaching the internet, we'll want to add a tunnel + // route for both IPv4 and IPv6. + // NOTE: This is incorrect. We're assuming that any "default destination" is used for + // tunneling. + let (v4_conn, v6_conn) = self + .non_tunnel_routes + .iter() + .fold((false, false), |(v4, v6), route| { + (v4 || route.is_ipv4(), v6 || route.is_ipv6()) + }); + let relay_route_is_valid = (v4_conn && self.v4_default_route.is_some()) + || (v6_conn && self.v6_default_route.is_some()); + + if !relay_route_is_valid { + return Ok(()); + } + + for tunnel_route in [ + self.v4_tunnel_default_route.clone(), + self.v6_tunnel_default_route.clone(), + ] { + let tunnel_route = match tunnel_route { + Some(route) => route, + None => continue, + }; + + log::debug!("Adding default route for tunnel"); + + // Replace the default route with an ifscope route + self.set_default_route_ifscope(tunnel_route.is_ipv4(), true) + .await?; + self.add_route_with_record(tunnel_route).await?; + } + + Ok(()) + } + + /// Update/add routes that use the default non-tunnel interface. If some applied destination is + /// a default route, this function replaces the non-tunnel default route with an ifscope route. + async fn apply_non_tunnel_routes(&mut self) -> Result<()> { + let v4_gateway = self + .v4_default_route + .as_ref() + .and_then(|route| route.gateway()) + .cloned(); + let v6_gateway = self + .v6_default_route + .as_ref() + .and_then(|route| route.gateway()) + .cloned(); + + // Reapply routes that use the default (non-tunnel) node + for dest in self.non_tunnel_routes.clone() { + let gateway = if dest.is_ipv4() { + v4_gateway + } else { + v6_gateway + }; + let gateway = match gateway { + Some(gateway) => gateway, + None => continue, + }; + let route = + RouteMessage::new_route(Destination::Network(dest)).set_gateway_sockaddr(gateway); + + if let Some(dest) = self + .applied_routes + .keys() + .find(|applied_dest| applied_dest.network == dest) + .cloned() + { + let _ = self.routing_table.delete_route(&route).await; + self.applied_routes.remove(&dest); + } + + self.add_route_with_record(route).await?; + } + + Ok(()) + } + + /// Replace a known default route with an ifscope route, if should_be_ifscoped is true. + /// If should_be_ifscoped is false, the route is replaced with a non-ifscoped default route + /// instead. + async fn set_default_route_ifscope( + &mut self, + ipv4: bool, + should_be_ifscoped: bool, + ) -> Result<()> { + let default_route = match (ipv4, &mut self.v4_default_route, &mut self.v6_default_route) { + (true, Some(default_route), _) | (false, _, Some(default_route)) => default_route, + _ => { + return Ok(()); + } + }; + + if default_route.is_ifscope() == should_be_ifscoped { + return Ok(()); + } + + log::trace!("Setting non-ifscope: {default_route:?}"); + + let interface_index = if should_be_ifscoped { + let interface_index = default_route.interface_index(); + if interface_index == 0 { + log::error!("Cannot find interface index of default interface"); + } + interface_index + } else { + 0 + }; + let new_route = default_route.clone().set_ifscope(interface_index); + let old_route = std::mem::replace(default_route, new_route); + + self.routing_table + .delete_route(&old_route) + .await + .map_err(Error::DeleteRoute)?; + + self.routing_table + .add_route(default_route) + .await + .map_err(Error::AddRoute) + } + + async fn add_route_with_record(&mut self, route: RouteMessage) -> Result<()> { + let destination = RouteDestination::try_from(&route).map_err(Error::InvalidData)?; + + self.routing_table + .add_route(&route) + .await + .map_err(Error::AddRoute)?; + + self.applied_routes.insert(destination, route); + Ok(()) + } + + async fn cleanup_routes(&mut self) -> Result<()> { + // Remove all applied routes. This includes default destination routes + let old_routes = std::mem::take(&mut self.applied_routes); + for (_dest, route) in old_routes.into_iter() { + log::trace!("Removing route: {route:?}"); + match self.routing_table.delete_route(&route).await { + Ok(_) | Err(watch::Error::RouteNotFound) | Err(watch::Error::Unreachable) => (), + Err(err) => { + log::error!("Failed to remove relay route: {err:?}"); + } + } + } + + // Reset default route + if let Err(error) = self + .set_default_route_ifscope(true, false) + .await + .and(self.set_default_route_ifscope(false, false).await) + { + log::error!("Failed to restore default routes: {error}"); + } + + // We have already removed the applied default routes + self.v4_tunnel_default_route = None; + self.v6_tunnel_default_route = None; + + self.non_tunnel_routes.clear(); + + Ok(()) + } +} + +fn v4_default() -> IpNetwork { + IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap() +} + +fn v6_default() -> IpNetwork { + IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap() +} + +/// Return a map from interface name to link addresses (AF_LINK) +fn get_interface_link_addresses() -> Result<BTreeMap<String, SockaddrStorage>> { + let mut gateway_link_addrs = BTreeMap::new(); + let addrs = nix::ifaddrs::getifaddrs().map_err(Error::FetchLinkAddresses)?; + for addr in addrs.into_iter() { + if addr.address.and_then(|addr| addr.family()) != Some(AddressFamily::Link) { + continue; + } + gateway_link_addrs.insert(addr.interface_name, addr.address.unwrap()); + } + Ok(gateway_link_addrs) +} diff --git a/talpid-routing/src/unix/macos/routing_socket.rs b/talpid-routing/src/unix/macos/routing_socket.rs new file mode 100644 index 0000000000..213128fd8c --- /dev/null +++ b/talpid-routing/src/unix/macos/routing_socket.rs @@ -0,0 +1,185 @@ +use std::{ + collections::VecDeque, + mem::size_of, + os::unix::prelude::{FromRawFd, RawFd}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use nix::{ + fcntl, + sys::socket::{socket, AddressFamily, SockFlag, SockType}, +}; +use std::{ + fs::File, + io::{self, Read, Write}, +}; + +use super::data::{rt_msghdr_short, MessageType, RouteMessage}; + +use tokio::io::{unix::AsyncFd, AsyncWrite, AsyncWriteExt}; + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Failed to open routing socket")] + OpenSocket(io::Error), + #[error(display = "Failed to write to routing socket")] + Write(io::Error), + #[error(display = "Failed to read from routing socket")] + Read(io::Error), + #[error(display = "Received a message that's too small")] + MessageTooSmall(usize), +} + +type Result<T> = std::result::Result<T, Error>; + +/// Wraps a `PF_ROUTE` socket, keeps track of sent message IDs, and facilitates sending and +/// receiving [route socket messages](#RouteMessage) +pub struct RoutingSocket { + socket: RoutingSocketInner, + seq: i32, + // buffers up messages received whilst waiting on a response + // TODO: might we want to limit the max size of this? + buf: VecDeque<Vec<u8>>, + own_pid: i32, +} + +impl RoutingSocket { + pub fn new() -> Result<Self> { + Ok(Self { + socket: RoutingSocketInner::new().map_err(Error::OpenSocket)?, + seq: 1, + buf: Default::default(), + own_pid: std::process::id().try_into().unwrap(), + }) + } + + pub async fn recv_msg(&mut self, mut buf: &mut [u8]) -> Result<usize> { + if let Some(buffered_msg) = self.buf.pop_front() { + let bytes_written = buf.write(&buffered_msg).map_err(Error::Write)?; + return Ok(bytes_written); + } + self.read_next_msg(buf).await + } + + async fn read_next_msg(&mut self, buf: &mut [u8]) -> Result<usize> { + self.socket.read(buf).await.map_err(Error::Read) + } + + pub async fn send_route_message( + &mut self, + message: &RouteMessage, + message_type: MessageType, + ) -> Result<Vec<u8>> { + let (msg, seq) = self.next_route_msg(message, message_type); + match self.socket.write(&msg).await { + Ok(_) => self.wait_for_response(seq).await, + Err(err) => Err(Error::Write(err)), + } + } + + pub async fn wait_for_response(&mut self, response_num: i32) -> Result<Vec<u8>> { + loop { + let mut buffer = vec![0u8; 2048]; + // do not truncate the buffer - trailing empty bytes won't be written but will be + // assumed in the data format. + let bytes_read = self.read_next_msg(&mut buffer).await?; + + { + let header = rt_msghdr_short::from_bytes(buffer.as_slice()) + .ok_or(Error::MessageTooSmall(bytes_read))?; + + if header.rtm_pid == self.own_pid && response_num == header.rtm_seq { + return Ok(buffer); + } + } + + self.buf.push_back(buffer); + } + } + + fn next_route_msg(&mut self, message: &RouteMessage, msg_type: MessageType) -> (Vec<u8>, i32) { + let seq = self.seq; + self.seq = seq.wrapping_add(1); + + let (header, payload) = message.payload(msg_type, seq, self.own_pid); + let mut msg_buffer = vec![0u8; header.rtm_msglen.into()]; + + // SAFETY: `msg_buffer` is guaranteed to be at least as large as `rt_msghdr`. + unsafe { + std::ptr::copy_nonoverlapping( + &header as *const _ as *const u8, + msg_buffer.as_mut_ptr(), + size_of::<super::data::rt_msghdr>(), + ); + } + let mut sockaddr_buf = &mut msg_buffer[std::mem::size_of::<super::data::rt_msghdr>()..]; + for socket_addr in payload { + sockaddr_buf + .write_all(socket_addr.as_slice()) + .expect("faled to write socket address into message buffer"); + } + (msg_buffer, header.rtm_seq) + } +} + +struct RoutingSocketInner { + // storing the file handle in a std::file::File automagically provides sane io::{Write, + // Read} and Drop implementations. + socket: AsyncFd<File>, +} + +impl RoutingSocketInner { + fn new() -> io::Result<Self> { + let fd = socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None)?; + let _ = fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::OFlag::O_NONBLOCK))?; + // SAFETY: File handle is valid here + let socket = unsafe { File::from_raw_fd(fd) }; + Ok(Self { + socket: AsyncFd::new(socket)?, + }) + } + + async fn read(&mut self, out: &mut [u8]) -> std::io::Result<usize> { + loop { + let mut guard = self.socket.readable().await?; + match guard.try_io(|sock| sock.get_ref().read(out)) { + Ok(result) => return result, + Err(_err) => continue, + } + } + } +} + +impl std::os::unix::prelude::AsRawFd for RoutingSocketInner { + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +impl AsyncWrite for RoutingSocketInner { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + loop { + let mut guard = ready!(self.socket.poll_write_ready(cx))?; + + match guard.try_io(|inner| inner.get_ref().write(buf)) { + Ok(result) => return Poll::Ready(result), + Err(_would_block) => continue, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // no need for a shutdown on the routing socket + Poll::Ready(Ok(())) + } +} diff --git a/talpid-routing/src/unix/macos/watch.rs b/talpid-routing/src/unix/macos/watch.rs new file mode 100644 index 0000000000..4cf1799276 --- /dev/null +++ b/talpid-routing/src/unix/macos/watch.rs @@ -0,0 +1,147 @@ +use super::{ + data::{self, MessageType, RouteMessage, RouteSocketMessage}, + routing_socket, +}; +use std::io; + +type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug, err_derive::Error)] +pub enum Error { + #[error(display = "Failed to open routing socket")] + RoutingSocket(routing_socket::Error), + #[error(display = "Invalid message")] + InvalidMessage(data::Error), + #[error(display = "Failed to send routing message")] + Send(routing_socket::Error), + #[error(display = "Unexpected message type")] + UnexpectedMessageType(RouteSocketMessage, MessageType), + #[error(display = "Route not found")] + RouteNotFound, + #[error(display = "Destination unreachable")] + Unreachable, + #[error(display = "Failed to delete a route")] + Deletion(RouteMessage), +} + +/// Provides an interface for manipulating the routing table on macOS using a PF_ROUTE socket. +pub struct RoutingTable { + socket: routing_socket::RoutingSocket, +} + +impl RoutingTable { + pub fn new() -> Result<Self> { + let socket = routing_socket::RoutingSocket::new().map_err(Error::RoutingSocket)?; + + Ok(Self { socket }) + } + + pub async fn next_message(&mut self) -> Result<RouteSocketMessage> { + let mut buf = [0u8; 2048]; + let bytes_read = self + .socket + .recv_msg(&mut buf) + .await + .map_err(Error::RoutingSocket)?; + let msg_buf = &buf[0..bytes_read]; + data::RouteSocketMessage::parse_message(msg_buf).map_err(Error::InvalidMessage) + } + + pub async fn add_route(&mut self, message: &RouteMessage) -> Result<()> { + let msg = self + .alter_routing_table(message, MessageType::RTM_ADD) + .await; + + match msg { + Ok(RouteSocketMessage::AddRoute(_route)) => Ok(()), + Err(Error::Send(routing_socket::Error::Write(err))) + if err.kind() == io::ErrorKind::AlreadyExists => + { + Ok(()) + } + Ok(anything_else) => { + log::error!("Unexpected route message: {anything_else:?}"); + Err(Error::UnexpectedMessageType( + anything_else, + MessageType::RTM_ADD, + )) + } + + Err(err) => Err(err), + } + } + + async fn alter_routing_table( + &mut self, + message: &RouteMessage, + message_type: MessageType, + ) -> Result<RouteSocketMessage> { + let result = self.socket.send_route_message(message, message_type).await; + + match result { + Ok(response) => { + data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage) + } + + Err(routing_socket::Error::Write(err)) if err.kind() == io::ErrorKind::NotFound => { + Err(Error::RouteNotFound) + } + Err(routing_socket::Error::Write(err)) + if [Some(libc::ENETUNREACH), Some(libc::ESRCH)].contains(&err.raw_os_error()) => + { + Err(Error::Unreachable) + } + Err(err) => Err(Error::Send(err)), + } + } + + pub async fn delete_route(&mut self, message: &RouteMessage) -> Result<()> { + let response = self + .alter_routing_table(message, MessageType::RTM_DELETE) + .await?; + + match response { + RouteSocketMessage::DeleteRoute(route) if route.errno() == 0 => Ok(()), + RouteSocketMessage::DeleteRoute(route) if route.errno() != 0 => { + Err(Error::Deletion(route)) + } + anything_else => Err(Error::UnexpectedMessageType( + anything_else, + MessageType::RTM_DELETE, + )), + } + } + + pub async fn get_route( + &mut self, + message: &RouteMessage, + ) -> Result<Option<data::RouteMessage>> { + let response = self + .socket + .send_route_message(message, MessageType::RTM_GET) + .await; + + let response = match response { + Ok(response) => response, + Err(routing_socket::Error::Write(err)) => { + if let Some(err) = err.raw_os_error() { + if [libc::ENETUNREACH, libc::ESRCH].contains(&err) { + return Ok(None); + } + } + return Err(Error::RoutingSocket(routing_socket::Error::Write(err))); + } + Err(other_err) => { + return Err(Error::RoutingSocket(other_err)); + } + }; + + match data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage)? { + data::RouteSocketMessage::GetRoute(route) => Ok(Some(route)), + unexpected_route_message => Err(Error::UnexpectedMessageType( + unexpected_route_message, + MessageType::RTM_GET, + )), + } + } +} diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 822615d0f1..757d3775fc 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -1,9 +1,7 @@ -// TODO: remove the allow(dead_code) for android once it's up to scratch. -#![cfg_attr(target_os = "android", allow(dead_code))] +#[cfg(any(target_os = "linux", target_os = "macos"))] +use crate::Route; use super::RequiredRoute; -#[cfg(target_os = "linux")] -use super::Route; use futures::channel::{ mpsc::{self, UnboundedSender}, @@ -11,7 +9,7 @@ use futures::channel::{ }; use std::{collections::HashSet, io}; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "macos"))] use futures::stream::Stream; #[cfg(target_os = "linux")] @@ -19,11 +17,8 @@ use std::net::IpAddr; #[allow(clippy::module_inception)] #[cfg(target_os = "macos")] -#[path = "macos.rs"] -mod imp; - -#[cfg(target_os = "macos")] -pub use imp::{get_default_routes, listen_for_default_route_changes}; +#[path = "macos/mod.rs"] +pub mod imp; #[allow(clippy::module_inception)] #[cfg(target_os = "linux")] @@ -76,6 +71,28 @@ impl RouteManagerHandle { .map_err(Error::PlatformError) } + /// Listen for non-tunnel default route changes. + #[cfg(target_os = "macos")] + pub async fn default_route_listener( + &self, + ) -> Result<impl Stream<Item = DefaultRouteEvent>, Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::NewDefaultRouteListener(response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown) + } + + /// Get current non-tunnel default routes. + #[cfg(target_os = "macos")] + pub async fn get_default_routes(&self) -> Result<(Option<Route>, Option<Route>), Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::GetDefaultRoutes(response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown) + } + /// Ensure that packets are routed using the correct tables. #[cfg(target_os = "linux")] pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { @@ -163,6 +180,10 @@ pub(crate) enum RouteManagerCommand { ), ClearRoutes, Shutdown(oneshot::Sender<()>), + #[cfg(target_os = "macos")] + NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>), + #[cfg(target_os = "macos")] + GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>), #[cfg(target_os = "linux")] CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>), #[cfg(target_os = "linux")] @@ -180,6 +201,21 @@ pub(crate) enum RouteManagerCommand { ), } +/// Event that is sent when a preferred non-tunnel default route is +/// added or removed. +#[cfg(target_os = "macos")] +#[derive(Debug, Clone, Copy)] +pub enum DefaultRouteEvent { + /// Added or updated a non-tunnel default IPv4 route + AddedOrChangedV4, + /// Added or updated a non-tunnel default IPv6 route + AddedOrChangedV6, + /// Non-tunnel default IPv4 route was removed + RemovedV4, + /// Non-tunnel default IPv6 route was removed + RemovedV6, +} + #[cfg(target_os = "linux")] #[derive(Debug, Clone)] pub enum CallbackMessage { @@ -196,17 +232,13 @@ pub struct RouteManager { } impl RouteManager { - /// Constructs a RouteManager and applies the required routes. - /// Takes a set of network destinations and network nodes as an argument, and applies said - /// routes. + /// Construct a RouteManager. pub async fn new( - required_routes: HashSet<RequiredRoute>, #[cfg(target_os = "linux")] fwmark: u32, #[cfg(target_os = "linux")] table_id: u32, ) -> Result<Self, Error> { let (manage_tx, manage_rx) = mpsc::unbounded(); let manager = imp::RouteManagerImpl::new( - required_routes, #[cfg(target_os = "linux")] fwmark, #[cfg(target_os = "linux")] @@ -260,8 +292,7 @@ impl RouteManager { } } - /// Removes all routes previously applied in [`RouteManager::new`] or - /// [`RouteManager::add_routes`]. + /// Removes all routes previously applied in [`RouteManager::add_routes`]. pub fn clear_routes(&mut self) -> Result<(), Error> { if let Some(tx) = &self.manage_tx { if tx.unbounded_send(RouteManagerCommand::ClearRoutes).is_err() { diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs index e68c9255fc..51ac345f82 100644 --- a/talpid-routing/src/windows/mod.rs +++ b/talpid-routing/src/windows/mod.rs @@ -148,9 +148,9 @@ pub enum RouteManagerCommand { } impl RouteManager { - /// Creates a new route manager that will apply the provided routes and ensure they exist until - /// it's stopped. - pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { + /// Create a new route manager + #[allow(clippy::unused_async)] + pub async fn new() -> Result<Self> { let internal = match RouteManagerInternal::new() { Ok(internal) => internal, Err(_) => return Err(Error::FailedToStartManager), @@ -160,7 +160,6 @@ impl RouteManager { manage_tx: Some(manage_tx), }; tokio::spawn(RouteManager::listen(manage_rx, internal)); - manager.add_routes(required_routes).await?; Ok(manager) } @@ -270,8 +269,7 @@ impl RouteManager { } } - /// Removes all routes previously applied in [`RouteManager::new`] or - /// [`RouteManager::add_routes`]. + /// Removes all routes previously applied in [`RouteManager::add_routes`]. pub fn clear_routes(&self) -> Result<()> { if let Some(tx) = &self.manage_tx { tx.unbounded_send(RouteManagerCommand::ClearRoutes) diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index da70b63bfc..cd250856df 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -355,6 +355,7 @@ impl WireguardMonitor { let routes = Self::get_pre_tunnel_routes(&iface_name, &config) .chain(Self::get_endpoint_routes(&endpoint_addrs)) .collect(); + args.route_manager .add_routes(routes) .await @@ -914,7 +915,7 @@ impl WireguardMonitor { /// Replace default (0-prefix) routes with more specific routes. fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> { - #[cfg(not(any(target_os = "linux", target_os = "android")))] + #[cfg(windows)] if network.prefix() == 0 { if network.is_ipv4() { vec!["0.0.0.0/1".parse().unwrap(), "128.0.0.0/1".parse().unwrap()] @@ -925,7 +926,7 @@ impl WireguardMonitor { vec![network] } - #[cfg(any(target_os = "linux", target_os = "android"))] + #[cfg(not(windows))] vec![network] } |
