diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-06-23 13:39:20 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-06-23 13:39:20 +0200 |
| commit | 3ff99131790ca0a953247a07a45909e02a0cf313 (patch) | |
| tree | 361df08ae8b0fbe062f5b04ea9d361a0d97e0a08 | |
| parent | 26022a683c2e1f7aa8fdd912e7e991c5c66118d5 (diff) | |
| parent | 9adc4d7ee90e0c6e4d73f14cf63041021302ce88 (diff) | |
| download | mullvadvpn-3ff99131790ca0a953247a07a45909e02a0cf313.tar.xz mullvadvpn-3ff99131790ca0a953247a07a45909e02a0cf313.zip | |
Merge branch 'linux-refactor-routing'
| -rw-r--r-- | Cargo.lock | 41 | ||||
| -rw-r--r-- | talpid-core/Cargo.toml | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/mod.rs | 17 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/routing.rs | 192 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/systemd_resolved.rs | 19 | ||||
| -rw-r--r-- | talpid-core/src/dns/mod.rs | 16 | ||||
| -rw-r--r-- | talpid-core/src/offline/linux.rs | 201 | ||||
| -rw-r--r-- | talpid-core/src/offline/mod.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/routing/linux.rs | 125 | ||||
| -rw-r--r-- | talpid-core/src/routing/mod.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/routing/unix.rs | 102 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 14 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 52 |
17 files changed, 410 insertions, 432 deletions
diff --git a/Cargo.lock b/Cargo.lock index 1a24bf4a7f..e7f61ca913 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -635,9 +635,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" [[package]] name = "futures" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150" +checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27" dependencies = [ "futures-channel", "futures-core", @@ -650,9 +650,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d31b7ec7efab6eefc7c57233bb10b847986139d88cc2f5a02a1ae6871a1846" +checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" dependencies = [ "futures-core", "futures-sink", @@ -660,15 +660,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65" +checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" [[package]] name = "futures-executor" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9" +checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79" dependencies = [ "futures-core", "futures-task", @@ -677,16 +677,17 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28be053525281ad8259d47e4de5de657b25e7bac113458555bb4b70bc6870500" +checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" [[package]] name = "futures-macro" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c287d25add322d9f9abdcdc5927ca398917996600182178774032e9f8258fedd" +checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" dependencies = [ + "autocfg", "proc-macro-hack", "proc-macro2", "quote", @@ -695,25 +696,23 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf5c69029bda2e743fddd0582d1083951d65cc9539aebf8812f36c3491342d6" +checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" [[package]] name = "futures-task" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13de07eb8ea81ae445aca7b69f5f7bf15d7bf4912d8ca37d6645c77ae8a58d86" -dependencies = [ - "once_cell", -] +checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" [[package]] name = "futures-util" -version = "0.3.12" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b" +checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" dependencies = [ + "autocfg", "futures-channel", "futures-core", "futures-io", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index f84e4a5bd4..9d372a7c00 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -13,7 +13,7 @@ atty = "0.2" cfg-if = "1.0" duct = "0.13" err-derive = "0.3.0" -futures = "0.3" +futures = "0.3.15" hex = "0.4" ipnetwork = "0.16" lazy_static = "1.0" diff --git a/talpid-core/src/dns/linux/mod.rs b/talpid-core/src/dns/linux/mod.rs index 62827211bd..4b046035aa 100644 --- a/talpid-core/src/dns/linux/mod.rs +++ b/talpid-core/src/dns/linux/mod.rs @@ -8,6 +8,7 @@ use self::{ network_manager::NetworkManager, resolvconf::Resolvconf, static_resolv_conf::StaticResolvConf, systemd_resolved::SystemdResolved, }; +use crate::routing::RouteManagerHandle; use std::{env, fmt, net::IpAddr, path::Path}; @@ -40,6 +41,7 @@ pub enum Error { } pub struct DnsMonitor { + route_manager: RouteManagerHandle, handle: tokio::runtime::Handle, inner: Option<DnsMonitorHolder>, } @@ -47,8 +49,13 @@ pub struct DnsMonitor { impl super::DnsMonitorT for DnsMonitor { type Error = Error; - fn new(handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> { + fn new( + handle: tokio::runtime::Handle, + _cache_dir: impl AsRef<Path>, + route_manager: RouteManagerHandle, + ) -> Result<Self> { Ok(DnsMonitor { + route_manager, handle, inner: None, }) @@ -58,7 +65,7 @@ impl super::DnsMonitorT for DnsMonitor { self.reset()?; // Creating a new DNS monitor for each set, in case the system changed how it manages DNS. let mut inner = DnsMonitorHolder::new()?; - inner.set(&self.handle, interface, servers)?; + inner.set(&self.handle, &self.route_manager, interface, servers)?; self.inner = Some(inner); Ok(()) } @@ -128,6 +135,7 @@ impl DnsMonitorHolder { fn set( &mut self, handle: &tokio::runtime::Handle, + route_manager: &RouteManagerHandle, interface: &str, servers: &[IpAddr], ) -> Result<()> { @@ -137,9 +145,8 @@ impl DnsMonitorHolder { StaticResolvConf(ref mut static_resolv_conf) => { static_resolv_conf.set_dns(servers.to_vec())? } - SystemdResolved(ref mut systemd_resolved) => { - handle.block_on(systemd_resolved.set_dns(interface, &servers))? - } + SystemdResolved(ref mut systemd_resolved) => handle + .block_on(systemd_resolved.set_dns(route_manager.clone(), interface, &servers))?, NetworkManager(ref mut network_manager) => { network_manager.set_dns(interface, servers)? } diff --git a/talpid-core/src/dns/linux/routing.rs b/talpid-core/src/dns/linux/routing.rs index 5c90b57a33..e04b4cc29c 100644 --- a/talpid-core/src/dns/linux/routing.rs +++ b/talpid-core/src/dns/linux/routing.rs @@ -1,19 +1,16 @@ -use futures::{ - channel::mpsc::UnboundedSender, future::abortable, FutureExt, StreamExt, TryStream, - TryStreamExt, -}; -use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST}; -use netlink_packet_route::{ - rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage, +use crate::{ + linux::{iface_index, IfaceIndexLookupError}, + routing::{self, RouteManagerHandle}, }; -use rtnetlink::{ - constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_NOTIFY}, - sys::SocketAddr, - Handle, IpVersion, +use futures::{ + channel::mpsc::UnboundedSender, + stream::{abortable, AbortHandle}, + StreamExt, }; +use rtnetlink::IpVersion; use std::{ collections::BTreeMap, - fmt, io, + fmt, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; use talpid_types::ErrorExt; @@ -27,29 +24,20 @@ const PUBLIC_INTERNET_ADDRESS_V6: IpAddr = #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { - #[error(display = "Failed to get a route for an arbitrary IP address")] - GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>), - - #[error(display = "Failed to connect to bind to netlink socket")] - BindError(#[error(source)] io::Error), + #[error(display = "The route manager returned an error")] + RouteManagerError(#[error(source)] routing::Error), - #[error(display = "No netlink response for route query")] - NoRouteError, - - #[error(display = "Route is missing an output interface")] - RouteNoInterfaceError, + #[error(display = "Failed to resolve interface index with error {}", _0)] + InterfaceNameError(#[error(source)] IfaceIndexLookupError), } pub struct DnsRouteMonitor { - _handle: rtnetlink::Handle, - stop_tx: Option<futures::channel::oneshot::Sender<()>>, + abort_handle: AbortHandle, } impl Drop for DnsRouteMonitor { fn drop(&mut self) { - if let Some(stop_tx) = self.stop_tx.take() { - let _ = stop_tx.send(()); - } + self.abort_handle.abort(); } } @@ -70,70 +58,50 @@ impl fmt::Display for DnsConfig { } pub async fn spawn_monitor( + route_manager: RouteManagerHandle, destinations: Vec<IpAddr>, update_tx: UnboundedSender<BTreeMap<u32, DnsConfig>>, ) -> Result<(DnsRouteMonitor, BTreeMap<u32, DnsConfig>)> { - let (mut connection, handle, messages) = - rtnetlink::new_connection().expect("Failed to create a netlink connection"); - - let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_NOTIFY; - let addr = SocketAddr::new(0, mgroup_flags); + let listener = route_manager + .change_listener() + .await + .map_err(Error::RouteManagerError)?; + let (mut listener, abort_handle) = abortable(listener); - connection - .socket_mut() - .bind(&addr) - .map_err(Error::BindError)?; + let monitor = DnsRouteMonitor { abort_handle }; - let (abortable_connection, abort_connection) = abortable(connection); - tokio::spawn(abortable_connection); - - let (stop_tx, stop_rx) = futures::channel::oneshot::channel(); - - let monitor = DnsRouteMonitor { - _handle: handle.clone(), - stop_tx: Some(stop_tx), - }; - - let mut last_config = setup_configurations(&handle, &destinations).await?; + let mut last_config = setup_configurations(&route_manager, &destinations).await?; let initial_config = last_config.clone(); tokio::spawn(async move { - let mut messages = messages.fuse(); - let mut stop_rx = stop_rx.fuse(); - loop { - futures::select! { - _new_message = messages.next() => { - match setup_configurations(&handle, &destinations).await { - Ok(new_config) => { - if last_config != new_config { - last_config = new_config.clone(); - if update_tx.unbounded_send(new_config).is_err() { - log::trace!("Stopping DNS monitor: channel is closed"); - break; - } - } - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to determine new DNS interface settings" - ) - ); + while let Some(_event) = listener.next().await { + match setup_configurations(&route_manager, &destinations).await { + Ok(new_config) => { + if last_config != new_config { + last_config = new_config.clone(); + if update_tx.unbounded_send(new_config).is_err() { + log::trace!("Stopping DNS monitor: channel is closed"); + break; } } - }, - _ = stop_rx => break, + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to determine new DNS interface settings" + ) + ); + } } } - abort_connection.abort(); }); Ok((monitor, initial_config)) } async fn setup_configurations( - handle: &Handle, + handle: &RouteManagerHandle, destinations: &[IpAddr], ) -> Result<BTreeMap<u32, DnsConfig>> { let mut interface_to_destinations = BTreeMap::<u32, DnsConfig>::new(); @@ -141,8 +109,8 @@ async fn setup_configurations( let interface = if destination.is_loopback() { get_default_route_interface(handle, get_ip_version(destination), true).await? } else { - if crate::firewall::is_local_address(&destination) { - get_destination_interface(handle, destination, true).await? + if crate::firewall::is_local_address(destination) { + get_destination_interface(handle, *destination, true).await? } else { get_default_route_interface(handle, get_ip_version(destination), false).await? } @@ -174,80 +142,34 @@ async fn setup_configurations( } async fn get_default_route_interface( - handle: &Handle, + handle: &RouteManagerHandle, ip_version: IpVersion, set_mark: bool, ) -> Result<Option<u32>> { match ip_version { IpVersion::V4 => { - get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V4, set_mark).await + get_destination_interface(handle, PUBLIC_INTERNET_ADDRESS_V4, set_mark).await } IpVersion::V6 => { - get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V6, set_mark).await + get_destination_interface(handle, PUBLIC_INTERNET_ADDRESS_V6, set_mark).await } } } async fn get_destination_interface( - handle: &Handle, - destination: &IpAddr, + handle: &RouteManagerHandle, + destination: IpAddr, set_mark: bool, ) -> Result<Option<u32>> { - let mut request = handle.route().get(get_ip_version(destination)); - let octets = match destination { - IpAddr::V4(address) => address.octets().to_vec(), - IpAddr::V6(address) => address.octets().to_vec(), - }; - let message = request.message_mut(); - if set_mark { - message - .nlas - .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK)); - } - message.header.destination_prefix_length = 8u8 * (octets.len() as u8); - message.header.flags = RouteFlags::RTM_F_FIB_MATCH; - message.nlas.push(RouteNla::Destination(octets)); - let mut stream = execute_route_get_request(handle.clone(), message.clone()); - match stream.try_next().await { - Ok(Some(route_msg)) => { - for nla in &route_msg.nlas { - if let RouteNla::Oif(interface) = nla { - return Ok(Some(*interface)); - } - } - Err(Error::RouteNoInterfaceError) - } - Ok(None) => Err(Error::NoRouteError), - Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { - Ok(None) - } - Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))), - } -} - -pub fn execute_route_get_request( - mut handle: Handle, - message: RouteMessage, -) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> { - use futures::future::{self, Either}; - use rtnetlink::Error; - - let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); - req.header.flags = NLM_F_REQUEST; - - match handle.request(req) { - Ok(response) => Either::Left(response.map(move |msg| { - let (header, payload) = msg.into_parts(); - match payload { - NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg), - NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)), - _ => Err(Error::UnexpectedMessage(NetlinkMessage::new( - header, payload, - ))), - } - })), - Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()), - } + let route = handle + .get_destination_route(destination, set_mark) + .await + .map_err(Error::RouteManagerError)?; + route + .map(|route| route.get_node().get_device().map(iface_index)) + .flatten() + .transpose() + .map_err(Error::InterfaceNameError) } fn get_ip_version(addr: &IpAddr) -> IpVersion { diff --git a/talpid-core/src/dns/linux/systemd_resolved.rs b/talpid-core/src/dns/linux/systemd_resolved.rs index 01b6d3ea22..ea487a3077 100644 --- a/talpid-core/src/dns/linux/systemd_resolved.rs +++ b/talpid-core/src/dns/linux/systemd_resolved.rs @@ -1,4 +1,7 @@ -use crate::linux::{iface_index, IfaceIndexLookupError}; +use crate::{ + linux::{iface_index, IfaceIndexLookupError}, + routing::RouteManagerHandle, +}; use futures::{channel::mpsc, StreamExt}; use std::{ collections::BTreeMap, @@ -56,11 +59,17 @@ impl SystemdResolved { Ok(systemd_resolved) } - pub async fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> { + pub async fn set_dns( + &mut self, + route_manager: RouteManagerHandle, + interface_name: &str, + servers: &[IpAddr], + ) -> Result<()> { let (update_tx, mut update_rx) = mpsc::unbounded(); - let (monitor, initial_config) = super::routing::spawn_monitor(servers.to_vec(), update_tx) - .await - .map_err(Error::SpawnInterfaceMonitor)?; + let (monitor, initial_config) = + super::routing::spawn_monitor(route_manager, servers.to_vec(), update_tx) + .await + .map_err(Error::SpawnInterfaceMonitor)?; let tunnel_index = iface_index(interface_name)?; self.tunnel_index = tunnel_index; diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs index b07ce13eac..229896cf98 100644 --- a/talpid-core/src/dns/mod.rs +++ b/talpid-core/src/dns/mod.rs @@ -1,3 +1,5 @@ +#[cfg(target_os = "linux")] +use crate::routing::RouteManagerHandle; use std::{net::IpAddr, path::Path}; #[cfg(target_os = "macos")] @@ -28,9 +30,18 @@ pub struct DnsMonitor { impl DnsMonitor { /// Returns a new `DnsMonitor` that can set and monitor the system DNS. - pub fn new(handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> { + pub fn new( + handle: tokio::runtime::Handle, + cache_dir: impl AsRef<Path>, + #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, + ) -> Result<Self, Error> { Ok(DnsMonitor { - inner: imp::DnsMonitor::new(handle, cache_dir)?, + inner: imp::DnsMonitor::new( + handle, + cache_dir, + #[cfg(target_os = "linux")] + route_manager, + )?, }) } @@ -61,6 +72,7 @@ trait DnsMonitorT: Sized { fn new( handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>, + #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, ) -> Result<Self, Self::Error>; fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Self::Error>; diff --git a/talpid-core/src/offline/linux.rs b/talpid-core/src/offline/linux.rs index cf255b0541..ceaa864cc7 100644 --- a/talpid-core/src/offline/linux.rs +++ b/talpid-core/src/offline/linux.rs @@ -1,21 +1,12 @@ -use crate::tunnel_state_machine::TunnelCommand; -use futures::{ - channel::{mpsc::UnboundedSender, oneshot}, - FutureExt, StreamExt, TryStream, TryStreamExt, +use crate::{ + routing::{self, RouteManagerHandle}, + tunnel_state_machine::TunnelCommand, }; -use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST}; -use netlink_packet_route::{ - rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage, +use futures::{channel::mpsc::UnboundedSender, StreamExt}; +use std::{ + net::{IpAddr, Ipv4Addr}, + sync::Weak, }; -use rtnetlink::{ - constants::{ - RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_IFADDR, RTMGRP_IPV6_ROUTE, RTMGRP_LINK, - RTMGRP_NOTIFY, - }, - sys::SocketAddr, - Handle, IpVersion, -}; -use std::{io, net::Ipv4Addr, sync::Weak}; use talpid_types::ErrorExt; pub type Result<T> = std::result::Result<T, Error>; @@ -23,52 +14,21 @@ pub type Result<T> = std::result::Result<T, Error>; #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { - #[error(display = "Failed to resolve output interface index")] - GetLinkError(#[error(source)] failure::Compat<rtnetlink::Error>), - - #[error(display = "No netlink response for output interface query")] - NoLinkError, - - #[error(display = "Failed to get list of IP addresses")] - GetAddressesError(#[error(source)] failure::Compat<rtnetlink::Error>), - - #[error(display = "Failed to get a route for an arbitrary IP address")] - GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>), - - #[error(display = "No netlink response for route query")] - NoRouteError, - - #[error(display = "Failed to connect to netlink socket")] - NetlinkConnectionError(#[error(source)] io::Error), - - #[error(display = "Failed to connect to bind to netlink socket")] - BindError(#[error(source)] io::Error), - - #[error(display = "Failed to start listening on netlink socket")] - NetlinkBindError(#[error(source)] io::Error), - - #[error(display = "Error while processing netlink messages")] - MonitorNetlinkError, - - #[error(display = "Netlink connection has unexpectedly disconnected")] - NetlinkDisconnected, - - #[error(display = "Failed to initialize event loop")] - EventLoopError(#[error(source)] io::Error), + #[error(display = "The route manager returned an error")] + RouteManagerError(#[error(source)] routing::Error), } pub struct MonitorHandle { - handle: rtnetlink::Handle, - _stop_connection_tx: oneshot::Sender<()>, + route_manager: RouteManagerHandle, } // Mullvad API's public IP address, correct at the time of writing, but any public IP address will // work. -const PUBLIC_INTERNET_ADDRESS: Ipv4Addr = Ipv4Addr::new(193, 138, 218, 78); +const PUBLIC_INTERNET_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78)); impl MonitorHandle { pub async fn is_offline(&mut self) -> bool { - match public_ip_unreachable(&self.handle).await { + match public_ip_unreachable(&self.route_manager).await { Ok(is_offline) => is_offline, Err(err) => { log::error!( @@ -81,46 +41,28 @@ impl MonitorHandle { } } -pub async fn spawn_monitor(sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<MonitorHandle> { - let (mut connection, handle, mut messages) = - rtnetlink::new_connection().map_err(Error::NetlinkConnectionError)?; - - let mgroup_flags = RTMGRP_IPV4_IFADDR - | RTMGRP_IPV4_ROUTE - | RTMGRP_IPV6_IFADDR - | RTMGRP_IPV6_ROUTE - | RTMGRP_LINK - | RTMGRP_NOTIFY; - let addr = SocketAddr::new(0, mgroup_flags); - - connection - .socket_mut() - .bind(&addr) - .map_err(Error::BindError)?; +pub async fn spawn_monitor( + sender: Weak<UnboundedSender<TunnelCommand>>, + route_manager: RouteManagerHandle, +) -> Result<MonitorHandle> { + let mut is_offline = public_ip_unreachable(&route_manager).await?; - let (stop_connection_tx, stop_rx) = oneshot::channel(); - - // Connection will be closed once the channel is dropped - tokio::spawn(async { - futures::select! { - _ = connection.fuse() => (), - _ = stop_rx.fuse() => (), - } - }); - let mut is_offline = public_ip_unreachable(&handle).await?; + let mut listener = route_manager + .change_listener() + .await + .map_err(Error::RouteManagerError)?; let monitor_handle = MonitorHandle { - handle: handle.clone(), - _stop_connection_tx: stop_connection_tx, + route_manager: route_manager.clone(), }; - tokio::spawn(async move { - while let Some(_new_message) = messages.next().await { + while let Some(_event) = listener.next().await { match sender.upgrade() { Some(sender) => { - let new_offline_state = - public_ip_unreachable(&handle).await.unwrap_or_else(|err| { + let new_offline_state = public_ip_unreachable(&route_manager) + .await + .unwrap_or_else(|err| { log::error!( "{}", err.display_chain_with_msg("Failed to infer offline state") @@ -141,89 +83,10 @@ pub async fn spawn_monitor(sender: Weak<UnboundedSender<TunnelCommand>>) -> Resu } -async fn public_ip_unreachable(handle: &Handle) -> Result<bool> { - let mut request = handle.route().get(IpVersion::V4); - let message = request.message_mut(); - message - .nlas - .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK)); - message.nlas.push(RouteNla::Destination( - PUBLIC_INTERNET_ADDRESS.octets().to_vec(), - )); - message.header.destination_prefix_length = 32; - message.header.flags = RouteFlags::RTM_F_LOOKUP_TABLE; - let mut stream = execute_route_get_request(handle.clone(), message.clone()); - match stream.try_next().await { - // Presance of any route implies connectivity, even if it's a loopback route - Ok(Some(_)) => Ok(false), - Ok(None) => Err(Error::NoRouteError), - // ENETUNREACH implies that there exists no route that'd reach our random API address, - // as such, the host is assumed to be offline - Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { - Ok(true) - } - Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))), - } -} - -pub fn execute_route_get_request( - mut handle: Handle, - message: RouteMessage, -) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> { - use futures::future::{self, Either}; - use rtnetlink::Error; - - let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); - req.header.flags = NLM_F_REQUEST; - - match handle.request(req) { - Ok(response) => Either::Left(response.map(move |msg| { - let (header, payload) = msg.into_parts(); - match payload { - NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg), - NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)), - _ => Err(Error::UnexpectedMessage(NetlinkMessage::new( - header, payload, - ))), - } - })), - Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()), - } -} - -#[cfg(test)] -mod test { - use super::*; - use rtnetlink::{ - constants::{ - RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_IFADDR, RTMGRP_IPV6_ROUTE, - RTMGRP_LINK, RTMGRP_NOTIFY, - }, - sys::SocketAddr, - }; - - #[test] - fn test_route_table_query() { - let mut runtime = tokio::runtime::Runtime::new().expect("failed to initialize runtime"); - let (mut connection, handle, _) = runtime.block_on(async { - rtnetlink::new_connection() - .map_err(Error::NetlinkConnectionError) - .expect("Failed to create a netlink connection") - }); - - let mgroup_flags = RTMGRP_IPV4_IFADDR - | RTMGRP_IPV4_ROUTE - | RTMGRP_IPV6_IFADDR - | RTMGRP_IPV6_ROUTE - | RTMGRP_LINK - | RTMGRP_NOTIFY; - let addr = SocketAddr::new(0, mgroup_flags); - - connection.socket_mut().bind(&addr).unwrap(); - runtime.spawn(connection); - - runtime - .block_on(public_ip_unreachable(&handle)) - .expect("Failed to query routing table"); - } +async fn public_ip_unreachable(handle: &RouteManagerHandle) -> Result<bool> { + Ok(handle + .get_destination_route(PUBLIC_INTERNET_ADDRESS, true) + .await + .map_err(Error::RouteManagerError)? + .is_none()) } diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs index b059d1f216..ac8b10e222 100644 --- a/talpid-core/src/offline/mod.rs +++ b/talpid-core/src/offline/mod.rs @@ -1,3 +1,5 @@ +#[cfg(target_os = "linux")] +use crate::routing::RouteManagerHandle; use crate::tunnel_state_machine::TunnelCommand; use futures::channel::mpsc::UnboundedSender; use std::sync::Weak; @@ -42,12 +44,15 @@ impl MonitorHandle { pub async fn spawn_monitor( sender: Weak<UnboundedSender<TunnelCommand>>, + #[cfg(target_os = "linux")] route_manager: RouteManagerHandle, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<MonitorHandle, Error> { let monitor = if !*FORCE_DISABLE_OFFLINE_MONITOR { Some( imp::spawn_monitor( sender, + #[cfg(target_os = "linux")] + route_manager, #[cfg(target_os = "android")] android_context, ) diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index afc02eedb2..a63430fed3 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -1,4 +1,7 @@ -use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route}; +use crate::routing::{ + imp::{CallbackMessage, RouteManagerCommand}, + NetNode, Node, RequiredRoute, Route, +}; use std::{ collections::{BTreeMap, HashSet}, io, @@ -6,11 +9,15 @@ use std::{ }; use talpid_types::ErrorExt; -use futures::{channel::mpsc::UnboundedReceiver, future::FutureExt, StreamExt, TryStreamExt}; +use futures::{ + channel::mpsc::{UnboundedReceiver, UnboundedSender}, + future::FutureExt, + StreamExt, TryStream, TryStreamExt, +}; use ipnetwork::IpNetwork; use lazy_static::lazy_static; use netlink_packet_route::{ - constants::{ARPHRD_LOOPBACK, FIB_RULE_INVERT, FR_ACT_TO_TBL}, + constants::{ARPHRD_LOOPBACK, FIB_RULE_INVERT, FR_ACT_TO_TBL, NLM_F_REQUEST}, link::{nlas::Nla as LinkNla, LinkMessage}, route::{nlas::Nla as RouteNla, RouteHeader, RouteMessage}, rtnl::{ @@ -26,7 +33,7 @@ use netlink_packet_route::{ use rtnetlink::{ constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_LINK, RTMGRP_NOTIFY}, sys::SocketAddr, - Handle, + Handle, IpVersion, }; use libc::{AF_INET, AF_INET6}; @@ -102,6 +109,12 @@ pub enum Error { #[error(display = "Unknown device index - {}", _0)] UnknownDeviceIndex(u32), + #[error(display = "Failed to get a route for the given IP address")] + GetRouteError(#[error(source)] rtnetlink::Error), + + #[error(display = "No netlink response for route query")] + NoRouteError, + /// Unable to create routing table for tagged connections and packets. #[error(display = "Cannot find a free routing table ID")] NoFreeRoutingTableId, @@ -115,6 +128,7 @@ pub struct RouteManagerImpl { handle: Handle, messages: UnboundedReceiver<(NetlinkMessage<RtnlMessage>, SocketAddr)>, iface_map: BTreeMap<u32, NetworkInterface>, + listeners: Vec<UnboundedSender<CallbackMessage>>, // currently added routes added_routes: HashSet<Route>, @@ -137,9 +151,10 @@ impl RouteManagerImpl { let iface_map = Self::initialize_link_map(&handle).await?; let mut monitor = Self { - iface_map, handle, messages, + iface_map, + listeners: vec![], added_routes: HashSet::new(), }; @@ -295,8 +310,8 @@ impl RouteManagerImpl { .map(|(idx, _name)| *idx) } - async fn process_deleted_route(&mut self, route: Route) -> Result<()> { - self.added_routes.remove(&route); + fn process_deleted_route(&mut self, route: &Route) -> Result<()> { + self.added_routes.remove(route); Ok(()) } @@ -347,6 +362,12 @@ impl RouteManagerImpl { RouteManagerCommand::ClearRoutingRules(result_tx) => { let _ = result_tx.send(self.clear_routing_rules().await); } + RouteManagerCommand::NewChangeListener(result_tx) => { + let _ = result_tx.send(self.listen()); + } + RouteManagerCommand::GetDestinationRoute(destination, set_mark, result_tx) => { + let _ = result_tx.send(self.get_destination_route(&destination, set_mark).await); + } RouteManagerCommand::ClearRoutes => { log::debug!("Clearing routes"); self.cleanup_routes().await; @@ -367,9 +388,15 @@ impl RouteManagerImpl { self.iface_map.remove(&idx); } } + NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(new_route)) => { + if let Some(addition) = self.parse_route_message(new_route)? { + self.notify_change_listeners(CallbackMessage::NewRoute(addition)); + } + } NetlinkPayload::InnerMessage(RtnlMessage::DelRoute(old_route)) => { if let Some(deletion) = self.parse_route_message(old_route)? { - self.process_deleted_route(deletion).await?; + self.process_deleted_route(&deletion)?; + self.notify_change_listeners(CallbackMessage::DelRoute(deletion)); } } _ => (), @@ -377,18 +404,13 @@ impl RouteManagerImpl { Ok(()) } - // Tries to coax a Route out of a RouteMessage, but only if it's a route from the main routing - // table - // TODO: Change to account for different routing tables. - fn parse_route_message(&self, msg: RouteMessage) -> Result<Option<Route>> { - if msg.header.table != RT_TABLE_MAIN { - return Ok(None); - } - self.parse_route_message_inner(msg) + fn notify_change_listeners(&mut self, message: CallbackMessage) { + self.listeners + .retain(|listener| listener.unbounded_send(message.clone()).is_ok()); } // Tries to coax a Route out of a RouteMessage - fn parse_route_message_inner(&self, msg: RouteMessage) -> Result<Option<Route>> { + fn parse_route_message(&self, msg: RouteMessage) -> Result<Option<Route>> { let af_spec = msg.header.address_family; let destination_length = msg.header.destination_prefix_length; let is_ipv4 = match af_spec as i32 { @@ -686,6 +708,12 @@ impl RouteManagerImpl { Ok(()) } + fn listen(&mut self) -> UnboundedReceiver<CallbackMessage> { + let (tx, rx) = futures::channel::mpsc::unbounded(); + self.listeners.push(tx); + rx + } + async fn destructor(&mut self) { self.cleanup_routes().await; @@ -696,6 +724,36 @@ impl RouteManagerImpl { ); } } + + async fn get_destination_route( + &self, + destination: &IpAddr, + set_mark: bool, + ) -> Result<Option<Route>> { + let mut request = self.handle.route().get(get_ip_version(destination)); + let octets = match destination { + IpAddr::V4(address) => address.octets().to_vec(), + IpAddr::V6(address) => address.octets().to_vec(), + }; + let message = request.message_mut(); + if set_mark { + message + .nlas + .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK)); + } + message.header.destination_prefix_length = 8u8 * (octets.len() as u8); + message.header.flags = RouteFlags::RTM_F_FIB_MATCH; + message.nlas.push(RouteNla::Destination(octets)); + let mut stream = execute_route_get_request(self.handle.clone(), message.clone()); + match stream.try_next().await { + Ok(Some(route_msg)) => self.parse_route_message(route_msg), + Ok(None) => Err(Error::NoRouteError), + Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { + Ok(None) + } + Err(err) => Err(Error::GetRouteError(err)), + } + } } fn ip_to_bytes(addr: IpAddr) -> Vec<u8> { @@ -714,6 +772,39 @@ fn compat_table_id(id: u32) -> u8 { } } +fn get_ip_version(addr: &IpAddr) -> IpVersion { + if addr.is_ipv4() { + IpVersion::V4 + } else { + IpVersion::V6 + } +} + +fn execute_route_get_request( + mut handle: Handle, + message: RouteMessage, +) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> { + use futures::future::{self, Either}; + use rtnetlink::Error; + + let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); + req.header.flags = NLM_F_REQUEST; + + match handle.request(req) { + Ok(response) => Either::Left(response.map(move |msg| { + let (header, payload) = msg.into_parts(); + match payload { + NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg), + NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)), + _ => Err(Error::UnexpectedMessage(NetlinkMessage::new( + header, payload, + ))), + } + })), + Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()), + } +} + #[derive(Debug)] struct NetworkInterface { name: String, diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs index a8491a8588..8dc8dcd778 100644 --- a/talpid-core/src/routing/mod.rs +++ b/talpid-core/src/routing/mod.rs @@ -45,6 +45,11 @@ impl Route { self.table_id = new_id; self } + + /// Returns the network node of the route. + pub fn get_node(&self) -> &Node { + &self.node + } } impl fmt::Display for Route { diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index c1d2da522b..989ec7ad24 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -2,6 +2,8 @@ #![cfg_attr(target_os = "windows", allow(dead_code))] // TODO: remove the allow(dead_code) for android once it's up to scratch. use super::RequiredRoute; +#[cfg(target_os = "linux")] +use super::Route; use futures::channel::{ mpsc::{self, UnboundedSender}, @@ -9,6 +11,12 @@ use futures::channel::{ }; use std::{collections::HashSet, io}; +#[cfg(target_os = "linux")] +use futures::stream::Stream; + +#[cfg(target_os = "linux")] +use std::net::IpAddr; + #[cfg(target_os = "macos")] #[path = "macos.rs"] mod imp; @@ -46,26 +54,25 @@ pub enum Error { /// Handle to a route manager. #[derive(Clone)] pub struct RouteManagerHandle { - runtime: tokio::runtime::Handle, tx: UnboundedSender<RouteManagerCommand>, } impl RouteManagerHandle { /// Applies the given routes while the route manager is running. - pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { + pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { let (response_tx, response_rx) = oneshot::channel(); self.tx .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) .map_err(|_| Error::RouteManagerDown)?; - self.runtime - .block_on(response_rx) + response_rx + .await .map_err(|_| Error::ManagerChannelDown)? .map_err(Error::PlatformError) } /// Ensure that packets are routed using the correct tables. #[cfg(target_os = "linux")] - pub fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { + pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> { let (response_tx, response_rx) = oneshot::channel(); self.tx .unbounded_send(RouteManagerCommand::CreateRoutingRules( @@ -73,21 +80,52 @@ impl RouteManagerHandle { response_tx, )) .map_err(|_| Error::RouteManagerDown)?; - self.runtime - .block_on(response_rx) + response_rx + .await .map_err(|_| Error::ManagerChannelDown)? .map_err(Error::PlatformError) } /// Remove any routing rules created by [`create_routing_rules`]. #[cfg(target_os = "linux")] - pub fn clear_routing_rules(&self) -> Result<(), Error> { + pub async fn clear_routing_rules(&self) -> Result<(), Error> { let (response_tx, response_rx) = oneshot::channel(); self.tx .unbounded_send(RouteManagerCommand::ClearRoutingRules(response_tx)) .map_err(|_| Error::RouteManagerDown)?; - self.runtime - .block_on(response_rx) + response_rx + .await + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) + } + + /// Listen for route changes. + #[cfg(target_os = "linux")] + pub async fn change_listener(&self) -> Result<impl Stream<Item = CallbackMessage>, Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::NewChangeListener(response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown) + } + + /// Listen for route changes. + #[cfg(target_os = "linux")] + pub async fn get_destination_route( + &self, + destination: IpAddr, + set_mark: bool, + ) -> Result<Option<Route>, Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::GetDestinationRoute( + destination, + set_mark, + response_tx, + )) + .map_err(|_| Error::RouteManagerDown)?; + response_rx + .await .map_err(|_| Error::ManagerChannelDown)? .map_err(Error::PlatformError) } @@ -106,6 +144,21 @@ pub(crate) enum RouteManagerCommand { CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>), #[cfg(target_os = "linux")] ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>), + #[cfg(target_os = "linux")] + NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>), + #[cfg(target_os = "linux")] + GetDestinationRoute( + IpAddr, + bool, + oneshot::Sender<Result<Option<Route>, PlatformError>>, + ), +} + +#[cfg(target_os = "linux")] +#[derive(Debug, Clone)] +pub enum CallbackMessage { + NewRoute(Route), + DelRoute(Route), } /// RouteManager applies a set of routes to the route table. @@ -120,12 +173,12 @@ impl RouteManager { /// Constructs a RouteManager and applies the required routes. /// Takes a set of network destinations and network nodes as an argument, and applies said /// routes. - pub fn new( + pub async fn new( runtime: tokio::runtime::Handle, required_routes: HashSet<RequiredRoute>, ) -> Result<Self, Error> { let (manage_tx, manage_rx) = mpsc::unbounded(); - let manager = runtime.block_on(imp::RouteManagerImpl::new(required_routes))?; + let manager = imp::RouteManagerImpl::new(required_routes).await?; runtime.spawn(manager.run(manage_rx)); Ok(Self { @@ -135,7 +188,7 @@ impl RouteManager { } /// Stops RouteManager and removes all of the applied routes. - pub fn stop(&mut self) { + pub async fn stop(&mut self) { if let Some(tx) = self.manage_tx.take() { let (wait_tx, wait_rx) = oneshot::channel(); @@ -147,14 +200,14 @@ impl RouteManager { return; } - if self.runtime.block_on(wait_rx).is_err() { + if wait_rx.await.is_err() { log::error!("{}", Error::ManagerChannelDown); } } } /// Applies the given routes until [`RouteManager::stop`] is called. - pub fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { + pub async fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { if let Some(tx) = &self.manage_tx { let (result_tx, result_rx) = oneshot::channel(); if tx @@ -164,8 +217,8 @@ impl RouteManager { return Err(Error::RouteManagerDown); } - self.runtime - .block_on(result_rx) + result_rx + .await .map_err(|_| Error::ManagerChannelDown)? .map_err(Error::PlatformError) } else { @@ -188,23 +241,20 @@ impl RouteManager { /// Ensure that packets are routed using the correct tables. #[cfg(target_os = "linux")] - pub fn create_routing_rules(&mut self, enable_ipv6: bool) -> Result<(), Error> { - self.handle()?.create_routing_rules(enable_ipv6) + pub async fn create_routing_rules(&mut self, enable_ipv6: bool) -> Result<(), Error> { + self.handle()?.create_routing_rules(enable_ipv6).await } /// Remove any routing rules created by [`create_routing_rules`]. #[cfg(target_os = "linux")] - pub fn clear_routing_rules(&mut self) -> Result<(), Error> { - self.handle()?.clear_routing_rules() + pub async fn clear_routing_rules(&mut self) -> Result<(), Error> { + self.handle()?.clear_routing_rules().await } /// Retrieve a sender directly to the command channel. pub fn handle(&self) -> Result<RouteManagerHandle, Error> { if let Some(tx) = &self.manage_tx { - Ok(RouteManagerHandle { - runtime: self.runtime.clone(), - tx: tx.clone(), - }) + Ok(RouteManagerHandle { tx: tx.clone() }) } else { Err(Error::RouteManagerDown) } @@ -219,6 +269,6 @@ impl RouteManager { impl Drop for RouteManager { fn drop(&mut self) { - self.stop(); + self.runtime.clone().block_on(self.stop()); } } diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index ca5fbb8ea2..93574884f6 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -41,20 +41,17 @@ pub struct RouteManager { /// Handle to a route manager. #[derive(Clone)] pub struct RouteManagerHandle { - runtime: tokio::runtime::Handle, tx: UnboundedSender<RouteManagerCommand>, } impl RouteManagerHandle { /// Applies the given routes while the route manager is running. - pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { let (response_tx, response_rx) = oneshot::channel(); self.tx .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) .map_err(|_| Error::RouteManagerDown)?; - self.runtime - .block_on(response_rx) - .map_err(|_| Error::ManagerChannelDown)? + response_rx.await.map_err(|_| Error::ManagerChannelDown)? } } @@ -67,7 +64,7 @@ pub enum RouteManagerCommand { impl RouteManager { /// Creates a new route manager that will apply the provided routes and ensure they exist until /// it's stopped. - pub fn new( + pub async fn new( runtime: tokio::runtime::Handle, required_routes: HashSet<RequiredRoute>, ) -> Result<Self> { @@ -89,10 +86,7 @@ impl RouteManager { /// Retrieve a sender directly to the command channel. pub fn handle(&self) -> Result<RouteManagerHandle> { if let Some(tx) = &self.manage_tx { - Ok(RouteManagerHandle { - runtime: self.runtime.clone(), - tx: tx.clone(), - }) + Ok(RouteManagerHandle { tx: tx.clone() }) } else { Err(Error::RouteManagerDown) } diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index e3709b0183..87dbc55101 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -1028,19 +1028,14 @@ mod event_server { .filter(|route| route.prefix.is_ipv4() || ipv6_enabled) .collect(); - tokio::task::spawn_blocking(move || { - if let Err(error) = route_handle.add_routes(routes) { - log::error!("{}", error.display_chain()); - return Err(tonic::Status::failed_precondition("Failed to add routes")); - } - if let Err(error) = route_handle.create_routing_rules(ipv6_enabled) { - log::error!("{}", error.display_chain()); - return Err(tonic::Status::failed_precondition("Failed to add routes")); - } - Ok(()) - }) - .await - .map_err(|_| tonic::Status::internal("task failed to complete"))??; + if let Err(error) = route_handle.add_routes(routes).await { + log::error!("{}", error.display_chain()); + return Err(tonic::Status::failed_precondition("Failed to add routes")); + } + if let Err(error) = route_handle.create_routing_rules(ipv6_enabled).await { + log::error!("{}", error.display_chain()); + return Err(tonic::Status::failed_precondition("Failed to add routes")); + } } let tunnel_alias = env diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 4d7fa478b1..899589dbc4 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -253,20 +253,24 @@ impl WireguardMonitor { } } - let setup_iface_routes = move || -> Result<()> { + let setup_iface_routes = || -> Result<()> { #[cfg(target_os = "windows")] if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) { return Err(Error::SetIpAddressesError); } - #[cfg(target_os = "linux")] - route_handle - .create_routing_rules(config.enable_ipv6) - .map_err(Error::SetupRoutingError)?; + runtime.block_on(async move { + #[cfg(target_os = "linux")] + route_handle + .create_routing_rules(config.enable_ipv6) + .await + .map_err(Error::SetupRoutingError)?; - route_handle - .add_routes(Self::get_routes(&iface_name, &config)) - .map_err(Error::SetupRoutingError) + route_handle + .add_routes(Self::get_routes(&iface_name, &config)) + .await + .map_err(Error::SetupRoutingError) + }) }; if let Err(error) = setup_iface_routes() { diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index ea8ce8e992..72af2b6f38 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -139,7 +139,10 @@ impl ConnectedState { log::error!("{}", error.display_chain_with_msg("Failed to clear routes")); } #[cfg(target_os = "linux")] - if let Err(error) = shared_values.route_manager.clear_routing_rules() { + if let Err(error) = shared_values + .runtime + .block_on(shared_values.route_manager.clear_routing_rules()) + { log::error!( "{}", error.display_chain_with_msg("Failed to clear routing rules") diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 8486ae0a3e..b0c87acdb4 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -204,7 +204,10 @@ impl ConnectingState { log::error!("{}", error.display_chain_with_msg("Failed to clear routes")); } #[cfg(target_os = "linux")] - if let Err(error) = shared_values.route_manager.clear_routing_rules() { + if let Err(error) = shared_values + .runtime + .block_on(shared_values.route_manager.clear_routing_rules()) + { log::error!( "{}", error.display_chain_with_msg("Failed to clear routing rules") diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 1bcc6961d9..38496865a4 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -89,18 +89,10 @@ pub async fn spawn( ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); - let mut offline_monitor = offline::spawn_monitor( - Arc::downgrade(&command_tx), - #[cfg(target_os = "android")] - android_context.clone(), - ) - .await - .map_err(Error::OfflineMonitorError)?; - let is_offline = offline_monitor.is_offline().await; let tun_provider = TunProvider::new( #[cfg(target_os = "android")] - android_context, + android_context.clone(), #[cfg(target_os = "android")] allow_lan, #[cfg(target_os = "android")] @@ -112,12 +104,13 @@ pub async fn spawn( let runtime = tokio::runtime::Handle::current(); let (startup_result_tx, startup_result_rx) = sync_mpsc::channel(); + let weak_command_tx = Arc::downgrade(&command_tx); std::thread::spawn(move || { - let state_machine = TunnelStateMachine::new( + let state_machine = runtime.block_on(TunnelStateMachine::new( runtime.clone(), + weak_command_tx, allow_lan, block_when_disconnected, - is_offline, dns_servers, allowed_endpoint, tunnel_parameters_generator, @@ -127,7 +120,9 @@ pub async fn spawn( cache_dir, command_rx, reset_firewall, - ); + #[cfg(target_os = "android")] + android_context, + )); let state_machine = match state_machine { Ok(state_machine) => { startup_result_tx.send(Ok(())).unwrap(); @@ -144,8 +139,6 @@ pub async fn spawn( if shutdown_tx.send(()).is_err() { log::error!("Can't send shutdown completion to daemon"); } - - std::mem::drop(offline_monitor); }); startup_result_rx @@ -199,11 +192,11 @@ struct TunnelStateMachine { } impl TunnelStateMachine { - fn new( + async fn new( runtime: tokio::runtime::Handle, + command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>, allow_lan: bool, block_when_disconnected: bool, - is_offline: bool, dns_servers: Option<Vec<IpAddr>>, allowed_endpoint: Endpoint, tunnel_parameters_generator: impl TunnelParametersGenerator, @@ -213,6 +206,7 @@ impl TunnelStateMachine { cache_dir: impl AsRef<Path>, commands: mpsc::UnboundedReceiver<TunnelCommand>, reset_firewall: bool, + #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result<Self, Error> { let args = FirewallArguments { initialize_blocked: block_when_disconnected || !reset_firewall, @@ -221,15 +215,36 @@ impl TunnelStateMachine { }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; - let dns_monitor = - DnsMonitor::new(runtime.clone(), cache_dir).map_err(Error::InitDnsMonitorError)?; let route_manager = RouteManager::new(runtime.clone(), HashSet::new()) + .await .map_err(Error::InitRouteManagerError)?; + let dns_monitor = DnsMonitor::new( + runtime.clone(), + cache_dir, + #[cfg(target_os = "linux")] + route_manager + .handle() + .map_err(Error::InitRouteManagerError)?, + ) + .map_err(Error::InitDnsMonitorError)?; + let mut offline_monitor = offline::spawn_monitor( + command_tx, + #[cfg(target_os = "linux")] + route_manager + .handle() + .map_err(Error::InitRouteManagerError)?, + #[cfg(target_os = "android")] + android_context, + ) + .await + .map_err(Error::OfflineMonitorError)?; + let is_offline = offline_monitor.is_offline().await; let mut shared_values = SharedTunnelStateValues { runtime, firewall, dns_monitor, route_manager, + _offline_monitor: offline_monitor, allow_lan, block_when_disconnected, is_offline, @@ -299,6 +314,7 @@ struct SharedTunnelStateValues { firewall: Firewall, dns_monitor: DnsMonitor, route_manager: RouteManager, + _offline_monitor: offline::MonitorHandle, /// Should LAN access be allowed outside the tunnel. allow_lan: bool, /// Should network access be allowed when in the disconnected state. |
