diff options
| author | Emīls <emils@mullvad.net> | 2022-10-18 22:21:03 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-06-05 19:31:49 +0200 |
| commit | 7747639ffa0a1f8d4ee685e88240ca3dff8a9a66 (patch) | |
| tree | b99853791db2fff37c358b8a0a48c325b6e99dbb | |
| parent | 2adf93174484af166c2104627530874698498b7d (diff) | |
| download | mullvadvpn-7747639ffa0a1f8d4ee685e88240ca3dff8a9a66.tar.xz mullvadvpn-7747639ffa0a1f8d4ee685e88240ca3dff8a9a66.zip | |
Attempt to setup routes the other way
| -rw-r--r-- | Cargo.lock | 139 | ||||
| -rw-r--r-- | talpid-core/src/offline/macos.rs | 82 | ||||
| -rw-r--r-- | talpid-core/src/routing/scutil.rs | 43 | ||||
| -rw-r--r-- | talpid-routing/Cargo.toml | 12 | ||||
| -rw-r--r-- | talpid-routing/src/bin/watch.rs | 30 | ||||
| -rw-r--r-- | talpid-routing/src/interfaces.rs | 462 | ||||
| -rw-r--r-- | talpid-routing/src/lib.rs | 5 | ||||
| -rw-r--r-- | talpid-routing/src/unix/ip6addr_ext.rs | 95 | ||||
| -rw-r--r-- | talpid-routing/src/unix/macos.rs | 1046 | ||||
| -rw-r--r-- | talpid-routing/src/unix/mod.rs | 86 | ||||
| -rw-r--r-- | talpid-routing/src/unix/route_watch.rs | 87 | ||||
| -rw-r--r-- | talpid-routing/src/unix/watch/data.rs | 1285 | ||||
| -rw-r--r-- | talpid-routing/src/unix/watch/routing_socket.rs | 191 | ||||
| -rw-r--r-- | talpid-routing/src/watch.rs | 369 | ||||
| -rw-r--r-- | talpid-wireguard/src/config.rs | 8 | ||||
| -rw-r--r-- | talpid-wireguard/src/lib.rs | 40 | ||||
| -rwxr-xr-x | wireguard/libwg/build-android.sh | 1 |
17 files changed, 3729 insertions, 252 deletions
diff --git a/Cargo.lock b/Cargo.lock index 17847ae5cd..0909ef3fc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,7 +85,7 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7ed72e1635e121ca3e79420540282af22da58be50de153d36f81ddc6b83aa9e" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -264,7 +264,7 @@ dependencies = [ "addr2line", "cc", "cfg-if", - "libc", + "libc 0.2.137", "miniz_oxide", "object", "rustc-demangle", @@ -601,7 +601,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a89e2ae426ea83155dccf10c0fa6b1463ef6d5fcb44cee0b224a408fa640a62" dependencies = [ "core-foundation-sys", - "libc", + "libc 0.2.137", ] [[package]] @@ -616,7 +616,7 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -751,7 +751,7 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de0a745c25b32caa56b82a3950f5fec7893a960f4c10ca3b02060b0c38d8c2ce" dependencies = [ - "libc", + "libc 0.2.137", "libdbus-sys", "winapi", ] @@ -860,7 +860,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" dependencies = [ - "libc", + "libc 0.2.137", "redox_users", "winapi", ] @@ -871,7 +871,7 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fc6a0a59ed0888e0041cf708e66357b7ae1a82f1c67247e1f93b5e0818f7d8d" dependencies = [ - "libc", + "libc 0.2.137", "once_cell", "os_pipe", "shared_child", @@ -988,7 +988,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" dependencies = [ "errno-dragonfly", - "libc", + "libc 0.2.137", "winapi", ] @@ -1010,7 +1010,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" dependencies = [ "cc", - "libc", + "libc 0.2.137", ] [[package]] @@ -1205,7 +1205,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" dependencies = [ "cfg-if", - "libc", + "libc 0.2.137", "wasi 0.9.0+wasi-snapshot-preview1", ] @@ -1216,7 +1216,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ "cfg-if", - "libc", + "libc 0.2.137", "wasi 0.10.2+wasi-snapshot-preview1", ] @@ -1273,7 +1273,7 @@ version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -1321,7 +1321,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" dependencies = [ - "libc", + "libc 0.2.137", "match_cfg", "winapi", ] @@ -1491,7 +1491,7 @@ dependencies = [ "bitflags", "futures-core", "inotify-sys", - "libc", + "libc 0.2.137", "tokio", ] @@ -1501,7 +1501,7 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -1872,7 +1872,7 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ - "libc", + "libc 0.2.137", "log", "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.45.0", @@ -1884,7 +1884,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1a5469630da93e1813bb257964c0ccee3b26b6879dd858039ddec35cc8681ed" dependencies = [ - "libc", + "libc 0.2.137", "log", "mnl-sys", ] @@ -1895,7 +1895,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9750685b201e1ecfaaf7aa5d0387829170fa565989cc481b49080aa155f70457" dependencies = [ - "libc", + "libc 0.2.137", "pkg-config", ] @@ -1964,7 +1964,7 @@ dependencies = [ "futures", "ipnetwork", "lazy_static", - "libc", + "libc 0.2.137", "log", "log-panics", "mullvad-api", @@ -2188,7 +2188,7 @@ checksum = "345b8ab5bd4e71a2986663e88c56856699d060e78e152e6e9d7966fcd5491297" dependencies = [ "anyhow", "byteorder", - "libc", + "libc 0.2.137", "netlink-packet-utils", ] @@ -2201,7 +2201,7 @@ dependencies = [ "anyhow", "bitflags", "byteorder", - "libc", + "libc 0.2.137", "netlink-packet-core", "netlink-packet-utils", ] @@ -2241,7 +2241,7 @@ checksum = "92b654097027250401127914afb37cb1f311df6610a9891ff07a757e94199027" dependencies = [ "bytes", "futures", - "libc", + "libc 0.2.137", "log", "tokio", ] @@ -2265,7 +2265,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b5c587b6a5e76a3a5d51e0a757ae66dbff38c277563485807ae979ce361b56" dependencies = [ "cfg-if", - "libc", + "libc 0.2.137", "pkg-config", ] @@ -2299,7 +2299,18 @@ checksum = "195cdbc1741b8134346d515b3a56a1c94b0912758009cfd53f99ea0f57b065fc" dependencies = [ "bitflags", "cfg-if", - "libc", + "libc 0.2.137", +] + +[[package]] +name = "nix" +version = "0.25.0" +dependencies = [ + "bitflags", + "cfg-if", + "libc 0.2.135", + "memoffset", + "pin-utils", ] [[package]] @@ -2416,7 +2427,7 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb233f06c2307e1f5ce2ecad9f8121cffbbee2c95428f44ea85222e460d0d213" dependencies = [ - "libc", + "libc 0.2.137", "winapi", ] @@ -2464,7 +2475,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9981e32fb75e004cc148f5fb70342f393830e0a4aa62e3cc93b50976218d42b6" dependencies = [ "futures", - "libc", + "libc 0.2.137", "log", "rand 0.7.3", "tokio", @@ -2488,7 +2499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", - "libc", + "libc 0.2.137", "redox_syscall", "smallvec", "windows-sys 0.45.0", @@ -2570,7 +2581,7 @@ dependencies = [ "error-chain", "ioctl-sys", "ipnetwork", - "libc", + "libc 0.2.137", ] [[package]] @@ -2856,7 +2867,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" dependencies = [ "getrandom 0.1.16", - "libc", + "libc 0.2.137", "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc", @@ -2868,7 +2879,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", + "libc 0.2.137", "rand_chacha 0.3.1", "rand_core 0.6.4", ] @@ -2973,7 +2984,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" dependencies = [ "cc", - "libc", + "libc 0.2.137", "once_cell", "spin 0.5.2", "untrusted", @@ -3159,7 +3170,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa25200c6de90f8da82d63f8806bd2ea1261018620dd4881626d6b146e13bd7" dependencies = [ - "libc", + "libc 0.2.137", "tokio", ] @@ -3252,7 +3263,7 @@ dependencies = [ "bytes", "cfg-if", "futures", - "libc", + "libc 0.2.137", "log", "notify", "once_cell", @@ -3321,7 +3332,7 @@ dependencies = [ "ipnet", "iprange", "json5", - "libc", + "libc 0.2.137", "log", "lru_time_cache", "nix 0.26.2", @@ -3345,7 +3356,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6be9f7d5565b1483af3e72975e2dee33879b3b86bd48c0929fccf6585d79e65a" dependencies = [ - "libc", + "libc 0.2.137", "winapi", ] @@ -3361,7 +3372,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -3380,7 +3391,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53f7da44adcc42667d57483bd93f81295f27d66897804b757573b61b6f13288b" dependencies = [ "lazy_static", - "libc", + "libc 0.2.137", ] [[package]] @@ -3407,7 +3418,7 @@ version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ - "libc", + "libc 0.2.137", "winapi", ] @@ -3511,13 +3522,30 @@ dependencies = [ [[package]] name = "system-configuration" +version = "0.4.0" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys 0.4.1", +] + +[[package]] +name = "system-configuration" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75182f12f490e953596550b65ee31bda7c8e043d9386174b353bda50838c3fd" dependencies = [ "bitflags", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration-sys" +version = "0.4.1" +dependencies = [ + "core-foundation-sys", + "libc 0.2.137", ] [[package]] @@ -3527,7 +3555,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" dependencies = [ "core-foundation-sys", - "libc", + "libc 0.2.137", ] [[package]] @@ -3548,7 +3576,7 @@ dependencies = [ "ipnetwork", "jnix", "lazy_static", - "libc", + "libc 0.2.137", "log", "memoffset 0.6.4", "mnl", @@ -3571,7 +3599,7 @@ dependencies = [ "shell-escape", "socket2", "subslice", - "system-configuration", + "system-configuration 0.5.0", "talpid-dbus", "talpid-openvpn", "talpid-routing", @@ -3604,7 +3632,7 @@ dependencies = [ "dbus", "err-derive", "lazy_static", - "libc", + "libc 0.2.137", "log", "tokio", ] @@ -3678,16 +3706,21 @@ dependencies = [ name = "talpid-routing" version = "0.0.0" dependencies = [ + "base64 0.20.0", + "bitflags", "err-derive", "futures", "ipnetwork", "lazy_static", - "libc", + "libc 0.2.137", "log", "netlink-packet-route", "netlink-sys", + "nix 0.25.0", "rtnetlink", "socket2", + "system-configuration 0.4.0", + "talpid-time", "talpid-types", "talpid-windows-net", "tokio", @@ -3700,7 +3733,7 @@ dependencies = [ name = "talpid-time" version = "0.0.0" dependencies = [ - "libc", + "libc 0.2.137", "tokio", ] @@ -3763,7 +3796,7 @@ version = "0.0.0" dependencies = [ "err-derive", "futures", - "libc", + "libc 0.2.137", "socket2", "winapi", "windows-sys 0.45.0", @@ -3783,7 +3816,7 @@ dependencies = [ "internet-checksum", "ipnetwork", "lazy_static", - "libc", + "libc 0.2.137", "log", "netlink-packet-core", "netlink-packet-route", @@ -3862,7 +3895,7 @@ version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" dependencies = [ - "libc", + "libc 0.2.137", "winapi", ] @@ -3872,7 +3905,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41effe7cfa8af36f439fac33861b66b049edc6f9a32331e2312660529c1c24ad" dependencies = [ - "libc", + "libc 0.2.137", ] [[package]] @@ -3898,7 +3931,7 @@ checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" dependencies = [ "autocfg", "bytes", - "libc", + "libc 0.2.137", "memchr", "mio", "num_cpus", @@ -3961,7 +3994,7 @@ checksum = "35ccf89920b48afc418f18135342355d30ad048f3c95ba54670f50a52371a439" dependencies = [ "cfg-if", "futures", - "libc", + "libc 0.2.137", "log", "once_cell", "pin-project", @@ -4253,7 +4286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cb3f24867499300ae21771a95bbaede2761497ae51094bbefcfd40646815b2a" dependencies = [ "ioctl-sys", - "libc", + "libc 0.2.137", "thiserror", ] @@ -4490,7 +4523,7 @@ checksum = "ea187a8ef279bc014ec368c27a920da2024d2a711109bfbe3440585d5cf27ad9" dependencies = [ "either", "lazy_static", - "libc", + "libc 0.2.137", ] [[package]] diff --git a/talpid-core/src/offline/macos.rs b/talpid-core/src/offline/macos.rs index baae3c40f1..b14071839c 100644 --- a/talpid-core/src/offline/macos.rs +++ b/talpid-core/src/offline/macos.rs @@ -21,7 +21,7 @@ //! [`SCNetworkReachability`]: https://developer.apple.com/documentation/systemconfiguration/scnetworkreachability-g7d //! [`NWPathMonitor`]: https://developer.apple.com/documentation/network/nwpathmonitor use futures::{channel::mpsc::UnboundedSender, Future, StreamExt}; -use std::sync::{Arc, Weak}; +use std::sync::{Arc, Mutex, Weak}; use talpid_types::ErrorExt; #[derive(err_derive::Error, Debug)] @@ -31,6 +31,7 @@ pub enum Error { } pub struct MonitorHandle { + excluded_interface: Arc<Mutex<Option<String>>>, _notify_tx: Arc<UnboundedSender<bool>>, } @@ -38,58 +39,81 @@ impl MonitorHandle { /// Host is considered to be offline if the IPv4 internet is considered to be unreachable by the /// given reachability flags *or* there are no active physical interfaces. pub async fn host_is_offline(&self) -> bool { - !exists_non_tunnel_default_route().await + let excluded_interface: Option<String> = self + .excluded_interface + .lock() + .expect("excluded_interface lock poisoned") + .clone(); + !exists_non_tunnel_default_route(excluded_interface).await } -} -async fn exists_non_tunnel_default_route() -> bool { - match talpid_routing::get_default_routes().await { - Ok((Some(node), _)) | Ok((None, Some(node))) => { - let route_exists = node - .get_device() - .map(|iface_name| !iface_name.contains("tun")) - .unwrap_or(true); - log::debug!("Assuming non-tunnel default route exists due to {:?}", node); - route_exists - } - Ok((None, None)) => { - log::debug!("No default routes exist, assuming machine is offline"); - false - } - Err(err) => { - log::error!( - "{}", - err.display_chain_with_msg( - "Failed to obtain default routes, assuming machine is online." - ) - ); - true - } + /// Sets excluded interface - excluded interfaces will be assuemd to not provide any + /// connectivity. + pub fn set_excluded_interface(&self, excluded_interface: Option<String>) { + *self + .excluded_interface + .lock() + .expect("excluded interface lock poisoned") = excluded_interface; } } + +async fn exists_non_tunnel_default_route(excluded_interface: Option<String>) -> bool { + true + // match talpid_routing::get_default_routes(excluded_interface).await { + // Ok((Some(node), _)) | Ok((None, Some(node))) => { + // let route_exists = node + // .get_device() + // .map(|iface_name| !iface_name.contains("tun")) + // .unwrap_or(true); + // log::debug!("Assuming non-tunnel default route exists due to {:?}", node); + // route_exists + // } + // Ok((None, None)) => { + // log::debug!("No default routes exist, assuming machine is offline"); + // false + // } + // Err(err) => { + // log::error!( + // "{}", + // err.display_chain_with_msg( + // "Failed to obtain default routes, assuming machine is online." + // ) + // ); + // true + // } + // } +} pub async fn spawn_monitor(notify_tx: UnboundedSender<bool>) -> Result<MonitorHandle, Error> { let notify_tx = Arc::new(notify_tx); + let excluded_interface = Arc::new(Mutex::new(None)); let context = OfflineStateContext { sender: Arc::downgrade(¬ify_tx), - is_offline: !exists_non_tunnel_default_route().await, + is_offline: !exists_non_tunnel_default_route(None).await, }; - let route_monitor = watch_route_monitor(context)?; + let route_monitor = watch_route_monitor(context, excluded_interface.clone())?; tokio::spawn(route_monitor); Ok(MonitorHandle { + excluded_interface, _notify_tx: notify_tx, }) } fn watch_route_monitor( mut context: OfflineStateContext, + excluded_interface: Arc<Mutex<Option<String>>>, ) -> Result<impl Future<Output = ()>, Error> { let mut monitor = talpid_routing::listen_for_default_route_changes()?; Ok(async move { while let Some(_route_change) = monitor.next().await { - context.new_state(!exists_non_tunnel_default_route().await); + let interface = excluded_interface + .lock() + .expect("excluded_interface lock poisoned") + .clone(); + + context.new_state(!exists_non_tunnel_default_route(interface).await); if context.should_shut_down() { break; } diff --git a/talpid-core/src/routing/scutil.rs b/talpid-core/src/routing/scutil.rs new file mode 100644 index 0000000000..fa5f5dff6a --- /dev/null +++ b/talpid-core/src/routing/scutil.rs @@ -0,0 +1,43 @@ +use std::collections::HashMap; + +use tokio::process::Command; + +use super::{Error, Result}; + +fn get_default_interface() -> () { + () +} + +/// Expected to produce output like: +/// Network information +/// IPv4 network interface information +/// utun3 : flags : 0x5 (IPv4,DNS) +/// address : 10.113.48.185 +/// VPN server : 127.0.0.1 +/// reach : 0x00000003 (Reachable,Transient Connection) +/// en0 : flags : 0x5 (IPv4,DNS) +/// address : 192.168.102.106 +/// reach : 0x00000002 (Reachable) +/// +/// REACH : flags 0x00000003 (Reachable,Transient Connection) +/// +/// IPv6 network interface information +/// No IPv6 states found +/// +/// +/// REACH : flags 0x00000007 (Reachable,Transient Connection,Connection Required) +/// +/// Network interfaces: utun3 en0 + +/// +async fn obtain_output() -> Result<Vec<u8>> { + let mut cmd = Command::new("scutil"); + cmd.arg("--nwi"); + + Ok(cmd.output().await.map_err(|_| Error::ScUtilCommand)?.stdout) +} + +fn parse_scutil_output(output: &[u8]) -> Result<Vec<HashMap<&str, &[u8]>>> { + for line in output.split(|c| *c == b'\n') {} + Ok(vec![]) +} diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml index a2f7e780be..f3b9f5b9cf 100644 --- a/talpid-routing/Cargo.toml +++ b/talpid-routing/Cargo.toml @@ -13,7 +13,8 @@ err-derive = "0.3.1" futures = "0.3.15" ipnetwork = "0.16" log = "0.4" -tokio = { version = "1.8", features = ["process", "rt-multi-thread"] } +talpid-types = { path = "../talpid-types" } +tokio = { version = "1.8", features = ["process", "rt-multi-thread", "net"] } [target.'cfg(not(target_os="android"))'.dependencies] talpid-types = { path = "../talpid-types" } @@ -27,6 +28,14 @@ netlink-sys = "0.8.3" [target.'cfg(target_os = "macos")'.dependencies] tokio-stream = { version = "0.1", features = ["io-util"] } +nix = { path = "../../src/nix-rs", feautres = ["socket"] } +libc = "0.2" +bitflags = "1.2" +talpid-time = { path = "../talpid-time" } +# TODO: use real release before merging to master +system-configuration = { path = "../../src/system-configuration-rs/system-configuration/" } + +# system-configuration = { repo = "https://github.com/mullvad/system-configuration-rs", ref = "5df2d065fa6c40a75f4f93b8a1db598abfc16743" } [target.'cfg(windows)'.dependencies] @@ -41,3 +50,4 @@ windows-sys = { version = "0.45.0", features = [ [dev-dependencies] tokio = { version = "1", features = [ "test-util" ] } +base64 = "0.20" diff --git a/talpid-routing/src/bin/watch.rs b/talpid-routing/src/bin/watch.rs new file mode 100644 index 0000000000..e1de22eea6 --- /dev/null +++ b/talpid-routing/src/bin/watch.rs @@ -0,0 +1,30 @@ +use system_configuration::{ + core_foundation::string::CFString, + network_configuration::{SCNetworkService, SCNetworkSet}, + preferences::SCPreferences, +}; + +fn main() { + let rt = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); + println!("order_of-interfaces {}", order_of_interfaces().join(", ")); + rt.block_on(talpid_routing::watch_routes()) + .expect("rt panicked"); +} + +fn order_of_interfaces() -> Vec<String> { + let prefs = SCPreferences::default(&CFString::new("talpid-routing")); + let services = SCNetworkService::get_services(&prefs); + let set = SCNetworkSet::new(&prefs); + let service_order = set.service_order(); + + service_order + .iter() + .filter_map(|service_id| { + services + .iter() + .find(|service| service.id().as_ref() == Some(&*service_id)) + .and_then(|service| service.network_interface()?.bsd_name()) + .map(|cf_name| cf_name.to_string()) + }) + .collect::<Vec<_>>() +} diff --git a/talpid-routing/src/interfaces.rs b/talpid-routing/src/interfaces.rs new file mode 100644 index 0000000000..32b32db07a --- /dev/null +++ b/talpid-routing/src/interfaces.rs @@ -0,0 +1,462 @@ +use std::{collections::BTreeMap, net::IpAddr, time::Duration}; + +use system_configuration::{ + core_foundation::string::CFString, + network_configuration::{SCNetworkService, SCNetworkSet}, + preferences::SCPreferences, +}; +use talpid_time::Instant; + +use super::{ + ip6addr_ext::IpAddrExt, + watch::data::{self, AddressMessage, RouteDestination}, + Error, Result, +}; + +const NON_DEFAULT_ROUTE_VALIDITY_TIMEOUT: Duration = Duration::from_secs(120); + +/// Keepes track of valid interfaces +pub struct Interfaces { + map: BTreeMap<u16, Interface>, + current_v4_interface: Option<u16>, + current_v6_interface: Option<u16>, +} + +impl Interfaces { + pub fn new() -> Interfaces { + Self { + map: BTreeMap::new(), + current_v4_interface: None, + current_v6_interface: None, + } + } + + pub fn confirm_route(&mut self, interface: u16, destination: &RouteDestination) { + if let Some(iface) = self.map.get_mut(&interface) { + iface.confirm_route(destination); + } + } + + // Currently only lookgs for best V4 default interface + pub fn get_best_default_interface_v4(&self) -> Result<Option<BestRoute>> { + self.get_best_interface(&|interface| interface.best_v4_route()) + } + + // Currently only lookgs for best V6 default interface + pub fn get_best_default_interface_v6(&self) -> Result<Option<BestRoute>> { + self.get_best_interface(&|interface| interface.best_v6_route()) + } + + fn get_best_interface( + &self, + ipv_fn: &dyn Fn(&Interface) -> Option<BestRoute>, + ) -> Result<Option<BestRoute>> { + let mut ordered_interfaces = order_of_interfaces() + .into_iter() + .filter_map(|iface_name| { + self.map.iter().find_map(|(idx, interface)| { + if interface.name == iface_name { + Some(*idx) + } else { + None + } + }) + }) + .filter_map(|index| self.map.get(&index)) + .filter_map(|interface| Some((interface, ipv_fn(interface)?))) + .collect::<Vec<_>>(); + + if ordered_interfaces.is_empty() { + log::error!("Failed to obtain a list of valid network services"); + return Ok(None); + } + + ordered_interfaces.sort_by_key(|(_interface, best_route)| best_route.validity); + Ok(ordered_interfaces + .into_iter() + .next() + .map(|(_, best_route)| best_route)) + } + + pub fn handle_add_address(&mut self, address: AddressMessage) -> bool { + let interface = match self.map.get_mut(&address.index()) { + Some(interface) => interface, + None => { + log::error!( + "Received address message for non-existant interface with index {}", + address.index() + ); + return false; + } + }; + + match address.address() { + Ok(addr) => interface.add_address(addr), + Err(err) => { + log::error!("Failed to get interface address from address message: {err:?}"); + false + } + } + } + + pub fn handle_delete_address(&mut self, address: AddressMessage) -> bool { + let interface = match self.map.get_mut(&address.index()) { + Some(interface) => interface, + None => { + log::error!( + "Received address message with an unknown interface {}", + address.index() + ); + return false; + } + }; + + match address.address() { + Ok(addr) => interface.addresses.remove(&addr).is_some(), + Err(err) => { + log::error!("Failed to get interface address from address message: {err:?}"); + false + } + } + } + + // returning true implies that the best default interface might've changed + pub fn handle_iface_msg(&mut self, interface: data::Interface) -> Result<bool> { + let index = interface.index(); + if interface.is_up() { + if self.map.contains_key(&index) { + return Ok(false); + } + self.map.insert(index, Interface::new(index)?); + // just because an interface is added doesn't imply that routes will change - have to + // wait for new addresses and routes to come in. + Ok(false) + } else { + Ok(self.map.remove(&index).is_some()) + } + } + + pub fn handle_add_route(&mut self, route: &data::RouteMessage) -> Result<bool> { + let destination = route.destination_ip().map_err(Error::InvalidData)?; + + let mut new_v4_route = false; + let mut new_v6_route = false; + + if route.is_ipv4() { + match (route.ifscope(), route.interface_sockaddr_index()) { + (Some(index), _) | (_, Some(index)) => match self.map.get_mut(&index) { + Some(interface) => { + let destination = + RouteDestination::try_from(route).map_err(Error::InvalidData)?; + interface.add_route(destination); + if route.is_ipv4() { + new_v4_route = Some(index) != self.current_v4_interface; + } else { + new_v6_route = Some(index) != self.current_v6_interface; + } + } + None => { + log::error!("Received a route with destination {:?} about through an unknown interface {index}", route.destination_ip()); + } + }, + _ => (), + } + } + + if new_v4_route { + let new_interface = self + .get_best_default_interface_v4()? + .map(|route| route.iface_index); + if new_interface != self.current_v4_interface { + self.current_v4_interface = new_interface; + return Ok(true); + } + } + + if new_v6_route { + let new_interface = self + .get_best_default_interface_v6()? + .map(|route| route.iface_index); + if new_interface != self.current_v6_interface { + self.current_v6_interface = new_interface; + return Ok(true); + } + } + + Ok(false) + } + + pub fn handle_delete_route(&mut self, route: &data::RouteMessage) -> Result<bool> { + if let Some(ifscope) = route.ifscope() { + if let Some(interface) = self.map.get_mut(&ifscope) { + let destination = route.try_into().map_err(Error::RouteDestination)?; + interface.remove_route(&destination); + return Ok(true); + } else { + log::error!("Received route message about unknown interface"); + return Ok(false); + } + } + + let iface_addr = match route.interface_address() { + Some(addr) => addr, + None => { + return Ok(false); + } + }; + + let interface = self + .map + .values_mut() + .find(|interface| interface.has_addr(&iface_addr)); + + if let Some(interface) = interface { + let destination = route.try_into().map_err(Error::RouteDestination)?; + interface.remove_route(&destination); + return Ok(true); + } + + Ok(false) + } + + pub fn handle_changed_route(&mut self, route: &data::RouteMessage) -> Result<bool> { + // If an ifscoped route is changed, it can be interpreted as though a new route has been + // added, if the old is removed first. + if route.is_ifscope() { + return Ok(false); + } + + Ok(self.handle_delete_route(route)? || self.handle_add_route(route)?) + } +} + +/// Represents all the data about the current best route to the internet +pub struct BestRoute { + iface_index: u16, + destination: RouteDestination, + validity: RouteValidity, +} + +impl BestRoute { + /// Returns amount of time between now and until the route will no longer be considered valid. + pub fn timeout(&self) -> Option<Instant> { + None + // match self.validity { + // RouteValidity::Unknown() + // } + } +} + +pub struct InterfaceIdentifier { + index: u16, + name: String, +} + +pub struct Interface { + /// Network interface index + index: u16, + /// BSD name of the network interface + name: String, + /// routes assigned to interface, should not be used to track ifscoped routes + routes: BTreeMap<RouteDestination, RouteValidity>, + /// Addresses assigned to the network interface + addresses: BTreeMap<IpAddr, Instant>, +} + +impl Interface { + fn new(index: u16) -> Result<Self> { + let interfaces = nix::net::if_::if_nameindex().map_err(Error::GetInterfaceNames)?; + let c_name = interfaces + .iter() + .find(|iface| iface.index() == index.into()) + .map(|iface| iface.name()) + .ok_or(Error::GetInterfaceName)?; + + let name = match c_name.to_str() { + Ok(name) => name.to_owned(), + Err(_) => { + log::error!("Interface name is not valid UTF-8: {:?}", c_name); + return Err(Error::GetInterfaceName); + } + }; + + Ok(Self { + name, + index, + routes: Default::default(), + addresses: Default::default(), + }) + } + + fn best_v4_route(&self) -> Option<BestRoute> { + let mut candidates = self + .routes + .iter() + .filter(|(destination, validity)| destination.is_ipv4() && validity.is_valid()) + .collect::<Vec<_>>(); + + candidates.first().map(|(destination, validity)| BestRoute { + iface_index: self.index, + destination: (*destination).clone(), + validity: **validity, + }) + } + + fn best_v6_route(&self) -> Option<BestRoute> { + let mut candidates = self + .routes + .iter() + .filter(|(destination, validity)| !destination.is_ipv4() && validity.is_valid()) + .collect::<Vec<_>>(); + + candidates.sort_by_key(|(destination, validity)| *validity); + + candidates.first().map(|(destination, validity)| BestRoute { + iface_index: self.index, + destination: (*destination).clone(), + validity: **validity, + }) + } + + fn add_route(&mut self, destination: RouteDestination) -> bool { + let validity = if destination.is_default() { + RouteValidity::Default + } else { + RouteValidity::Unknown(Instant::now()) + }; + self.routes.insert(destination, validity).is_some() + } + + fn confirm_route(&mut self, destination: &RouteDestination) { + if let Some(route) = self.routes.get_mut(&destination) { + *route = RouteValidity::Confirmed; + } + } + + fn remove_route(&mut self, destination: &RouteDestination) { + self.routes.remove(destination); + } + + fn has_v4_default_route(&self) -> bool { + self.routes + .keys() + .any(|route| route.is_ipv4() && route.is_default()) + } + + fn has_v6_default_route(&self) -> bool { + self.routes + .keys() + .any(|route| !route.is_ipv4() && route.is_default()) + } + + fn add_address(&mut self, address: IpAddr) -> bool { + self.addresses.insert(address, Instant::now()).is_some() + } + + fn has_addr(&self, iface_addr: &IpAddr) -> bool { + self.addresses.contains_key(iface_addr) + } +} + +fn order_of_interfaces() -> Vec<String> { + let prefs = SCPreferences::default(&CFString::new("talpid-routing")); + let services = SCNetworkService::get_services(&prefs); + let set = SCNetworkSet::new(&prefs); + let service_order = set.service_order(); + + service_order + .iter() + .filter_map(|service_id| { + services + .iter() + .find(|service| service.id().as_ref() == Some(&*service_id)) + .and_then(|service| service.network_interface()?.bsd_name()) + .map(|cf_name| cf_name.to_string()) + }) + .collect::<Vec<_>>() +} + +/// Represents whether a route is considered to provide internet. +#[derive(Clone, Copy)] +pub enum RouteValidity { + /// Route is default, hence it's expected it must provide connectivity. + Default, + /// If a non-default route seems to provide connectivity, then this should be considered to be + /// a valid route. + Confirmed, + /// When a new route appears that isn't a default one, we can assume it's interface may provide + /// internet connectivity. In this case, it should be valid for a specific amount of time - + /// afterwards, if connectivity cannot be established, it it can be assumed it doesn't provide + /// connectivity. + Unknown(Instant), +} + +impl RouteValidity { + fn is_valid(&self) -> bool { + match self { + Self::Default | Self::Confirmed => true, + Self::Unknown(time_seen) => { + time_seen.duration_since(Instant::now()) <= NON_DEFAULT_ROUTE_VALIDITY_TIMEOUT + } + } + } +} + +impl std::fmt::Debug for RouteValidity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RouteValidity::")?; + match self { + Self::Default => write!(f, "Default"), + Self::Confirmed => write!(f, "Confirmed"), + Self::Unknown(time_seen) => { + if self.is_valid() { + let msecs_valid = time_seen.duration_since(Instant::now()).as_millis(); + write!(f, "Uknown(valid for {msecs_valid} ms)") + } else { + write!(f, "Uknown(expired)") + } + } + } + } +} + +impl PartialEq for RouteValidity { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Default, Self::Default) => true, + (Self::Confirmed, Self::Confirmed) => true, + (Self::Unknown(time_created), Self::Unknown(other)) => { + let now = Instant::now(); + time_created.duration_since(now) == other.duration_since(now) + } + _ => false, + } + } +} +impl Eq for RouteValidity {} + +impl PartialOrd for RouteValidity { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for RouteValidity { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + if self == other { + return std::cmp::Ordering::Equal; + } + match (self, other) { + (Self::Default, _) | (Self::Confirmed, Self::Unknown(_)) => std::cmp::Ordering::Greater, + (Self::Unknown(unknown), Self::Unknown(other)) => { + let now = Instant::now(); + // the newer a route is, the higher preference it has + unknown + .duration_since(now) + .cmp(&other.duration_since(now)) + .reverse() + } + _ => std::cmp::Ordering::Less, + } + } +} diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index d8c65e80da..8c1d423fac 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -21,7 +21,10 @@ mod imp; use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; #[cfg(target_os = "macos")] -pub use imp::{get_default_routes, listen_for_default_route_changes, PlatformError}; +pub use imp::{ + get_default_routes, imp::watch::watch_routes, listen_for_default_route_changes, PlatformError, + TunnelRoutesV4, TunnelRoutesV6, +}; pub use imp::{Error, RouteManager}; diff --git a/talpid-routing/src/unix/ip6addr_ext.rs b/talpid-routing/src/unix/ip6addr_ext.rs new file mode 100644 index 0000000000..314263a614 --- /dev/null +++ b/talpid-routing/src/unix/ip6addr_ext.rs @@ -0,0 +1,95 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +/// To distinguish between globally addressable and unadressable IPv6 addresses with the current +/// stable Rust standard library, one has to reimplement the nightly code. +pub trait IpAddrExt { + /// Returns true if the IPv6 address is globally addressable. + fn is_global(&self) -> bool; +} + +fn addr_has_a_chance_to_reach_internet(addr: &IpAddr) -> bool { + match addr { + IpAddr::V4(v4) => v4.is_private() || v4.is_global(), + + IpAddr::V6(v6) => IpAddrExt::is_global(v6), + } +} + +impl IpAddrExt for IpAddr { + fn is_global(&self) -> bool { + match self { + IpAddr::V4(addr) => addr.is_global(), + IpAddr::V6(addr) => addr.is_global(), + } + } +} + +impl IpAddrExt for Ipv6Addr { + fn is_global(&self) -> bool { + !(self.is_unspecified() + || self.is_loopback() + // IPv4-mapped Address (`::ffff:0:0/96`) + || matches!(self.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) + // IPv4-IPv6 Translat. (`64:ff9b:1::/48`) + || matches!(self.segments(), [0x64, 0xff9b, 1, _, _, _, _, _]) + // Discard-Only Address Block (`100::/64`) + || matches!(self.segments(), [0x100, 0, 0, 0, _, _, _, _]) + // IETF Protocol Assignments (`2001::/23`) + || (matches!(self.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200) + && !( + // Port Control Protocol Anycast (`2001:1::1`) + u128::from_be_bytes(self.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001 + // Traversal Using Relays around NAT Anycast (`2001:1::2`) + || u128::from_be_bytes(self.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002 + // AMT (`2001:3::/32`) + || matches!(self.segments(), [0x2001, 3, _, _, _, _, _, _]) + // AS112-v6 (`2001:4:112::/48`) + || matches!(self.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) + // ORCHIDv2 (`2001:20::/28`) + || matches!(self.segments(), [0x2001, b, _, _, _, _, _, _] if b >= 0x20 && b <= 0x2F) + )) + || ipv6_is_documentation(self) + || ipv6_is_unique_local(self) + || ipv6_is_unicast_link_local(self)) + } +} + +impl IpAddrExt for Ipv4Addr { + fn is_global(&self) -> bool { + !(self.octets()[0] == 0 // "This network" + || self.is_private() + || ipv4_is_shared(self) + || self.is_loopback() + || self.is_link_local() + // addresses reserved for future protocols (`192.0.0.0/24`) + ||(self.octets()[0] == 192 && self.octets()[1] == 0 && self.octets()[2] == 0) + || self.is_documentation() + || ipv4_is_benchmarking(self) + || ipv4_is_reserved(self) + || self.is_broadcast()) + } +} + +fn ipv6_is_documentation(addr: &Ipv6Addr) -> bool { + (addr.segments()[0] == 0x2001) && (addr.segments()[1] == 0xdb8) +} + +fn ipv6_is_unique_local(addr: &Ipv6Addr) -> bool { + (addr.segments()[0] & 0xfe00) == 0xfc00 +} + +fn ipv6_is_unicast_link_local(addr: &Ipv6Addr) -> bool { + (addr.segments()[0] & 0xffc0) == 0xfe80 +} + +fn ipv4_is_shared(addr: &Ipv4Addr) -> bool { + addr.octets()[0] == 100 && (addr.octets()[1] & 0b1100_0000 == 0b0100_0000) +} + +fn ipv4_is_benchmarking(addr: &Ipv4Addr) -> bool { + addr.octets()[0] == 198 && (addr.octets()[1] & 0xfe) == 18 +} + +fn ipv4_is_reserved(addr: &Ipv4Addr) -> bool { + addr.octets()[0] & 240 == 240 && !addr.is_broadcast() +} diff --git a/talpid-routing/src/unix/macos.rs b/talpid-routing/src/unix/macos.rs index 893416a699..e6c72b6520 100644 --- a/talpid-routing/src/unix/macos.rs +++ b/talpid-routing/src/unix/macos.rs @@ -1,22 +1,51 @@ -use super::RouteManagerCommand; -use crate::{NetNode, Node, RequiredRoute, Route}; +use crate::{ + imp::{imp::watch::data::RouteSocketMessage, RouteManagerCommand}, + NetNode, Node, RequiredRoute, Route, +}; use futures::{ channel::mpsc, - future, + future::{self, FutureExt}, stream::{FusedStream, Stream, StreamExt, TryStreamExt}, }; -use ipnetwork::IpNetwork; +use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; +use nix::ifaddrs; use std::{ - collections::HashSet, + collections::{BTreeMap, BTreeSet, HashSet}, io, - net::IpAddr, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, process::{ExitStatus, Stdio}, + time::Duration, +}; +use system_configuration::{ + core_foundation::string::CFString, + network_configuration::{SCNetworkService, SCNetworkSet}, + preferences::SCPreferences, }; + +use talpid_time::Instant; use talpid_types::net::IpVersion; use tokio::{io::AsyncBufReadExt, process::Command}; use tokio_stream::wrappers::LinesStream; +use self::{ + interfaces::{RouteValidity, BestRoute}, + watch::{ + data::{self, AddressMessage, Destination, RouteDestination, RouteMessage}, + RoutingTable, + }, +}; + +mod ip6addr_ext; +use ip6addr_ext::IpAddrExt; + +use super::{TunnelRoutesV4, TunnelRoutesV6}; + +mod interfaces; +mod route_watch; +pub mod watch; +use interfaces::Interfaces; + pub type Result<T> = std::result::Result<T, Error>; /// Errors that can happen in the macOS routing integration. @@ -25,7 +54,11 @@ pub type Result<T> = std::result::Result<T, Error>; pub enum Error { /// Failed to add route. #[error(display = "Failed to add route")] - FailedToAddRoute(#[error(source)] io::Error), + FailedToAddRoute(#[error(source)] watch::Error), + + /// Failed to add route via 'route' subcommand. + #[error(display = "Failed to add route via subcommand")] + FailedToAddRouteExec(#[error(source)] io::Error), /// Failed to remove route. #[error(display = "Failed to remove route")] @@ -42,6 +75,103 @@ pub enum Error { /// Unexpected output from netstat #[error(display = "Unexpected output from netstat")] BadOutputFromNetstat, + + /// Failed to run scutil + #[error(display = "Failed to run scutil command")] + ScUtilCommand, + + /// Encountered unexpected output from scutil + #[error(display = "Unexpected scutil output")] + ScUtilUnexpectedOutput, + + /// Encountered an error when interacting with the routing socket + #[error(display = "Error occured when interfaceing with the routing table")] + RoutingTable(watch::Error), + + /// Unknown interface + #[error(display = "Unknown interface: {}", _0)] + UnkownInterface(String), + + /// Failed to remvoe route + #[error(display = "Error occured when deleting a route")] + DeleteRoute(watch::Error), + + /// Failed to change route + #[error(display = "Failed to change route")] + ChangeRoute(watch::Error), + + /// Failed to add route + #[error(display = "Error occured when adding a route")] + AddRoute(watch::Error), + + /// Failed to fetch link addresses + #[error(display = "Failed to fetch link addresses")] + FetchLinkAddresses(nix::Error), + + /// Received message isn't valid + #[error(display = "Invalid data")] + InvalidData(watch::data::Error), + + /// Failed to resolve tunnel interface name to an interface index + #[error(display = "Failed to find tunnel interface by name")] + NoTunnelInterface, + + /// Gateway route has no IP + #[error(display = "Gateway route has no gateway address")] + NoGatewayAddress, + + /// Invalid gateway route + #[error(display = "Received gateway route is invalid")] + InvalidGatewayRoute(watch::data::RouteMessage), + + /// Failed to obtain interface indices + #[error(display = "Failed to obtain list of interface names and indices")] + GetInterfaceNames(nix::Error), + + /// Failed to find interface name + #[error(display = "Failed to find name for interface")] + GetInterfaceName, + + /// Failed to create route destination from route message + #[error(display = "Failed to derive destination from route message")] + RouteDestination(watch::data::Error), +} + +pub async fn get_default_routes() -> std::result::Result<bool, watch::Error> { + let mut routing_table = RoutingTable::new()?; + Ok(routing_table + .get_route(v4_default()) + .await? + .or(routing_table.get_route(v6_default()).await?) + .is_some()) +} + +impl Error { + fn is_delete_err(&self) -> bool { + matches!(&self, Error::DeleteRoute(_)) + } + + fn is_add_err(&self) -> bool { + matches!(&self, Error::AddRoute(_)) + } +} + +#[derive(Clone, PartialEq)] +struct AppliedRoute { + destination: watch::data::RouteDestination, + route: RouteMessage, +} + +impl AppliedRoute { + fn uses(&self, route: &watch::data::RouteMessage) -> bool { + // // unimplemented!() + // self.route.ga == route.gateway_ip() + // && self + // .route.interface_index() + // .and_then(|iface| RouteManagerImpl::get_interface_index(iface)) + // == route.interface_index().unwrap_or(Some(0)) + false + } } /// Route manager can be in 1 of 4 states - @@ -56,247 +186,837 @@ pub enum Error { /// new changes, obtain new default routes and reapply routes that should be routed through the /// default nodes. Once the routes are reapplied, the route table changes are monitored again. pub struct RouteManagerImpl { - default_destinations: HashSet<IpNetwork>, - applied_routes: HashSet<Route>, - v4_gateway: Option<Node>, - v6_gateway: Option<Node>, - connectivity_change: - Option<Box<dyn FusedStream<Item = std::io::Result<()>> + Unpin + Send + Sync>>, + routing_table: RoutingTable, + v4_gateway: Option<watch::data::RouteMessage>, + v6_gateway: Option<watch::data::RouteMessage>, + interface_map: BTreeMap<u16, ifaddrs::InterfaceAddress>, + applied_interface: Option<AppliedInterface>, + changed_default_routes: BTreeSet<ifaddrs::InterfaceAddress>, + // required_routes: BTreeMap<RouteDestination, > + applied_routes: BTreeMap<RouteDestination, AppliedRoute>, + interfaces: Interfaces, + unsatisifed_routes: BTreeSet<IpNetwork>, + v4_default_route_check: Option<tokio::time::Sleep>, + v6_default_route_check: Option<tokio::time::Sleep>, +} + +struct GatewayInterface { + route_msg: watch::data::RouteMessage, + interface_address: ifaddrs::InterfaceAddress, +} + +struct AppliedInterface { + index: u16, + tunnel_routes_v4: TunnelRoutesV4, + tunnel_routes_v6: Option<TunnelRoutesV6>, + relay_address: IpAddr, +} + +struct PrimaryIfaceBackupV4 { + interface: String, + gateway: Ipv4Addr, + address: Ipv4Addr, +} + +struct PrimaryIfaceBackupV6 { + interface: String, + gateway: Ipv6Addr, + address: Ipv6Addr, +} + +struct BestV4Interface { + idx: Option<usize>, + gateway: Option<Ipv4Addr>, +} + +struct BestV6Interface { + idx: Option<usize>, + gateway: Option<Ipv6Addr>, } impl RouteManagerImpl { - pub async fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { - let v4_gateway = Self::get_default_node(IpVersion::V4).await?; - let v6_gateway = Self::get_default_node(IpVersion::V6).await?; + /// create new route manager + pub async fn new(_required_routes: HashSet<RequiredRoute>) -> Result<Self> { + let mut routing_table = RoutingTable::new().map_err(Error::RoutingTable)?; + + let v4_gateway = routing_table + .get_route(v4_default()) + .await + .map_err(Error::RoutingTable)?; - let monitor = listen_for_default_route_changes()?; + let v6_gateway = routing_table + .get_route(v6_default()) + .await + .map_err(Error::RoutingTable)?; - let mut manager = Self { - default_destinations: HashSet::new(), - applied_routes: HashSet::new(), - connectivity_change: Some(Box::new(monitor.fuse())), + let manager = Self { + unsatisifed_routes: BTreeSet::new(), + routing_table, + applied_routes: BTreeMap::new(), + applied_interface: None, + interface_map: Self::collect_interfaces()?, + changed_default_routes: BTreeSet::new(), v4_gateway, + v4_default_route_check: None, v6_gateway, + v6_default_route_check: None, + interfaces: Interfaces::new(), }; - manager.add_required_routes(required_routes).await?; - Ok(manager) } + fn collect_interfaces() -> Result<BTreeMap<u16, ifaddrs::InterfaceAddress>> { + Ok(nix::ifaddrs::getifaddrs() + .map_err(Error::FetchLinkAddresses)? + .filter(|iface| iface.interface_name != "lo0") + .filter_map(|iface: ifaddrs::InterfaceAddress| { + // forcing the interface index to be a 'usize' is an incredibly questionable + // design choice made in the `nix` + let ifindex = Self::get_interface_index(&iface)?; + Some((ifindex, iface)) + }) + .collect::<BTreeMap<_, _>>()) + } + pub(crate) async fn run(mut self, manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>) { let mut manage_rx = manage_rx.fuse(); - let mut connectivity_change = self.connectivity_change.take().unwrap(); loop { - futures::select! { + futures::select_biased! { + route_message = self.routing_table.next_message().fuse() => { + self.handle_route_mesage(route_message).await; + } + command = manage_rx.next() => { match command { + Some(RouteManagerCommand::SetupTunnelRoutes { + tunnel_interface, + relay_address, + tunnel_routes_v4, + tunnel_routes_v6, + response_tx + }) => { + let result = self.setup_tunnel_routes(tunnel_interface, + relay_address, + tunnel_routes_v4, + tunnel_routes_v6, + ).await; + if result.is_err() { + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to restore routes {err}"); + } + } + let _ = response_tx.send(result); + }, Some(RouteManagerCommand::Shutdown(tx)) => { - self.cleanup_routes().await; + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up routes: {err}"); + } let _ = tx.send(()); return; }, - Some(RouteManagerCommand::AddRoutes(routes, result_tx)) => { - let result = self.add_required_routes(routes).await; - let _ = result_tx.send(result); - }, + Some(RouteManagerCommand::AddRoutes(routes, tx)) => { + let _ = tx.send(Ok(())); + } Some(RouteManagerCommand::ClearRoutes) => { - self.cleanup_routes().await; + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up rotues: {err}"); + } + self.applied_interface = None; }, None => { break; } } }, + }; + } + + if let Err(err) = self.cleanup_routes().await { + log::error!("Failed to clean up routing table when shutitng down: {err}"); + } + } - _result = connectivity_change.select_next_some() => { - let v4_gateway = Self::get_default_node(IpVersion::V4).await.unwrap_or(None); - let v6_gateway = Self::get_default_node(IpVersion::V6).await.unwrap_or(None); + async fn handle_route_mesage( + &mut self, + message: std::result::Result<RouteSocketMessage, watch::Error>, + ) { + match message { + Ok(RouteSocketMessage::Interface(interface)) => { + // handle changes in interfaces, possibly recollect all interfaces + self.handle_interface_change(interface).await; + } - if v4_gateway != self.v4_gateway { - self.v4_gateway = v4_gateway; - self.apply_new_default_route(&self.v4_gateway, true).await; - } + Ok(RouteSocketMessage::AddAddress(address)) => { + self.handle_add_address(address).await; + } + Ok(RouteSocketMessage::DeleteAddress(address)) => { + self.handle_delete_address(address).await; + } + Ok(RouteSocketMessage::DeleteRoute(route)) => { + self.handle_deleted_route(route).await; + // handle deletion of a route - only interested default route removals + // or routes that were applied for our tunnel interface + } + + Ok(RouteSocketMessage::AddRoute(route)) => { + self.handle_added_route(route).await; + // handle new route - if it's a default route, current best default + // route should be updated. if it's a default route whilst engaged, + // remove it, route the tunne traffic through it, and apply + } + + Ok(RouteSocketMessage::ChangeRoute(route)) => { + self.handle_changed_route(route).await; + } + // ignore all other message types + Ok(_) => {} + Err(err) => { + log::error!("Failed to receive a message from the routing table: {err}"); + } + } + } + + async fn handle_deleted_route(&mut self, route: watch::data::RouteMessage) { + match RouteDestination::try_from(&route).map_err(Error::InvalidData) { + Ok(destination) => { + self.applied_routes.remove(&destination); + } + Err(err) => { + log::error!("Failed to process deleted route: {}", err); + } + } + } - if v6_gateway != self.v6_gateway { - self.v6_gateway = v6_gateway; - self.apply_new_default_route(&self.v6_gateway, false).await; + async fn handle_added_route(&mut self, route: watch::data::RouteMessage) { + if let Err(err) = self.handle_added_route_inner(route).await { + log::error!("Failed to process an added route: {}", err); + } + } + + async fn handle_added_route_inner(&mut self, route: watch::data::RouteMessage) -> Result<()> { + let updated_interface = self.interfaces.handle_add_route(&route)?; + if let Some(tunnel) = &self.applied_interface { + if updated_interface { + match self.try_update_best_interface(tunnel.relay_address).await { + Ok(true) => return Ok(()), + Ok(false) => { + // TODO: enter offline state + } + Err(_err) => { + // TODO: consider removing routes here } - }, - complete => { - break; } - }; + } + } + + Ok(()) + } + + fn get_interface_index(iface: &ifaddrs::InterfaceAddress) -> Option<u16> { + Some(iface.address?.as_link_addr()?.ifindex().try_into().unwrap()) + } + + async fn update_tracked_routes( + &mut self, + old_route: watch::data::RouteMessage, + new_route: &watch::data::RouteMessage, + ) -> Result<()> { + let old_interface = self.get_interface_for_route(&old_route); + let old_gateway = old_route.gateway_ip(); + let interface = self.get_interface_for_route(new_route); + let gateway = new_route.gateway_ip(); + + for (destination, applied_route) in &mut self.applied_routes { + if applied_route.uses(&old_route) {} } - self.cleanup_routes().await; + Ok(()) } - async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { - let mut routes_to_apply = vec![]; - let mut default_destinations = HashSet::new(); + async fn drain_unsatisifed_routes( + &mut self, + new_route: &watch::data::RouteMessage, + ) -> Result<()> { + let new_route_is_ipv4 = new_route + .destination_ip() + .map_err(Error::InvalidData)? + .is_ipv4(); - for route in required_routes { - match route.node { - NetNode::DefaultNode => { - default_destinations.insert(route.prefix); - } + let gateway = new_route.gateway_ip(); + let interface = self.get_interface_for_route(&new_route); + + let satisfieable_destinations = self + .unsatisifed_routes + .iter() + .filter(|destination| new_route_is_ipv4 == destination.is_ipv4()) + .cloned() + .collect::<Vec<_>>(); - NetNode::RealNode(node) => routes_to_apply.push(Route::new(node, route.prefix)), + for destination in satisfieable_destinations { + let mut route = RouteMessage::new_route(destination.into()); + if let Some(gateway) = gateway { + route = route.set_gateway_addr(gateway); } + + if let Some(interface) = &interface { + route = route.set_interface_addr(interface); + } + + self.add_route_with_record(destination, route).await?; + self.unsatisifed_routes.remove(&destination); } - for route in routes_to_apply { - Self::add_route(&route).await?; - self.applied_routes.insert(route); + Ok(()) + } + + async fn handle_changed_route(&mut self, route: watch::data::RouteMessage) { + if let Err(err) = self.handle_changed_route_inner(route).await { + log::error!("Failed to process route change: {err}"); } + } - for destination in default_destinations.iter() { - match (&self.v4_gateway, &self.v6_gateway, destination.is_ipv4()) { - (Some(gateway), _, true) | (_, Some(gateway), false) => { - let route = Route::new(gateway.clone(), *destination); - Self::add_route(&route).await?; - self.applied_routes.insert(route); - } - _ => (), - }; + async fn handle_changed_route_inner(&mut self, route: watch::data::RouteMessage) -> Result<()> { + if self.interfaces.handle_changed_route(&route)? { + self.refresh_routes().await?; } - self.default_destinations = default_destinations; + Ok(()) + } + /// Used to refresh routes when routes should be tracked. + async fn refresh_routes(&mut self) -> Result<()> { + if let Some(applied_routes) = &self.applied_interface {} Ok(()) } - // Retrieves the node that's currently used to reach 0.0.0.0/0 - pub(crate) async fn get_default_node(ip_version: IpVersion) -> Result<Option<Node>> { - let ip_version_arg = match ip_version { - IpVersion::V4 => "-inet", - IpVersion::V6 => "-inet6", + async fn add_route_through_default_interface(&mut self, route: RequiredRoute) -> Result<()> { + let gateway = match (route.prefix, &self.v4_gateway, &self.v6_gateway) { + (IpNetwork::V4(_), Some(gateway_route), _) + | (IpNetwork::V6(_), _, Some(gateway_route)) => gateway_route.gateway(), + _ => { + log::error!("UNSATISFIABLE ROUTE"); + self.unsatisifed_routes.insert(route.prefix); + return Ok(()); + } }; - let mut cmd = Command::new("route"); - cmd.arg("-n").arg("get").arg(ip_version_arg).arg("default"); - let output = cmd.output().await.map_err(Error::FailedToRunRoute)?; - let output = String::from_utf8(output.stdout).map_err(|e| { - log::error!("Failed to parse utf-8 bytes from output of netstat: {}", e); - Error::BadOutputFromNetstat - })?; - Ok(Self::parse_route(&output)) + match gateway { + Some(gateway_addr) => { + let new_route = RouteMessage::new_route(route.prefix.into()) + .set_gateway_sockaddr(gateway_addr.clone()); + + self.add_route_with_record(route.prefix, new_route).await + } + None => { + log::debug!("Gateway route has no gateway IP address"); + self.unsatisifed_routes.insert(route.prefix); + Ok(()) + } + } + } + + async fn add_route_with_record( + &mut self, + destination: IpNetwork, + route: RouteMessage, + ) -> Result<()> { + let _ = self + .routing_table + .add_route(&route) + .await + .map_err(Error::RoutingTable)?; + + let destination = RouteDestination::try_from(&route).map_err(Error::InvalidData)?; + + self.applied_routes + .insert(destination.clone(), AppliedRoute { destination, route }); + Ok(()) + } + + async fn add_faux_default_routes_v4( + &mut self, + tunnel_routes: super::TunnelRoutesV4, + ) -> Result<()> { + for half in v4_faux_destinations() { + let route = RouteMessage::new_route(half.into()) + .set_gateway_addr(tunnel_routes.tunnel_gateway.into()); + self.add_route_with_record(half, route).await?; + } + + Ok(()) } - fn parse_route(route_output: &str) -> Option<Node> { - let mut address: Option<IpAddr> = None; - let mut device = None; - for line in route_output.lines() { - // we're looking for just 2 different lines: - // interface: utun0 - // gateway: 192.168.3.1 - let tokens: Vec<_> = line.split_whitespace().collect(); - if tokens.len() == 2 { - match tokens[0].trim() { - "interface:" => { - device = Some(tokens[1].to_string()); + async fn setup_v4_default_route( + &mut self, + v4_routes: &super::TunnelRoutesV4, + _interface: &ifaddrs::InterfaceAddress, + ) -> Result<()> { + if let Some(v4_route) = self.v4_gateway.clone() { + if !v4_route.is_ifscope() { + if let Err(route_err) = self.ifscope_route(&v4_route).await { + if route_err.is_add_err() { + if let Err(err) = self.restore_default_v4().await { + log::error!("Failed to restore v4 routes {err}"); + } } - "gateway:" => { - address = Self::parse_gateway_line(tokens[1]); + return Err(route_err); + } + } + } + + let default_route = RouteMessage::new_route(Destination::default_v4()) + .set_gateway_addr(v4_routes.tunnel_gateway.into()); + + self.routing_table + .add_route(&default_route) + .await + .map_err(Error::AddRoute)?; + + Ok(()) + } + + async fn setup_v6_default_route( + &mut self, + _tunnel_interface: &ifaddrs::InterfaceAddress, + gateway: Ipv6Addr, + ) -> Result<()> { + if let Some(v6_route) = self.v6_gateway.clone() { + if let Err(route_err) = self.ifscope_route(&v6_route).await { + if route_err.is_add_err() { + if let Err(err) = self.restore_default_v6().await { + log::error!("Failed to restore v6 routes {err}"); } - _ => continue, } + return Err(route_err); } } - match (address, device) { - (Some(address), Some(device)) => Some(Node::new(address, device)), - (Some(address), None) => Some(Node::address(address)), - (None, Some(device)) => Some(Node::device(device)), - _ => None, + let default_route = + RouteMessage::new_route(Destination::default_v4()).set_gateway_addr(gateway.into()); + + let _ = self + .routing_table + .add_route(&default_route) + .await + .map_err(Error::AddRoute)?; + Ok(()) + } + + async fn add_faux_default_routes_v6( + &mut self, + tunnel_routes: super::TunnelRoutesV6, + ) -> Result<()> { + for half in v6_faux_destinations() { + let route = RouteMessage::new_route(half.into()) + .set_gateway_addr(tunnel_routes.tunnel_gateway.into()); + self.add_route_with_record(half, route).await? } + + Ok(()) } - fn parse_gateway_line(line: &str) -> Option<IpAddr> { - // IPv6 addresses may contain interfaces - // if line contains '%' it should be split off - line.split('%') - .next() - .and_then(|ip_str| ip_str.parse().ok()) + // async fn setup_v4_default_route( + // &mut self, + // tunnel_interface: &InterfaceAddress, + // gateway: Ipv4Addr, + // ) -> Result<()> { + // let v4_default_destination: IpNetwork = + // IpNetwork::V4(Ipv4Network::new(Ipv4Addr::UNSPECIFIED, 0).unwrap()); + // if let Some(v4_gateway) = &self.v4_gateway { + // let real_interface = self.get_interface_for_route(&v4_gateway); + // let gateway_addr = v4_gateway.gateway_v4(); + // let _ = self + // .routing_table + // .delete_route( + // v4_default_destination, + // real_interface.as_ref(), + // v4_gateway.is_ifscoped().map_err(Error::InvalidData)?, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // let _ = self + // .routing_table + // .add_route( + // v4_default_destination, + // gateway_addr.map(Into::into), + // real_interface.as_ref(), + // true, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // } + + // let _ = self + // .routing_table + // .add_route( + // v4_default_destination, + // Some(gateway.into()), + // Some(tunnel_interface), + // false, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // Ok(()) + // } + + // async fn setup_v6_default_route( + // &mut self, + // tunnel_interface: &InterfaceAddress, + // gateway: Ipv6Addr, + // // ) -> Result<()> { + // let v6_default_destination: IpNetwork = + // IpNetwork::V6(Ipv6Network::new(Ipv6Addr::UNSPECIFIED, 0).unwrap()); + // if let Some(v6_gateway) = &self.v6_gateway { + // let real_interface = self.get_interface_for_route(&v6_gateway); + // let gateway_addr = v6_gateway.gateway_v4(); + // let _ = self + // .routing_table + // .delete_route( + // v6_default_destination, + // real_interface.as_ref(), + // v6_gateway.is_ifscoped().map_err(Error::InvalidData)?, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // let _ = self + // .routing_table + // .add_route( + // v6_default_destination, + // gateway_addr.map(Into::into), + // real_interface.as_ref(), + // true, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // } + + // let _ = self + // .routing_table + // .add_route( + // v6_default_destination, + // Some(gateway.into()), + // Some(tunnel_interface), + // false, + // ) + // .await + // .map_err(Error::RoutingTable)?; + // Ok(()) + // } + + async fn update_v6_relay_route( + &mut self, + new_default_route: watch::data::RouteMessage, + ) -> Result<()> { + if let Some(relay_addr) = self.get_v6_relay_addr() { + let gateway = new_default_route.gateway_v6(); + let interface_addrs = self.get_interface_for_route(&new_default_route); + // self.routing_table + // .change_route( + // IpAddr::from(relay_addr).into(), + // gateway.map(Into::into), + // interface_addrs.as_ref(), + // new_default_route + // .is_ifscoped() + // .map_err(Error::InvalidData)?, + // ) + // .await + // .map_err(Error::RoutingTable)?; + } + Ok(()) } - async fn delete_route(destination: IpNetwork) -> Result<ExitStatus> { - let mut cmd = Command::new("route"); - cmd.arg("-q") - .arg("-n") - .arg("delete") - .arg(ip_vers(destination)) - .arg(destination.to_string()) - .stderr(Stdio::null()); + async fn update_v4_relay_route( + &mut self, + new_default_route: watch::data::RouteMessage, + ) -> Result<()> { + if let Some(relay_addr) = self.get_v4_relay_addr() { + // let gateway = new_default_route.gateway_v4(); + // let interface_addrs = self.get_interface_for_route(&new_default_route); + // self.routing_table + // .change_route( + // IpAddr::from(relay_addr).into(), + // gateway.map(Into::into), + // interface_addrs.as_ref(), + // new_default_route + // .is_ifscoped() + // .map_err(Error::InvalidData)?, + // ) + // .await + // .map_err(Error::RoutingTable)?; + } + Ok(()) + } - cmd.status().await.map_err(Error::FailedToRemoveRoute) + fn get_v4_relay_addr(&self) -> Option<Ipv4Addr> { + if let Some(tunnel) = &self.applied_interface { + if let IpAddr::V4(addr) = tunnel.relay_address { + return Some(addr); + } + } + None } - async fn add_route(route: &Route) -> Result<ExitStatus> { - let mut cmd = Command::new("route"); - cmd.arg("-q") - .arg("-n") - .arg("add") - .arg(ip_vers(route.prefix)) - .arg(route.prefix.to_string()); + fn get_v6_relay_addr(&self) -> Option<Ipv6Addr> { + if let Some(tunnel) = &self.applied_interface { + if let IpAddr::V6(addr) = tunnel.relay_address { + return Some(addr); + } + } + None + } - if let Some(addr) = route.node.get_address() { - cmd.arg("-gateway").arg(addr.to_string()); - } else if let Some(device) = route.node.get_device() { - cmd.arg("-interface").arg(device); + fn route_change_relevant(&self, route: &watch::data::RouteMessage) -> Result<bool> { + // If route is non-default, it should be disregarded. + if !route.is_default().map_err(Error::InvalidData)? { + return Ok(false); + } + // If the default route is changed on our interface, it doesn't matter - if it was removed, + // the correct route will be re-applied when + // TODO: consider adding a timer to check if a default route was added later + if Some(route.interface_index()) == self.applied_interface.as_ref().map(|iface| iface.index) + { + return Ok(false); } - cmd.status().await.map_err(Error::FailedToAddRoute) + Ok(true) } - async fn cleanup_routes(&self) { - let destinations_to_remove = self - .applied_routes - .iter() - .map(|route| &route.prefix) - .chain(self.default_destinations.iter()); + async fn handle_interface_change(&mut self, interface: data::Interface) { + let interfaces_changed = match self.interfaces.handle_iface_msg(interface) { + Ok(interface_changed) => interface_changed, + Err(err) => { + log::error!("Failed to handle interface change: {err:?}"); + return; + } + }; - for destination in destinations_to_remove { - match Self::delete_route(*destination).await { - Ok(status) => { - if !status.success() { - log::debug!("Failed to remove route during shutdown"); - } - } - Err(e) => log::error!("Failed to remove route during shutdown: {}", e), - }; + if interfaces_changed {} + + // TODO: recalculate default route here, if necessary + } + + async fn restore_default_v4(&mut self) -> Result<()> { + if self.applied_interface.is_some() { + if let Some(route) = self.v4_gateway.clone() { + self.restore_gateway_routes(&route).await?; + } } + + Ok(()) } - async fn apply_new_default_route(&self, new_node: &Option<Node>, v4: bool) { - for destination in self.default_destinations.iter() { - if destination.is_ipv4() == v4 { - let _ = Self::delete_route(*destination).await; + async fn restore_gateway_routes(&mut self, gateway_route: &RouteMessage) -> Result<()> { + if !gateway_route.is_ifscope() { + let ifscoped_route = gateway_route + .clone() + .set_ifscope(gateway_route.interface_index()); + if let Err(err) = self + .routing_table + .delete_route(&ifscoped_route) + .await + .map_err(Error::DeleteRoute) + { + log::error!("Failed to remove ifscoped route: {err}"); + } - if let Some(node) = new_node { - log::error!("Resetting default route for {}", destination); - match Self::add_route(&Route::new(node.clone(), *destination)).await { - Ok(status) => { - if !status.success() { - log::error!("Failed to reapply route"); - } - } - Err(e) => log::error!("Failed to reset route: {}", e), - } + let old_route = gateway_route.clone().set_ifscope(0); + self.routing_table + .add_route(&old_route) + .await + .map_err(Error::AddRoute)?; + } + Ok(()) + } + + async fn restore_default_v6(&mut self) -> Result<()> { + if self.applied_interface.is_some() { + if let Some(route) = &self.v6_gateway.clone() { + self.restore_gateway_routes(&route).await?; + } + } + Ok(()) + } + + /// Setup routes specifically for a tunnel + async fn setup_tunnel_routes( + &mut self, + tunnel_interface: String, + relay_address: IpAddr, + tunnel_routes_v4: super::TunnelRoutesV4, + tunnel_routes_v6: Option<super::TunnelRoutesV6>, + ) -> Result<()> { + let (index, interface) = self + .resolve_interface_name(&tunnel_interface) + .ok_or(Error::NoTunnelInterface)?; + self.setup_v4_default_route(&tunnel_routes_v4, &interface) + .await?; + self.add_faux_default_routes_v4(tunnel_routes_v4).await?; + + if let Some(v6) = tunnel_routes_v6 { + self.setup_v6_default_route(&interface, v6.tunnel_gateway) + .await?; + self.add_faux_default_routes_v6(v6).await?; + } + + self.applied_interface = Some(AppliedInterface { + index, + relay_address, + tunnel_routes_v4, + tunnel_routes_v6, + }); + + Ok(()) + } + + /// Removes a route and adds the same route, but ifscoped. Maybe this can be done by just + /// changing the route - haven't tested but I don't believe so, since the ifscope flag is used + /// to identify a route. + async fn ifscope_route(&mut self, original_route: &watch::data::RouteMessage) -> Result<()> { + let interface_index = original_route.interface_index(); + log::error!("iface index {interface_index} original route {original_route:?}"); + let ifscoped_route = original_route.clone().set_ifscope(interface_index); + + self.routing_table + .delete_route(original_route) + .await + .map_err(Error::DeleteRoute)?; + + self.routing_table + .add_route(&ifscoped_route) + .await + .map_err(Error::AddRoute)?; + + Ok(()) + } + + fn get_interface_for_route( + &self, + route: &watch::data::RouteMessage, + ) -> Option<ifaddrs::InterfaceAddress> { + let idx = route.interface_index(); + self.interface_map.get(&idx).cloned() + } + + fn resolve_interface_name(&self, name: &str) -> Option<(u16, ifaddrs::InterfaceAddress)> { + self.interface_map + .iter() + .find(|(_idx, interface)| interface.interface_name == name) + .map(|(idx, interface)| (*idx, interface.clone())) + } + + async fn cleanup_routes(&mut self) -> Result<()> { + self.cleanup_relay_routes().await; + log::error!("CLEANED UP RELAY"); + let v4_default = self.restore_default_v4().await; + log::error!("CLEANED UP v4"); + let v6_default = self.restore_default_v6().await; + log::error!("CLEANED UP v6"); + v4_default.and(v6_default) + } + + async fn cleanup_relay_routes(&mut self) { + let old_routes = std::mem::replace(&mut self.applied_routes, BTreeMap::new()); + let mut routes_to_delete = old_routes + .into_iter() + .map(|(_, route)| route.route) + .collect::<Vec<_>>(); + + if let Some(iface) = &self.applied_interface { + for v4_dest in v4_faux_destinations().chain(std::iter::once(v4_default())) { + let route = RouteMessage::new_route(v4_dest.into()) + .set_gateway_addr(iface.tunnel_routes_v4.tunnel_gateway.into()); + routes_to_delete.push(route); + } + } + + for route in routes_to_delete { + match self.routing_table.delete_route(&route).await { + Ok(_) | Err(watch::Error::RouteNotFound) | Err(watch::Error::Unreachable) => (), + Err(err) => { + log::error!("Failed to remove relay route: {err:?}"); } } } } -} -fn ip_vers(prefix: IpNetwork) -> &'static str { - if prefix.is_ipv4() { - "-inet" - } else { - "-inet6" + async fn handle_add_address(&mut self, address: AddressMessage) { + if self.interfaces.handle_add_address(address) { + // TODO: recalculate best interface if need be + } + } + + async fn handle_delete_address(&mut self, address: AddressMessage) { + if self.interfaces.handle_delete_address(address) { + // TODO: recalculate best interface if need be + } + } + + /// Change routes for tunnel traffic, returns true if V4 + async fn try_update_best_interface(&self, relay_address: IpAddr) -> Result<bool> { + match relay_address { + IpAddr::V4(addr) => {} + IpAddr::V6(addr) => {} + }; + + // Ok(v4_result? || v6_result?) + Ok(true) + } + + async fn try_update_best_interface_v4(&self, addr: Ipv4Addr) -> Result<bool> { + if let Some(interface) = self.interfaces.get_best_default_interface_v4()? {} + + Ok(false) + } + + async fn try_update_best_interface_v6(&self, addr: Ipv4Addr) -> Result<bool> { + if let Some(interface) = self.interfaces.get_best_default_interface_v6()? {} + + Ok(false) + } + + async fn add_route_to_tunnel( + &self, + addr: IpAddr, + interface: &BestRoute, + ) -> Result<()> { + let mut route = RouteMessage::new_route(addr.into()); + + Ok(()) + } + + async fn change_route_to_tunnel( + &self, + addr: &IpAddr, + interface: &interfaces::Interface, + ) -> Result<()> { + Ok(()) } + + async fn delete_route_to_tunnel(&self, addr: IpAddr) -> Result<()> { + Ok(()) + } +} + +fn v4_faux_destinations() -> impl Iterator<Item = IpNetwork> { + let half_of_internet: IpNetwork = "0.0.0.0/1".parse().unwrap(); + let other_half_of_internet: IpNetwork = "128.0.0.0/1".parse().unwrap(); + [half_of_internet, other_half_of_internet].into_iter() +} + +fn v6_faux_destinations() -> impl Iterator<Item = IpNetwork> { + let half_of_internet: IpNetwork = "::/1".parse().unwrap(); + let other_half_of_internet: IpNetwork = "128::/1".parse().unwrap(); + [half_of_internet, other_half_of_internet].into_iter() +} + +fn v4_default() -> IpNetwork { + IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap() +} + +fn v6_default() -> IpNetwork { + IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap() } /// Returns a tuple containing a IPv4 and IPv6 default route nodes. diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 822615d0f1..c6c1c49ada 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -9,7 +9,13 @@ use futures::channel::{ mpsc::{self, UnboundedSender}, oneshot, }; -use std::{collections::HashSet, io}; +use std::{ + collections::HashSet, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; +#[cfg(target_os = "macos")] +use talpid_types::net::IpVersion; #[cfg(target_os = "linux")] use futures::stream::Stream; @@ -20,8 +26,7 @@ use std::net::IpAddr; #[allow(clippy::module_inception)] #[cfg(target_os = "macos")] #[path = "macos.rs"] -mod imp; - +pub mod imp; #[cfg(target_os = "macos")] pub use imp::{get_default_routes, listen_for_default_route_changes}; @@ -55,6 +60,10 @@ pub enum Error { /// Attempt to use route manager that has been dropped #[error(display = "Cannot send message to route manager since it is down")] RouteManagerDown, + /// Failed to obtain a default route + // TODO: elaborate on this variant, possibly add more data + #[error(display = "Failed to obtain default routes")] + DefaultRoute, } /// Handle to a route manager. @@ -76,6 +85,31 @@ impl RouteManagerHandle { .map_err(Error::PlatformError) } + /// Setup tunnel routes + #[cfg(target_os = "macos")] + pub async fn setup_tunnel_routes( + &self, + tunnel_interface: String, + relay_address: IpAddr, + tunnel_routes_v4: TunnelRoutesV4, + tunnel_routes_v6: Option<TunnelRoutesV6>, + ) -> Result<(), Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::SetupTunnelRoutes { + tunnel_interface, + relay_address, + tunnel_routes_v4, + tunnel_routes_v6, + response_tx, + }) + .map_err(|_| Error::ManagerChannelDown)?; + response_rx + .await + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) + } + /// Ensure that packets are routed using the correct tables. #[cfg(target_os = "linux")] pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { @@ -154,13 +188,42 @@ impl RouteManagerHandle { #[cfg(target_os = "linux")] type Fwmark = u32; +/// IPv6 addresess for tunnel interface +#[cfg(target_os = "macos")] +#[derive(Clone, Copy, Debug)] +pub struct TunnelRoutesV4 { + /// IPv4 gateway of the tunnel + pub tunnel_gateway: Ipv4Addr, + /// IPv4 interface address + pub tunnel_ip: Ipv4Addr, +} + +/// IPv6 addresess for tunnel interface +#[cfg(target_os = "macos")] +#[derive(Copy, Clone, Debug)] +pub struct TunnelRoutesV6 { + /// IPv6 gateway of the tunnel + pub tunnel_gateway: Ipv6Addr, + /// IPv6 interface address + pub tunnel_ip: Ipv6Addr, +} + /// Commands for the underlying route manager object. #[derive(Debug)] pub(crate) enum RouteManagerCommand { + #[cfg(target_os = "macos")] + SetupTunnelRoutes { + tunnel_interface: String, + relay_address: IpAddr, + tunnel_routes_v4: TunnelRoutesV4, + tunnel_routes_v6: Option<TunnelRoutesV6>, + response_tx: oneshot::Sender<Result<(), PlatformError>>, + }, AddRoutes( HashSet<RequiredRoute>, oneshot::Sender<Result<(), PlatformError>>, ), + ClearRoutes, Shutdown(oneshot::Sender<()>), #[cfg(target_os = "linux")] @@ -300,3 +363,20 @@ impl Drop for RouteManager { self.runtime.clone().block_on(self.stop()); } } + +/// Returns a tuple containing a IPv4 and IPv6 default route nodes. +#[cfg(target_os = "macos")] +pub async fn get_default_routes( + _excluded_interface: Option<String>, +) -> Result<(Option<super::Node>, Option<super::Node>), Error> { + // TODO: Fix this + Ok((None, None)) + // return get_default_routes().await; + // use futures::TryFutureExt; + // futures::try_join!( + // imp::RouteManagerImpl::get_default_node(IpVersion::V4, excluded_interface.clone()) + // .map_err(Into::into), + // imp::RouteManagerImpl::get_default_node(IpVersion::V6, excluded_interface) + // .map_err(Into::into) + //) +} diff --git a/talpid-routing/src/unix/route_watch.rs b/talpid-routing/src/unix/route_watch.rs new file mode 100644 index 0000000000..5956a19784 --- /dev/null +++ b/talpid-routing/src/unix/route_watch.rs @@ -0,0 +1,87 @@ +use std::{ffi::OsString, io, net::IpAddr}; +use tokio::{ + io::{AsyncBufReadExt, BufReader, Lines}, + process::{ChildStdout, Command}, +}; +use tokio_stream::Stream; + +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// Failed to spawn route. + #[error(display = "Failed to spawn route command")] + Spawn(#[error(source)] io::Error), + + /// Route subprocess has no stdout pipe. + #[error(display = "`route monitor -n` subprocess has no stdout pipe")] + NoStdout, + + /// Unexpected output from `route`. + #[error(display = "Encountered unexpected output from route: _0")] + UnexpectedOutput(OsString), + + /// `route monitor` subcommand exited unexpectedly + #[error(display = "route subcommand exited unexpectedly")] + UnexpectedShutdown, +} + +#[derive(PartialEq, Debug)] +pub enum RouteChange { + Add(Route), + Remove(Route), +} + +#[derive(PartialEq, Debug)] +pub struct Route { + netmask: IpAddr, + destination: IpAddr, + iflocal: bool, + ifaddr: Option<IpAddr>, + interface: Option<String>, +} + +pub struct RouteWatcher { + route_output: Lines<BufReader<ChildStdout>>, +} + +impl RouteWatcher { + pub async fn new() -> Result<Self, Error> { + let child = Command::new("/sbin/route").spawn().map_err(Error::Spawn)?; + let route_output = BufReader::new(child.stdout.ok_or(Error::NoStdout)?).lines(); + + Ok(Self { route_output }) + } + + pub async fn next(&mut self) -> Result<RouteChange, Error> { + let mut buffer = vec![]; + + while let Ok(Some(line)) = self.route_output.next_line().await { + let line_empty = line.is_empty(); + if line_empty { + if buffer.is_empty() { + return Self::parse_route(&buffer); + } + } else { + buffer.push(line); + } + } + + Err(Error::UnexpectedShutdown) + } + + fn parse_route(buffer: &[String]) -> Result<RouteChange, Error> { + unimplemented!() + } +} + +// Gateway strings can be "255", "(255)" "fff fff fff" +fn parse_netmask_v4(input: &str) -> Result<IpAddr, Error> { + unimplemented!() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_route_delete_message() {} +} diff --git a/talpid-routing/src/unix/watch/data.rs b/talpid-routing/src/unix/watch/data.rs new file mode 100644 index 0000000000..b283d48647 --- /dev/null +++ b/talpid-routing/src/unix/watch/data.rs @@ -0,0 +1,1285 @@ +use std::{ + collections::{BTreeMap, BTreeSet, HashSet}, + ffi::{OsStr, OsString}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + os::unix::prelude::OsStringExt, +}; + +use ipnetwork::IpNetwork; +use nix::{ + ifaddrs::InterfaceAddress, + net::if_::if_nametoindex, + sys::socket::{SockAddr, SockaddrIn, SockaddrIn6, SockaddrLike, SockaddrStorage}, +}; + +/// Message that describes a route - either an added, removed, changed or plainly retrieved route. +#[derive(Debug, Clone, PartialEq)] +pub struct RouteMessage { + sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>, + route_flags: RouteFlag, + interface_index: u16, + errno: i32, +} + +impl RouteMessage { + pub fn new_route(destination: Destination) -> Self { + let mut route_flags = RouteFlag::RTF_STATIC | RouteFlag::RTF_DONE | RouteFlag::RTF_UP; + let mut sockaddrs = BTreeMap::new(); + match destination { + Destination::Network(net) => { + let destination = + RouteSocketAddress::Destination(Some(SocketAddr::from((net.ip(), 0)).into())); + let netmask = + RouteSocketAddress::Netmask(Some(SocketAddr::from((net.mask(), 0)).into())); + sockaddrs.insert(destination.address_flag(), destination); + sockaddrs.insert(netmask.address_flag(), netmask); + } + Destination::Host(addr) => { + let destination = + RouteSocketAddress::Destination(Some(SocketAddr::from((addr, 0)).into())); + route_flags |= RouteFlag::RTF_HOST; + sockaddrs.insert(destination.address_flag(), destination); + } + }; + + Self { + sockaddrs, + route_flags, + interface_index: 0, + errno: 0, + } + } + + pub fn route_addrs(&self) -> impl Iterator<Item = &RouteSocketAddress> { + self.sockaddrs.values() + } + + fn socketaddress_to_ip(sockaddr: &SockaddrStorage) -> Option<IpAddr> { + saddr_to_ipv4(sockaddr) + .map(Into::into) + .or_else(|| saddr_to_ipv6(sockaddr).map(Into::into)) + } + + pub fn netmask(&self) -> Option<IpAddr> { + self.route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Netmask(netmask) => Some(netmask), + _ => None, + })? + .as_ref() + .and_then(Self::socketaddress_to_ip) + } + + pub fn is_default(&self) -> Result<bool> { + Ok(self.is_default_v4()? || self.is_default_v6()?) + } + + pub fn is_default_v4(&self) -> Result<bool> { + let destination_is_default = self + .destination_v4()? + .map(|addr| addr == Ipv4Addr::UNSPECIFIED) + .unwrap_or(false); + let netmask = self.route_addrs().find_map(|saddr| match saddr { + RouteSocketAddress::Netmask(addr) => Some(addr), + _ => None, + }); + + let netmask_is_default = match netmask { + // empty socket address implies that it is a 'default' netmask + Some(None) => true, + Some(Some(addr)) => { + if let Some(netmask_addr) = addr.as_sockaddr_in() { + let std_addr = SocketAddrV4::from(netmask_addr.clone()); + *std_addr.ip() == Ipv4Addr::UNSPECIFIED + } else if let Some(netmask_addr) = addr.as_sockaddr_in6() { + let std_addr = SocketAddrV6::from(netmask_addr.clone()); + *std_addr.ip() == Ipv6Addr::UNSPECIFIED + } else { + // if the route socket address describing the netmask isn't a sockaddr_in or a + // sockaddr_in6, it can't possibly be a default route for IP + false + } + } + // absence of a netmask socket address implies that it is a host route + None => false, + }; + + Ok(destination_is_default && netmask_is_default) + } + + pub fn is_default_v6(&self) -> Result<bool> { + Ok(self + .destination_v6()? + .map(|addr| addr == Ipv6Addr::UNSPECIFIED) + .unwrap_or(false)) + } + + pub fn print_route(&self) { + println!( + "route is default - {:?} - interface index: {} - is iscoped - {}", + self.is_default(), + self.interface_index, + self.is_ifscope() + ); + for sa in &self.sockaddrs { + println!("\t{:?}", &sa); + } + } + + fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + let header: rt_msghdr = rt_msghdr::from_bytes(buffer)?; + + let msg_len = usize::from(header.rtm_msglen); + if msg_len > buffer.len() { + return Err(Error::BufferTooSmall( + "Message is shorter than it's msg_len indicates", + msg_len, + buffer.len(), + )); + } + + let payload = &buffer[ROUTE_MESSAGE_HEADER_SIZE..std::cmp::min(msg_len, buffer.len())]; + + let route_flags = RouteFlag::from_bits(header.rtm_flags) + .ok_or(Error::UnknownRouteFlag(header.rtm_flags))?; + let address_flags = AddressFlag::from_bits(header.rtm_addrs) + .ok_or(Error::UnknownAddressFlag(header.rtm_addrs))?; + if !address_flags.contains(AddressFlag::RTA_DST) { + return Err(Error::NoDestination); + } + let sockaddrs = RouteSockAddrIterator::new(payload, address_flags) + .map(|addr| addr.map(|a| (a.address_flag(), a))) + .collect::<Result<BTreeMap<_, _>>>()?; + let interface_index = header.rtm_index; + + Ok(Self { + route_flags, + sockaddrs, + interface_index, + errno: header.rtm_errno, + }) + } + + fn insert_sockaddr(&mut self, saddr: RouteSocketAddress) { + self.sockaddrs.insert(saddr.address_flag(), saddr); + } + + pub fn set_destination(mut self, destination: Destination) -> Self { + match destination { + Destination::Network(net) => { + let sockaddr: SocketAddr = (net.ip(), 0).into(); + let netmask: SocketAddr = (net.mask(), 0).into(); + self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into()))); + self.insert_sockaddr(RouteSocketAddress::Netmask(Some(netmask.into()))); + self.route_flags.remove(RouteFlag::RTF_HOST); + } + Destination::Host(addr) => { + self.route_flags.insert(RouteFlag::RTF_HOST); + let sockaddr: SocketAddr = (addr, 0).into(); + self.insert_sockaddr(RouteSocketAddress::Destination(Some(sockaddr.into()))); + } + }; + + self + } + + pub fn set_interface_addr(mut self, link: &InterfaceAddress) -> Self { + self.insert_sockaddr(RouteSocketAddress::Gateway(link.address.clone())); + self.route_flags |= RouteFlag::RTF_GATEWAY; + self + } + + pub fn set_gateway_sockaddr(mut self, sockaddr: SockaddrStorage) -> Self { + self.insert_sockaddr(RouteSocketAddress::Gateway(Some(sockaddr))); + self.route_flags |= RouteFlag::RTF_GATEWAY; + self + } + + pub fn set_gateway_addr(mut self, addr: IpAddr) -> Self { + let gateway: SocketAddr = (addr, 0).into(); + self.insert_sockaddr(RouteSocketAddress::Gateway(Some(gateway.into()))); + self.route_flags |= RouteFlag::RTF_GATEWAY; + + self + } + + pub fn set_gateway_route(mut self, is_gateway_route: bool) -> Self { + if is_gateway_route { + self.route_flags.insert(RouteFlag::RTF_GATEWAY); + } else { + self.route_flags.remove(RouteFlag::RTF_GATEWAY); + } + self + } + + pub fn route_flag(mut self, route_flags: RouteFlag) -> Self { + self.route_flags = route_flags; + self + } + + pub fn gateway(&self) -> Option<&SockaddrStorage> { + self.route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Gateway(gateway) => Some(gateway), + _ => None, + })? + .as_ref() + } + + pub fn gateway_ip(&self) -> Option<IpAddr> { + self.gateway_v4() + .map(IpAddr::V4) + .or(self.gateway_v6().map(IpAddr::V6)) + } + + pub fn gateway_v4(&self) -> Option<Ipv4Addr> { + saddr_to_ipv4(self.gateway()?) + } + + pub fn gateway_v6(&self) -> Option<Ipv6Addr> { + saddr_to_ipv6(self.gateway()?) + } + + pub fn destination_ip(&self) -> Result<IpNetwork> { + if let Some(saddr) = self.destination()? { + if let Some(v4) = saddr.as_sockaddr_in() { + let ip_addr = *SocketAddrV4::from(*v4).ip(); + let netmask = self.netmask().unwrap_or(Ipv4Addr::UNSPECIFIED.into()); + let destination = IpNetwork::with_netmask(ip_addr.into(), netmask) + .map_err(Error::InvalidNetmask)?; + return Ok(destination.into()); + } + + if let Some(v6) = saddr.as_sockaddr_in6() { + let ip_addr = *SocketAddrV6::from(*v6).ip(); + let netmask = self.netmask().unwrap_or(Ipv6Addr::UNSPECIFIED.into()); + let destination = IpNetwork::with_netmask(ip_addr.into(), netmask) + .map_err(Error::InvalidNetmask)?; + return Ok(destination.into()); + } + + return Err(Error::MismatchedSocketAddress( + AddressFlag::RTA_DST, + saddr.clone(), + )); + }; + return Err(Error::NoDestination); + } + + pub fn destination(&self) -> Result<Option<&SockaddrStorage>> { + Ok(self + .route_addrs() + .find_map(|saddr| match saddr { + RouteSocketAddress::Destination(destination) => Some(destination), + _ => None, + }) + .ok_or(Error::NoDestination)? + .as_ref()) + } + + pub fn destination_v4(&self) -> Result<Option<Ipv4Addr>> { + Ok(self.destination()?.and_then(saddr_to_ipv4)) + } + + pub fn destination_v6(&self) -> Result<Option<Ipv6Addr>> { + Ok(self.destination()?.and_then(saddr_to_ipv6)) + } + + pub fn flags(&self) -> &RouteFlag { + &self.route_flags + } + + pub fn payload( + &self, + message_type: MessageType, + sequence: i32, + pid: i32, + ) -> (rt_msghdr, Vec<Vec<u8>>) { + let address_flags = self.route_addrs().fold(AddressFlag::empty(), |flag, addr| { + flag | addr.address_flag() + }); + + let mut sockaddrs = self.route_addrs().collect::<Vec<_>>(); + sockaddrs.sort_by_key(|saddr| saddr.address_flag()); + let payload_bytes = sockaddrs + .into_iter() + .map(RouteSocketAddress::to_bytes) + .collect::<Vec<_>>(); + + let payload_len: usize = payload_bytes.iter().map(Vec::len).sum(); + + let rtm_msglen = (payload_len + std::mem::size_of::<super::data::rt_msghdr>()) + .try_into() + .expect("route message buffer size cannot fit in 32 bits"); + + let header = super::data::rt_msghdr { + rtm_msglen, + rtm_version: libc::RTM_VERSION.try_into().unwrap(), + rtm_type: message_type.bits(), + rtm_index: self.interface_index, + rtm_flags: self.route_flags.bits(), + rtm_addrs: address_flags.bits(), + rtm_pid: pid, + rtm_seq: sequence, + rtm_errno: 0, + rtm_use: 0, + rtm_inits: 0, + rtm_rmx: Default::default(), + }; + + (header, payload_bytes) + } + + pub fn interface_index(&self) -> u16 { + self.interface_index + } + + pub fn interface_address(&self) -> Option<IpAddr> { + self.get_address(&AddressFlag::RTA_IFA) + } + + fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> { + let addr = self.sockaddrs.get(&address_flag)?; + saddr_to_ipv4(addr.inner()?) + .map(IpAddr::from) + .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from)) + } + + pub fn interface_sockaddr_index(&self) -> Option<u16> { + self.sockaddrs + .values() + .find_map(|addr| addr.interface_index()) + } + + pub fn errno(&self) -> i32 { + self.errno + } + + pub fn is_ipv4(&self) -> bool { + self.destination_v4() + .map(|addr| addr.is_some()) + .unwrap_or(false) + } + + pub fn is_ipv6(&self) -> bool { + self.destination_v6() + .map(|addr| addr.is_some()) + .unwrap_or(false) + } + + pub fn is_ifscope(&self) -> bool { + self.route_flags.contains(RouteFlag::RTF_IFSCOPE) + } + + pub fn ifscope(&self) -> Option<u16> { + if self.is_ifscope() { + Some(self.interface_index) + } else { + None + } + } + + pub fn unset_ifscope(mut self) -> Self { + self.route_flags.remove(RouteFlag::RTF_IFSCOPE); + self + } + + pub fn set_ifscope(mut self, iface_index: u16) -> Self { + if iface_index > 0 { + self.interface_index = iface_index; + self.route_flags.insert(RouteFlag::RTF_IFSCOPE); + } else { + self.interface_index = iface_index; + self.route_flags.remove(RouteFlag::RTF_IFSCOPE); + } + + self + } +} + +#[derive(Debug)] +#[repr(C)] +struct ifa_msghdr { + ifam_msglen: libc::c_ushort, + ifam_version: libc::c_uchar, + ifam_type: libc::c_uchar, + ifam_addrs: libc::c_int, + ifam_flags: libc::c_int, + ifam_index: libc::c_ushort, + ifam_metric: libc::c_int, +} + +#[derive(Debug)] +pub struct AddressMessage { + sockaddrs: BTreeMap<AddressFlag, RouteSocketAddress>, + interface_index: u16, + ifam_type: libc::c_uchar, + flags: RouteFlag, +} + +impl AddressMessage { + pub fn index(&self) -> u16 { + self.interface_index + } + + pub fn print_sockaddrs(&self) { + println!("ifam_type - {}", self.ifam_type); + match self.address() { + Ok(addr) => println!("address - {addr}"), + Err(err) => { + println!("failed to get address {err:?}"); + } + } + for (flag, addr) in &self.sockaddrs { + println!("{flag:?} - {addr:?}"); + } + } + + pub fn address(&self) -> Result<IpAddr> { + self.get_address(&AddressFlag::RTA_IFP) + .or_else(|| self.get_address(&AddressFlag::RTA_IFA)) + .ok_or(Error::NoInterfaceAddress) + } + + fn get_address(&self, address_flag: &AddressFlag) -> Option<IpAddr> { + let addr = self.sockaddrs.get(&address_flag)?; + saddr_to_ipv4(addr.inner()?) + .map(IpAddr::from) + .or_else(|| saddr_to_ipv6(addr.inner()?).map(IpAddr::from)) + } + + pub fn netmask(&self) -> Result<IpAddr> { + self.get_address(&AddressFlag::RTA_NETMASK) + .ok_or(Error::NoNetmaskAddress) + } + + pub fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + const HEADER_SIZE: usize = std::mem::size_of::<ifa_msghdr>(); + if HEADER_SIZE > buffer.len() { + return Err(Error::BufferTooSmall( + "ifa_msghdr", + buffer.len(), + HEADER_SIZE, + )); + } + + // SAFETY: buffer is pointing to enough memory to contain a valid value for ifa_msghdr + let header: ifa_msghdr = unsafe { std::ptr::read_unaligned(buffer.as_ptr() as *const _) }; + + let msg_len = usize::from(header.ifam_msglen); + if msg_len > buffer.len() { + return Err(Error::BufferTooSmall( + "Mesage is shorter than it's msg_len indicates", + msg_len, + buffer.len(), + )); + } + + let payload = &buffer[HEADER_SIZE..std::cmp::min(msg_len, buffer.len())]; + + let flags = RouteFlag::from_bits(header.ifam_flags) + .ok_or(Error::UnknownRouteFlag(header.ifam_flags))?; + + let address_flags = AddressFlag::from_bits(header.ifam_addrs) + .ok_or(Error::UnknownAddressFlag(header.ifam_addrs))?; + + let sockaddrs = RouteSockAddrIterator::new(payload, address_flags) + .map(|addr| addr.map(|addr| (addr.address_flag(), addr))) + .collect::<Result<BTreeMap<_, _>>>()?; + + Ok(Self { + sockaddrs, + flags, + ifam_type: header.ifam_type, + interface_index: header.ifam_index, + }) + } +} + +#[derive(Debug)] +pub enum RouteSocketMessage { + AddRoute(RouteMessage), + DeleteRoute(RouteMessage), + ChangeRoute(RouteMessage), + GetRoute(RouteMessage), + Interface(Interface), + AddAddress(AddressMessage), + DeleteAddress(AddressMessage), + Other { + header: rt_msghdr_short, + payload: Vec<u8>, + }, + Error { + header: rt_msghdr_short, + payload: Vec<u8>, + }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Destination { + Host(IpAddr), + Network(IpNetwork), +} + +impl Destination { + pub fn is_network(&self) -> bool { + matches!(self, Self::Network(_)) + } + + pub fn default_v4() -> Self { + Destination::Network(IpNetwork::new(Ipv4Addr::UNSPECIFIED.into(), 0).unwrap()) + } + + pub fn default_v6() -> Self { + Destination::Network(IpNetwork::new(Ipv6Addr::UNSPECIFIED.into(), 0).unwrap()) + } +} + +impl From<IpAddr> for Destination { + fn from(addr: IpAddr) -> Self { + Self::Host(addr) + } +} + +impl From<IpNetwork> for Destination { + fn from(net: IpNetwork) -> Self { + if net.prefix() == 32 && net.is_ipv4() { + return Self::Host(net.ip()); + } + + Self::Network(net) + } +} + +#[derive(Debug)] +pub enum Error { + /// Payload buffer didn't match the reported message size in header + InvalidBuffer(Vec<u8>, AddressFlag), + /// Buffer too small for specific message type + BufferTooSmall(&'static str, usize, usize), + /// Unknown route flag + UnknownRouteFlag(i32), + /// Socket address is empty for the given address flag + EmptySockaddr(AddressFlag), + /// Unrecognized message + UnknownMessageType(u8), + /// Unrecognized address flag + UnknownAddressFlag(libc::c_int), + /// Mismatched socket address type + MismatchedSocketAddress(AddressFlag, SockaddrStorage), + /// Link socket address contains no identifier + NoLinkIdentifier(nix::libc::sockaddr_dl), + /// Failed to resolve an interface name to an index + InterfaceIndex(nix::Error), + /// An error message as received from the routing socket + RouteError(rt_msghdr, Vec<u8>), + /// Invalid netmask + InvalidNetmask(ipnetwork::IpNetworkError), + /// Route contains no netmask socket address + NoDestination, + /// Address message does not contain an interface address + NoInterfaceAddress, + /// Address message does not contain an interface address + NoNetmaskAddress, +} + +type Result<T> = std::result::Result<T, Error>; + +impl RouteSocketMessage { + pub fn parse_message(buffer: &[u8]) -> Result<Self> { + let route_message = |route_constructor: fn(RouteMessage) -> RouteSocketMessage, buffer| { + let route = RouteMessage::from_byte_buffer(buffer)?; + Ok(route_constructor(route)) + }; + + match rt_msghdr_short::from_bytes(buffer) { + Some(header) if header.is_type(libc::RTM_ADD) => route_message(Self::AddRoute, buffer), + + Some(header) if header.is_type(libc::RTM_CHANGE) => { + route_message(Self::ChangeRoute, buffer) + } + + Some(header) if header.is_type(libc::RTM_DELETE) => { + route_message(Self::DeleteRoute, buffer) + } + + Some(header) if header.is_type(libc::RTM_GET) => route_message(Self::GetRoute, buffer), + + Some(header) if header.is_type(libc::RTM_IFINFO) => Ok(RouteSocketMessage::Interface( + Interface::from_byte_buffer(buffer)?, + )), + + Some(header) if header.is_type(libc::RTM_NEWADDR) => Ok( + RouteSocketMessage::AddAddress(AddressMessage::from_byte_buffer(buffer)?), + ), + Some(header) if header.is_type(libc::RTM_DELADDR) => Ok( + RouteSocketMessage::DeleteAddress(AddressMessage::from_byte_buffer(buffer)?), + ), + Some(header) => Ok(Self::Other { + header, + payload: buffer.to_vec(), + }), + None => Err(Error::BufferTooSmall( + "rt_msghdr_short", + buffer.len(), + std::mem::size_of::<rt_msghdr_short>(), + )), + } + } +} + +/// hush, this will come in later +fn align_to_nearest_u32(idx: usize) -> usize { + if idx > 0 { + 1 + (((idx) - 1) | (std::mem::size_of::<u32>() - 1)) + } else { + std::mem::size_of::<u32>() + } +} + +pub struct Interface { + header: libc::if_msghdr, + payload: Vec<u8>, +} + +impl Interface { + pub fn is_up(&self) -> bool { + self.header.ifm_flags & nix::libc::IFF_UP != 0 + } + + pub fn index(&self) -> u16 { + self.header.ifm_index + } +} + +impl std::fmt::Debug for Interface { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let if_data = f + .debug_struct("if_data") + .field("ifi_type", &self.header.ifm_data.ifi_type) + .field("ifi_typelen", &self.header.ifm_data.ifi_typelen) + .field("ifi_physical", &self.header.ifm_data.ifi_physical) + .field("ifi_addrlen", &self.header.ifm_data.ifi_addrlen) + .field("ifi_hdrlen", &self.header.ifm_data.ifi_hdrlen) + .field("ifi_recvquota", &self.header.ifm_data.ifi_recvquota) + .field("ifi_xmitquota", &self.header.ifm_data.ifi_xmitquota) + .field("ifi_unused1", &self.header.ifm_data.ifi_unused1) + .field("ifi_mtu", &self.header.ifm_data.ifi_mtu) + .field("ifi_metric", &self.header.ifm_data.ifi_metric) + .field("ifi_baudrate", &self.header.ifm_data.ifi_baudrate) + .field("ifi_ipackets", &self.header.ifm_data.ifi_ipackets) + .field("ifi_ierrors", &self.header.ifm_data.ifi_ierrors) + .field("ifi_opackets", &self.header.ifm_data.ifi_opackets) + .field("ifi_oerrors", &self.header.ifm_data.ifi_oerrors) + .field("ifi_collisions", &self.header.ifm_data.ifi_collisions) + .field("ifi_ibytes", &self.header.ifm_data.ifi_ibytes) + .field("ifi_obytes", &self.header.ifm_data.ifi_obytes) + .field("ifi_imcasts", &self.header.ifm_data.ifi_imcasts) + .field("ifi_omcasts", &self.header.ifm_data.ifi_omcasts) + .field("ifi_iqdrops", &self.header.ifm_data.ifi_iqdrops) + .field("ifi_noproto", &self.header.ifm_data.ifi_noproto) + .field("ifi_recvtiming", &self.header.ifm_data.ifi_recvtiming) + .field("ifi_xmittiming", &self.header.ifm_data.ifi_xmittiming) + .field( + "ifi_lastchange", + &( + self.header.ifm_data.ifi_lastchange.tv_sec, + self.header.ifm_data.ifi_lastchange.tv_usec, + ), + ) + .field("ifi_unused2", &self.header.ifm_data.ifi_unused2) + .field("ifi_hwassist", &self.header.ifm_data.ifi_hwassist) + .field("ifi_reserved1", &self.header.ifm_data.ifi_reserved1) + .field("ifi_reserved2", &self.header.ifm_data.ifi_reserved2) + .finish(); + let header = f + .debug_struct("if_msghdr") + .field("ifm_msglen", &self.header.ifm_msglen) + .field("ifm_version", &self.header.ifm_version) + .field("ifm_type", &self.header.ifm_type) + .field("ifm_addrs", &self.header.ifm_addrs) + .field("ifm_flags", &self.header.ifm_flags) + .field("ifm_index", &self.header.ifm_index) + .field("ifm_data", &if_data) + .finish()?; + f.debug_struct("Interface") + .field("header", &header) + .field("payload", &self.payload) + .finish() + } +} + +impl Interface { + fn from_byte_buffer(buffer: &[u8]) -> Result<Self> { + const INTERFACE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<libc::if_msghdr>(); + if INTERFACE_MESSAGE_HEADER_SIZE > buffer.len() { + return Err(Error::BufferTooSmall( + "if_msghdr", + buffer.len(), + INTERFACE_MESSAGE_HEADER_SIZE, + )); + } + let header: libc::if_msghdr = unsafe { std::ptr::read(buffer.as_ptr() as *const _) }; + let payload = buffer[INTERFACE_MESSAGE_HEADER_SIZE..header.ifm_msglen.into()].to_vec(); + Ok(Self { header, payload }) + } +} + +// #define RTA_DST 0x1 /* destination sockaddr present */ +// #define RTA_GATEWAY 0x2 /* gateway sockaddr present */ +// #define RTA_NETMASK 0x4 /* netmask sockaddr present */ +// #define RTA_GENMASK 0x8 /* cloning mask sockaddr present */ +// #define RTA_IFP 0x10 /* interface name sockaddr present */ +// #define RTA_IFA 0x20 /* interface addr sockaddr present */ +// #define RTA_AUTHOR 0x40 /* sockaddr for author of redirect */ +// #define RTA_BRD 0x80 /* for NEWADDR, broadcast or p-p dest addr */ +bitflags::bitflags! { + /// All enum values of address flags can be iterated via `flag <<= 1`, starting from 1. + // #[derive(Clone, Copy, PartialOrd)] + pub struct AddressFlag: i32 { + /// Destination socket address + const RTA_DST = 0x1; + /// Gateway socket address + const RTA_GATEWAY = 0x2; + /// Netmask socket address + const RTA_NETMASK = 0x4; + /// Cloning mask socket address + const RTA_GENMASK = 0x8; + /// Interface name socket address + const RTA_IFP = 0x10; + /// Interface address socket address + const RTA_IFA = 0x20; + /// Socket address for author of redirect + const RTA_AUTHOR = 0x40; + /// Socket address for `NEWADDR`, broadcast or point-to-point destination address + const RTA_BRD = 0x80; + } +} + +bitflags::bitflags! { + /// Types of routing messages + // #[derive(Clone, Copy, PartialOrd)] + pub struct MessageType: u8 { + /// Add Route + const RTM_ADD = 0x1; + /// Delete Route + const RTM_DELETE = 0x2; + /// Change Metrics or flags + const RTM_CHANGE = 0x3; + /// Report Metrics + const RTM_GET = 0x4; + /// RTM_LOSING is no longer generated by and is deprecated + const RTM_LOSING = 0x5; + /// Told to use different route + const RTM_REDIRECT = 0x6; + /// Lookup failed on this address + const RTM_MISS = 0x7; + /// fix specified metrics + const RTM_LOCK = 0x8; + /// caused by SIOCADDRT + const RTM_OLDADD = 0x9; + /// caused by SIOCDELRT + const RTM_OLDDEL = 0xa; + /// req to resolve dst to LL addr + const RTM_RESOLVE = 0xb; + /// address being added to iface + const RTM_NEWADDR = 0xc; + /// address being removed from iface + const RTM_DELADDR = 0xd; + /// iface going up/down etc. + const RTM_IFINFO = 0xe; + /// mcast group membership being added to if + const RTM_NEWMADDR = 0xf; + /// mcast group membership being deleted + const RTM_DELMADDR = 0x10; + } + + + + /// Types of routing messages + // #[derive(Clone, Copy, PartialOrd)] + pub struct RouteFlag: i32 { + /// route usable + const RTF_UP = 0x1; + /// destination is a gateway + const RTF_GATEWAY = 0x2; + /// host entry (net otherwise) + const RTF_HOST = 0x4; + /// host or net unreachable + const RTF_REJECT = 0x8; + /// created dynamically (by redirect) + const RTF_DYNAMIC = 0x10; + /// modified dynamically (by redirect) + const RTF_MODIFIED = 0x20; + /// message confirmed + const RTF_DONE = 0x40; + /// delete cloned route + const RTF_DELCLONE = 0x80; + /// generate new routes on use + const RTF_CLONING = 0x100; + /// external daemon resolves name + const RTF_XRESOLVE = 0x200; + /// DEPRECATED - exists ONLY for backwards compatibility + const RTF_LLINFO = 0x400; + /// used by apps to add/del L2 entries + const RTF_LLDATA = 0x400; + /// manually added + const RTF_STATIC = 0x800; + /// just discard pkts (during updates) + const RTF_BLACKHOLE = 0x1000; + /// not eligible for RTF_IFREF + const RTF_NOIFREF = 0x2000; + /// protocol specific routing flag + const RTF_PROTO2 = 0x4000; + /// protocol specific routing flag + const RTF_PROTO1 = 0x8000; + /// protocol requires cloning + const RTF_PRCLONING = 0x10000; + /// route generated through cloning + const RTF_WASCLONED = 0x20000; + /// protocol specific routing flag + const RTF_PROTO3 = 0x40000; + /// future use + const RTF_PINNED = 0x100000; + /// route represents a local address + const RTF_LOCAL = 0x200000; + /// route represents a bcast address + const RTF_BROADCAST = 0x400000; + /// route represents a mcast address + const RTF_MULTICAST = 0x800000; + /// has valid interface scope + const RTF_IFSCOPE = 0x1000000; + /// defunct; no longer modifiable + const RTF_CONDEMNED = 0x2000000; + /// route holds a ref to interface + const RTF_IFREF = 0x4000000; + /// proxying, no interface scope + const RTF_PROXY = 0x8000000; + /// host is a router + const RTF_ROUTER = 0x10000000; + /// Route entry is being freed + const RTF_DEAD = 0x20000000; + /// route to destination of the global internet + const RTF_GLOBAL = 0x40000000; + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RouteSocketAddress { + /// Corresponds to RTA_DST + Destination(Option<SockaddrStorage>), + /// RTA_GATEWAY + Gateway(Option<SockaddrStorage>), + /// RTA_NETMASK + Netmask(Option<SockaddrStorage>), + /// RTA_GENMASK + CloningMask(Option<SockaddrStorage>), + /// RTA_IFP + IfName(Option<SockaddrStorage>), + /// RTA_IFA + IfSockaddr(Option<SockaddrStorage>), + /// RTA_AUTHOR + RedirectAuthor(Option<SockaddrStorage>), + /// RTA_BRD + Broadcast(Option<SockaddrStorage>), +} + +impl RouteSocketAddress { + // Returns a new route socket address and number of bytes read from the buffer + pub fn new(flag: AddressFlag, buf: &[u8]) -> Result<(Self, u8)> { + // If buffer is empty, then the socket address is empty too, the backing buffer shouldn't + // be advanced. + if buf.is_empty() { + return Ok((Self::with_sockaddr(flag, None)?, 0)); + } + + // to get the length and type of + if buf.len() < std::mem::size_of::<sockaddr_hdr>() { + return Err(Error::BufferTooSmall( + "sockaddr buffer too small", + buf.len(), + std::mem::size_of::<sockaddr_hdr>(), + )); + } + + let addr_header_ptr = buf.as_ptr() as *const sockaddr_hdr; + // safety - since `buf` is at least as long as a `sockaddr_hdr`, it's perfectly valid to + // read from. + let addr_header = unsafe { std::ptr::read(addr_header_ptr) }; + let saddr_len = addr_header.sa_len; + if saddr_len == 0 { + return Ok((Self::with_sockaddr(flag, None)?, 4)); + } + + if Into::<usize>::into(saddr_len) > buf.len() { + return Err(Error::InvalidBuffer(buf.to_vec(), flag)); + } + + // SAFETY: the buffer is big enough for the sockaddr struct inside it, so accessing as a + // `sockaddr` is valid. + let saddr = unsafe { + SockaddrStorage::from_raw( + addr_header_ptr as *const nix::libc::sockaddr, + Some(saddr_len.into()), + ) + }; + + return Ok((Self::with_sockaddr(flag, saddr)?, saddr_len)); + } + + pub fn to_bytes(&self) -> Vec<u8> { + match self.inner() { + None => vec![0u8; 4], + Some(addr) => { + let len = addr.len(); + // The "serialized" socket addresses must be padded to be aligned to 4 bytes, with + // the smallest size being 4 bytes. + let buffer_size = len + len % 4; + let mut buffer = vec![0u8; buffer_size as usize]; + let mut buffer_ptr = buffer.as_mut_ptr(); + unsafe { + // SAFETY: copying conents of addr into buffer is safe, as long as addr.len() + // returns a correct size for the socket address pointer. + std::ptr::copy_nonoverlapping( + addr.as_ptr() as *const _, + buffer.as_mut_ptr(), + addr.len() as usize, + ); + } + buffer + } + } + } + + pub fn address_flag(&self) -> AddressFlag { + match &self { + Self::Destination(_) => AddressFlag::RTA_DST, + Self::Gateway(_) => AddressFlag::RTA_GATEWAY, + Self::Netmask(_) => AddressFlag::RTA_NETMASK, + Self::CloningMask(_) => AddressFlag::RTA_GENMASK, + Self::IfName(_) => AddressFlag::RTA_IFP, + Self::IfSockaddr(_) => AddressFlag::RTA_IFA, + Self::RedirectAuthor(_) => AddressFlag::RTA_AUTHOR, + Self::Broadcast(_) => AddressFlag::RTA_BRD, + } + } + + pub fn inner(&self) -> Option<&SockaddrStorage> { + match &self { + Self::Gateway(addr) + | Self::Destination(addr) + | Self::Netmask(addr) + | Self::CloningMask(addr) + | Self::IfName(addr) + | Self::IfSockaddr(addr) + | Self::RedirectAuthor(addr) + | Self::Broadcast(addr) => addr.as_ref(), + } + } + + fn with_sockaddr(flag: AddressFlag, sockaddr: Option<SockaddrStorage>) -> Result<Self> { + let constructor = match flag { + AddressFlag::RTA_GATEWAY => Self::Gateway, + AddressFlag::RTA_DST => Self::Destination, + AddressFlag::RTA_NETMASK => Self::Netmask, + AddressFlag::RTA_GENMASK => Self::CloningMask, + AddressFlag::RTA_IFP => Self::IfName, + AddressFlag::RTA_IFA => Self::IfSockaddr, + AddressFlag::RTA_AUTHOR => Self::RedirectAuthor, + AddressFlag::RTA_BRD => Self::Broadcast, + unknown => return Err(Error::UnknownAddressFlag(unknown.bits())), + }; + + Ok(constructor(sockaddr)) + } + + pub fn interface_index(&self) -> Option<u16> { + match self { + Self::IfName(Some(iface)) => { + let index = iface.as_link_addr()?.ifindex(); + Some( + u16::try_from(index) + .expect("interface indexes actually are u16s, nix is just *interesting*"), + ) + } + _ => None, + } + } + + pub fn set_interface_index(mut self, index: u16) -> Self { + unimplemented!() + // self.insert_sockaddr(RouteSocketAddress::IfName(Some(sockaddr))); + // self + } +} + +/// Route socket addreses should be ordered by their corresponding address flag when a route +/// message is constructed +impl std::cmp::PartialOrd for RouteSocketAddress { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + self.address_flag().partial_cmp(&other.address_flag()) + } +} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +struct sockaddr_hdr { + sa_len: u8, + sa_family: libc::sa_family_t, + padding: u16, +} + +pub enum InterfaceIdentifier { + Index(u16), + Name(OsString), + Unspecified, +} + +/// An iterator to consume a byte buffer containing socket address structures originating from a +/// routing socket message. +pub struct RouteSockAddrIterator<'a> { + buffer: &'a [u8], + flags: AddressFlag, + // Cursor used to iterate through address flags + flag_cursor: i32, +} + +impl<'a> RouteSockAddrIterator<'a> { + fn new(buffer: &'a [u8], flags: AddressFlag) -> Self { + Self { + buffer, + flags, + flag_cursor: AddressFlag::RTA_DST.bits(), + } + } + + /// Advances internal byte buffer by given amount. The byte amount will be padded to be + /// aligned to 4 bytes if there's more data in the buffer. + fn advance_buffer(&mut self, mut saddr_len: u8) { + let saddr_len = usize::from(saddr_len); + + // if consumed as many bytes as are left in the buffer, the buffer can be cleared + if saddr_len == self.buffer.len() { + self.buffer = &[]; + return; + } + + let padded_saddr_len = if saddr_len % 4 != 0 { + usize::from(saddr_len + (4 - saddr_len % 4)) + } else { + usize::from(saddr_len) + }; + + // if offest is larger than current buffer, ensure slice gets truncated + // since the socket address should've already be read from the buffer at this point, this + // probably should be an invariant? + self.buffer = &self.buffer[padded_saddr_len..]; + } +} + +impl<'a> Iterator for RouteSockAddrIterator<'a> { + type Item = Result<RouteSocketAddress>; + + fn next(&mut self) -> Option<Self::Item> { + loop { + // If address flags don't contain the current one, try the next one. + // Will return None if it runs out of valid flags. + let current_flag = AddressFlag::from_bits(self.flag_cursor)?; + self.flag_cursor <<= 1; + + if !self.flags.contains(current_flag) { + continue; + } + return match RouteSocketAddress::new(current_flag, self.buffer) { + Ok((next_addr, addr_len)) => { + self.advance_buffer(addr_len); + Some(Ok(next_addr)) + } + Err(err) => { + self.buffer = &[]; + Some(Err(err)) + } + }; + } + } +} + +// struct rt_msghdr { +// u_short rtm_msglen; /* to skip over non-understood messages */ +// u_char rtm_version; /* future binary compatibility */ +// u_char rtm_type; /* message type */ +// u_short rtm_index; /* index for associated ifp */ +// int rtm_flags; /* flags, incl. kern & message, e.g. DONE */ +// int rtm_addrs; /* bitmask identifying sockaddrs in msg */ +// pid_t rtm_pid; /* identify sender */ +// int rtm_seq; /* for sender to identify action */ +// int rtm_errno; /* why failed */ +// int rtm_use; /* from rtentry */ +// u_int32_t rtm_inits; /* which metrics we are initializing */ +// struct rt_metrics rtm_rmx; /* metrics themselves */ +// }; +#[derive(Debug, Clone)] +#[repr(C)] +pub struct rt_msghdr { + pub rtm_msglen: libc::c_ushort, + pub rtm_version: libc::c_uchar, + pub rtm_type: libc::c_uchar, + pub rtm_index: libc::c_ushort, + pub rtm_flags: libc::c_int, + pub rtm_addrs: libc::c_int, + pub rtm_pid: libc::pid_t, + pub rtm_seq: libc::c_int, + pub rtm_errno: libc::c_int, + pub rtm_use: libc::c_int, + pub rtm_inits: u32, + pub rtm_rmx: rt_metrics, +} +const ROUTE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::<rt_msghdr>(); + +fn saddr_to_ipv4(saddr: &SockaddrStorage) -> Option<Ipv4Addr> { + let addr = saddr.as_sockaddr_in()?; + Some(*SocketAddrV4::from(*addr).ip()) +} + +fn saddr_to_ipv6(saddr: &SockaddrStorage) -> Option<Ipv6Addr> { + let addr = saddr.as_sockaddr_in6()?; + Some(*SocketAddrV6::from(*addr).ip()) +} + +impl rt_msghdr { + pub fn from_bytes(buf: &[u8]) -> Result<Self> { + if buf.len() >= std::mem::size_of::<rt_msghdr>() { + let ptr = buf.as_ptr(); + // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr value and it's + // readable. rt_msghdr doesn't contain any pointers so any values are valid. + Ok(unsafe { std::ptr::read(ptr as *const _) }) + } else { + Err(Error::BufferTooSmall( + "if_msghdr", + buf.len(), + ROUTE_MESSAGE_HEADER_SIZE, + )) + } + } +} + +/// Shorter rt_msghdr version that matches all routing messages +#[derive(Debug)] +#[repr(C)] +pub struct rt_msghdr_short { + pub rtm_msglen: libc::c_ushort, + pub rtm_version: libc::c_uchar, + pub rtm_type: libc::c_uchar, + pub rtm_index: libc::c_ushort, + pub rtm_flags: libc::c_int, + pub rtm_addrs: libc::c_int, + pub rtm_pid: libc::pid_t, + pub rtm_seq: libc::c_int, + pub rtm_errno: libc::c_int, +} + +impl rt_msghdr_short { + fn is_type(&self, expected_type: i32) -> bool { + u8::try_from(expected_type) + .map(|expected| self.rtm_type == expected) + .unwrap_or(false) + } + + pub fn from_bytes<'a>(buf: &'a [u8]) -> Option<Self> { + if buf.len() >= std::mem::size_of::<rt_msghdr_short>() { + let ptr = buf.as_ptr(); + // SAFETY: `ptr` is backed by enough valid bytes to contain a rt_msghdr_short value and + // it's readable. rt_msghdr_short doesn't contain any pointers so any values are valid. + Some(unsafe { std::ptr::read(ptr as *const rt_msghdr_short) }) + } else { + None + } + } + + fn is_err(&self) -> bool { + self.rtm_errno != 0 + } +} + +#[derive(PartialEq, PartialOrd, Ord, Eq, Clone)] +pub struct RouteDestination { + pub network: IpNetwork, + pub interface: Option<u16>, + pub gateway: Option<IpAddr>, +} + +impl RouteDestination { + pub fn is_default(&self) -> bool { + if self.network.prefix() != 0 { + return false; + } + match self.network.ip() { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) => true, + IpAddr::V6(Ipv6Addr::UNSPECIFIED) => true, + _ => false, + } + } + + pub fn is_ipv4(&self) -> bool { + self.network.is_ipv4() + } +} + +impl TryFrom<&RouteMessage> for RouteDestination { + type Error = Error; + + fn try_from(msg: &RouteMessage) -> std::result::Result<Self, Self::Error> { + let network = msg.destination_ip()?; + let interface = msg.ifscope(); + let gateway = msg.gateway_ip(); + Ok(Self { + network, + interface, + gateway, + }) + } +} + +// Struct containing metrics of various metrics for a specific route +// struct rt_metrics { +// u_int32_t rmx_locks; /* Kernel leaves these values alone */ +// u_int32_t rmx_mtu; /* MTU for this path */ +// u_int32_t rmx_hopcount; /* max hops expected */ +// int32_t rmx_expire; /* lifetime for route, e.g. redirect */ +// u_int32_t rmx_recvpipe; /* inbound delay-bandwidth product */ +// u_int32_t rmx_sendpipe; /* outbound delay-bandwidth product */ +// u_int32_t rmx_ssthresh; /* outbound gateway buffer limit */ +// u_int32_t rmx_rtt; /* estimated round trip time */ +// u_int32_t rmx_rttvar; /* estimated rtt variance */ +// u_int32_t rmx_pksent; /* packets sent using this route */ +// u_int32_t rmx_state; /* route state */ +// u_int32_t rmx_filler[3]; /* will be used for TCP's peer-MSS cache */ +// }; +#[derive(Debug, Default, Clone)] +#[repr(C)] +pub struct rt_metrics { + pub rmx_locks: u32, + pub rmx_mtu: u32, + pub rmx_hopcount: u32, + pub rmx_expire: i32, + pub rmx_recvpipe: u32, + pub rmx_sendpipe: u32, + pub rmx_ssthresh: u32, + pub rmx_rtt: u32, + pub rmx_rttvar: u32, + pub rmx_pksent: u32, + pub rmx_state: u32, + pub rmx_filler: [u32; 3], +} + +#[test] +fn test_failing_rtmsg() { + let bytes = [ + 135, 0, 5, 1, 11, 0, 0, 0, 1, 1, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 16, 2, 0, 0, 192, 168, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 18, 11, 0, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 255, 255, 255, 255, 255, 255, + ]; + let _ = RouteSocketMessage::parse_message(&bytes).unwrap(); +} diff --git a/talpid-routing/src/unix/watch/routing_socket.rs b/talpid-routing/src/unix/watch/routing_socket.rs new file mode 100644 index 0000000000..72bf50df8a --- /dev/null +++ b/talpid-routing/src/unix/watch/routing_socket.rs @@ -0,0 +1,191 @@ +use std::{ + collections::VecDeque, + mem::size_of, + num::NonZeroU16, + os::unix::prelude::{FromRawFd, RawFd}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use nix::{ + fcntl::{self, OFlag}, + sys::socket::{socket, AddressFamily, SockFlag, SockType}, +}; +use std::{ + fs::File, + io::{self, Read, Write}, +}; + +use super::data::{ + rt_msghdr_short, AddressFlag, MessageType, RouteFlag, RouteMessage, RouteSocketAddress, +}; + +use tokio::io::{unix::AsyncFd, AsyncWrite, AsyncWriteExt}; + +/// Routing socket interface version +const RTM_VERSION: libc::c_uchar = 5; + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Failed to open routing socket")] + OpenSocket(io::Error), + #[error(display = "Failed to write to routing socket")] + WriteError(io::Error), + #[error(display = "Failed to read from routing socket")] + ReadError(io::Error), + #[error(display = "Received a message that's too small")] + MessageTooSmall(usize), +} + +type Result<T> = std::result::Result<T, Error>; + +pub struct RoutingSocket { + socket: RoutingSocketInner, + seq: i32, + // buffers up messages received whilst waiting on a response + buf: VecDeque<Vec<u8>>, + own_pid: i32, +} + +impl RoutingSocket { + pub fn new() -> Result<Self> { + Ok(Self { + socket: RoutingSocketInner::new().map_err(Error::OpenSocket)?, + seq: 1, + buf: Default::default(), + own_pid: std::process::id().try_into().unwrap(), + }) + } + + pub async fn recv_msg(&mut self, mut buf: &mut [u8]) -> Result<usize> { + if let Some(buffered_msg) = self.buf.pop_front() { + let bytes_written = buf.write(&buffered_msg).map_err(Error::WriteError)?; + return Ok(bytes_written); + } + self.read_next_msg(buf).await + } + + async fn read_next_msg(&mut self, buf: &mut [u8]) -> Result<usize> { + self.socket.read(buf).await.map_err(Error::ReadError) + } + + pub async fn send_route_message( + &mut self, + message: &RouteMessage, + message_type: MessageType, + ) -> Result<Vec<u8>> { + let (msg, seq) = self.next_route_msg(message, message_type); + match self.socket.write(&msg).await { + Ok(_) => self.wait_for_response(seq).await, + Err(err) => { + self.seq -= 1; + Err(Error::WriteError(err)) + } + } + } + + pub async fn wait_for_response(&mut self, response_num: i32) -> Result<Vec<u8>> { + loop { + let mut buffer = vec![0u8; 2024]; + // do not truncate the buffer - trailing empty bytes won't be written but will be + // assumed in the data format. + let bytes_read = self.read_next_msg(&mut buffer).await?; + + { + let header = rt_msghdr_short::from_bytes(buffer.as_slice()) + .ok_or(Error::MessageTooSmall(bytes_read))?; + + if header.rtm_pid == self.own_pid && response_num == header.rtm_seq { + return Ok(buffer); + } + } + + self.buf.push_back(buffer); + } + } + + /// Will panic if the message length ends up overflowing an i32. + fn next_route_msg(&mut self, message: &RouteMessage, msg_type: MessageType) -> (Vec<u8>, i32) { + let seq = self.seq; + self.seq += 1; + + let (header, payload) = message.payload(msg_type, seq, self.own_pid); + let mut msg_buffer = vec![0u8; header.rtm_msglen.into()]; + + unsafe { + std::ptr::copy_nonoverlapping( + &header as *const _ as *const u8, + msg_buffer.as_mut_ptr(), + size_of::<super::data::rt_msghdr>(), + ); + } + let mut sockaddr_buf = &mut msg_buffer[std::mem::size_of::<super::data::rt_msghdr>()..]; + for socket_addr in payload { + sockaddr_buf + .write(socket_addr.as_slice()) + .expect("faled to write socket address into message buffer"); + } + (msg_buffer, header.rtm_seq) + } +} + +struct RoutingSocketInner { + // storing the file handle in a std::file::File automagically provides sane io::{Write, + // Read} and Drop implementations. + socket: AsyncFd<File>, +} + +impl RoutingSocketInner { + fn new() -> io::Result<Self> { + let fd = socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None)?; + let _ = fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::OFlag::O_NONBLOCK))?; + // SAFETY: File handle is valid here + let socket = unsafe { File::from_raw_fd(fd) }; + Ok(Self { + socket: AsyncFd::new(socket)?, + }) + } + + async fn read(&mut self, out: &mut [u8]) -> std::io::Result<usize> { + loop { + let mut guard = self.socket.readable().await?; + match guard.try_io(|sock| sock.get_ref().read(out)) { + Ok(result) => return result, + Err(err) => continue, + } + } + } +} + +impl std::os::unix::prelude::AsRawFd for RoutingSocketInner { + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +impl AsyncWrite for RoutingSocketInner { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + loop { + let mut guard = ready!(self.socket.poll_write_ready(cx))?; + + match guard.try_io(|inner| inner.get_ref().write(buf)) { + Ok(result) => return Poll::Ready(result), + Err(_would_block) => continue, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // no need for a shutdown on the routing socket + Poll::Ready(Ok(())) + } +} diff --git a/talpid-routing/src/watch.rs b/talpid-routing/src/watch.rs new file mode 100644 index 0000000000..e7ce148096 --- /dev/null +++ b/talpid-routing/src/watch.rs @@ -0,0 +1,369 @@ +use ipnetwork::IpNetwork; +use nix::{ + ifaddrs::{getifaddrs, InterfaceAddress}, + sys::socket::{socket, AddressFamily, SockFlag, SockType, SockaddrLike, SockaddrStorage}, +}; +use std::{ + collections::BTreeMap, + io::{self, Read}, + mem, + net::{IpAddr, SocketAddr}, +}; +use tokio::io::unix::AsyncFd; + +use crate::imp::imp::watch::data::RouteSocketAddress; + +use self::data::{ + Destination, Interface, MessageType, RouteFlag, RouteMessage, RouteSocketMessage, +}; + +pub(crate) mod data; +pub(crate) mod routing_socket; + +type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug, err_derive::Error)] +pub enum Error { + #[error(display = "Failed to resolve interface name to index")] + ResolveInterfaceName(nix::Error), + #[error(display = "Failed to open routing socket")] + RoutingSocket(routing_socket::Error), + #[error(display = "Invalid message")] + InvalidMessage(data::Error), + #[error(display = "Failed to send routing message")] + SendError(routing_socket::Error), + #[error(display = "Unexepcted message type")] + UnexpectedMessageType(RouteSocketMessage, MessageType), + #[error(display = "Route not found")] + RouteNotFound, + #[error(display = "Destination unreachable")] + Unreachable, + #[error(display = "Failed to delete a route")] + Deletion(RouteMessage), + #[error(display = "Failed to add a route")] + Add(RouteMessage), + #[error(display = "Faield to fetch a route")] + Get(RouteMessage), +} + +pub struct RoutingTable { + socket: routing_socket::RoutingSocket, +} + +// TODO: Ensure that route socket messages with error messages in them get returned as a result +impl RoutingTable { + pub fn new() -> Result<Self> { + let socket = routing_socket::RoutingSocket::new().map_err(Error::RoutingSocket)?; + + Ok(Self { socket }) + } + + pub async fn next_message(&mut self) -> Result<RouteSocketMessage> { + let mut buf = [0u8; 2048]; + let bytes_read = self + .socket + .recv_msg(&mut buf) + .await + .map_err(Error::RoutingSocket)?; + let msg_buf = &buf[0..bytes_read]; + data::RouteSocketMessage::parse_message(&msg_buf).map_err(Error::InvalidMessage) + } + + pub async fn add_route(&mut self, message: &RouteMessage) -> Result<()> { + let msg = self + .alter_routing_table(&message, MessageType::RTM_ADD) + .await; + + match msg { + Ok(RouteSocketMessage::AddRoute(_route)) => Ok(()), + Err(Error::SendError(routing_socket::Error::WriteError(err))) + if err.kind() == io::ErrorKind::AlreadyExists => + { + Ok(()) + } + Ok(anything_else) => Err(Error::UnexpectedMessageType( + anything_else, + MessageType::RTM_ADD, + )), + + Err(err) => Err(err), + } + } + + pub async fn change_route(&mut self, message: &RouteMessage) -> Result<()> { + let response = self + .alter_routing_table(message, MessageType::RTM_CHANGE) + .await?; + + match response { + RouteSocketMessage::ChangeRoute(_route) => Ok(()), + anything_else => Err(Error::UnexpectedMessageType( + anything_else, + MessageType::RTM_CHANGE, + )), + } + } + + async fn alter_routing_table( + &mut self, + message: &RouteMessage, + message_type: MessageType, + ) -> Result<RouteSocketMessage> { + let result = self.socket.send_route_message(&message, message_type).await; + + match result { + Ok(response) => { + data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage) + } + + Err(routing_socket::Error::WriteError(err)) + if err.kind() == io::ErrorKind::NotFound => + { + Err(Error::RouteNotFound) + } + Err(routing_socket::Error::WriteError(err)) + if [Some(libc::ENETUNREACH), Some(libc::ESRCH)].contains(&err.raw_os_error()) => + { + Err(Error::Unreachable) + } + Err(err) => Err(Error::SendError(err)), + } + } + + pub async fn delete_route(&mut self, message: &RouteMessage) -> Result<()> { + let response = self + .alter_routing_table(message, MessageType::RTM_DELETE) + .await?; + + match response { + RouteSocketMessage::DeleteRoute(route) if route.errno() == 0 => Ok(()), + RouteSocketMessage::DeleteRoute(route) if route.errno() != 0 => { + Err(Error::Deletion(route)) + } + anything_else => Err(Error::UnexpectedMessageType( + anything_else, + MessageType::RTM_DELETE, + )), + } + } + + pub async fn get_route( + &mut self, + destination: impl Into<Destination>, + ) -> Result<Option<data::RouteMessage>> { + let destination = destination.into(); + + let mut msg = RouteMessage::new_route(destination); + if destination.is_network() { + msg = msg.set_gateway_route(true); + } + + let response = self + .socket + .send_route_message(&msg, MessageType::RTM_GET) + .await; + + let response = match response { + Ok(response) => response, + Err(routing_socket::Error::WriteError(err)) => { + if let Some(err) = err.raw_os_error() { + if [libc::ENETUNREACH, libc::ESRCH].contains(&err) { + return Ok(None); + } + } + return Err(Error::RoutingSocket(routing_socket::Error::WriteError(err))); + } + Err(other_err) => { + return Err(Error::RoutingSocket(other_err)); + } + }; + + match data::RouteSocketMessage::parse_message(&response).map_err(Error::InvalidMessage)? { + data::RouteSocketMessage::GetRoute(route) => Ok(Some(route)), + unexpected_route_message => Err(Error::UnexpectedMessageType( + unexpected_route_message, + MessageType::RTM_GET, + )), + } + } +} + +fn if_sockaddr_for_default_route_fetching() -> SockaddrStorage { + let interface_sockaddr = nix::libc::sockaddr_dl { + sdl_len: mem::size_of::<nix::libc::sockaddr_dl>() as u8, + sdl_family: libc::AF_LINK as u8, + sdl_index: 0, + sdl_type: 0, + sdl_nlen: 0, + sdl_alen: 0, + sdl_slen: 0, + sdl_data: [0; 12], + }; + unsafe { + SockaddrStorage::from_raw( + &interface_sockaddr as *const _ as *const _, + Some(interface_sockaddr.sdl_len.into()), + ) + } + .unwrap() +} + +/// Watch routes +pub async fn watch_routes() -> Result<()> { + // read_from_file(&"add-two-no-slash"); + // read_from_file(&"add-two-slash-28"); + // read_from_file(&"add-two-slash-32"); + read_from_file(&"add-tensixtyfour"); + let mut table = RoutingTable::new().expect("failed to open routing table"); + let default_route_v4 = table.get_route(Destination::default_v4()).await; + let default_route_v6 = table.get_route(Destination::default_v6()).await; + let route_to_gw = table + .get_route(Destination::Host("10.64.0.1".parse().unwrap())) + .await; + + let dest: IpNetwork = "173.11.33.44/32".parse().unwrap(); + let gateway: IpAddr = "192.168.1.1".parse().unwrap(); + let new_route = RouteMessage::new_route(dest.into()).set_gateway_addr(gateway); + + // let add_route = table + // .add_route( + // "1.1.1.1/32".parse().unwrap(), + // Some("8.8.8.8".parse().unwrap()), + // None, + // false, + // ) + // .await; + // let _ = std::io::stdin().read(&mut [0u8; 1]); + // + + // let remove_route = table + // .delete_route( + // "1.1.1.1/32".parse().unwrap(), + // None, + // false, + // ) + // .await; + + // let interface = nix::ifaddrs::getifaddrs() + // .unwrap() + // .find(|iface| iface.interface_name == "en0") + // .unwrap(); + + // let new_route = RouteMessage::new_route(Destination::Host("1.1.1.1".parse().unwrap())) + // .set_gateway_addr("192.168.88.1".parse().unwrap()); + + // let add = table.add_route(&new_route); + + // .delete_route( + // "1.1.1.1/32".parse().unwrap(), + // // Some("192.168.88.1".parse().unwrap()), + // None, + // Some(&interface), + // false, + // ) + // .await; + // println!("delet {delet:?}"); + // println!("add_route - {add_route:?}\nremove_route {remove_route:?}"); + + loop { + let msg = table.next_message().await?; + print_msg(msg); + } + Ok(()) +} + +fn read_from_file(path: impl AsRef<std::path::Path>) { + match std::fs::read(path.as_ref()) { + Ok(bytes) => { + let msg = data::RouteSocketMessage::parse_message(&bytes).unwrap(); + println!("{} contains {msg:?}", path.as_ref().display()); + } + Err(err) => { + println!("Failed to read file {}:{err}", path.as_ref().display()); + } + } +} + +fn print_msg(msg: data::RouteSocketMessage) { + match msg { + data::RouteSocketMessage::GetRoute(route) => { + println!( + "================================================================================" + ); + println!("get route"); + route.print_route(); + println!(""); + } + data::RouteSocketMessage::AddRoute(route) => { + println!( + "================================================================================" + ); + println!("add route"); + route.print_route(); + println!(""); + } + data::RouteSocketMessage::ChangeRoute(route) => { + println!( + "================================================================================" + ); + println!("change route"); + route.print_route(); + println!(""); + } + data::RouteSocketMessage::DeleteRoute(route) => { + println!( + "================================================================================" + ); + let addrs = route.route_addrs().collect::<Vec<_>>(); + println!("route-addrs = {}", addrs.len()); + route.print_route(); + println!(""); + } + + data::RouteSocketMessage::Interface(interface) => { + println!( + "================================================================================" + ); + let action_msg = if interface.is_up() { + "added" + } else { + "removed" + }; + let idx = interface.index(); + println!("{action_msg} interface {idx}"); + } + + data::RouteSocketMessage::AddAddress(address) => { + println!( + "================================================================================" + ); + let idx = address.index(); + println!("Added address {:?} for interface {idx}", address.address()); + address.print_sockaddrs(); + } + + data::RouteSocketMessage::DeleteAddress(address) => { + println!( + "================================================================================" + ); + let idx = address.index(); + let address = address.address(); + println!("Deleted address {address:?} for interface {idx}"); + } + // ignoring other kinds of route messages + _ => { + return; + } + }; +} + +const REMOVE_ROUTE_MSG: &[u8] = &[ + 164, 0, 5, 2, 11, 0, 0, 0, 66, 8, 1, 67, 55, 0, 0, 0, 64, 1, 0, 0, 71, 8, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 220, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 16, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 2, 0, 0, 192, 168, 185, 1, 11, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 20, 18, 11, 0, 6, 3, 6, 0, 101, 110, 48, 60, 6, 48, 3, 54, 249, 0, 0, 0, + 16, 2, 0, 0, 192, 168, 185, 116, 0, 0, 0, 0, 0, 0, 0, 0, +]; + +const GET_ROUTE_MSG: &'static str = "qAAFBBIAAABBCQBANwAAAAcPAAABAAAAAAAAAAAAAAAAAAAAAAAAAIwFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAgAAAAAAAAAAAAAAAAAAFBISAAAAAAAAAAAAAAAAAAAAAAAAAgAAFBISAAEFAAB1dHVuMwAAAAAAAAAQAgAACnYA7wAAAAAAAAAA"; diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index 358aa5d64a..308156b7aa 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -56,6 +56,14 @@ pub enum Error { } impl Config { + /// get the first IPv4 address + pub fn get_private_ipv4(&self) -> Option<Ipv4Addr> { + self.tunnel.addresses.iter().find_map(|addr| match addr { + std::net::IpAddr::V4(addr) => Some(*addr), + _ => None, + }) + } + /// Constructs a Config from parameters pub fn from_parameters(params: &wireguard::TunnelParameters) -> Result<Config, Error> { let tunnel = params.connection.tunnel.clone(); diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index da70b63bfc..e3085226d5 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -103,6 +103,11 @@ pub enum Error { #[error(display = "Failed to set up IP interfaces")] IpInterfacesError, + /// Configuration has no tunnel IPv4 address + #[cfg(target_os = "macos")] + #[error(display = "No IPv4 tunnel interface address")] + NoIpv4Address, + /// Failed to set IP addresses on WireGuard interface #[cfg(target_os = "windows")] #[error(display = "Failed to set IP addresses on WireGuard interface")] @@ -243,6 +248,19 @@ impl WireguardMonitor { ) -> Result<WireguardMonitor> { let on_event = args.on_event.clone(); + #[cfg(target_os = "macos")] + let tunnel_ipv4 = config + .tunnel + .addresses + .iter() + .cloned() + .find(|ip| ip.is_ipv4()) + .and_then(|addr| match addr { + IpAddr::V4(addr) => Some(addr), + _ => None, + }) + .ok_or(Error::NoIpv4Address)?; + let endpoint_addrs: Vec<IpAddr> = config.peers.iter().map(|peer| peer.endpoint.ip()).collect(); @@ -355,6 +373,8 @@ impl WireguardMonitor { let routes = Self::get_pre_tunnel_routes(&iface_name, &config) .chain(Self::get_endpoint_routes(&endpoint_addrs)) .collect(); + + log::error!("Routes being added {routes:?}"); args.route_manager .add_routes(routes) .await @@ -395,13 +415,29 @@ impl WireguardMonitor { .unwrap()?; // Add any default route(s) that may exist. + #[cfg(not(target_os = "macos"))] args.route_manager .add_routes(Self::get_post_tunnel_routes(&iface_name, &config).collect()) .await .map_err(Error::SetupRoutingError) .map_err(CloseMsg::SetupError)?; - let metadata = Self::tunnel_metadata(&iface_name, &config); + let relay_addr = config.peers[0].endpoint.ip(); + #[cfg(target_os = "macos")] + args.route_manager + .setup_tunnel_routes( + iface_name.to_string(), + relay_addr, + talpid_routing::TunnelRoutesV4 { + tunnel_gateway: config.ipv4_gateway, + tunnel_ip: tunnel_ipv4, + }, + None, + ) + .await + .map_err(Error::SetupRoutingError) + .map_err(CloseMsg::SetupError)?; + (on_event)(TunnelEvent::Up(metadata)).await; tokio::task::spawn_blocking(move || { @@ -860,7 +896,7 @@ impl WireguardMonitor { gateway_node.clone(), )) .chain(config.ipv6_gateway.map(|gateway| { - RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), gateway_node) + return RequiredRoute::new(ipnetwork::Ipv6Network::from(gateway).into(), gateway_node); })); let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); diff --git a/wireguard/libwg/build-android.sh b/wireguard/libwg/build-android.sh index dcfbb34dff..23a872d621 100755 --- a/wireguard/libwg/build-android.sh +++ b/wireguard/libwg/build-android.sh @@ -54,6 +54,7 @@ for arch in ${ARCHITECTURES:-armv7 aarch64 x86_64 i686}; do mkdir -m 777 -p "$(dirname "$STRIPPED_LIB_PATH")" $ANDROID_STRIP_TOOL --strip-unneeded --strip-debug -o "$STRIPPED_LIB_PATH" "$UNSTRIPPED_LIB_PATH" + cp $UNSTRIPPED_LIB_PATH $STRIPPED_LIB_PATH # Set permissions so that the build server can clean the outputs afterwards chmod 777 "$STRIPPED_LIB_PATH" |
