summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-06-05 20:24:24 +0200
committerDavid Lönnhager <david.l@mullvad.net>2023-06-05 20:24:24 +0200
commit4f60f2368fba733c1aaa968c34d0a33d43e4e90b (patch)
treeebbae33de7e464d85bc794b344df14076d8a9923
parent2adf93174484af166c2104627530874698498b7d (diff)
parent38d7c4b0a02c293a84ccec4a4cc632934a36cdfb (diff)
downloadmullvadvpn-4f60f2368fba733c1aaa968c34d0a33d43e4e90b.tar.xz
mullvadvpn-4f60f2368fba733c1aaa968c34d0a33d43e4e90b.zip
Merge branch 'macos-routing-rework-2'
-rw-r--r--CHANGELOG.md3
-rw-r--r--Cargo.lock30
-rw-r--r--deny.toml3
-rw-r--r--talpid-core/Cargo.toml2
-rw-r--r--talpid-core/src/offline/macos.rs213
-rw-r--r--talpid-core/src/offline/mod.rs6
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs10
-rw-r--r--talpid-routing/Cargo.toml9
-rw-r--r--talpid-routing/src/lib.rs2
-rw-r--r--talpid-routing/src/unix/android.rs5
-rw-r--r--talpid-routing/src/unix/linux.rs12
-rw-r--r--talpid-routing/src/unix/macos.rs362
-rw-r--r--talpid-routing/src/unix/macos/data.rs1147
-rw-r--r--talpid-routing/src/unix/macos/interface.rs78
-rw-r--r--talpid-routing/src/unix/macos/mod.rs533
-rw-r--r--talpid-routing/src/unix/macos/routing_socket.rs185
-rw-r--r--talpid-routing/src/unix/macos/watch.rs147
-rw-r--r--talpid-routing/src/unix/mod.rs65
-rw-r--r--talpid-routing/src/windows/mod.rs10
-rw-r--r--talpid-wireguard/src/lib.rs5
20 files changed, 2321 insertions, 506 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 19d4f60383..042f059b1c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -95,6 +95,9 @@ Line wrap the file at 100 chars. Th
- Fix adaptive app icon which previously had a displaced nose and some other oddities.
- Fix app version sometimes missing in the settings menu.
+#### macOS
+- Fix inability to sync iCloud and Safari bookmarks while connected to the VPN.
+
## [2023.4-beta1] - 2023-05-02
### Added
diff --git a/Cargo.lock b/Cargo.lock
index 17847ae5cd..61c6c83af4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1851,6 +1851,15 @@ dependencies = [
]
[[package]]
+name = "memoffset"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
name = "mime"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2304,6 +2313,19 @@ dependencies = [
[[package]]
name = "nix"
+version = "0.26.1"
+source = "git+https://github.com/nix-rust/nix?rev=b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28#b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28"
+dependencies = [
+ "bitflags",
+ "cfg-if",
+ "libc",
+ "memoffset 0.8.0",
+ "pin-utils",
+ "static_assertions",
+]
+
+[[package]]
+name = "nix"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a"
@@ -3511,9 +3533,9 @@ dependencies = [
[[package]]
name = "system-configuration"
-version = "0.5.0"
+version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d75182f12f490e953596550b65ee31bda7c8e043d9386174b353bda50838c3fd"
+checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags",
"core-foundation",
@@ -3678,6 +3700,7 @@ dependencies = [
name = "talpid-routing"
version = "0.0.0"
dependencies = [
+ "bitflags",
"err-derive",
"futures",
"ipnetwork",
@@ -3686,12 +3709,13 @@ dependencies = [
"log",
"netlink-packet-route",
"netlink-sys",
+ "nix 0.26.1",
"rtnetlink",
"socket2",
+ "system-configuration",
"talpid-types",
"talpid-windows-net",
"tokio",
- "tokio-stream",
"widestring 1.0.2",
"windows-sys 0.45.0",
]
diff --git a/deny.toml b/deny.toml
index 0f6b77cbbd..9285d2202c 100644
--- a/deny.toml
+++ b/deny.toml
@@ -96,7 +96,8 @@ skip-tree = []
unknown-registry = "deny"
unknown-git = "deny"
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
-allow-git = []
+# TODO: The PF socket type isn't released yet
+allow-git = ["https://github.com/nix-rust/nix"]
[sources.allow-org]
# 1 or more github.com organizations to allow git sources for
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 9a833353d3..f16aa7307c 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -67,7 +67,7 @@ talpid-dbus = { path = "../talpid-dbus" }
[target.'cfg(target_os = "macos")'.dependencies]
pfctl = "0.4.4"
-system-configuration = "0.5"
+system-configuration = "0.5.1"
trust-dns-server = { version = "0.22.0", features = ["resolver"] }
tun = "0.5.1"
subslice = "0.2"
diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs
index baae3c40f1..7263f9ab0f 100644
--- a/talpid-core/src/offline/macos.rs
+++ b/talpid-core/src/offline/macos.rs
@@ -2,119 +2,156 @@
//! that the app gets stuck in an offline state, blocking all internet access and preventing the
//! user from connecting to a relay.
//!
-//! Currently, this functionality is implemented by using `route monitor -n` to observe routing
-//! table changes and then use the CLI once more to query if there exists a default route.
-//! Generally, it is assumed that a machine is online if there exists a route to a public IP
-//! address that isn't using a tunnel adapter. On macOS, there were various ways of deducing this:
-//! - watching the `State:/Network/Global/IPv4` key in SystemConfiguration via
-//! `system-configuration-rs`, relying on a CoreFoundation runloop to drive callbacks.
-//! The issue with this is that sometimes during early boot or after a re-install, the callbacks
-//! won't be called, often leaving the daemon stuck in an offline state.
-//! - setting a callback via [`SCNetworkReachability`]. The callback should be called whenever the
-//! reachability of a remote host changes, but sometimes the callbacks just don't get called.
-//! - [`NWPathMonitor`] is a macOS native interface to watch changes in the routing table. It works
-//! great, but it seems to deliver updates before they actually get added to the routing table,
-//! effectively calling our callbacks with routes that aren't yet usable, so starting tunnels
-//! would fail anyway. This would be the API to use if we were able to bind the sockets our tunnel
-//! implementations would use, but that is far too much complexity.
-//!
-//! [`SCNetworkReachability`]: https://developer.apple.com/documentation/systemconfiguration/scnetworkreachability-g7d
-//! [`NWPathMonitor`]: https://developer.apple.com/documentation/network/nwpathmonitor
-use futures::{channel::mpsc::UnboundedSender, Future, StreamExt};
-use std::sync::{Arc, Weak};
-use talpid_types::ErrorExt;
+//! Currently, this functionality is implemented by watching for changes to the default route
+//! in [`RouteManager`] using a `PF_ROUTE` socket. If there is no default route for neither IPv4 nor
+//! IPv6, the host is considered to be offline.
+use futures::{channel::mpsc::UnboundedSender, StreamExt};
+use std::{
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc, Mutex,
+ },
+ time::Duration,
+};
+use talpid_routing::{DefaultRouteEvent, RouteManagerHandle};
+
+/// How long to wait before announcing changes to the offline state
+const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(2);
#[derive(err_derive::Error, Debug)]
pub enum Error {
#[error(display = "Failed to initialize route monitor")]
- StartMonitorError(#[error(source)] talpid_routing::PlatformError),
+ StartMonitorError(#[error(source)] talpid_routing::Error),
}
pub struct MonitorHandle {
+ state: Arc<Mutex<ConnectivityState>>,
_notify_tx: Arc<UnboundedSender<bool>>,
}
-impl MonitorHandle {
- /// Host is considered to be offline if the IPv4 internet is considered to be unreachable by the
- /// given reachability flags *or* there are no active physical interfaces.
- pub async fn host_is_offline(&self) -> bool {
- !exists_non_tunnel_default_route().await
+struct ConnectivityState {
+ v4_connectivity: bool,
+ v6_connectivity: bool,
+}
+
+impl ConnectivityState {
+ fn get_connectivity(&self) -> bool {
+ self.v4_connectivity || self.v6_connectivity
}
}
-async fn exists_non_tunnel_default_route() -> bool {
- match talpid_routing::get_default_routes().await {
- Ok((Some(node), _)) | Ok((None, Some(node))) => {
- let route_exists = node
- .get_device()
- .map(|iface_name| !iface_name.contains("tun"))
- .unwrap_or(true);
- log::debug!("Assuming non-tunnel default route exists due to {:?}", node);
- route_exists
- }
- Ok((None, None)) => {
- log::debug!("No default routes exist, assuming machine is offline");
- false
- }
- Err(err) => {
- log::error!(
- "{}",
- err.display_chain_with_msg(
- "Failed to obtain default routes, assuming machine is online."
- )
- );
- true
- }
+impl MonitorHandle {
+ /// Host is considered to be offline if macOS doesn't assign a non-tunnel default route
+ #[allow(clippy::unused_async)]
+ pub async fn host_is_offline(&self) -> bool {
+ let state = self.state.lock().unwrap();
+ !state.get_connectivity()
}
}
-pub async fn spawn_monitor(notify_tx: UnboundedSender<bool>) -> Result<MonitorHandle, Error> {
+
+pub async fn spawn_monitor(
+ notify_tx: UnboundedSender<bool>,
+ route_manager_handle: RouteManagerHandle,
+) -> Result<MonitorHandle, Error> {
let notify_tx = Arc::new(notify_tx);
- let context = OfflineStateContext {
- sender: Arc::downgrade(&notify_tx),
- is_offline: !exists_non_tunnel_default_route().await,
+ let (v4_connectivity, v6_connectivity) = match route_manager_handle.get_default_routes().await {
+ Ok((v4_route, v6_route)) => (v4_route.is_some(), v6_route.is_some()),
+ Err(error) => {
+ log::warn!("Failed to initialize offline monitor: {error}");
+ // Fail open: Assume that we have connectivity if we cannot determine the existence of
+ // a default route, since we don't want to block the user from connecting
+ (true, true)
+ }
};
- let route_monitor = watch_route_monitor(context)?;
- tokio::spawn(route_monitor);
- Ok(MonitorHandle {
- _notify_tx: notify_tx,
- })
-}
+ let state = ConnectivityState {
+ v4_connectivity,
+ v6_connectivity,
+ };
+ let initial_connectivity = state.get_connectivity();
+ let state = Arc::new(Mutex::new(state));
+
+ let mut route_listener = route_manager_handle.default_route_listener().await?;
+ let weak_state = Arc::downgrade(&state);
+ let weak_notify_tx = Arc::downgrade(&notify_tx);
+
+ // Detect changes to the default route
+ tokio::spawn(async move {
+ let mut state_update_handle: Option<tokio::task::JoinHandle<()>> = None;
+ let prev_notified_state = Arc::new(AtomicBool::new(initial_connectivity));
-fn watch_route_monitor(
- mut context: OfflineStateContext,
-) -> Result<impl Future<Output = ()>, Error> {
- let mut monitor = talpid_routing::listen_for_default_route_changes()?;
+ while let Some(event) = route_listener.next().await {
+ let state = match weak_state.upgrade() {
+ Some(state) => state,
+ None => break,
+ };
- Ok(async move {
- while let Some(_route_change) = monitor.next().await {
- context.new_state(!exists_non_tunnel_default_route().await);
- if context.should_shut_down() {
- break;
+ let mut state = state.lock().unwrap();
+
+ log::trace!("Default route event: {event:?}");
+
+ let previous_connectivity = state.get_connectivity();
+
+ match event {
+ DefaultRouteEvent::AddedOrChangedV4 => {
+ state.v4_connectivity = true;
+ }
+ DefaultRouteEvent::AddedOrChangedV6 => {
+ state.v6_connectivity = true;
+ }
+ DefaultRouteEvent::RemovedV4 => {
+ state.v4_connectivity = false;
+ }
+ DefaultRouteEvent::RemovedV6 => {
+ state.v6_connectivity = false;
+ }
}
- }
- log::debug!("Stopping offline monitor");
- })
-}
-#[derive(Clone)]
-struct OfflineStateContext {
- sender: Weak<UnboundedSender<bool>>,
- is_offline: bool,
-}
+ let new_connectivity = state.get_connectivity();
+ if previous_connectivity != new_connectivity {
+ if let Some(update_state) = state_update_handle.take() {
+ update_state.abort();
+ }
-impl OfflineStateContext {
- fn should_shut_down(&self) -> bool {
- self.sender.upgrade().is_none()
- }
+ let prev_notified = prev_notified_state.clone();
+
+ let notify_copy = weak_notify_tx.clone();
+ let update_task = tokio::spawn(async move {
+ let notify_tx = match notify_copy.upgrade() {
+ Some(tx) => tx,
+ None => return,
+ };
+
+ // Debounce event updates
+ tokio::time::sleep(DEBOUNCE_INTERVAL).await;
- fn new_state(&mut self, is_offline: bool) {
- if self.is_offline != is_offline {
- self.is_offline = is_offline;
- if let Some(sender) = self.sender.upgrade() {
- let _ = sender.unbounded_send(is_offline);
+ if prev_notified.swap(new_connectivity, Ordering::AcqRel) == new_connectivity {
+ // We don't care about network changes here
+ return;
+ }
+
+ log::info!(
+ "Connectivity changed: {}",
+ if new_connectivity {
+ "Connected"
+ } else {
+ "Offline"
+ }
+ );
+
+ let _ = notify_tx.unbounded_send(!new_connectivity);
+ });
+
+ state_update_handle = Some(update_task);
}
}
- }
+
+ log::trace!("Offline monitor exiting");
+ });
+
+ Ok(MonitorHandle {
+ state,
+ _notify_tx: notify_tx,
+ })
}
diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs
index e2df51e4e1..76f9b1c315 100644
--- a/talpid-core/src/offline/mod.rs
+++ b/talpid-core/src/offline/mod.rs
@@ -1,5 +1,5 @@
use futures::channel::mpsc::UnboundedSender;
-#[cfg(any(target_os = "linux", target_os = "windows"))]
+#[cfg(not(target_os = "android"))]
use talpid_routing::RouteManagerHandle;
#[cfg(target_os = "android")]
use talpid_types::android::AndroidContext;
@@ -42,7 +42,7 @@ impl MonitorHandle {
pub async fn spawn_monitor(
sender: UnboundedSender<bool>,
- #[cfg(any(target_os = "linux", target_os = "windows"))] route_manager: RouteManagerHandle,
+ #[cfg(not(target_os = "android"))] route_manager: RouteManagerHandle,
#[cfg(target_os = "linux")] fwmark: Option<u32>,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<MonitorHandle, Error> {
@@ -50,7 +50,7 @@ pub async fn spawn_monitor(
Some(
imp::spawn_monitor(
sender,
- #[cfg(any(target_os = "windows", target_os = "linux"))]
+ #[cfg(not(target_os = "android"))]
route_manager,
#[cfg(target_os = "linux")]
fwmark,
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 86ad1c6dd2..f4c58b849c 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -31,7 +31,6 @@ use futures::{
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
use std::{
- collections::HashSet,
future::Future,
io,
net::IpAddr,
@@ -271,7 +270,6 @@ impl TunnelStateMachine {
let filtering_resolver = crate::resolver::start_resolver().await?;
let route_manager = RouteManager::new(
- HashSet::new(),
#[cfg(target_os = "linux")]
args.linux_ids.fwmark,
#[cfg(target_os = "linux")]
@@ -332,16 +330,12 @@ impl TunnelStateMachine {
});
let offline_monitor = offline::spawn_monitor(
offline_tx,
- #[cfg(target_os = "linux")]
- route_manager
- .handle()
- .map_err(Error::InitRouteManagerError)?,
+ #[cfg(not(target_os = "android"))]
+ route_manager.handle()?,
#[cfg(target_os = "linux")]
Some(args.linux_ids.fwmark),
#[cfg(target_os = "android")]
android_context,
- #[cfg(target_os = "windows")]
- route_manager.handle()?,
)
.await
.map_err(Error::OfflineMonitorError)?;
diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml
index a2f7e780be..27ffccce2b 100644
--- a/talpid-routing/Cargo.toml
+++ b/talpid-routing/Cargo.toml
@@ -13,7 +13,8 @@ err-derive = "0.3.1"
futures = "0.3.15"
ipnetwork = "0.16"
log = "0.4"
-tokio = { version = "1.8", features = ["process", "rt-multi-thread"] }
+talpid-types = { path = "../talpid-types" }
+tokio = { version = "1.8", features = ["process", "rt-multi-thread", "net"] }
[target.'cfg(not(target_os="android"))'.dependencies]
talpid-types = { path = "../talpid-types" }
@@ -26,7 +27,11 @@ netlink-packet-route = "0.13"
netlink-sys = "0.8.3"
[target.'cfg(target_os = "macos")'.dependencies]
-tokio-stream = { version = "0.1", features = ["io-util"] }
+# TODO: The PF socket type isn't released yet
+nix = { git = "https://github.com/nix-rust/nix", rev = "b13b7d18e0d2f4a8c05e41576c7ebf26d6dbfb28", features = ["socket"] }
+libc = "0.2"
+bitflags = "1.2"
+system-configuration = "0.5.1"
[target.'cfg(windows)'.dependencies]
diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs
index d8c65e80da..8985b4e394 100644
--- a/talpid-routing/src/lib.rs
+++ b/talpid-routing/src/lib.rs
@@ -21,7 +21,7 @@ mod imp;
use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN;
#[cfg(target_os = "macos")]
-pub use imp::{get_default_routes, listen_for_default_route_changes, PlatformError};
+pub use imp::{DefaultRouteEvent, PlatformError};
pub use imp::{Error, RouteManager};
diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs
index 953f0901ca..b8ee26e480 100644
--- a/talpid-routing/src/unix/android.rs
+++ b/talpid-routing/src/unix/android.rs
@@ -1,6 +1,5 @@
-use crate::{imp::RouteManagerCommand, RequiredRoute};
+use crate::imp::RouteManagerCommand;
use futures::{channel::mpsc, stream::StreamExt};
-use std::collections::HashSet;
/// Stub error type for routing errors on Android.
#[derive(Debug, err_derive::Error)]
@@ -12,7 +11,7 @@ pub struct RouteManagerImpl {}
impl RouteManagerImpl {
#[allow(clippy::unused_async)]
- pub async fn new(_required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> {
+ pub async fn new() -> Result<Self, Error> {
Ok(RouteManagerImpl {})
}
diff --git a/talpid-routing/src/unix/linux.rs b/talpid-routing/src/unix/linux.rs
index a642bf5e3d..d2f98b0701 100644
--- a/talpid-routing/src/unix/linux.rs
+++ b/talpid-routing/src/unix/linux.rs
@@ -148,11 +148,7 @@ pub struct RouteManagerImpl {
}
impl RouteManagerImpl {
- pub async fn new(
- required_routes: HashSet<RequiredRoute>,
- table_id: u32,
- fwmark: u32,
- ) -> Result<Self> {
+ pub async fn new(table_id: u32, fwmark: u32) -> Result<Self> {
let (mut connection, handle, messages) =
rtnetlink::new_connection().map_err(Error::Connect)?;
@@ -179,7 +175,6 @@ impl RouteManagerImpl {
};
monitor.clear_routing_rules().await?;
- monitor.add_required_routes(required_routes).await?;
Ok(monitor)
}
@@ -903,14 +898,13 @@ impl NetworkInterface {
#[cfg(test)]
mod test {
use super::*;
- use std::collections::HashSet;
/// Tests if dropping inside a tokio runtime panics
#[test]
fn test_drop_in_executor() {
let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
runtime.block_on(async {
- let manager = RouteManagerImpl::new(HashSet::new(), 0, 0)
+ let manager = RouteManagerImpl::new(0, 0)
.await
.expect("Failed to initialize route manager");
std::mem::drop(manager);
@@ -922,7 +916,7 @@ mod test {
fn test_drop() {
let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
let manager = runtime.block_on(async {
- RouteManagerImpl::new(HashSet::new(), 1000, 1000)
+ RouteManagerImpl::new(1000, 1000)
.await
.expect("Failed to initialize route manager")
});
diff --git a/talpid-routing/src/unix/macos.rs b/talpid-routing/src/unix/macos.rs
deleted file mode 100644
index 893416a699..0000000000
--- a/talpid-routing/src/unix/macos.rs
+++ /dev/null
@@ -1,362 +0,0 @@
-use super::RouteManagerCommand;
-use crate::{NetNode, Node, RequiredRoute, Route};
-
-use futures::{
- channel::mpsc,
- future,
- stream::{FusedStream, Stream, StreamExt, TryStreamExt},
-};
-use ipnetwork::IpNetwork;
-use std::{
- collections::HashSet,
- io,
- net::IpAddr,
- process::{ExitStatus, Stdio},
-};
-use talpid_types::net::IpVersion;
-use tokio::{io::AsyncBufReadExt, process::Command};
-use tokio_stream::wrappers::LinesStream;
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-/// Errors that can happen in the macOS routing integration.
-#[derive(err_derive::Error, Debug)]
-#[error(no_from)]
-pub enum Error {
- /// Failed to add route.
- #[error(display = "Failed to add route")]
- FailedToAddRoute(#[error(source)] io::Error),
-
- /// Failed to remove route.
- #[error(display = "Failed to remove route")]
- FailedToRemoveRoute(#[error(source)] io::Error),
-
- /// Error while running "ip route".
- #[error(display = "Error while running \"route get\"")]
- FailedToRunRoute(#[error(source)] io::Error),
-
- /// Error while monitoring routes with `route -nv monitor`
- #[error(display = "Error while running \"route -nv monitor\"")]
- FailedToMonitorRoutes(#[error(source)] io::Error),
-
- /// Unexpected output from netstat
- #[error(display = "Unexpected output from netstat")]
- BadOutputFromNetstat,
-}
-
-/// Route manager can be in 1 of 4 states -
-/// - waiting for a route to be added or removed from the route table
-/// - obtaining default routes
-/// - applying changes to the route table
-/// - shutting down
-///
-/// Only the _shutting down_ state can be reached from all other states, but during normal
-/// operation, the route manager will add all the required routes during startup and will start
-/// waiting for changes to the route table. If any change is detected, it will stop listening for
-/// new changes, obtain new default routes and reapply routes that should be routed through the
-/// default nodes. Once the routes are reapplied, the route table changes are monitored again.
-pub struct RouteManagerImpl {
- default_destinations: HashSet<IpNetwork>,
- applied_routes: HashSet<Route>,
- v4_gateway: Option<Node>,
- v6_gateway: Option<Node>,
- connectivity_change:
- Option<Box<dyn FusedStream<Item = std::io::Result<()>> + Unpin + Send + Sync>>,
-}
-
-impl RouteManagerImpl {
- pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
- let v4_gateway = Self::get_default_node(IpVersion::V4).await?;
- let v6_gateway = Self::get_default_node(IpVersion::V6).await?;
-
- let monitor = listen_for_default_route_changes()?;
-
- let mut manager = Self {
- default_destinations: HashSet::new(),
- applied_routes: HashSet::new(),
- connectivity_change: Some(Box::new(monitor.fuse())),
- v4_gateway,
- v6_gateway,
- };
-
- manager.add_required_routes(required_routes).await?;
-
- Ok(manager)
- }
-
- pub(crate) async fn run(mut self, manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>) {
- let mut manage_rx = manage_rx.fuse();
- let mut connectivity_change = self.connectivity_change.take().unwrap();
-
- loop {
- futures::select! {
- command = manage_rx.next() => {
- match command {
- Some(RouteManagerCommand::Shutdown(tx)) => {
- self.cleanup_routes().await;
- let _ = tx.send(());
- return;
- },
-
- Some(RouteManagerCommand::AddRoutes(routes, result_tx)) => {
- let result = self.add_required_routes(routes).await;
- let _ = result_tx.send(result);
- },
- Some(RouteManagerCommand::ClearRoutes) => {
- self.cleanup_routes().await;
- },
- None => {
- break;
- }
- }
- },
-
- _result = connectivity_change.select_next_some() => {
- let v4_gateway = Self::get_default_node(IpVersion::V4).await.unwrap_or(None);
- let v6_gateway = Self::get_default_node(IpVersion::V6).await.unwrap_or(None);
-
- if v4_gateway != self.v4_gateway {
- self.v4_gateway = v4_gateway;
- self.apply_new_default_route(&self.v4_gateway, true).await;
- }
-
- if v6_gateway != self.v6_gateway {
- self.v6_gateway = v6_gateway;
- self.apply_new_default_route(&self.v6_gateway, false).await;
- }
- },
- complete => {
- break;
- }
- };
- }
- self.cleanup_routes().await;
- }
-
- async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> {
- let mut routes_to_apply = vec![];
- let mut default_destinations = HashSet::new();
-
- for route in required_routes {
- match route.node {
- NetNode::DefaultNode => {
- default_destinations.insert(route.prefix);
- }
-
- NetNode::RealNode(node) => routes_to_apply.push(Route::new(node, route.prefix)),
- }
- }
-
- for route in routes_to_apply {
- Self::add_route(&route).await?;
- self.applied_routes.insert(route);
- }
-
- for destination in default_destinations.iter() {
- match (&self.v4_gateway, &self.v6_gateway, destination.is_ipv4()) {
- (Some(gateway), _, true) | (_, Some(gateway), false) => {
- let route = Route::new(gateway.clone(), *destination);
- Self::add_route(&route).await?;
- self.applied_routes.insert(route);
- }
- _ => (),
- };
- }
-
- self.default_destinations = default_destinations;
-
- Ok(())
- }
-
- // Retrieves the node that's currently used to reach 0.0.0.0/0
- pub(crate) async fn get_default_node(ip_version: IpVersion) -> Result<Option<Node>> {
- let ip_version_arg = match ip_version {
- IpVersion::V4 => "-inet",
- IpVersion::V6 => "-inet6",
- };
- let mut cmd = Command::new("route");
- cmd.arg("-n").arg("get").arg(ip_version_arg).arg("default");
-
- let output = cmd.output().await.map_err(Error::FailedToRunRoute)?;
- let output = String::from_utf8(output.stdout).map_err(|e| {
- log::error!("Failed to parse utf-8 bytes from output of netstat: {}", e);
- Error::BadOutputFromNetstat
- })?;
- Ok(Self::parse_route(&output))
- }
-
- fn parse_route(route_output: &str) -> Option<Node> {
- let mut address: Option<IpAddr> = None;
- let mut device = None;
- for line in route_output.lines() {
- // we're looking for just 2 different lines:
- // interface: utun0
- // gateway: 192.168.3.1
- let tokens: Vec<_> = line.split_whitespace().collect();
- if tokens.len() == 2 {
- match tokens[0].trim() {
- "interface:" => {
- device = Some(tokens[1].to_string());
- }
- "gateway:" => {
- address = Self::parse_gateway_line(tokens[1]);
- }
- _ => continue,
- }
- }
- }
-
- match (address, device) {
- (Some(address), Some(device)) => Some(Node::new(address, device)),
- (Some(address), None) => Some(Node::address(address)),
- (None, Some(device)) => Some(Node::device(device)),
- _ => None,
- }
- }
-
- fn parse_gateway_line(line: &str) -> Option<IpAddr> {
- // IPv6 addresses may contain interfaces
- // if line contains '%' it should be split off
- line.split('%')
- .next()
- .and_then(|ip_str| ip_str.parse().ok())
- }
-
- async fn delete_route(destination: IpNetwork) -> Result<ExitStatus> {
- let mut cmd = Command::new("route");
- cmd.arg("-q")
- .arg("-n")
- .arg("delete")
- .arg(ip_vers(destination))
- .arg(destination.to_string())
- .stderr(Stdio::null());
-
- cmd.status().await.map_err(Error::FailedToRemoveRoute)
- }
-
- async fn add_route(route: &Route) -> Result<ExitStatus> {
- let mut cmd = Command::new("route");
- cmd.arg("-q")
- .arg("-n")
- .arg("add")
- .arg(ip_vers(route.prefix))
- .arg(route.prefix.to_string());
-
- if let Some(addr) = route.node.get_address() {
- cmd.arg("-gateway").arg(addr.to_string());
- } else if let Some(device) = route.node.get_device() {
- cmd.arg("-interface").arg(device);
- }
-
- cmd.status().await.map_err(Error::FailedToAddRoute)
- }
-
- async fn cleanup_routes(&self) {
- let destinations_to_remove = self
- .applied_routes
- .iter()
- .map(|route| &route.prefix)
- .chain(self.default_destinations.iter());
-
- for destination in destinations_to_remove {
- match Self::delete_route(*destination).await {
- Ok(status) => {
- if !status.success() {
- log::debug!("Failed to remove route during shutdown");
- }
- }
- Err(e) => log::error!("Failed to remove route during shutdown: {}", e),
- };
- }
- }
-
- async fn apply_new_default_route(&self, new_node: &Option<Node>, v4: bool) {
- for destination in self.default_destinations.iter() {
- if destination.is_ipv4() == v4 {
- let _ = Self::delete_route(*destination).await;
-
- if let Some(node) = new_node {
- log::error!("Resetting default route for {}", destination);
- match Self::add_route(&Route::new(node.clone(), *destination)).await {
- Ok(status) => {
- if !status.success() {
- log::error!("Failed to reapply route");
- }
- }
- Err(e) => log::error!("Failed to reset route: {}", e),
- }
- }
- }
- }
- }
-}
-
-fn ip_vers(prefix: IpNetwork) -> &'static str {
- if prefix.is_ipv4() {
- "-inet"
- } else {
- "-inet6"
- }
-}
-
-/// Returns a tuple containing a IPv4 and IPv6 default route nodes.
-pub async fn get_default_routes() -> Result<(Option<Node>, Option<Node>)> {
- futures::try_join!(
- RouteManagerImpl::get_default_node(IpVersion::V4),
- RouteManagerImpl::get_default_node(IpVersion::V6)
- )
-}
-
-/// Returns a stream that produces an item whenever a default route is either added or deleted from
-/// the routing table.
-pub fn listen_for_default_route_changes() -> Result<impl Stream<Item = std::io::Result<()>>> {
- let mut cmd = Command::new("route");
- cmd.arg("-n")
- .arg("monitor")
- .arg("-")
- .stderr(Stdio::null())
- .stdout(Stdio::piped())
- .stdin(Stdio::null());
-
- let mut process = cmd.spawn().map_err(Error::FailedToMonitorRoutes)?;
- let reader = tokio::io::BufReader::new(process.stdout.take().unwrap());
- let lines = reader.lines();
-
- // route -n monitor will produce netlink messages in the following format
- // ```
- // got message of size 176 on Thu Jun 4 10:08:05 2020
- // RTM_DELETE: Delete Route: len 176, pid: 109, seq 1151, errno 3, ifscope 23,
- // flags:<UP,GATEWAY,STATIC,IFSCOPE>
- // locks: inits:
- // sockaddrs: <DST,GATEWAY,NETMASK,IFP,IFA>
- // default 192.168.44.1 default 192.168.44.90
- // ```
- // On the second line of the message, the message type is specified. Only messages with the
- // type 'RTM_ADD' or 'RTM_DELETE' are considered. On the 6th line, message attribute values are
- // shown. To detect a change for a default route in the routing table, check whether this line
- // contains 'default'. Whenever an empty line is encountered, the message has been sent, so
- // the state can be reset.
-
- let mut add_or_delete_message = false;
- let mut contains_default = false;
-
- let monitor = LinesStream::new(lines).try_filter_map(move |line| {
- if add_or_delete_message {
- if line.contains("default") {
- contains_default = true;
- }
- if line.trim().is_empty() {
- add_or_delete_message = false;
- if contains_default {
- contains_default = false;
- return future::ready(Ok(Some(())));
- }
- }
- } else {
- add_or_delete_message = line.starts_with("RTM_ADD:") || line.starts_with("RTM_DELETE:");
- }
- future::ready(Ok(None))
- });
-
- Ok(monitor)
-}
diff --git a/talpid-routing/src/unix/macos/data.rs b/talpid-routing/src/unix/macos/data.rs
new file mode 100644
index 0000000000..1e955c5fa0
--- /dev/null
+++ b/talpid-routing/src/unix/macos/data.rs
@@ -0,0 +1,1147 @@
+use ipnetwork::IpNetwork;
+use nix::{
+ ifaddrs::InterfaceAddress,
+ sys::socket::{SockaddrLike, SockaddrStorage},
+};
+use std::{
+ collections::BTreeMap,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
+};
+
+/// Message that describes a route - either an added, removed, changed or plainly retrieved route.
+#[derive(Debug, Clone, PartialEq)]
+pub struct RouteMessage {
+ sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>,
+ route_flags: RouteFlag,
+ interface_index: u16,
+ errno: i32,
+}
+
+impl RouteMessage {
+ pub fn new_route(destination: Destination) -> Self {
+ let mut route_flags = RouteFlag::RTF_STATIC | RouteFlag::RTF_DONE | RouteFlag::RTF_UP;
+ let mut sockaddrs = BTreeMap::new();
+ match destination {
+ Destination::Network(net) => {
+ let dest_addr = SockaddrStorage::from(SocketAddr::from((net.ip(), 0)));
+ let destination = RouteSocketAddress::Destination(Some(dest_addr));
+ let netmask =
+ RouteSocketAddress::Netmask(Some(SocketAddr::from((net.mask(), 0)).into()));
+ sockaddrs.insert(destination.address_flag(), destination);
+ sockaddrs.insert(netmask.address_flag(), netmask);
+ }
+ Destination::Host(addr) => {
+ let destination =
+ RouteSocketAddress::Destination(Some(SocketAddr::from((addr, 0)).into()));
+ route_flags |= RouteFlag::RTF_HOST;
+ sockaddrs.insert(destination.address_flag(), destination);
+ }
+ };
+
+ Self {
+ sockaddrs,
+ route_flags,
+ interface_index: 0,
+ errno: 0,
+ }
+ }
+
+ pub fn route_addrs(&self) -> impl Iterator<Item = &RouteSocketAddress> {
+ self.sockaddrs.values()
+ }
+
+ fn socketaddress_to_ip(sockaddr: &SockaddrStorage) -> Option<IpAddr> {
+ saddr_to_ipv4(sockaddr)
+ .map(Into::into)
+ .or_else(|| saddr_to_ipv6(sockaddr).map(Into::into))
+ }
+
+ pub fn netmask(&self) -> Option<IpAddr> {
+ self.route_addrs()
+ .find_map(|saddr| match saddr {
+ RouteSocketAddress::Netmask(netmask) => Some(netmask),
+ _ => None,
+ })?
+ .as_ref()
+ .and_then(Self::socketaddress_to_ip)
+ }
+
+ pub fn is_default(&self) -> Result<bool> {
+ Ok(self.is_default_v4()? || self.is_default_v6()?)
+ }
+
+ pub fn is_default_v4(&self) -> Result<bool> {
+ let destination_is_default = self
+ .destination_v4()?
+ .map(|addr| addr == Ipv4Addr::UNSPECIFIED)
+ .unwrap_or(false);
+ let netmask = self.route_addrs().find_map(|saddr| match saddr {
+ RouteSocketAddress::Netmask(addr) => Some(addr),
+ _ => None,
+ });
+
+ // TODO: This might be superfluous
+ let netmask_is_default = match netmask {
+ // empty socket address implies that it is a 'default' netmask
+ Some(None) => true,
+ Some(Some(addr)) => {
+ if let Some(netmask_addr) = saddr_to_ipv4(addr) {
+ netmask_addr.is_unspecified()
+ } else if let Some(netmask_addr) = saddr_to_ipv6(addr) {
+ netmask_addr.is_unspecified()
+ } else {
+ // if the route socket address describing the netmask isn't a sockaddr_in or a
+ // sockaddr_in6, it can't possibly be a default route for IP
+ false
+ }
+ }
+ // absence of a netmask socket address implies that it is a host route
+ None => false,
+ };
+
+ Ok(destination_is_default && netmask_is_default)
+ }
+
+ pub fn is_default_v6(&self) -> Result<bool> {
+ Ok(self
+ .destination_v6()?
+ .map(|addr| addr == Ipv6Addr::UNSPECIFIED)
+ .unwrap_or(false))
+ }
+
+ fn from_byte_buffer(buffer: &[u8]) -> Result<Self> {
+ let header: rt_msghdr = rt_msghdr::from_bytes(buffer)?;
+
+ let msg_len = usize::from(header.rtm_msglen);
+ if msg_len > buffer.len() {
+ return Err(Error::BufferTooSmall(
+ "Message is shorter than it's msg_len indicates",
+ msg_len,
+ buffer.len(),
+ ));
+ }
+
+ let payload = &buffer[ROUTE_MESSAGE_HEADER_SIZE..std::cmp::min(msg_len, buffer.len())];
+
+ let route_flags = RouteFlag::from_bits(header.rtm_flags)
+ .ok_or(Error::UnknownRouteFlag(header.rtm_flags))?;
+ let address_flags = AddressFlag::from_bits(header.rtm_addrs)
+ .ok_or(Error::UnknownAddressFlag(header.rtm_addrs))?;
+ if !address_flags.contains(AddressFlag::RTA_DST) {
+ return Err(Error::NoDestination);
+ }
+ let sockaddrs = RouteSockAddrIterator::new(payload, address_flags)
+ .map(|addr| addr.map(|a| (a.address_flag(), a)))
+ .collect::<Result<BTreeMap<_, _>>>()?;
+ let interface_index = header.rtm_index;
+
+ Ok(Self {
+ route_flags,
+ sockaddrs,
+ interface_index,
+ errno: header.rtm_errno,
+ })
+ }
+
+ fn insert_sockaddr(&mut self, saddr: RouteSocketAddress) {
+ self.sockaddrs.insert(saddr.address_flag(), saddr);
+ }
+
+ pub fn set_destination(mut self, destination: Destination) -> Self {
+ match destination {
+ Destination::Network(net) => {
+ let sockaddr: SocketAddr = (net.ip(), 0).into();
+ let netmask: SocketAddr = (net.mask(), 0).into();
+ self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into())));
+ self.insert_sockaddr(RouteSocketAddress::Netmask(Some(netmask.into())));
+ self.route_flags.remove(RouteFlag::RTF_HOST);
+ }
+ Destination::Host(addr) => {
+ self.route_flags.insert(RouteFlag::RTF_HOST);
+ let sockaddr: SocketAddr = (addr, 0).into();
+ self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into())));
+ }
+ };
+
+ self
+ }
+
+ pub fn set_interface_addr(mut self, link: &InterfaceAddress) -> Self {
+ self.insert_sockaddr(RouteSocketAddress::Gateway(link.address));
+ self.route_flags |= RouteFlag::RTF_GATEWAY;
+ self
+ }
+
+ pub fn set_gateway_sockaddr(mut self, sockaddr: SockaddrStorage) -> Self {
+ self.insert_sockaddr(RouteSocketAddress::Gateway(Some(sockaddr)));
+ self.route_flags |= RouteFlag::RTF_GATEWAY;
+ self
+ }
+
+ pub fn set_gateway_addr(mut self, addr: IpAddr) -> Self {
+ let gateway: SocketAddr = (addr, 0).into();
+ self.insert_sockaddr(RouteSocketAddress::Gateway(Some(gateway.into())));
+ self.route_flags |= RouteFlag::RTF_GATEWAY;
+
+ self
+ }
+
+ pub fn set_gateway_route(mut self, is_gateway_route: bool) -> Self {
+ if is_gateway_route {
+ self.route_flags.insert(RouteFlag::RTF_GATEWAY);
+ } else {
+ self.route_flags.remove(RouteFlag::RTF_GATEWAY);
+ }
+ self
+ }
+
+ pub fn route_flag(mut self, route_flags: RouteFlag) -> Self {
+ self.route_flags = route_flags;
+ self
+ }
+
+ pub fn gateway(&self) -> Option<&SockaddrStorage> {
+ self.route_addrs()
+ .find_map(|saddr| match saddr {
+ RouteSocketAddress::Gateway(gateway) => Some(gateway),
+ _ => None,
+ })?
+ .as_ref()
+ }
+
+ pub fn gateway_ip(&self) -> Option<IpAddr> {
+ self.gateway_v4()
+ .map(IpAddr::V4)
+ .or(self.gateway_v6().map(IpAddr::V6))
+ }
+
+ pub fn gateway_v4(&self) -> Option<Ipv4Addr> {
+ saddr_to_ipv4(self.gateway()?)
+ }
+
+ pub fn gateway_v6(&self) -> Option<Ipv6Addr> {
+ saddr_to_ipv6(self.gateway()?)
+ }
+
+ pub fn destination_ip(&self) -> Result<IpNetwork> {
+ if let Some(saddr) = self.destination()? {
+ if let Some(v4) = saddr.as_sockaddr_in() {
+ let ip_addr = *SocketAddrV4::from(*v4).ip();
+ let netmask = self.netmask().unwrap_or(Ipv4Addr::UNSPECIFIED.into());
+ let destination = IpNetwork::with_netmask(ip_addr.into(), netmask)
+ .map_err(Error::InvalidNetmask)?;
+ return Ok(destination);
+ }
+
+ if let Some(v6) = saddr.as_sockaddr_in6() {
+ let ip_addr = *SocketAddrV6::from(*v6).ip();
+ let netmask = self.netmask().unwrap_or(Ipv6Addr::UNSPECIFIED.into());
+ let destination = IpNetwork::with_netmask(ip_addr.into(), netmask)
+ .map_err(Error::InvalidNetmask)?;
+ return Ok(destination);
+ }
+
+ return Err(Error::MismatchedSocketAddress(
+ AddressFlag::RTA_DST,
+ Box::new(*saddr),
+ ));
+ }
+ Err(Error::NoDestination)
+ }
+
+ pub fn destination(&self) -> Result<Option<&SockaddrStorage>> {
+ Ok(self
+ .route_addrs()
+ .find_map(|saddr| match saddr {
+ RouteSocketAddress::Destination(destination) => Some(destination),
+ _ => None,
+ })
+ .ok_or(Error::NoDestination)?
+ .as_ref())
+ }
+
+ pub fn destination_v4(&self) -> Result<Option<Ipv4Addr>> {
+ Ok(self.destination()?.and_then(saddr_to_ipv4))
+ }
+
+ pub fn destination_v6(&self) -> Result<Option<Ipv6Addr>> {
+ Ok(self.destination()?.and_then(saddr_to_ipv6))
+ }
+
+ pub fn flags(&self) -> &RouteFlag {
+ &self.route_flags
+ }
+
+ pub fn payload(
+ &self,
+ message_type: MessageType,
+ sequence: i32,
+ pid: i32,
+ ) -> (rt_msghdr, Vec<Vec<u8>>) {
+ let address_flags = self.route_addrs().fold(AddressFlag::empty(), |flag, addr| {
+ flag | addr.address_flag()
+ });
+
+ // The sockaddrs should be ordered by their address flag in the payload,
+ // because the payload does not contain their flags. Flags are only specified
+ // in the header.
+ let mut sockaddrs = self.route_addrs().collect::<Vec<_>>();
+ sockaddrs.sort_by_key(|saddr| saddr.address_flag());
+ let payload_bytes = sockaddrs
+ .into_iter()
+ .map(RouteSocketAddress::to_bytes)
+ .collect::<Vec<_>>();
+
+ let payload_len: usize = payload_bytes.iter().map(Vec::len).sum();
+
+ let rtm_msglen = (payload_len + ROUTE_MESSAGE_HEADER_SIZE)
+ .try_into()
+ .expect("route message buffer size cannot fit in 32 bits");
+
+ let header = super::data::rt_msghdr {
+ rtm_msglen,
+ rtm_version: libc::RTM_VERSION.try_into().unwrap(),
+ rtm_type: message_type.bits(),
+ rtm_index: self.interface_index,
+ rtm_flags: self.route_flags.bits(),
+ rtm_addrs: address_flags.bits(),
+ rtm_pid: pid,
+ rtm_seq: sequence,
+ rtm_errno: 0,
+ rtm_use: 0,
+ rtm_inits: 0,
+ rtm_rmx: Default::default(),
+ };
+
+ (header, payload_bytes)
+ }
+
+ pub fn interface_index(&self) -> u16 {
+ self.interface_index
+ }
+
+ pub fn interface_address(&self) -> Option<IpAddr> {
+ self.get_address(&AddressFlag::RTA_IFA)
+ }
+
+ fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> {
+ let addr = self.sockaddrs.get(address_flag)?;
+ saddr_to_ipv4(addr.inner()?)
+ .map(IpAddr::from)
+ .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from))
+ }
+
+ pub fn interface_sockaddr_index(&self) -> Option<u16> {
+ self.sockaddrs
+ .values()
+ .find_map(|addr| addr.interface_index())
+ }
+
+ pub fn errno(&self) -> i32 {
+ self.errno
+ }
+
+ pub fn is_ipv4(&self) -> bool {
+ self.destination_v4()
+ .map(|addr| addr.is_some())
+ .unwrap_or(false)
+ }
+
+ pub fn is_ipv6(&self) -> bool {
+ self.destination_v6()
+ .map(|addr| addr.is_some())
+ .unwrap_or(false)
+ }
+
+ pub fn is_ifscope(&self) -> bool {
+ self.route_flags.contains(RouteFlag::RTF_IFSCOPE)
+ }
+
+ pub fn ifscope(&self) -> Option<u16> {
+ if self.is_ifscope() {
+ Some(self.interface_index)
+ } else {
+ None
+ }
+ }
+
+ pub fn unset_ifscope(mut self) -> Self {
+ self.route_flags.remove(RouteFlag::RTF_IFSCOPE);
+ self
+ }
+
+ pub fn set_ifscope(mut self, iface_index: u16) -> Self {
+ if iface_index > 0 {
+ self.interface_index = iface_index;
+ self.route_flags.insert(RouteFlag::RTF_IFSCOPE);
+ } else {
+ self.route_flags.remove(RouteFlag::RTF_IFSCOPE);
+ }
+
+ self
+ }
+}
+
+#[derive(Debug)]
+#[repr(C)]
+struct ifa_msghdr {
+ ifam_msglen: libc::c_ushort,
+ ifam_version: libc::c_uchar,
+ ifam_type: libc::c_uchar,
+ ifam_addrs: libc::c_int,
+ ifam_flags: libc::c_int,
+ ifam_index: libc::c_ushort,
+ ifam_metric: libc::c_int,
+}
+
+#[derive(Debug)]
+pub struct AddressMessage {
+ sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>,
+ interface_index: u16,
+}
+
+impl AddressMessage {
+ pub fn index(&self) -> u16 {
+ self.interface_index
+ }
+
+ pub fn address(&self) -> Result<IpAddr> {
+ self.get_address(&AddressFlag::RTA_IFP)
+ .or_else(|| self.get_address(&AddressFlag::RTA_IFA))
+ .ok_or(Error::NoInterfaceAddress)
+ }
+
+ fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> {
+ let addr = self.sockaddrs.get(address_flag)?;
+ saddr_to_ipv4(addr.inner()?)
+ .map(IpAddr::from)
+ .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from))
+ }
+
+ pub fn netmask(&self) -> Result<IpAddr> {
+ self.get_address(&AddressFlag::RTA_NETMASK)
+ .ok_or(Error::NoNetmaskAddress)
+ }
+
+ pub fn from_byte_buffer(buffer: &[u8]) -> Result<Self> {
+ const HEADER_SIZE: usize = std::mem::size_of::<ifa_msghdr>();
+ if HEADER_SIZE > buffer.len() {
+ return Err(Error::BufferTooSmall(
+ "ifa_msghdr",
+ buffer.len(),
+ HEADER_SIZE,
+ ));
+ }
+
+ // SAFETY: buffer is pointing to enough memory to contain a valid value for ifa_msghdr
+ let header: ifa_msghdr = unsafe { std::ptr::read_unaligned(buffer.as_ptr() as *const _) };
+
+ let msg_len = usize::from(header.ifam_msglen);
+ if msg_len > buffer.len() {
+ return Err(Error::BufferTooSmall(
+ "Message is shorter than it's msg_len indicates",
+ msg_len,
+ buffer.len(),
+ ));
+ }
+
+ let payload = &buffer[HEADER_SIZE..std::cmp::min(msg_len, buffer.len())];
+
+ let address_flags = AddressFlag::from_bits(header.ifam_addrs)
+ .ok_or(Error::UnknownAddressFlag(header.ifam_addrs))?;
+
+ let sockaddrs = RouteSockAddrIterator::new(payload, address_flags)
+ .map(|addr| addr.map(|addr| (addr.address_flag(), addr)))
+ .collect::<Result<BTreeMap<_, _>>>()?;
+
+ Ok(Self {
+ sockaddrs,
+ interface_index: header.ifam_index,
+ })
+ }
+}
+
+#[derive(Debug)]
+pub enum RouteSocketMessage {
+ AddRoute(RouteMessage),
+ DeleteRoute(RouteMessage),
+ ChangeRoute(RouteMessage),
+ GetRoute(RouteMessage),
+ Interface(Interface),
+ AddAddress(AddressMessage),
+ DeleteAddress(AddressMessage),
+ Other {
+ header: rt_msghdr_short,
+ payload: Vec<u8>,
+ },
+ Error {
+ header: rt_msghdr_short,
+ payload: Vec<u8>,
+ },
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+pub enum Destination {
+ Host(IpAddr),
+ Network(IpNetwork),
+}
+
+impl Destination {
+ pub fn is_network(&self) -> bool {
+ matches!(self, Self::Network(_))
+ }
+
+ pub fn default_v4() -> Self {
+ Destination::Network(IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap())
+ }
+
+ pub fn default_v6() -> Self {
+ Destination::Network(IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap())
+ }
+}
+
+impl From<IpAddr> for Destination {
+ fn from(addr: IpAddr) -> Self {
+ Self::Host(addr)
+ }
+}
+
+impl From<IpNetwork> for Destination {
+ fn from(net: IpNetwork) -> Self {
+ if net.prefix() == 32 && net.is_ipv4() {
+ return Self::Host(net.ip());
+ }
+
+ Self::Network(net)
+ }
+}
+
+#[derive(Debug)]
+pub enum Error {
+ /// Payload buffer didn't match the reported message size in header
+ InvalidBuffer(Vec<u8>, AddressFlag),
+ /// Buffer too small for specific message type
+ BufferTooSmall(&'static str, usize, usize),
+ /// Unknown route flag
+ UnknownRouteFlag(i32),
+ /// Socket address is empty for the given address flag
+ EmptySockaddr(AddressFlag),
+ /// Unrecognized message
+ UnknownMessageType(u8),
+ /// Unrecognized address flag
+ UnknownAddressFlag(libc::c_int),
+ /// Mismatched socket address type
+ MismatchedSocketAddress(AddressFlag, Box<SockaddrStorage>),
+ /// Link socket address contains no identifier
+ NoLinkIdentifier(nix::libc::sockaddr_dl),
+ /// Failed to resolve an interface name to an index
+ InterfaceIndex(nix::Error),
+ /// Invalid netmask
+ InvalidNetmask(ipnetwork::IpNetworkError),
+ /// Route contains no netmask socket address
+ NoDestination,
+ /// Address message does not contain an interface address
+ NoInterfaceAddress,
+ /// Address message does not contain an interface address
+ NoNetmaskAddress,
+}
+
+type Result<T> = std::result::Result<T, Error>;
+
+impl RouteSocketMessage {
+ pub fn parse_message(buffer: &[u8]) -> Result<Self> {
+ let route_message = |route_constructor: fn(RouteMessage) -> RouteSocketMessage, buffer| {
+ let route = RouteMessage::from_byte_buffer(buffer)?;
+ Ok(route_constructor(route))
+ };
+
+ match rt_msghdr_short::from_bytes(buffer) {
+ Some(header) if header.is_type(libc::RTM_ADD) => route_message(Self::AddRoute, buffer),
+
+ Some(header) if header.is_type(libc::RTM_CHANGE) => {
+ route_message(Self::ChangeRoute, buffer)
+ }
+
+ Some(header) if header.is_type(libc::RTM_DELETE) => {
+ route_message(Self::DeleteRoute, buffer)
+ }
+
+ Some(header) if header.is_type(libc::RTM_GET) => route_message(Self::GetRoute, buffer),
+
+ Some(header) if header.is_type(libc::RTM_IFINFO) => Ok(RouteSocketMessage::Interface(
+ Interface::from_byte_buffer(buffer)?,
+ )),
+
+ Some(header) if header.is_type(libc::RTM_NEWADDR) => Ok(
+ RouteSocketMessage::AddAddress(AddressMessage::from_byte_buffer(buffer)?),
+ ),
+ Some(header) if header.is_type(libc::RTM_DELADDR) => Ok(
+ RouteSocketMessage::DeleteAddress(AddressMessage::from_byte_buffer(buffer)?),
+ ),
+ Some(header) => Ok(Self::Other {
+ header,
+ payload: buffer.to_vec(),
+ }),
+ None => Err(Error::BufferTooSmall(
+ "rt_msghdr_short",
+ buffer.len(),
+ ROUTE_MESSAGE_HEADER_SHORT_SIZE,
+ )),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Interface {
+ header: libc::if_msghdr,
+}
+
+impl Interface {
+ pub fn is_up(&self) -> bool {
+ self.header.ifm_flags & nix::libc::IFF_UP != 0
+ }
+
+ pub fn index(&self) -> u16 {
+ self.header.ifm_index
+ }
+
+ fn from_byte_buffer(buffer: &[u8]) -> Result<Self> {
+ const INTERFACE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<libc::if_msghdr>();
+ if INTERFACE_MESSAGE_HEADER_SIZE > buffer.len() {
+ return Err(Error::BufferTooSmall(
+ "if_msghdr",
+ buffer.len(),
+ INTERFACE_MESSAGE_HEADER_SIZE,
+ ));
+ }
+ let header: libc::if_msghdr = unsafe { std::ptr::read(buffer.as_ptr() as *const _) };
+ // let payload = buffer[INTERFACE_MESSAGE_HEADER_SIZE..header.ifm_msglen.into()].to_vec();
+ Ok(Self { header })
+ }
+}
+
+// #define RTA_DST 0x1 /* destination sockaddr present */
+// #define RTA_GATEWAY 0x2 /* gateway sockaddr present */
+// #define RTA_NETMASK 0x4 /* netmask sockaddr present */
+// #define RTA_GENMASK 0x8 /* cloning mask sockaddr present */
+// #define RTA_IFP 0x10 /* interface name sockaddr present */
+// #define RTA_IFA 0x20 /* interface addr sockaddr present */
+// #define RTA_AUTHOR 0x40 /* sockaddr for author of redirect */
+// #define RTA_BRD 0x80 /* for NEWADDR, broadcast or p-p dest addr */
+bitflags::bitflags! {
+ /// All enum values of address flags can be iterated via `flag <<= 1`, starting from 1.
+ /// See https://www.manpagez.com/man/4/route/.
+ pub struct AddressFlag: i32 {
+ /// Destination socket address
+ const RTA_DST = 0x1;
+ /// Gateway socket address
+ const RTA_GATEWAY = 0x2;
+ /// Netmask socket address
+ const RTA_NETMASK = 0x4;
+ /// Cloning mask socket address
+ const RTA_GENMASK = 0x8;
+ /// Interface name socket address
+ const RTA_IFP = 0x10;
+ /// Interface address socket address
+ const RTA_IFA = 0x20;
+ /// Socket address for author of redirect
+ const RTA_AUTHOR = 0x40;
+ /// Socket address for `NEWADDR`, broadcast or point-to-point destination address
+ const RTA_BRD = 0x80;
+ }
+}
+
+bitflags::bitflags! {
+ /// Types of routing messages
+ /// See https://www.manpagez.com/man/4/route/.
+ pub struct MessageType: u8 {
+ /// Add Route
+ const RTM_ADD = 0x1;
+ /// Delete Route
+ const RTM_DELETE = 0x2;
+ /// Change Metrics or flags
+ const RTM_CHANGE = 0x3;
+ /// Report Metrics
+ const RTM_GET = 0x4;
+ /// RTM_LOSING is no longer generated by and is deprecated
+ const RTM_LOSING = 0x5;
+ /// Told to use different route
+ const RTM_REDIRECT = 0x6;
+ /// Lookup failed on this address
+ const RTM_MISS = 0x7;
+ /// fix specified metrics
+ const RTM_LOCK = 0x8;
+ /// caused by SIOCADDRT
+ const RTM_OLDADD = 0x9;
+ /// caused by SIOCDELRT
+ const RTM_OLDDEL = 0xa;
+ /// req to resolve dst to LL addr
+ const RTM_RESOLVE = 0xb;
+ /// address being added to iface
+ const RTM_NEWADDR = 0xc;
+ /// address being removed from iface
+ const RTM_DELADDR = 0xd;
+ /// iface going up/down etc.
+ const RTM_IFINFO = 0xe;
+ /// mcast group membership being added to if
+ const RTM_NEWMADDR = 0xf;
+ /// mcast group membership being deleted
+ const RTM_DELMADDR = 0x10;
+ }
+
+ /// Routing message flags
+ /// See https://www.manpagez.com/man/4/route/.
+ pub struct RouteFlag: i32 {
+ /// route usable
+ const RTF_UP = 0x1;
+ /// destination is a gateway
+ const RTF_GATEWAY = 0x2;
+ /// host entry (net otherwise)
+ const RTF_HOST = 0x4;
+ /// host or net unreachable
+ const RTF_REJECT = 0x8;
+ /// created dynamically (by redirect)
+ const RTF_DYNAMIC = 0x10;
+ /// modified dynamically (by redirect)
+ const RTF_MODIFIED = 0x20;
+ /// message confirmed
+ const RTF_DONE = 0x40;
+ /// delete cloned route
+ const RTF_DELCLONE = 0x80;
+ /// generate new routes on use
+ const RTF_CLONING = 0x100;
+ /// external daemon resolves name
+ const RTF_XRESOLVE = 0x200;
+ /// DEPRECATED - exists ONLY for backwards compatibility
+ const RTF_LLINFO = 0x400;
+ /// used by apps to add/del L2 entries
+ const RTF_LLDATA = 0x400;
+ /// manually added
+ const RTF_STATIC = 0x800;
+ /// just discard pkts (during updates)
+ const RTF_BLACKHOLE = 0x1000;
+ /// not eligible for RTF_IFREF
+ const RTF_NOIFREF = 0x2000;
+ /// protocol specific routing flag
+ const RTF_PROTO2 = 0x4000;
+ /// protocol specific routing flag
+ const RTF_PROTO1 = 0x8000;
+ /// protocol requires cloning
+ const RTF_PRCLONING = 0x10000;
+ /// route generated through cloning
+ const RTF_WASCLONED = 0x20000;
+ /// protocol specific routing flag
+ const RTF_PROTO3 = 0x40000;
+ /// future use
+ const RTF_PINNED = 0x100000;
+ /// route represents a local address
+ const RTF_LOCAL = 0x200000;
+ /// route represents a bcast address
+ const RTF_BROADCAST = 0x400000;
+ /// route represents a mcast address
+ const RTF_MULTICAST = 0x800000;
+ /// has valid interface scope
+ const RTF_IFSCOPE = 0x1000000;
+ /// defunct; no longer modifiable
+ const RTF_CONDEMNED = 0x2000000;
+ /// route holds a ref to interface
+ const RTF_IFREF = 0x4000000;
+ /// proxying, no interface scope
+ const RTF_PROXY = 0x8000000;
+ /// host is a router
+ const RTF_ROUTER = 0x10000000;
+ /// Route entry is being freed
+ const RTF_DEAD = 0x20000000;
+ /// route to destination of the global internet
+ const RTF_GLOBAL = 0x40000000;
+ }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub enum RouteSocketAddress {
+ /// Corresponds to RTA_DST
+ Destination(Option<SockaddrStorage>),
+ /// RTA_GATEWAY
+ Gateway(Option<SockaddrStorage>),
+ /// RTA_NETMASK
+ Netmask(Option<SockaddrStorage>),
+ /// RTA_GENMASK
+ CloningMask(Option<SockaddrStorage>),
+ /// RTA_IFP
+ IfName(Option<SockaddrStorage>),
+ /// RTA_IFA
+ IfSockaddr(Option<SockaddrStorage>),
+ /// RTA_AUTHOR
+ RedirectAuthor(Option<SockaddrStorage>),
+ /// RTA_BRD
+ Broadcast(Option<SockaddrStorage>),
+}
+
+impl RouteSocketAddress {
+ // Returns a new route socket address and number of bytes read from the buffer
+ pub fn new(flag: AddressFlag, buf: &[u8]) -> Result<(Self, u8)> {
+ // If buffer is empty, then the socket address is empty too, the backing buffer shouldn't
+ // be advanced.
+ if buf.is_empty() {
+ return Ok((Self::with_sockaddr(flag, None)?, 0));
+ }
+
+ // to get the length and type of
+ if buf.len() < std::mem::size_of::<sockaddr_hdr>() {
+ return Err(Error::BufferTooSmall(
+ "sockaddr buffer too small",
+ buf.len(),
+ std::mem::size_of::<sockaddr_hdr>(),
+ ));
+ }
+
+ let addr_header_ptr = buf.as_ptr() as *const sockaddr_hdr;
+ // SAFETY: Since `buf` is at least as long as a `sockaddr_hdr`, it's perfectly valid to
+ // read from.
+ let addr_header = unsafe { std::ptr::read(addr_header_ptr) };
+ let saddr_len = addr_header.sa_len;
+ if saddr_len == 0 {
+ return Ok((Self::with_sockaddr(flag, None)?, 4));
+ }
+
+ if Into::<usize>::into(saddr_len) > buf.len() {
+ return Err(Error::InvalidBuffer(buf.to_vec(), flag));
+ }
+
+ // SAFETY: the buffer is big enough for the sockaddr struct inside it, so accessing as a
+ // `sockaddr` is valid.
+ let saddr = unsafe {
+ SockaddrStorage::from_raw(
+ addr_header_ptr as *const nix::libc::sockaddr,
+ Some(saddr_len.into()),
+ )
+ };
+
+ Ok((Self::with_sockaddr(flag, saddr)?, saddr_len))
+ }
+
+ pub fn to_bytes(&self) -> Vec<u8> {
+ match self.inner() {
+ None => vec![0u8; 4],
+ Some(addr) => {
+ let len = usize::try_from(addr.len()).unwrap();
+ assert!(len >= 4);
+
+ // The "serialized" socket addresses must be padded to be aligned to 4 bytes, with
+ // the smallest size being 4 bytes.
+ let buffer_size = len + len % 4;
+ let mut buffer = vec![0u8; buffer_size];
+ unsafe {
+ // SAFETY: copying conents of addr into buffer is safe, as long as addr.len()
+ // returns a correct size for the socket address pointer.
+ std::ptr::copy_nonoverlapping(
+ addr.as_ptr() as *const _,
+ buffer.as_mut_ptr(),
+ len,
+ );
+ }
+ buffer
+ }
+ }
+ }
+
+ pub fn address_flag(&self) -> AddressFlag {
+ match &self {
+ Self::Destination(_) => AddressFlag::RTA_DST,
+ Self::Gateway(_) => AddressFlag::RTA_GATEWAY,
+ Self::Netmask(_) => AddressFlag::RTA_NETMASK,
+ Self::CloningMask(_) => AddressFlag::RTA_GENMASK,
+ Self::IfName(_) => AddressFlag::RTA_IFP,
+ Self::IfSockaddr(_) => AddressFlag::RTA_IFA,
+ Self::RedirectAuthor(_) => AddressFlag::RTA_AUTHOR,
+ Self::Broadcast(_) => AddressFlag::RTA_BRD,
+ }
+ }
+
+ pub fn inner(&self) -> Option<&SockaddrStorage> {
+ match &self {
+ Self::Gateway(addr)
+ | Self::Destination(addr)
+ | Self::Netmask(addr)
+ | Self::CloningMask(addr)
+ | Self::IfName(addr)
+ | Self::IfSockaddr(addr)
+ | Self::RedirectAuthor(addr)
+ | Self::Broadcast(addr) => addr.as_ref(),
+ }
+ }
+
+ fn with_sockaddr(flag: AddressFlag, sockaddr: Option<SockaddrStorage>) -> Result<Self> {
+ let constructor = match flag {
+ AddressFlag::RTA_GATEWAY => Self::Gateway,
+ AddressFlag::RTA_DST => Self::Destination,
+ AddressFlag::RTA_NETMASK => Self::Netmask,
+ AddressFlag::RTA_GENMASK => Self::CloningMask,
+ AddressFlag::RTA_IFP => Self::IfName,
+ AddressFlag::RTA_IFA => Self::IfSockaddr,
+ AddressFlag::RTA_AUTHOR => Self::RedirectAuthor,
+ AddressFlag::RTA_BRD => Self::Broadcast,
+ unknown => return Err(Error::UnknownAddressFlag(unknown.bits())),
+ };
+
+ Ok(constructor(sockaddr))
+ }
+
+ pub fn interface_index(&self) -> Option<u16> {
+ match self {
+ Self::IfName(Some(iface)) => {
+ let index = iface.as_link_addr()?.ifindex();
+ Some(
+ u16::try_from(index)
+ .expect("interface indexes actually are u16s, nix is just *interesting*"),
+ )
+ }
+ _ => None,
+ }
+ }
+}
+
+/// Route socket addreses should be ordered by their corresponding address flag when a route
+/// message is constructed
+impl std::cmp::PartialOrd for RouteSocketAddress {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ self.address_flag().partial_cmp(&other.address_flag())
+ }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone, Debug)]
+struct sockaddr_hdr {
+ sa_len: u8,
+ sa_family: libc::sa_family_t,
+ padding: u16,
+}
+
+/// An iterator to consume a byte buffer containing socket address structures originating from a
+/// routing socket message.
+pub struct RouteSockAddrIterator<'a> {
+ buffer: &'a [u8],
+ flags: AddressFlag,
+ // Cursor used to iterate through address flags
+ flag_cursor: i32,
+}
+
+impl<'a> RouteSockAddrIterator<'a> {
+ fn new(buffer: &'a [u8], flags: AddressFlag) -> Self {
+ Self {
+ buffer,
+ flags,
+ flag_cursor: AddressFlag::RTA_DST.bits(),
+ }
+ }
+
+ /// Advances internal byte buffer by given amount. The byte amount will be padded to be
+ /// aligned to 4 bytes if there's more data in the buffer.
+ fn advance_buffer(&mut self, saddr_len: u8) {
+ let saddr_len = usize::from(saddr_len);
+
+ // if consumed as many bytes as are left in the buffer, the buffer can be cleared
+ if saddr_len == self.buffer.len() {
+ self.buffer = &[];
+ return;
+ }
+
+ let padded_saddr_len = if saddr_len % 4 != 0 {
+ saddr_len + (4 - saddr_len % 4)
+ } else {
+ saddr_len
+ };
+
+ // if offset is larger than current buffer, ensure slice gets truncated
+ // since the socket address should've already be read from the buffer at this point, this
+ // probably should be an invariant?
+ self.buffer = &self.buffer[padded_saddr_len..];
+ }
+}
+
+impl<'a> Iterator for RouteSockAddrIterator<'a> {
+ type Item = Result<RouteSocketAddress>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ loop {
+ // If address flags don't contain the current one, try the next one.
+ // Will return None if it runs out of valid flags.
+ let current_flag = AddressFlag::from_bits(self.flag_cursor)?;
+ self.flag_cursor <<= 1;
+
+ if !self.flags.contains(current_flag) {
+ continue;
+ }
+ return match RouteSocketAddress::new(current_flag, self.buffer) {
+ Ok((next_addr, addr_len)) => {
+ self.advance_buffer(addr_len);
+ Some(Ok(next_addr))
+ }
+ Err(err) => {
+ self.buffer = &[];
+ Some(Err(err))
+ }
+ };
+ }
+ }
+}
+
+// struct rt_msghdr {
+// u_short rtm_msglen; /* to skip over non-understood messages */
+// u_char rtm_version; /* future binary compatibility */
+// u_char rtm_type; /* message type */
+// u_short rtm_index; /* index for associated ifp */
+// int rtm_flags; /* flags, incl. kern & message, e.g. DONE */
+// int rtm_addrs; /* bitmask identifying sockaddrs in msg */
+// pid_t rtm_pid; /* identify sender */
+// int rtm_seq; /* for sender to identify action */
+// int rtm_errno; /* why failed */
+// int rtm_use; /* from rtentry */
+// u_int32_t rtm_inits; /* which metrics we are initializing */
+// struct rt_metrics rtm_rmx; /* metrics themselves */
+// };
+#[derive(Debug, Clone)]
+#[repr(C)]
+pub struct rt_msghdr {
+ pub rtm_msglen: libc::c_ushort,
+ pub rtm_version: libc::c_uchar,
+ pub rtm_type: libc::c_uchar,
+ pub rtm_index: libc::c_ushort,
+ pub rtm_flags: libc::c_int,
+ pub rtm_addrs: libc::c_int,
+ pub rtm_pid: libc::pid_t,
+ pub rtm_seq: libc::c_int,
+ pub rtm_errno: libc::c_int,
+ pub rtm_use: libc::c_int,
+ pub rtm_inits: u32,
+ pub rtm_rmx: rt_metrics,
+}
+const ROUTE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<rt_msghdr>();
+
+fn saddr_to_ipv4(saddr: &SockaddrStorage) -> Option<Ipv4Addr> {
+ let addr = saddr.as_sockaddr_in()?;
+ Some(*SocketAddrV4::from(*addr).ip())
+}
+
+fn saddr_to_ipv6(saddr: &SockaddrStorage) -> Option<Ipv6Addr> {
+ let addr = saddr.as_sockaddr_in6()?;
+ Some(*SocketAddrV6::from(*addr).ip())
+}
+
+impl rt_msghdr {
+ pub fn from_bytes(buf: &[u8]) -> Result<Self> {
+ if buf.len() >= ROUTE_MESSAGE_HEADER_SIZE {
+ let ptr = buf.as_ptr();
+ // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr value and it's
+ // readable. rt_msghdr doesn't contain any pointers so any values are valid.
+ Ok(unsafe { std::ptr::read(ptr as *const _) })
+ } else {
+ Err(Error::BufferTooSmall(
+ "if_msghdr",
+ buf.len(),
+ ROUTE_MESSAGE_HEADER_SIZE,
+ ))
+ }
+ }
+}
+
+/// Shorter rt_msghdr version that matches all routing messages
+#[derive(Debug)]
+#[repr(C)]
+pub struct rt_msghdr_short {
+ pub rtm_msglen: libc::c_ushort,
+ pub rtm_version: libc::c_uchar,
+ pub rtm_type: libc::c_uchar,
+ pub rtm_index: libc::c_ushort,
+ pub rtm_flags: libc::c_int,
+ pub rtm_addrs: libc::c_int,
+ pub rtm_pid: libc::pid_t,
+ pub rtm_seq: libc::c_int,
+ pub rtm_errno: libc::c_int,
+}
+const ROUTE_MESSAGE_HEADER_SHORT_SIZE: usize = std::mem::size_of::<rt_msghdr_short>();
+
+impl rt_msghdr_short {
+ fn is_type(&self, expected_type: i32) -> bool {
+ u8::try_from(expected_type)
+ .map(|expected| self.rtm_type == expected)
+ .unwrap_or(false)
+ }
+
+ pub fn from_bytes(buf: &[u8]) -> Option<Self> {
+ if buf.len() >= ROUTE_MESSAGE_HEADER_SHORT_SIZE {
+ let ptr = buf.as_ptr();
+ // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr_short value and
+ // is readable. `rt_msghdr_short` doesn't contain any pointers so any values are valid.
+ Some(unsafe { std::ptr::read(ptr as *const rt_msghdr_short) })
+ } else {
+ None
+ }
+ }
+}
+
+#[derive(PartialEq, PartialOrd, Ord, Eq, Clone)]
+pub struct RouteDestination {
+ pub network: IpNetwork,
+ pub interface: Option<u16>,
+ pub gateway: Option<IpAddr>,
+}
+
+impl TryFrom<&RouteMessage> for RouteDestination {
+ type Error = Error;
+
+ fn try_from(msg: &RouteMessage) -> std::result::Result<Self, Self::Error> {
+ let network = msg.destination_ip()?;
+ let interface = msg.ifscope();
+ let gateway = msg.gateway_ip();
+ Ok(Self {
+ network,
+ interface,
+ gateway,
+ })
+ }
+}
+
+// Struct containing metrics of various metrics for a specific route
+// struct rt_metrics {
+// u_int32_t rmx_locks; /* Kernel leaves these values alone */
+// u_int32_t rmx_mtu; /* MTU for this path */
+// u_int32_t rmx_hopcount; /* max hops expected */
+// int32_t rmx_expire; /* lifetime for route, e.g. redirect */
+// u_int32_t rmx_recvpipe; /* inbound delay-bandwidth product */
+// u_int32_t rmx_sendpipe; /* outbound delay-bandwidth product */
+// u_int32_t rmx_ssthresh; /* outbound gateway buffer limit */
+// u_int32_t rmx_rtt; /* estimated round trip time */
+// u_int32_t rmx_rttvar; /* estimated rtt variance */
+// u_int32_t rmx_pksent; /* packets sent using this route */
+// u_int32_t rmx_state; /* route state */
+// u_int32_t rmx_filler[3]; /* will be used for TCP's peer-MSS cache */
+// };
+#[derive(Debug, Default, Clone)]
+#[repr(C)]
+pub struct rt_metrics {
+ pub rmx_locks: u32,
+ pub rmx_mtu: u32,
+ pub rmx_hopcount: u32,
+ pub rmx_expire: i32,
+ pub rmx_recvpipe: u32,
+ pub rmx_sendpipe: u32,
+ pub rmx_ssthresh: u32,
+ pub rmx_rtt: u32,
+ pub rmx_rttvar: u32,
+ pub rmx_pksent: u32,
+ pub rmx_state: u32,
+ pub rmx_filler: [u32; 3],
+}
+
+#[test]
+fn test_failing_rtmsg() {
+ let bytes = [
+ 135, 0, 5, 1, 11, 0, 0, 0, 1, 1, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 16, 2, 0, 0, 192, 168, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 18, 11, 0, 6, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 255, 255, 255, 255, 255, 255,
+ ];
+ let _ = RouteSocketMessage::parse_message(&bytes).unwrap();
+}
diff --git a/talpid-routing/src/unix/macos/interface.rs b/talpid-routing/src/unix/macos/interface.rs
new file mode 100644
index 0000000000..8ec4863ab7
--- /dev/null
+++ b/talpid-routing/src/unix/macos/interface.rs
@@ -0,0 +1,78 @@
+use nix::net::if_::if_nametoindex;
+use std::ffi::CString;
+use system_configuration::{
+ core_foundation::string::CFString,
+ network_configuration::{SCNetworkService, SCNetworkSet},
+ preferences::SCPreferences,
+};
+
+use super::{
+ data::{Destination, RouteMessage},
+ watch::RoutingTable,
+};
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub enum Family {
+ V4,
+ V6,
+}
+
+/// Attempt to retrieve the best current default route.
+/// Note: The tunnel interface is not even listed in the service order, so it will be skipped.
+pub async fn get_best_default_route(
+ routing_table: &mut RoutingTable,
+ family: Family,
+) -> Option<RouteMessage> {
+ let destination = match family {
+ Family::V4 => super::v4_default(),
+ Family::V6 => super::v6_default(),
+ };
+
+ let mut msg = RouteMessage::new_route(Destination::Network(destination));
+ msg = msg.set_gateway_route(true);
+
+ for iface in network_service_order() {
+ let iface_bytes = match CString::new(iface.as_bytes()) {
+ Ok(name) => name,
+ Err(error) => {
+ log::error!("Invalid interface name: {iface}, {error}");
+ continue;
+ }
+ };
+
+ // Get interface ID
+ let index = match if_nametoindex(iface_bytes.as_c_str()) {
+ Ok(index) => index,
+ Err(error) => {
+ log::error!("Failed to get index of network interface: {error}");
+ continue;
+ }
+ };
+
+ // Request ifscoped route for this interface
+ let route_msg = msg.clone().set_ifscope(u16::try_from(index).unwrap());
+ if let Ok(Some(route)) = routing_table.get_route(&route_msg).await {
+ return Some(route);
+ }
+ }
+
+ None
+}
+
+fn network_service_order() -> Vec<String> {
+ let prefs = SCPreferences::default(&CFString::new("talpid-routing"));
+ let services = SCNetworkService::get_services(&prefs);
+ let set = SCNetworkSet::new(&prefs);
+ let service_order = set.service_order();
+
+ service_order
+ .iter()
+ .filter_map(|service_id| {
+ services
+ .iter()
+ .find(|service| service.id().as_ref() == Some(&*service_id))
+ .and_then(|service| service.network_interface()?.bsd_name())
+ .map(|cf_name| cf_name.to_string())
+ })
+ .collect::<Vec<_>>()
+}
diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs
new file mode 100644
index 0000000000..3e5d8b0aed
--- /dev/null
+++ b/talpid-routing/src/unix/macos/mod.rs
@@ -0,0 +1,533 @@
+use crate::{NetNode, Node, RequiredRoute, Route};
+
+use futures::{channel::mpsc, future::FutureExt, stream::StreamExt};
+use ipnetwork::IpNetwork;
+use nix::sys::socket::{AddressFamily, SockaddrLike, SockaddrStorage};
+use std::{
+ collections::{BTreeMap, HashSet},
+ net::{Ipv4Addr, Ipv6Addr},
+};
+use talpid_types::ErrorExt;
+use watch::RoutingTable;
+
+use super::{DefaultRouteEvent, RouteManagerCommand};
+use data::{Destination, RouteDestination, RouteMessage, RouteSocketMessage};
+
+mod data;
+mod interface;
+mod routing_socket;
+mod watch;
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Errors that can happen in the macOS routing integration.
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Encountered an error when interacting with the routing socket
+ #[error(display = "Error occurred when interfacing with the routing table")]
+ RoutingTable(#[error(source)] watch::Error),
+
+ /// Failed to remvoe route
+ #[error(display = "Error occurred when deleting a route")]
+ DeleteRoute(#[error(source)] watch::Error),
+
+ /// Failed to add route
+ #[error(display = "Error occurred when adding a route")]
+ AddRoute(#[error(source)] watch::Error),
+
+ /// Failed to fetch link addresses
+ #[error(display = "Failed to fetch link addresses")]
+ FetchLinkAddresses(nix::Error),
+
+ /// Received message isn't valid
+ #[error(display = "Invalid data")]
+ InvalidData(data::Error),
+}
+
+/// Route manager can be in 1 of 4 states -
+/// - waiting for a route to be added or removed from the route table
+/// - obtaining default routes
+/// - applying changes to the route table
+/// - shutting down
+///
+/// Only the _shutting down_ state can be reached from all other states, but during normal
+/// operation, the route manager will add all the required routes during startup and will start
+/// waiting for changes to the route table. If any change is detected, it will stop listening for
+/// new changes, obtain new default routes and reapply routes that should be routed through the
+/// default nodes. Once the routes are reapplied, the route table changes are monitored again.
+pub struct RouteManagerImpl {
+ routing_table: RoutingTable,
+ // Routes that use the default non-tunnel interface
+ non_tunnel_routes: HashSet<IpNetwork>,
+ v4_tunnel_default_route: Option<data::RouteMessage>,
+ v6_tunnel_default_route: Option<data::RouteMessage>,
+ applied_routes: BTreeMap<RouteDestination, RouteMessage>,
+ v4_default_route: Option<data::RouteMessage>,
+ v6_default_route: Option<data::RouteMessage>,
+ default_route_listeners: Vec<mpsc::UnboundedSender<DefaultRouteEvent>>,
+}
+
+impl RouteManagerImpl {
+ /// Create new route manager
+ #[allow(clippy::unused_async)]
+ pub async fn new() -> Result<Self> {
+ let routing_table = RoutingTable::new().map_err(Error::RoutingTable)?;
+ Ok(Self {
+ routing_table,
+ non_tunnel_routes: HashSet::new(),
+ v4_tunnel_default_route: None,
+ v6_tunnel_default_route: None,
+ applied_routes: BTreeMap::new(),
+ v4_default_route: None,
+ v6_default_route: None,
+ default_route_listeners: vec![],
+ })
+ }
+
+ pub(crate) async fn run(mut self, manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>) {
+ let mut manage_rx = manage_rx.fuse();
+
+ // Initialize default routes
+ // NOTE: This isn't race-free, as we're not listening for route changes before initializing
+ self.update_best_default_route(interface::Family::V4)
+ .await
+ .unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to get initial default v4 route")
+ );
+ });
+ self.update_best_default_route(interface::Family::V6)
+ .await
+ .unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to get initial default v6 route")
+ );
+ });
+
+ loop {
+ futures::select_biased! {
+ route_message = self.routing_table.next_message().fuse() => {
+ self.handle_route_message(route_message).await;
+ }
+
+ command = manage_rx.next() => {
+ match command {
+ Some(RouteManagerCommand::Shutdown(tx)) => {
+ if let Err(err) = self.cleanup_routes().await {
+ log::error!("Failed to clean up routes: {err}");
+ }
+ let _ = tx.send(());
+ return;
+ },
+
+ Some(RouteManagerCommand::NewDefaultRouteListener(tx)) => {
+ let (events_tx, events_rx) = mpsc::unbounded();
+ self.default_route_listeners.push(events_tx);
+ let _ = tx.send(events_rx);
+ }
+ Some(RouteManagerCommand::GetDefaultRoutes(tx)) => {
+ // NOTE: The device name isn't really relevant here,
+ // as we only care about routes with a gateway IP.
+ let v4_route = self.v4_default_route.as_ref().map(|route| {
+ Route {
+ node: Node {
+ device: None,
+ ip: route.gateway_ip(),
+ },
+ prefix: v4_default(),
+ metric: None,
+ }
+ });
+ let v6_route = self.v6_default_route.as_ref().map(|route| {
+ Route {
+ node: Node {
+ device: None,
+ ip: route.gateway_ip(),
+ },
+ prefix: v6_default(),
+ metric: None,
+ }
+ });
+
+ let _ = tx.send((v4_route, v6_route));
+ }
+
+ Some(RouteManagerCommand::AddRoutes(routes, tx)) => {
+ log::debug!("Adding routes: {routes:?}");
+ let _ = tx.send(self.add_required_routes(routes).await);
+ }
+ Some(RouteManagerCommand::ClearRoutes) => {
+ if let Err(err) = self.cleanup_routes().await {
+ log::error!("Failed to clean up rotues: {err}");
+ }
+ },
+ None => {
+ break;
+ }
+ }
+ },
+ };
+ }
+
+ if let Err(err) = self.cleanup_routes().await {
+ log::error!("Failed to clean up routing table when shutting down: {err}");
+ }
+ }
+
+ async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> {
+ let mut routes_to_apply = vec![];
+
+ for route in required_routes {
+ match route.node {
+ NetNode::DefaultNode => {
+ self.non_tunnel_routes.insert(route.prefix);
+ }
+
+ NetNode::RealNode(node) => routes_to_apply.push(Route::new(node, route.prefix)),
+ }
+ }
+
+ // Map all interfaces to their link addresses
+ let interface_link_addrs = get_interface_link_addresses()?;
+
+ // Add routes not using the default interface
+ for route in routes_to_apply {
+ let message = if let Some(ref device) = route.node.device {
+ // If we specify route by interface name, use the link address of the given
+ // interface
+ match interface_link_addrs.get(device) {
+ Some(link_addr) => RouteMessage::new_route(Destination::from(route.prefix))
+ .set_gateway_sockaddr(*link_addr),
+ None => {
+ log::error!("Route with unknown device: {route:?}, {device}");
+ continue;
+ }
+ }
+ } else {
+ log::error!("Specifying gateway by IP rather than device is unimplemented");
+ continue;
+ };
+
+ // Default routes are a special case: We must apply it after replacing the current
+ // default route with an ifscope route.
+ if route.prefix.prefix() == 0 {
+ if route.prefix.is_ipv4() {
+ self.v4_tunnel_default_route = Some(message);
+ } else {
+ self.v6_tunnel_default_route = Some(message);
+ }
+ continue;
+ }
+
+ // Add route
+ self.add_route_with_record(message).await?;
+ }
+
+ self.apply_tunnel_default_route().await?;
+
+ // Add routes that use the default interface
+ if let Err(error) = self.apply_non_tunnel_routes().await {
+ self.non_tunnel_routes.clear();
+ return Err(error);
+ }
+
+ Ok(())
+ }
+
+ async fn handle_route_message(
+ &mut self,
+ message: std::result::Result<RouteSocketMessage, watch::Error>,
+ ) {
+ match message {
+ Ok(RouteSocketMessage::DeleteRoute(route)) => {
+ // Forget about applied route, if relevant. This is simply prevent ourselves from
+ // deleting it later.
+ match RouteDestination::try_from(&route).map_err(Error::InvalidData) {
+ Ok(destination) => {
+ self.applied_routes.remove(&destination);
+ }
+ Err(err) => {
+ log::error!("Failed to process deleted route: {err}");
+ }
+ }
+
+ if let Err(error) = self.handle_route_change(route).await {
+ log::error!("Failed to process route change: {error}");
+ }
+ }
+ Ok(RouteSocketMessage::AddRoute(route))
+ | Ok(RouteSocketMessage::ChangeRoute(route)) => {
+ // Refresh routes that are using the default interface
+ if let Err(error) = self.handle_route_change(route).await {
+ log::error!("Failed to process route change: {error}");
+ }
+ }
+ // ignore all other message types
+ Ok(_) => {}
+ Err(err) => {
+ log::error!("Failed to receive a message from the routing table: {err}");
+ }
+ }
+ }
+
+ /// Update routes that use the non-tunnel default interface
+ async fn handle_route_change(&mut self, route: data::RouteMessage) -> Result<()> {
+ // Ignore routes that aren't default routes
+ if !route.is_default().map_err(Error::InvalidData)? {
+ return Ok(());
+ }
+
+ let new_gateway_link_addr = route.gateway().and_then(|addr| addr.as_link_addr());
+
+ // Ignore the new route if it is our tunnel route, lest we create a loop
+ for tunnel_default_route in [&self.v4_tunnel_default_route, &self.v6_tunnel_default_route] {
+ if let Some(tunnel_route) = tunnel_default_route.clone() {
+ let tun_gateway_link_addr =
+ tunnel_route.gateway().and_then(|addr| addr.as_link_addr());
+
+ if new_gateway_link_addr == tun_gateway_link_addr {
+ return Ok(());
+ }
+ }
+ }
+
+ let ip_version = if route.is_ipv4() {
+ interface::Family::V4
+ } else {
+ interface::Family::V6
+ };
+ self.update_best_default_route(ip_version).await
+ }
+
+ async fn update_best_default_route(&mut self, family: interface::Family) -> Result<()> {
+ let best_route = interface::get_best_default_route(&mut self.routing_table, family).await;
+ log::trace!("Best route: {best_route:?}");
+
+ let default_route = match family {
+ interface::Family::V4 => &mut self.v4_default_route,
+ interface::Family::V6 => &mut self.v6_default_route,
+ };
+
+ if default_route == &best_route {
+ log::trace!("Default route is unchanged");
+ return Ok(());
+ }
+
+ let old_route = std::mem::replace(default_route, best_route);
+
+ log::debug!("New default route: {old_route:?} -> {default_route:?}");
+
+ // Notify default route listeners
+ let event = match (family, default_route.is_some()) {
+ (interface::Family::V4, true) => DefaultRouteEvent::AddedOrChangedV4,
+ (interface::Family::V6, true) => DefaultRouteEvent::AddedOrChangedV6,
+ (interface::Family::V4, false) => DefaultRouteEvent::RemovedV4,
+ (interface::Family::V6, false) => DefaultRouteEvent::RemovedV6,
+ };
+ self.default_route_listeners
+ .retain(|tx| tx.unbounded_send(event).is_ok());
+
+ // Substitute route with a tunnel route
+ self.apply_tunnel_default_route().await?;
+
+ // Update routes using default interface
+ self.apply_non_tunnel_routes().await?;
+
+ Ok(())
+ }
+
+ /// Replace the default routes with an ifscope route, and
+ /// add a new default tunnel route.
+ async fn apply_tunnel_default_route(&mut self) -> Result<()> {
+ // As long as the relay route has a way of reaching the internet, we'll want to add a tunnel
+ // route for both IPv4 and IPv6.
+ // NOTE: This is incorrect. We're assuming that any "default destination" is used for
+ // tunneling.
+ let (v4_conn, v6_conn) = self
+ .non_tunnel_routes
+ .iter()
+ .fold((false, false), |(v4, v6), route| {
+ (v4 || route.is_ipv4(), v6 || route.is_ipv6())
+ });
+ let relay_route_is_valid = (v4_conn && self.v4_default_route.is_some())
+ || (v6_conn && self.v6_default_route.is_some());
+
+ if !relay_route_is_valid {
+ return Ok(());
+ }
+
+ for tunnel_route in [
+ self.v4_tunnel_default_route.clone(),
+ self.v6_tunnel_default_route.clone(),
+ ] {
+ let tunnel_route = match tunnel_route {
+ Some(route) => route,
+ None => continue,
+ };
+
+ log::debug!("Adding default route for tunnel");
+
+ // Replace the default route with an ifscope route
+ self.set_default_route_ifscope(tunnel_route.is_ipv4(), true)
+ .await?;
+ self.add_route_with_record(tunnel_route).await?;
+ }
+
+ Ok(())
+ }
+
+ /// Update/add routes that use the default non-tunnel interface. If some applied destination is
+ /// a default route, this function replaces the non-tunnel default route with an ifscope route.
+ async fn apply_non_tunnel_routes(&mut self) -> Result<()> {
+ let v4_gateway = self
+ .v4_default_route
+ .as_ref()
+ .and_then(|route| route.gateway())
+ .cloned();
+ let v6_gateway = self
+ .v6_default_route
+ .as_ref()
+ .and_then(|route| route.gateway())
+ .cloned();
+
+ // Reapply routes that use the default (non-tunnel) node
+ for dest in self.non_tunnel_routes.clone() {
+ let gateway = if dest.is_ipv4() {
+ v4_gateway
+ } else {
+ v6_gateway
+ };
+ let gateway = match gateway {
+ Some(gateway) => gateway,
+ None => continue,
+ };
+ let route =
+ RouteMessage::new_route(Destination::Network(dest)).set_gateway_sockaddr(gateway);
+
+ if let Some(dest) = self
+ .applied_routes
+ .keys()
+ .find(|applied_dest| applied_dest.network == dest)
+ .cloned()
+ {
+ let _ = self.routing_table.delete_route(&route).await;
+ self.applied_routes.remove(&dest);
+ }
+
+ self.add_route_with_record(route).await?;
+ }
+
+ Ok(())
+ }
+
+ /// Replace a known default route with an ifscope route, if should_be_ifscoped is true.
+ /// If should_be_ifscoped is false, the route is replaced with a non-ifscoped default route
+ /// instead.
+ async fn set_default_route_ifscope(
+ &mut self,
+ ipv4: bool,
+ should_be_ifscoped: bool,
+ ) -> Result<()> {
+ let default_route = match (ipv4, &mut self.v4_default_route, &mut self.v6_default_route) {
+ (true, Some(default_route), _) | (false, _, Some(default_route)) => default_route,
+ _ => {
+ return Ok(());
+ }
+ };
+
+ if default_route.is_ifscope() == should_be_ifscoped {
+ return Ok(());
+ }
+
+ log::trace!("Setting non-ifscope: {default_route:?}");
+
+ let interface_index = if should_be_ifscoped {
+ let interface_index = default_route.interface_index();
+ if interface_index == 0 {
+ log::error!("Cannot find interface index of default interface");
+ }
+ interface_index
+ } else {
+ 0
+ };
+ let new_route = default_route.clone().set_ifscope(interface_index);
+ let old_route = std::mem::replace(default_route, new_route);
+
+ self.routing_table
+ .delete_route(&old_route)
+ .await
+ .map_err(Error::DeleteRoute)?;
+
+ self.routing_table
+ .add_route(default_route)
+ .await
+ .map_err(Error::AddRoute)
+ }
+
+ async fn add_route_with_record(&mut self, route: RouteMessage) -> Result<()> {
+ let destination = RouteDestination::try_from(&route).map_err(Error::InvalidData)?;
+
+ self.routing_table
+ .add_route(&route)
+ .await
+ .map_err(Error::AddRoute)?;
+
+ self.applied_routes.insert(destination, route);
+ Ok(())
+ }
+
+ async fn cleanup_routes(&mut self) -> Result<()> {
+ // Remove all applied routes. This includes default destination routes
+ let old_routes = std::mem::take(&mut self.applied_routes);
+ for (_dest, route) in old_routes.into_iter() {
+ log::trace!("Removing route: {route:?}");
+ match self.routing_table.delete_route(&route).await {
+ Ok(_) | Err(watch::Error::RouteNotFound) | Err(watch::Error::Unreachable) => (),
+ Err(err) => {
+ log::error!("Failed to remove relay route: {err:?}");
+ }
+ }
+ }
+
+ // Reset default route
+ if let Err(error) = self
+ .set_default_route_ifscope(true, false)
+ .await
+ .and(self.set_default_route_ifscope(false, false).await)
+ {
+ log::error!("Failed to restore default routes: {error}");
+ }
+
+ // We have already removed the applied default routes
+ self.v4_tunnel_default_route = None;
+ self.v6_tunnel_default_route = None;
+
+ self.non_tunnel_routes.clear();
+
+ Ok(())
+ }
+}
+
+fn v4_default() -> IpNetwork {
+ IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap()
+}
+
+fn v6_default() -> IpNetwork {
+ IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap()
+}
+
+/// Return a map from interface name to link addresses (AF_LINK)
+fn get_interface_link_addresses() -> Result<BTreeMap<String, SockaddrStorage>> {
+ let mut gateway_link_addrs = BTreeMap::new();
+ let addrs = nix::ifaddrs::getifaddrs().map_err(Error::FetchLinkAddresses)?;
+ for addr in addrs.into_iter() {
+ if addr.address.and_then(|addr| addr.family()) != Some(AddressFamily::Link) {
+ continue;
+ }
+ gateway_link_addrs.insert(addr.interface_name, addr.address.unwrap());
+ }
+ Ok(gateway_link_addrs)
+}
diff --git a/talpid-routing/src/unix/macos/routing_socket.rs b/talpid-routing/src/unix/macos/routing_socket.rs
new file mode 100644
index 0000000000..213128fd8c
--- /dev/null
+++ b/talpid-routing/src/unix/macos/routing_socket.rs
@@ -0,0 +1,185 @@
+use std::{
+ collections::VecDeque,
+ mem::size_of,
+ os::unix::prelude::{FromRawFd, RawFd},
+ pin::Pin,
+ task::{ready, Context, Poll},
+};
+
+use nix::{
+ fcntl,
+ sys::socket::{socket, AddressFamily, SockFlag, SockType},
+};
+use std::{
+ fs::File,
+ io::{self, Read, Write},
+};
+
+use super::data::{rt_msghdr_short, MessageType, RouteMessage};
+
+use tokio::io::{unix::AsyncFd, AsyncWrite, AsyncWriteExt};
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "Failed to open routing socket")]
+ OpenSocket(io::Error),
+ #[error(display = "Failed to write to routing socket")]
+ Write(io::Error),
+ #[error(display = "Failed to read from routing socket")]
+ Read(io::Error),
+ #[error(display = "Received a message that's too small")]
+ MessageTooSmall(usize),
+}
+
+type Result<T> = std::result::Result<T, Error>;
+
+/// Wraps a `PF_ROUTE` socket, keeps track of sent message IDs, and facilitates sending and
+/// receiving [route socket messages](#RouteMessage)
+pub struct RoutingSocket {
+ socket: RoutingSocketInner,
+ seq: i32,
+ // buffers up messages received whilst waiting on a response
+ // TODO: might we want to limit the max size of this?
+ buf: VecDeque<Vec<u8>>,
+ own_pid: i32,
+}
+
+impl RoutingSocket {
+ pub fn new() -> Result<Self> {
+ Ok(Self {
+ socket: RoutingSocketInner::new().map_err(Error::OpenSocket)?,
+ seq: 1,
+ buf: Default::default(),
+ own_pid: std::process::id().try_into().unwrap(),
+ })
+ }
+
+ pub async fn recv_msg(&mut self, mut buf: &mut [u8]) -> Result<usize> {
+ if let Some(buffered_msg) = self.buf.pop_front() {
+ let bytes_written = buf.write(&buffered_msg).map_err(Error::Write)?;
+ return Ok(bytes_written);
+ }
+ self.read_next_msg(buf).await
+ }
+
+ async fn read_next_msg(&mut self, buf: &mut [u8]) -> Result<usize> {
+ self.socket.read(buf).await.map_err(Error::Read)
+ }
+
+ pub async fn send_route_message(
+ &mut self,
+ message: &RouteMessage,
+ message_type: MessageType,
+ ) -> Result<Vec<u8>> {
+ let (msg, seq) = self.next_route_msg(message, message_type);
+ match self.socket.write(&msg).await {
+ Ok(_) => self.wait_for_response(seq).await,
+ Err(err) => Err(Error::Write(err)),
+ }
+ }
+
+ pub async fn wait_for_response(&mut self, response_num: i32) -> Result<Vec<u8>> {
+ loop {
+ let mut buffer = vec![0u8; 2048];
+ // do not truncate the buffer - trailing empty bytes won't be written but will be
+ // assumed in the data format.
+ let bytes_read = self.read_next_msg(&mut buffer).await?;
+
+ {
+ let header = rt_msghdr_short::from_bytes(buffer.as_slice())
+ .ok_or(Error::MessageTooSmall(bytes_read))?;
+
+ if header.rtm_pid == self.own_pid && response_num == header.rtm_seq {
+ return Ok(buffer);
+ }
+ }
+
+ self.buf.push_back(buffer);
+ }
+ }
+
+ fn next_route_msg(&mut self, message: &RouteMessage, msg_type: MessageType) -> (Vec<u8>, i32) {
+ let seq = self.seq;
+ self.seq = seq.wrapping_add(1);
+
+ let (header, payload) = message.payload(msg_type, seq, self.own_pid);
+ let mut msg_buffer = vec![0u8; header.rtm_msglen.into()];
+
+ // SAFETY: `msg_buffer` is guaranteed to be at least as large as `rt_msghdr`.
+ unsafe {
+ std::ptr::copy_nonoverlapping(
+ &header as *const _ as *const u8,
+ msg_buffer.as_mut_ptr(),
+ size_of::<super::data::rt_msghdr>(),
+ );
+ }
+ let mut sockaddr_buf = &mut msg_buffer[std::mem::size_of::<super::data::rt_msghdr>()..];
+ for socket_addr in payload {
+ sockaddr_buf
+ .write_all(socket_addr.as_slice())
+ .expect("faled to write socket address into message buffer");
+ }
+ (msg_buffer, header.rtm_seq)
+ }
+}
+
+struct RoutingSocketInner {
+ // storing the file handle in a std::file::File automagically provides sane io::{Write,
+ // Read} and Drop implementations.
+ socket: AsyncFd<File>,
+}
+
+impl RoutingSocketInner {
+ fn new() -> io::Result<Self> {
+ let fd = socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None)?;
+ let _ = fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::OFlag::O_NONBLOCK))?;
+ // SAFETY: File handle is valid here
+ let socket = unsafe { File::from_raw_fd(fd) };
+ Ok(Self {
+ socket: AsyncFd::new(socket)?,
+ })
+ }
+
+ async fn read(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
+ loop {
+ let mut guard = self.socket.readable().await?;
+ match guard.try_io(|sock| sock.get_ref().read(out)) {
+ Ok(result) => return result,
+ Err(_err) => continue,
+ }
+ }
+ }
+}
+
+impl std::os::unix::prelude::AsRawFd for RoutingSocketInner {
+ fn as_raw_fd(&self) -> RawFd {
+ self.socket.as_raw_fd()
+ }
+}
+
+impl AsyncWrite for RoutingSocketInner {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ loop {
+ let mut guard = ready!(self.socket.poll_write_ready(cx))?;
+
+ match guard.try_io(|inner| inner.get_ref().write(buf)) {
+ Ok(result) => return Poll::Ready(result),
+ Err(_would_block) => continue,
+ }
+ }
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ // tcp flush is a no-op
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ // no need for a shutdown on the routing socket
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/talpid-routing/src/unix/macos/watch.rs b/talpid-routing/src/unix/macos/watch.rs
new file mode 100644
index 0000000000..4cf1799276
--- /dev/null
+++ b/talpid-routing/src/unix/macos/watch.rs
@@ -0,0 +1,147 @@
+use super::{
+ data::{self, MessageType, RouteMessage, RouteSocketMessage},
+ routing_socket,
+};
+use std::io;
+
+type Result<T> = std::result::Result<T, Error>;
+
+#[derive(Debug, err_derive::Error)]
+pub enum Error {
+ #[error(display = "Failed to open routing socket")]
+ RoutingSocket(routing_socket::Error),
+ #[error(display = "Invalid message")]
+ InvalidMessage(data::Error),
+ #[error(display = "Failed to send routing message")]
+ Send(routing_socket::Error),
+ #[error(display = "Unexpected message type")]
+ UnexpectedMessageType(RouteSocketMessage, MessageType),
+ #[error(display = "Route not found")]
+ RouteNotFound,
+ #[error(display = "Destination unreachable")]
+ Unreachable,
+ #[error(display = "Failed to delete a route")]
+ Deletion(RouteMessage),
+}
+
+/// Provides an interface for manipulating the routing table on macOS using a PF_ROUTE socket.
+pub struct RoutingTable {
+ socket: routing_socket::RoutingSocket,
+}
+
+impl RoutingTable {
+ pub fn new() -> Result<Self> {
+ let socket = routing_socket::RoutingSocket::new().map_err(Error::RoutingSocket)?;
+
+ Ok(Self { socket })
+ }
+
+ pub async fn next_message(&mut self) -> Result<RouteSocketMessage> {
+ let mut buf = [0u8; 2048];
+ let bytes_read = self
+ .socket
+ .recv_msg(&mut buf)
+ .await
+ .map_err(Error::RoutingSocket)?;
+ let msg_buf = &buf[0..bytes_read];
+ data::RouteSocketMessage::parse_message(msg_buf).map_err(Error::InvalidMessage)
+ }
+
+ pub async fn add_route(&mut self, message: &RouteMessage) -> Result<()> {
+ let msg = self
+ .alter_routing_table(message, MessageType::RTM_ADD)
+ .await;
+
+ match msg {
+ Ok(RouteSocketMessage::AddRoute(_route)) => Ok(()),
+ Err(Error::Send(routing_socket::Error::Write(err)))
+ if err.kind() == io::ErrorKind::AlreadyExists =>
+ {
+ Ok(())
+ }
+ Ok(anything_else) => {
+ log::error!("Unexpected route message: {anything_else:?}");
+ Err(Error::UnexpectedMessageType(
+ anything_else,
+ MessageType::RTM_ADD,
+ ))
+ }
+
+ Err(err) => Err(err),
+ }
+ }
+
+ async fn alter_routing_table(
+ &mut self,
+ message: &RouteMessage,
+ message_type: MessageType,
+ ) -> Result<RouteSocketMessage> {
+ let result = self.socket.send_route_message(message, message_type).await;
+
+ match result {
+ Ok(response) => {
+ data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage)
+ }
+
+ Err(routing_socket::Error::Write(err)) if err.kind() == io::ErrorKind::NotFound => {
+ Err(Error::RouteNotFound)
+ }
+ Err(routing_socket::Error::Write(err))
+ if [Some(libc::ENETUNREACH), Some(libc::ESRCH)].contains(&err.raw_os_error()) =>
+ {
+ Err(Error::Unreachable)
+ }
+ Err(err) => Err(Error::Send(err)),
+ }
+ }
+
+ pub async fn delete_route(&mut self, message: &RouteMessage) -> Result<()> {
+ let response = self
+ .alter_routing_table(message, MessageType::RTM_DELETE)
+ .await?;
+
+ match response {
+ RouteSocketMessage::DeleteRoute(route) if route.errno() == 0 => Ok(()),
+ RouteSocketMessage::DeleteRoute(route) if route.errno() != 0 => {
+ Err(Error::Deletion(route))
+ }
+ anything_else => Err(Error::UnexpectedMessageType(
+ anything_else,
+ MessageType::RTM_DELETE,
+ )),
+ }
+ }
+
+ pub async fn get_route(
+ &mut self,
+ message: &RouteMessage,
+ ) -> Result<Option<data::RouteMessage>> {
+ let response = self
+ .socket
+ .send_route_message(message, MessageType::RTM_GET)
+ .await;
+
+ let response = match response {
+ Ok(response) => response,
+ Err(routing_socket::Error::Write(err)) => {
+ if let Some(err) = err.raw_os_error() {
+ if [libc::ENETUNREACH, libc::ESRCH].contains(&err) {
+ return Ok(None);
+ }
+ }
+ return Err(Error::RoutingSocket(routing_socket::Error::Write(err)));
+ }
+ Err(other_err) => {
+ return Err(Error::RoutingSocket(other_err));
+ }
+ };
+
+ match data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage)? {
+ data::RouteSocketMessage::GetRoute(route) => Ok(Some(route)),
+ unexpected_route_message => Err(Error::UnexpectedMessageType(
+ unexpected_route_message,
+ MessageType::RTM_GET,
+ )),
+ }
+ }
+}
diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs
index 822615d0f1..757d3775fc 100644
--- a/talpid-routing/src/unix/mod.rs
+++ b/talpid-routing/src/unix/mod.rs
@@ -1,9 +1,7 @@
-// TODO: remove the allow(dead_code) for android once it's up to scratch.
-#![cfg_attr(target_os = "android", allow(dead_code))]
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+use crate::Route;
use super::RequiredRoute;
-#[cfg(target_os = "linux")]
-use super::Route;
use futures::channel::{
mpsc::{self, UnboundedSender},
@@ -11,7 +9,7 @@ use futures::channel::{
};
use std::{collections::HashSet, io};
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", target_os = "macos"))]
use futures::stream::Stream;
#[cfg(target_os = "linux")]
@@ -19,11 +17,8 @@ use std::net::IpAddr;
#[allow(clippy::module_inception)]
#[cfg(target_os = "macos")]
-#[path = "macos.rs"]
-mod imp;
-
-#[cfg(target_os = "macos")]
-pub use imp::{get_default_routes, listen_for_default_route_changes};
+#[path = "macos/mod.rs"]
+pub mod imp;
#[allow(clippy::module_inception)]
#[cfg(target_os = "linux")]
@@ -76,6 +71,28 @@ impl RouteManagerHandle {
.map_err(Error::PlatformError)
}
+ /// Listen for non-tunnel default route changes.
+ #[cfg(target_os = "macos")]
+ pub async fn default_route_listener(
+ &self,
+ ) -> Result<impl Stream<Item = DefaultRouteEvent>, Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::NewDefaultRouteListener(response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)
+ }
+
+ /// Get current non-tunnel default routes.
+ #[cfg(target_os = "macos")]
+ pub async fn get_default_routes(&self) -> Result<(Option<Route>, Option<Route>), Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::GetDefaultRoutes(response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)
+ }
+
/// Ensure that packets are routed using the correct tables.
#[cfg(target_os = "linux")]
pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> {
@@ -163,6 +180,10 @@ pub(crate) enum RouteManagerCommand {
),
ClearRoutes,
Shutdown(oneshot::Sender<()>),
+ #[cfg(target_os = "macos")]
+ NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>),
+ #[cfg(target_os = "macos")]
+ GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>),
#[cfg(target_os = "linux")]
CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>),
#[cfg(target_os = "linux")]
@@ -180,6 +201,21 @@ pub(crate) enum RouteManagerCommand {
),
}
+/// Event that is sent when a preferred non-tunnel default route is
+/// added or removed.
+#[cfg(target_os = "macos")]
+#[derive(Debug, Clone, Copy)]
+pub enum DefaultRouteEvent {
+ /// Added or updated a non-tunnel default IPv4 route
+ AddedOrChangedV4,
+ /// Added or updated a non-tunnel default IPv6 route
+ AddedOrChangedV6,
+ /// Non-tunnel default IPv4 route was removed
+ RemovedV4,
+ /// Non-tunnel default IPv6 route was removed
+ RemovedV6,
+}
+
#[cfg(target_os = "linux")]
#[derive(Debug, Clone)]
pub enum CallbackMessage {
@@ -196,17 +232,13 @@ pub struct RouteManager {
}
impl RouteManager {
- /// Constructs a RouteManager and applies the required routes.
- /// Takes a set of network destinations and network nodes as an argument, and applies said
- /// routes.
+ /// Construct a RouteManager.
pub async fn new(
- required_routes: HashSet<RequiredRoute>,
#[cfg(target_os = "linux")] fwmark: u32,
#[cfg(target_os = "linux")] table_id: u32,
) -> Result<Self, Error> {
let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = imp::RouteManagerImpl::new(
- required_routes,
#[cfg(target_os = "linux")]
fwmark,
#[cfg(target_os = "linux")]
@@ -260,8 +292,7 @@ impl RouteManager {
}
}
- /// Removes all routes previously applied in [`RouteManager::new`] or
- /// [`RouteManager::add_routes`].
+ /// Removes all routes previously applied in [`RouteManager::add_routes`].
pub fn clear_routes(&mut self) -> Result<(), Error> {
if let Some(tx) = &self.manage_tx {
if tx.unbounded_send(RouteManagerCommand::ClearRoutes).is_err() {
diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs
index e68c9255fc..51ac345f82 100644
--- a/talpid-routing/src/windows/mod.rs
+++ b/talpid-routing/src/windows/mod.rs
@@ -148,9 +148,9 @@ pub enum RouteManagerCommand {
}
impl RouteManager {
- /// Creates a new route manager that will apply the provided routes and ensure they exist until
- /// it's stopped.
- pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
+ /// Create a new route manager
+ #[allow(clippy::unused_async)]
+ pub async fn new() -> Result<Self> {
let internal = match RouteManagerInternal::new() {
Ok(internal) => internal,
Err(_) => return Err(Error::FailedToStartManager),
@@ -160,7 +160,6 @@ impl RouteManager {
manage_tx: Some(manage_tx),
};
tokio::spawn(RouteManager::listen(manage_rx, internal));
- manager.add_routes(required_routes).await?;
Ok(manager)
}
@@ -270,8 +269,7 @@ impl RouteManager {
}
}
- /// Removes all routes previously applied in [`RouteManager::new`] or
- /// [`RouteManager::add_routes`].
+ /// Removes all routes previously applied in [`RouteManager::add_routes`].
pub fn clear_routes(&self) -> Result<()> {
if let Some(tx) = &self.manage_tx {
tx.unbounded_send(RouteManagerCommand::ClearRoutes)
diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs
index da70b63bfc..cd250856df 100644
--- a/talpid-wireguard/src/lib.rs
+++ b/talpid-wireguard/src/lib.rs
@@ -355,6 +355,7 @@ impl WireguardMonitor {
let routes = Self::get_pre_tunnel_routes(&iface_name, &config)
.chain(Self::get_endpoint_routes(&endpoint_addrs))
.collect();
+
args.route_manager
.add_routes(routes)
.await
@@ -914,7 +915,7 @@ impl WireguardMonitor {
/// Replace default (0-prefix) routes with more specific routes.
fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec<ipnetwork::IpNetwork> {
- #[cfg(not(any(target_os = "linux", target_os = "android")))]
+ #[cfg(windows)]
if network.prefix() == 0 {
if network.is_ipv4() {
vec!["0.0.0.0/1".parse().unwrap(), "128.0.0.0/1".parse().unwrap()]
@@ -925,7 +926,7 @@ impl WireguardMonitor {
vec![network]
}
- #[cfg(any(target_os = "linux", target_os = "android"))]
+ #[cfg(not(windows))]
vec![network]
}