summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-06-23 13:39:20 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-06-23 13:39:20 +0200
commit3ff99131790ca0a953247a07a45909e02a0cf313 (patch)
tree361df08ae8b0fbe062f5b04ea9d361a0d97e0a08
parent26022a683c2e1f7aa8fdd912e7e991c5c66118d5 (diff)
parent9adc4d7ee90e0c6e4d73f14cf63041021302ce88 (diff)
downloadmullvadvpn-3ff99131790ca0a953247a07a45909e02a0cf313.tar.xz
mullvadvpn-3ff99131790ca0a953247a07a45909e02a0cf313.zip
Merge branch 'linux-refactor-routing'
-rw-r--r--Cargo.lock41
-rw-r--r--talpid-core/Cargo.toml2
-rw-r--r--talpid-core/src/dns/linux/mod.rs17
-rw-r--r--talpid-core/src/dns/linux/routing.rs192
-rw-r--r--talpid-core/src/dns/linux/systemd_resolved.rs19
-rw-r--r--talpid-core/src/dns/mod.rs16
-rw-r--r--talpid-core/src/offline/linux.rs201
-rw-r--r--talpid-core/src/offline/mod.rs5
-rw-r--r--talpid-core/src/routing/linux.rs125
-rw-r--r--talpid-core/src/routing/mod.rs5
-rw-r--r--talpid-core/src/routing/unix.rs102
-rw-r--r--talpid-core/src/routing/windows.rs14
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs21
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs20
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs5
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs5
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs52
17 files changed, 410 insertions, 432 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 1a24bf4a7f..e7f61ca913 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -635,9 +635,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
[[package]]
name = "futures"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150"
+checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27"
dependencies = [
"futures-channel",
"futures-core",
@@ -650,9 +650,9 @@ dependencies = [
[[package]]
name = "futures-channel"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f2d31b7ec7efab6eefc7c57233bb10b847986139d88cc2f5a02a1ae6871a1846"
+checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2"
dependencies = [
"futures-core",
"futures-sink",
@@ -660,15 +660,15 @@ dependencies = [
[[package]]
name = "futures-core"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65"
+checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1"
[[package]]
name = "futures-executor"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9"
+checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79"
dependencies = [
"futures-core",
"futures-task",
@@ -677,16 +677,17 @@ dependencies = [
[[package]]
name = "futures-io"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28be053525281ad8259d47e4de5de657b25e7bac113458555bb4b70bc6870500"
+checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1"
[[package]]
name = "futures-macro"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c287d25add322d9f9abdcdc5927ca398917996600182178774032e9f8258fedd"
+checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121"
dependencies = [
+ "autocfg",
"proc-macro-hack",
"proc-macro2",
"quote",
@@ -695,25 +696,23 @@ dependencies = [
[[package]]
name = "futures-sink"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "caf5c69029bda2e743fddd0582d1083951d65cc9539aebf8812f36c3491342d6"
+checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282"
[[package]]
name = "futures-task"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "13de07eb8ea81ae445aca7b69f5f7bf15d7bf4912d8ca37d6645c77ae8a58d86"
-dependencies = [
- "once_cell",
-]
+checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae"
[[package]]
name = "futures-util"
-version = "0.3.12"
+version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b"
+checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967"
dependencies = [
+ "autocfg",
"futures-channel",
"futures-core",
"futures-io",
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index f84e4a5bd4..9d372a7c00 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -13,7 +13,7 @@ atty = "0.2"
cfg-if = "1.0"
duct = "0.13"
err-derive = "0.3.0"
-futures = "0.3"
+futures = "0.3.15"
hex = "0.4"
ipnetwork = "0.16"
lazy_static = "1.0"
diff --git a/talpid-core/src/dns/linux/mod.rs b/talpid-core/src/dns/linux/mod.rs
index 62827211bd..4b046035aa 100644
--- a/talpid-core/src/dns/linux/mod.rs
+++ b/talpid-core/src/dns/linux/mod.rs
@@ -8,6 +8,7 @@ use self::{
network_manager::NetworkManager, resolvconf::Resolvconf, static_resolv_conf::StaticResolvConf,
systemd_resolved::SystemdResolved,
};
+use crate::routing::RouteManagerHandle;
use std::{env, fmt, net::IpAddr, path::Path};
@@ -40,6 +41,7 @@ pub enum Error {
}
pub struct DnsMonitor {
+ route_manager: RouteManagerHandle,
handle: tokio::runtime::Handle,
inner: Option<DnsMonitorHolder>,
}
@@ -47,8 +49,13 @@ pub struct DnsMonitor {
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
- fn new(handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> {
+ fn new(
+ handle: tokio::runtime::Handle,
+ _cache_dir: impl AsRef<Path>,
+ route_manager: RouteManagerHandle,
+ ) -> Result<Self> {
Ok(DnsMonitor {
+ route_manager,
handle,
inner: None,
})
@@ -58,7 +65,7 @@ impl super::DnsMonitorT for DnsMonitor {
self.reset()?;
// Creating a new DNS monitor for each set, in case the system changed how it manages DNS.
let mut inner = DnsMonitorHolder::new()?;
- inner.set(&self.handle, interface, servers)?;
+ inner.set(&self.handle, &self.route_manager, interface, servers)?;
self.inner = Some(inner);
Ok(())
}
@@ -128,6 +135,7 @@ impl DnsMonitorHolder {
fn set(
&mut self,
handle: &tokio::runtime::Handle,
+ route_manager: &RouteManagerHandle,
interface: &str,
servers: &[IpAddr],
) -> Result<()> {
@@ -137,9 +145,8 @@ impl DnsMonitorHolder {
StaticResolvConf(ref mut static_resolv_conf) => {
static_resolv_conf.set_dns(servers.to_vec())?
}
- SystemdResolved(ref mut systemd_resolved) => {
- handle.block_on(systemd_resolved.set_dns(interface, &servers))?
- }
+ SystemdResolved(ref mut systemd_resolved) => handle
+ .block_on(systemd_resolved.set_dns(route_manager.clone(), interface, &servers))?,
NetworkManager(ref mut network_manager) => {
network_manager.set_dns(interface, servers)?
}
diff --git a/talpid-core/src/dns/linux/routing.rs b/talpid-core/src/dns/linux/routing.rs
index 5c90b57a33..e04b4cc29c 100644
--- a/talpid-core/src/dns/linux/routing.rs
+++ b/talpid-core/src/dns/linux/routing.rs
@@ -1,19 +1,16 @@
-use futures::{
- channel::mpsc::UnboundedSender, future::abortable, FutureExt, StreamExt, TryStream,
- TryStreamExt,
-};
-use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST};
-use netlink_packet_route::{
- rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage,
+use crate::{
+ linux::{iface_index, IfaceIndexLookupError},
+ routing::{self, RouteManagerHandle},
};
-use rtnetlink::{
- constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_NOTIFY},
- sys::SocketAddr,
- Handle, IpVersion,
+use futures::{
+ channel::mpsc::UnboundedSender,
+ stream::{abortable, AbortHandle},
+ StreamExt,
};
+use rtnetlink::IpVersion;
use std::{
collections::BTreeMap,
- fmt, io,
+ fmt,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
use talpid_types::ErrorExt;
@@ -27,29 +24,20 @@ const PUBLIC_INTERNET_ADDRESS_V6: IpAddr =
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
- #[error(display = "Failed to get a route for an arbitrary IP address")]
- GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>),
-
- #[error(display = "Failed to connect to bind to netlink socket")]
- BindError(#[error(source)] io::Error),
+ #[error(display = "The route manager returned an error")]
+ RouteManagerError(#[error(source)] routing::Error),
- #[error(display = "No netlink response for route query")]
- NoRouteError,
-
- #[error(display = "Route is missing an output interface")]
- RouteNoInterfaceError,
+ #[error(display = "Failed to resolve interface index with error {}", _0)]
+ InterfaceNameError(#[error(source)] IfaceIndexLookupError),
}
pub struct DnsRouteMonitor {
- _handle: rtnetlink::Handle,
- stop_tx: Option<futures::channel::oneshot::Sender<()>>,
+ abort_handle: AbortHandle,
}
impl Drop for DnsRouteMonitor {
fn drop(&mut self) {
- if let Some(stop_tx) = self.stop_tx.take() {
- let _ = stop_tx.send(());
- }
+ self.abort_handle.abort();
}
}
@@ -70,70 +58,50 @@ impl fmt::Display for DnsConfig {
}
pub async fn spawn_monitor(
+ route_manager: RouteManagerHandle,
destinations: Vec<IpAddr>,
update_tx: UnboundedSender<BTreeMap<u32, DnsConfig>>,
) -> Result<(DnsRouteMonitor, BTreeMap<u32, DnsConfig>)> {
- let (mut connection, handle, messages) =
- rtnetlink::new_connection().expect("Failed to create a netlink connection");
-
- let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_NOTIFY;
- let addr = SocketAddr::new(0, mgroup_flags);
+ let listener = route_manager
+ .change_listener()
+ .await
+ .map_err(Error::RouteManagerError)?;
+ let (mut listener, abort_handle) = abortable(listener);
- connection
- .socket_mut()
- .bind(&addr)
- .map_err(Error::BindError)?;
+ let monitor = DnsRouteMonitor { abort_handle };
- let (abortable_connection, abort_connection) = abortable(connection);
- tokio::spawn(abortable_connection);
-
- let (stop_tx, stop_rx) = futures::channel::oneshot::channel();
-
- let monitor = DnsRouteMonitor {
- _handle: handle.clone(),
- stop_tx: Some(stop_tx),
- };
-
- let mut last_config = setup_configurations(&handle, &destinations).await?;
+ let mut last_config = setup_configurations(&route_manager, &destinations).await?;
let initial_config = last_config.clone();
tokio::spawn(async move {
- let mut messages = messages.fuse();
- let mut stop_rx = stop_rx.fuse();
- loop {
- futures::select! {
- _new_message = messages.next() => {
- match setup_configurations(&handle, &destinations).await {
- Ok(new_config) => {
- if last_config != new_config {
- last_config = new_config.clone();
- if update_tx.unbounded_send(new_config).is_err() {
- log::trace!("Stopping DNS monitor: channel is closed");
- break;
- }
- }
- }
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to determine new DNS interface settings"
- )
- );
+ while let Some(_event) = listener.next().await {
+ match setup_configurations(&route_manager, &destinations).await {
+ Ok(new_config) => {
+ if last_config != new_config {
+ last_config = new_config.clone();
+ if update_tx.unbounded_send(new_config).is_err() {
+ log::trace!("Stopping DNS monitor: channel is closed");
+ break;
}
}
- },
- _ = stop_rx => break,
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to determine new DNS interface settings"
+ )
+ );
+ }
}
}
- abort_connection.abort();
});
Ok((monitor, initial_config))
}
async fn setup_configurations(
- handle: &Handle,
+ handle: &RouteManagerHandle,
destinations: &[IpAddr],
) -> Result<BTreeMap<u32, DnsConfig>> {
let mut interface_to_destinations = BTreeMap::<u32, DnsConfig>::new();
@@ -141,8 +109,8 @@ async fn setup_configurations(
let interface = if destination.is_loopback() {
get_default_route_interface(handle, get_ip_version(destination), true).await?
} else {
- if crate::firewall::is_local_address(&destination) {
- get_destination_interface(handle, destination, true).await?
+ if crate::firewall::is_local_address(destination) {
+ get_destination_interface(handle, *destination, true).await?
} else {
get_default_route_interface(handle, get_ip_version(destination), false).await?
}
@@ -174,80 +142,34 @@ async fn setup_configurations(
}
async fn get_default_route_interface(
- handle: &Handle,
+ handle: &RouteManagerHandle,
ip_version: IpVersion,
set_mark: bool,
) -> Result<Option<u32>> {
match ip_version {
IpVersion::V4 => {
- get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V4, set_mark).await
+ get_destination_interface(handle, PUBLIC_INTERNET_ADDRESS_V4, set_mark).await
}
IpVersion::V6 => {
- get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V6, set_mark).await
+ get_destination_interface(handle, PUBLIC_INTERNET_ADDRESS_V6, set_mark).await
}
}
}
async fn get_destination_interface(
- handle: &Handle,
- destination: &IpAddr,
+ handle: &RouteManagerHandle,
+ destination: IpAddr,
set_mark: bool,
) -> Result<Option<u32>> {
- let mut request = handle.route().get(get_ip_version(destination));
- let octets = match destination {
- IpAddr::V4(address) => address.octets().to_vec(),
- IpAddr::V6(address) => address.octets().to_vec(),
- };
- let message = request.message_mut();
- if set_mark {
- message
- .nlas
- .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK));
- }
- message.header.destination_prefix_length = 8u8 * (octets.len() as u8);
- message.header.flags = RouteFlags::RTM_F_FIB_MATCH;
- message.nlas.push(RouteNla::Destination(octets));
- let mut stream = execute_route_get_request(handle.clone(), message.clone());
- match stream.try_next().await {
- Ok(Some(route_msg)) => {
- for nla in &route_msg.nlas {
- if let RouteNla::Oif(interface) = nla {
- return Ok(Some(*interface));
- }
- }
- Err(Error::RouteNoInterfaceError)
- }
- Ok(None) => Err(Error::NoRouteError),
- Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
- Ok(None)
- }
- Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))),
- }
-}
-
-pub fn execute_route_get_request(
- mut handle: Handle,
- message: RouteMessage,
-) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> {
- use futures::future::{self, Either};
- use rtnetlink::Error;
-
- let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message));
- req.header.flags = NLM_F_REQUEST;
-
- match handle.request(req) {
- Ok(response) => Either::Left(response.map(move |msg| {
- let (header, payload) = msg.into_parts();
- match payload {
- NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg),
- NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)),
- _ => Err(Error::UnexpectedMessage(NetlinkMessage::new(
- header, payload,
- ))),
- }
- })),
- Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()),
- }
+ let route = handle
+ .get_destination_route(destination, set_mark)
+ .await
+ .map_err(Error::RouteManagerError)?;
+ route
+ .map(|route| route.get_node().get_device().map(iface_index))
+ .flatten()
+ .transpose()
+ .map_err(Error::InterfaceNameError)
}
fn get_ip_version(addr: &IpAddr) -> IpVersion {
diff --git a/talpid-core/src/dns/linux/systemd_resolved.rs b/talpid-core/src/dns/linux/systemd_resolved.rs
index 01b6d3ea22..ea487a3077 100644
--- a/talpid-core/src/dns/linux/systemd_resolved.rs
+++ b/talpid-core/src/dns/linux/systemd_resolved.rs
@@ -1,4 +1,7 @@
-use crate::linux::{iface_index, IfaceIndexLookupError};
+use crate::{
+ linux::{iface_index, IfaceIndexLookupError},
+ routing::RouteManagerHandle,
+};
use futures::{channel::mpsc, StreamExt};
use std::{
collections::BTreeMap,
@@ -56,11 +59,17 @@ impl SystemdResolved {
Ok(systemd_resolved)
}
- pub async fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> {
+ pub async fn set_dns(
+ &mut self,
+ route_manager: RouteManagerHandle,
+ interface_name: &str,
+ servers: &[IpAddr],
+ ) -> Result<()> {
let (update_tx, mut update_rx) = mpsc::unbounded();
- let (monitor, initial_config) = super::routing::spawn_monitor(servers.to_vec(), update_tx)
- .await
- .map_err(Error::SpawnInterfaceMonitor)?;
+ let (monitor, initial_config) =
+ super::routing::spawn_monitor(route_manager, servers.to_vec(), update_tx)
+ .await
+ .map_err(Error::SpawnInterfaceMonitor)?;
let tunnel_index = iface_index(interface_name)?;
self.tunnel_index = tunnel_index;
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs
index b07ce13eac..229896cf98 100644
--- a/talpid-core/src/dns/mod.rs
+++ b/talpid-core/src/dns/mod.rs
@@ -1,3 +1,5 @@
+#[cfg(target_os = "linux")]
+use crate::routing::RouteManagerHandle;
use std::{net::IpAddr, path::Path};
#[cfg(target_os = "macos")]
@@ -28,9 +30,18 @@ pub struct DnsMonitor {
impl DnsMonitor {
/// Returns a new `DnsMonitor` that can set and monitor the system DNS.
- pub fn new(handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
+ pub fn new(
+ handle: tokio::runtime::Handle,
+ cache_dir: impl AsRef<Path>,
+ #[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
+ ) -> Result<Self, Error> {
Ok(DnsMonitor {
- inner: imp::DnsMonitor::new(handle, cache_dir)?,
+ inner: imp::DnsMonitor::new(
+ handle,
+ cache_dir,
+ #[cfg(target_os = "linux")]
+ route_manager,
+ )?,
})
}
@@ -61,6 +72,7 @@ trait DnsMonitorT: Sized {
fn new(
handle: tokio::runtime::Handle,
cache_dir: impl AsRef<Path>,
+ #[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
) -> Result<Self, Self::Error>;
fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Self::Error>;
diff --git a/talpid-core/src/offline/linux.rs b/talpid-core/src/offline/linux.rs
index cf255b0541..ceaa864cc7 100644
--- a/talpid-core/src/offline/linux.rs
+++ b/talpid-core/src/offline/linux.rs
@@ -1,21 +1,12 @@
-use crate::tunnel_state_machine::TunnelCommand;
-use futures::{
- channel::{mpsc::UnboundedSender, oneshot},
- FutureExt, StreamExt, TryStream, TryStreamExt,
+use crate::{
+ routing::{self, RouteManagerHandle},
+ tunnel_state_machine::TunnelCommand,
};
-use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST};
-use netlink_packet_route::{
- rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage,
+use futures::{channel::mpsc::UnboundedSender, StreamExt};
+use std::{
+ net::{IpAddr, Ipv4Addr},
+ sync::Weak,
};
-use rtnetlink::{
- constants::{
- RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_IFADDR, RTMGRP_IPV6_ROUTE, RTMGRP_LINK,
- RTMGRP_NOTIFY,
- },
- sys::SocketAddr,
- Handle, IpVersion,
-};
-use std::{io, net::Ipv4Addr, sync::Weak};
use talpid_types::ErrorExt;
pub type Result<T> = std::result::Result<T, Error>;
@@ -23,52 +14,21 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
- #[error(display = "Failed to resolve output interface index")]
- GetLinkError(#[error(source)] failure::Compat<rtnetlink::Error>),
-
- #[error(display = "No netlink response for output interface query")]
- NoLinkError,
-
- #[error(display = "Failed to get list of IP addresses")]
- GetAddressesError(#[error(source)] failure::Compat<rtnetlink::Error>),
-
- #[error(display = "Failed to get a route for an arbitrary IP address")]
- GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>),
-
- #[error(display = "No netlink response for route query")]
- NoRouteError,
-
- #[error(display = "Failed to connect to netlink socket")]
- NetlinkConnectionError(#[error(source)] io::Error),
-
- #[error(display = "Failed to connect to bind to netlink socket")]
- BindError(#[error(source)] io::Error),
-
- #[error(display = "Failed to start listening on netlink socket")]
- NetlinkBindError(#[error(source)] io::Error),
-
- #[error(display = "Error while processing netlink messages")]
- MonitorNetlinkError,
-
- #[error(display = "Netlink connection has unexpectedly disconnected")]
- NetlinkDisconnected,
-
- #[error(display = "Failed to initialize event loop")]
- EventLoopError(#[error(source)] io::Error),
+ #[error(display = "The route manager returned an error")]
+ RouteManagerError(#[error(source)] routing::Error),
}
pub struct MonitorHandle {
- handle: rtnetlink::Handle,
- _stop_connection_tx: oneshot::Sender<()>,
+ route_manager: RouteManagerHandle,
}
// Mullvad API's public IP address, correct at the time of writing, but any public IP address will
// work.
-const PUBLIC_INTERNET_ADDRESS: Ipv4Addr = Ipv4Addr::new(193, 138, 218, 78);
+const PUBLIC_INTERNET_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78));
impl MonitorHandle {
pub async fn is_offline(&mut self) -> bool {
- match public_ip_unreachable(&self.handle).await {
+ match public_ip_unreachable(&self.route_manager).await {
Ok(is_offline) => is_offline,
Err(err) => {
log::error!(
@@ -81,46 +41,28 @@ impl MonitorHandle {
}
}
-pub async fn spawn_monitor(sender: Weak<UnboundedSender<TunnelCommand>>) -> Result<MonitorHandle> {
- let (mut connection, handle, mut messages) =
- rtnetlink::new_connection().map_err(Error::NetlinkConnectionError)?;
-
- let mgroup_flags = RTMGRP_IPV4_IFADDR
- | RTMGRP_IPV4_ROUTE
- | RTMGRP_IPV6_IFADDR
- | RTMGRP_IPV6_ROUTE
- | RTMGRP_LINK
- | RTMGRP_NOTIFY;
- let addr = SocketAddr::new(0, mgroup_flags);
-
- connection
- .socket_mut()
- .bind(&addr)
- .map_err(Error::BindError)?;
+pub async fn spawn_monitor(
+ sender: Weak<UnboundedSender<TunnelCommand>>,
+ route_manager: RouteManagerHandle,
+) -> Result<MonitorHandle> {
+ let mut is_offline = public_ip_unreachable(&route_manager).await?;
- let (stop_connection_tx, stop_rx) = oneshot::channel();
-
- // Connection will be closed once the channel is dropped
- tokio::spawn(async {
- futures::select! {
- _ = connection.fuse() => (),
- _ = stop_rx.fuse() => (),
- }
- });
- let mut is_offline = public_ip_unreachable(&handle).await?;
+ let mut listener = route_manager
+ .change_listener()
+ .await
+ .map_err(Error::RouteManagerError)?;
let monitor_handle = MonitorHandle {
- handle: handle.clone(),
- _stop_connection_tx: stop_connection_tx,
+ route_manager: route_manager.clone(),
};
-
tokio::spawn(async move {
- while let Some(_new_message) = messages.next().await {
+ while let Some(_event) = listener.next().await {
match sender.upgrade() {
Some(sender) => {
- let new_offline_state =
- public_ip_unreachable(&handle).await.unwrap_or_else(|err| {
+ let new_offline_state = public_ip_unreachable(&route_manager)
+ .await
+ .unwrap_or_else(|err| {
log::error!(
"{}",
err.display_chain_with_msg("Failed to infer offline state")
@@ -141,89 +83,10 @@ pub async fn spawn_monitor(sender: Weak<UnboundedSender<TunnelCommand>>) -> Resu
}
-async fn public_ip_unreachable(handle: &Handle) -> Result<bool> {
- let mut request = handle.route().get(IpVersion::V4);
- let message = request.message_mut();
- message
- .nlas
- .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK));
- message.nlas.push(RouteNla::Destination(
- PUBLIC_INTERNET_ADDRESS.octets().to_vec(),
- ));
- message.header.destination_prefix_length = 32;
- message.header.flags = RouteFlags::RTM_F_LOOKUP_TABLE;
- let mut stream = execute_route_get_request(handle.clone(), message.clone());
- match stream.try_next().await {
- // Presance of any route implies connectivity, even if it's a loopback route
- Ok(Some(_)) => Ok(false),
- Ok(None) => Err(Error::NoRouteError),
- // ENETUNREACH implies that there exists no route that'd reach our random API address,
- // as such, the host is assumed to be offline
- Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
- Ok(true)
- }
- Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))),
- }
-}
-
-pub fn execute_route_get_request(
- mut handle: Handle,
- message: RouteMessage,
-) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> {
- use futures::future::{self, Either};
- use rtnetlink::Error;
-
- let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message));
- req.header.flags = NLM_F_REQUEST;
-
- match handle.request(req) {
- Ok(response) => Either::Left(response.map(move |msg| {
- let (header, payload) = msg.into_parts();
- match payload {
- NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg),
- NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)),
- _ => Err(Error::UnexpectedMessage(NetlinkMessage::new(
- header, payload,
- ))),
- }
- })),
- Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()),
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
- use rtnetlink::{
- constants::{
- RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_IFADDR, RTMGRP_IPV6_ROUTE,
- RTMGRP_LINK, RTMGRP_NOTIFY,
- },
- sys::SocketAddr,
- };
-
- #[test]
- fn test_route_table_query() {
- let mut runtime = tokio::runtime::Runtime::new().expect("failed to initialize runtime");
- let (mut connection, handle, _) = runtime.block_on(async {
- rtnetlink::new_connection()
- .map_err(Error::NetlinkConnectionError)
- .expect("Failed to create a netlink connection")
- });
-
- let mgroup_flags = RTMGRP_IPV4_IFADDR
- | RTMGRP_IPV4_ROUTE
- | RTMGRP_IPV6_IFADDR
- | RTMGRP_IPV6_ROUTE
- | RTMGRP_LINK
- | RTMGRP_NOTIFY;
- let addr = SocketAddr::new(0, mgroup_flags);
-
- connection.socket_mut().bind(&addr).unwrap();
- runtime.spawn(connection);
-
- runtime
- .block_on(public_ip_unreachable(&handle))
- .expect("Failed to query routing table");
- }
+async fn public_ip_unreachable(handle: &RouteManagerHandle) -> Result<bool> {
+ Ok(handle
+ .get_destination_route(PUBLIC_INTERNET_ADDRESS, true)
+ .await
+ .map_err(Error::RouteManagerError)?
+ .is_none())
}
diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs
index b059d1f216..ac8b10e222 100644
--- a/talpid-core/src/offline/mod.rs
+++ b/talpid-core/src/offline/mod.rs
@@ -1,3 +1,5 @@
+#[cfg(target_os = "linux")]
+use crate::routing::RouteManagerHandle;
use crate::tunnel_state_machine::TunnelCommand;
use futures::channel::mpsc::UnboundedSender;
use std::sync::Weak;
@@ -42,12 +44,15 @@ impl MonitorHandle {
pub async fn spawn_monitor(
sender: Weak<UnboundedSender<TunnelCommand>>,
+ #[cfg(target_os = "linux")] route_manager: RouteManagerHandle,
#[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<MonitorHandle, Error> {
let monitor = if !*FORCE_DISABLE_OFFLINE_MONITOR {
Some(
imp::spawn_monitor(
sender,
+ #[cfg(target_os = "linux")]
+ route_manager,
#[cfg(target_os = "android")]
android_context,
)
diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs
index afc02eedb2..a63430fed3 100644
--- a/talpid-core/src/routing/linux.rs
+++ b/talpid-core/src/routing/linux.rs
@@ -1,4 +1,7 @@
-use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route};
+use crate::routing::{
+ imp::{CallbackMessage, RouteManagerCommand},
+ NetNode, Node, RequiredRoute, Route,
+};
use std::{
collections::{BTreeMap, HashSet},
io,
@@ -6,11 +9,15 @@ use std::{
};
use talpid_types::ErrorExt;
-use futures::{channel::mpsc::UnboundedReceiver, future::FutureExt, StreamExt, TryStreamExt};
+use futures::{
+ channel::mpsc::{UnboundedReceiver, UnboundedSender},
+ future::FutureExt,
+ StreamExt, TryStream, TryStreamExt,
+};
use ipnetwork::IpNetwork;
use lazy_static::lazy_static;
use netlink_packet_route::{
- constants::{ARPHRD_LOOPBACK, FIB_RULE_INVERT, FR_ACT_TO_TBL},
+ constants::{ARPHRD_LOOPBACK, FIB_RULE_INVERT, FR_ACT_TO_TBL, NLM_F_REQUEST},
link::{nlas::Nla as LinkNla, LinkMessage},
route::{nlas::Nla as RouteNla, RouteHeader, RouteMessage},
rtnl::{
@@ -26,7 +33,7 @@ use netlink_packet_route::{
use rtnetlink::{
constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_LINK, RTMGRP_NOTIFY},
sys::SocketAddr,
- Handle,
+ Handle, IpVersion,
};
use libc::{AF_INET, AF_INET6};
@@ -102,6 +109,12 @@ pub enum Error {
#[error(display = "Unknown device index - {}", _0)]
UnknownDeviceIndex(u32),
+ #[error(display = "Failed to get a route for the given IP address")]
+ GetRouteError(#[error(source)] rtnetlink::Error),
+
+ #[error(display = "No netlink response for route query")]
+ NoRouteError,
+
/// Unable to create routing table for tagged connections and packets.
#[error(display = "Cannot find a free routing table ID")]
NoFreeRoutingTableId,
@@ -115,6 +128,7 @@ pub struct RouteManagerImpl {
handle: Handle,
messages: UnboundedReceiver<(NetlinkMessage<RtnlMessage>, SocketAddr)>,
iface_map: BTreeMap<u32, NetworkInterface>,
+ listeners: Vec<UnboundedSender<CallbackMessage>>,
// currently added routes
added_routes: HashSet<Route>,
@@ -137,9 +151,10 @@ impl RouteManagerImpl {
let iface_map = Self::initialize_link_map(&handle).await?;
let mut monitor = Self {
- iface_map,
handle,
messages,
+ iface_map,
+ listeners: vec![],
added_routes: HashSet::new(),
};
@@ -295,8 +310,8 @@ impl RouteManagerImpl {
.map(|(idx, _name)| *idx)
}
- async fn process_deleted_route(&mut self, route: Route) -> Result<()> {
- self.added_routes.remove(&route);
+ fn process_deleted_route(&mut self, route: &Route) -> Result<()> {
+ self.added_routes.remove(route);
Ok(())
}
@@ -347,6 +362,12 @@ impl RouteManagerImpl {
RouteManagerCommand::ClearRoutingRules(result_tx) => {
let _ = result_tx.send(self.clear_routing_rules().await);
}
+ RouteManagerCommand::NewChangeListener(result_tx) => {
+ let _ = result_tx.send(self.listen());
+ }
+ RouteManagerCommand::GetDestinationRoute(destination, set_mark, result_tx) => {
+ let _ = result_tx.send(self.get_destination_route(&destination, set_mark).await);
+ }
RouteManagerCommand::ClearRoutes => {
log::debug!("Clearing routes");
self.cleanup_routes().await;
@@ -367,9 +388,15 @@ impl RouteManagerImpl {
self.iface_map.remove(&idx);
}
}
+ NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(new_route)) => {
+ if let Some(addition) = self.parse_route_message(new_route)? {
+ self.notify_change_listeners(CallbackMessage::NewRoute(addition));
+ }
+ }
NetlinkPayload::InnerMessage(RtnlMessage::DelRoute(old_route)) => {
if let Some(deletion) = self.parse_route_message(old_route)? {
- self.process_deleted_route(deletion).await?;
+ self.process_deleted_route(&deletion)?;
+ self.notify_change_listeners(CallbackMessage::DelRoute(deletion));
}
}
_ => (),
@@ -377,18 +404,13 @@ impl RouteManagerImpl {
Ok(())
}
- // Tries to coax a Route out of a RouteMessage, but only if it's a route from the main routing
- // table
- // TODO: Change to account for different routing tables.
- fn parse_route_message(&self, msg: RouteMessage) -> Result<Option<Route>> {
- if msg.header.table != RT_TABLE_MAIN {
- return Ok(None);
- }
- self.parse_route_message_inner(msg)
+ fn notify_change_listeners(&mut self, message: CallbackMessage) {
+ self.listeners
+ .retain(|listener| listener.unbounded_send(message.clone()).is_ok());
}
// Tries to coax a Route out of a RouteMessage
- fn parse_route_message_inner(&self, msg: RouteMessage) -> Result<Option<Route>> {
+ fn parse_route_message(&self, msg: RouteMessage) -> Result<Option<Route>> {
let af_spec = msg.header.address_family;
let destination_length = msg.header.destination_prefix_length;
let is_ipv4 = match af_spec as i32 {
@@ -686,6 +708,12 @@ impl RouteManagerImpl {
Ok(())
}
+ fn listen(&mut self) -> UnboundedReceiver<CallbackMessage> {
+ let (tx, rx) = futures::channel::mpsc::unbounded();
+ self.listeners.push(tx);
+ rx
+ }
+
async fn destructor(&mut self) {
self.cleanup_routes().await;
@@ -696,6 +724,36 @@ impl RouteManagerImpl {
);
}
}
+
+ async fn get_destination_route(
+ &self,
+ destination: &IpAddr,
+ set_mark: bool,
+ ) -> Result<Option<Route>> {
+ let mut request = self.handle.route().get(get_ip_version(destination));
+ let octets = match destination {
+ IpAddr::V4(address) => address.octets().to_vec(),
+ IpAddr::V6(address) => address.octets().to_vec(),
+ };
+ let message = request.message_mut();
+ if set_mark {
+ message
+ .nlas
+ .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK));
+ }
+ message.header.destination_prefix_length = 8u8 * (octets.len() as u8);
+ message.header.flags = RouteFlags::RTM_F_FIB_MATCH;
+ message.nlas.push(RouteNla::Destination(octets));
+ let mut stream = execute_route_get_request(self.handle.clone(), message.clone());
+ match stream.try_next().await {
+ Ok(Some(route_msg)) => self.parse_route_message(route_msg),
+ Ok(None) => Err(Error::NoRouteError),
+ Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
+ Ok(None)
+ }
+ Err(err) => Err(Error::GetRouteError(err)),
+ }
+ }
}
fn ip_to_bytes(addr: IpAddr) -> Vec<u8> {
@@ -714,6 +772,39 @@ fn compat_table_id(id: u32) -> u8 {
}
}
+fn get_ip_version(addr: &IpAddr) -> IpVersion {
+ if addr.is_ipv4() {
+ IpVersion::V4
+ } else {
+ IpVersion::V6
+ }
+}
+
+fn execute_route_get_request(
+ mut handle: Handle,
+ message: RouteMessage,
+) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> {
+ use futures::future::{self, Either};
+ use rtnetlink::Error;
+
+ let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message));
+ req.header.flags = NLM_F_REQUEST;
+
+ match handle.request(req) {
+ Ok(response) => Either::Left(response.map(move |msg| {
+ let (header, payload) = msg.into_parts();
+ match payload {
+ NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg),
+ NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)),
+ _ => Err(Error::UnexpectedMessage(NetlinkMessage::new(
+ header, payload,
+ ))),
+ }
+ })),
+ Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()),
+ }
+}
+
#[derive(Debug)]
struct NetworkInterface {
name: String,
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs
index a8491a8588..8dc8dcd778 100644
--- a/talpid-core/src/routing/mod.rs
+++ b/talpid-core/src/routing/mod.rs
@@ -45,6 +45,11 @@ impl Route {
self.table_id = new_id;
self
}
+
+ /// Returns the network node of the route.
+ pub fn get_node(&self) -> &Node {
+ &self.node
+ }
}
impl fmt::Display for Route {
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index c1d2da522b..989ec7ad24 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -2,6 +2,8 @@
#![cfg_attr(target_os = "windows", allow(dead_code))]
// TODO: remove the allow(dead_code) for android once it's up to scratch.
use super::RequiredRoute;
+#[cfg(target_os = "linux")]
+use super::Route;
use futures::channel::{
mpsc::{self, UnboundedSender},
@@ -9,6 +11,12 @@ use futures::channel::{
};
use std::{collections::HashSet, io};
+#[cfg(target_os = "linux")]
+use futures::stream::Stream;
+
+#[cfg(target_os = "linux")]
+use std::net::IpAddr;
+
#[cfg(target_os = "macos")]
#[path = "macos.rs"]
mod imp;
@@ -46,26 +54,25 @@ pub enum Error {
/// Handle to a route manager.
#[derive(Clone)]
pub struct RouteManagerHandle {
- runtime: tokio::runtime::Handle,
tx: UnboundedSender<RouteManagerCommand>,
}
impl RouteManagerHandle {
/// Applies the given routes while the route manager is running.
- pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
+ pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
.map_err(|_| Error::RouteManagerDown)?;
- self.runtime
- .block_on(response_rx)
+ response_rx
+ .await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}
/// Ensure that packets are routed using the correct tables.
#[cfg(target_os = "linux")]
- pub fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> {
+ pub async fn create_routing_rules(&self, enable_ipv6: bool) -> Result<(), Error> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::CreateRoutingRules(
@@ -73,21 +80,52 @@ impl RouteManagerHandle {
response_tx,
))
.map_err(|_| Error::RouteManagerDown)?;
- self.runtime
- .block_on(response_rx)
+ response_rx
+ .await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}
/// Remove any routing rules created by [`create_routing_rules`].
#[cfg(target_os = "linux")]
- pub fn clear_routing_rules(&self) -> Result<(), Error> {
+ pub async fn clear_routing_rules(&self) -> Result<(), Error> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::ClearRoutingRules(response_tx))
.map_err(|_| Error::RouteManagerDown)?;
- self.runtime
- .block_on(response_rx)
+ response_rx
+ .await
+ .map_err(|_| Error::ManagerChannelDown)?
+ .map_err(Error::PlatformError)
+ }
+
+ /// Listen for route changes.
+ #[cfg(target_os = "linux")]
+ pub async fn change_listener(&self) -> Result<impl Stream<Item = CallbackMessage>, Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::NewChangeListener(response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)
+ }
+
+ /// Listen for route changes.
+ #[cfg(target_os = "linux")]
+ pub async fn get_destination_route(
+ &self,
+ destination: IpAddr,
+ set_mark: bool,
+ ) -> Result<Option<Route>, Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::GetDestinationRoute(
+ destination,
+ set_mark,
+ response_tx,
+ ))
+ .map_err(|_| Error::RouteManagerDown)?;
+ response_rx
+ .await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}
@@ -106,6 +144,21 @@ pub(crate) enum RouteManagerCommand {
CreateRoutingRules(bool, oneshot::Sender<Result<(), PlatformError>>),
#[cfg(target_os = "linux")]
ClearRoutingRules(oneshot::Sender<Result<(), PlatformError>>),
+ #[cfg(target_os = "linux")]
+ NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>),
+ #[cfg(target_os = "linux")]
+ GetDestinationRoute(
+ IpAddr,
+ bool,
+ oneshot::Sender<Result<Option<Route>, PlatformError>>,
+ ),
+}
+
+#[cfg(target_os = "linux")]
+#[derive(Debug, Clone)]
+pub enum CallbackMessage {
+ NewRoute(Route),
+ DelRoute(Route),
}
/// RouteManager applies a set of routes to the route table.
@@ -120,12 +173,12 @@ impl RouteManager {
/// Constructs a RouteManager and applies the required routes.
/// Takes a set of network destinations and network nodes as an argument, and applies said
/// routes.
- pub fn new(
+ pub async fn new(
runtime: tokio::runtime::Handle,
required_routes: HashSet<RequiredRoute>,
) -> Result<Self, Error> {
let (manage_tx, manage_rx) = mpsc::unbounded();
- let manager = runtime.block_on(imp::RouteManagerImpl::new(required_routes))?;
+ let manager = imp::RouteManagerImpl::new(required_routes).await?;
runtime.spawn(manager.run(manage_rx));
Ok(Self {
@@ -135,7 +188,7 @@ impl RouteManager {
}
/// Stops RouteManager and removes all of the applied routes.
- pub fn stop(&mut self) {
+ pub async fn stop(&mut self) {
if let Some(tx) = self.manage_tx.take() {
let (wait_tx, wait_rx) = oneshot::channel();
@@ -147,14 +200,14 @@ impl RouteManager {
return;
}
- if self.runtime.block_on(wait_rx).is_err() {
+ if wait_rx.await.is_err() {
log::error!("{}", Error::ManagerChannelDown);
}
}
}
/// Applies the given routes until [`RouteManager::stop`] is called.
- pub fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
+ pub async fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
if let Some(tx) = &self.manage_tx {
let (result_tx, result_rx) = oneshot::channel();
if tx
@@ -164,8 +217,8 @@ impl RouteManager {
return Err(Error::RouteManagerDown);
}
- self.runtime
- .block_on(result_rx)
+ result_rx
+ .await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
} else {
@@ -188,23 +241,20 @@ impl RouteManager {
/// Ensure that packets are routed using the correct tables.
#[cfg(target_os = "linux")]
- pub fn create_routing_rules(&mut self, enable_ipv6: bool) -> Result<(), Error> {
- self.handle()?.create_routing_rules(enable_ipv6)
+ pub async fn create_routing_rules(&mut self, enable_ipv6: bool) -> Result<(), Error> {
+ self.handle()?.create_routing_rules(enable_ipv6).await
}
/// Remove any routing rules created by [`create_routing_rules`].
#[cfg(target_os = "linux")]
- pub fn clear_routing_rules(&mut self) -> Result<(), Error> {
- self.handle()?.clear_routing_rules()
+ pub async fn clear_routing_rules(&mut self) -> Result<(), Error> {
+ self.handle()?.clear_routing_rules().await
}
/// Retrieve a sender directly to the command channel.
pub fn handle(&self) -> Result<RouteManagerHandle, Error> {
if let Some(tx) = &self.manage_tx {
- Ok(RouteManagerHandle {
- runtime: self.runtime.clone(),
- tx: tx.clone(),
- })
+ Ok(RouteManagerHandle { tx: tx.clone() })
} else {
Err(Error::RouteManagerDown)
}
@@ -219,6 +269,6 @@ impl RouteManager {
impl Drop for RouteManager {
fn drop(&mut self) {
- self.stop();
+ self.runtime.clone().block_on(self.stop());
}
}
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index ca5fbb8ea2..93574884f6 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -41,20 +41,17 @@ pub struct RouteManager {
/// Handle to a route manager.
#[derive(Clone)]
pub struct RouteManagerHandle {
- runtime: tokio::runtime::Handle,
tx: UnboundedSender<RouteManagerCommand>,
}
impl RouteManagerHandle {
/// Applies the given routes while the route manager is running.
- pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
.map_err(|_| Error::RouteManagerDown)?;
- self.runtime
- .block_on(response_rx)
- .map_err(|_| Error::ManagerChannelDown)?
+ response_rx.await.map_err(|_| Error::ManagerChannelDown)?
}
}
@@ -67,7 +64,7 @@ pub enum RouteManagerCommand {
impl RouteManager {
/// Creates a new route manager that will apply the provided routes and ensure they exist until
/// it's stopped.
- pub fn new(
+ pub async fn new(
runtime: tokio::runtime::Handle,
required_routes: HashSet<RequiredRoute>,
) -> Result<Self> {
@@ -89,10 +86,7 @@ impl RouteManager {
/// Retrieve a sender directly to the command channel.
pub fn handle(&self) -> Result<RouteManagerHandle> {
if let Some(tx) = &self.manage_tx {
- Ok(RouteManagerHandle {
- runtime: self.runtime.clone(),
- tx: tx.clone(),
- })
+ Ok(RouteManagerHandle { tx: tx.clone() })
} else {
Err(Error::RouteManagerDown)
}
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index e3709b0183..87dbc55101 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -1028,19 +1028,14 @@ mod event_server {
.filter(|route| route.prefix.is_ipv4() || ipv6_enabled)
.collect();
- tokio::task::spawn_blocking(move || {
- if let Err(error) = route_handle.add_routes(routes) {
- log::error!("{}", error.display_chain());
- return Err(tonic::Status::failed_precondition("Failed to add routes"));
- }
- if let Err(error) = route_handle.create_routing_rules(ipv6_enabled) {
- log::error!("{}", error.display_chain());
- return Err(tonic::Status::failed_precondition("Failed to add routes"));
- }
- Ok(())
- })
- .await
- .map_err(|_| tonic::Status::internal("task failed to complete"))??;
+ if let Err(error) = route_handle.add_routes(routes).await {
+ log::error!("{}", error.display_chain());
+ return Err(tonic::Status::failed_precondition("Failed to add routes"));
+ }
+ if let Err(error) = route_handle.create_routing_rules(ipv6_enabled).await {
+ log::error!("{}", error.display_chain());
+ return Err(tonic::Status::failed_precondition("Failed to add routes"));
+ }
}
let tunnel_alias = env
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 4d7fa478b1..899589dbc4 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -253,20 +253,24 @@ impl WireguardMonitor {
}
}
- let setup_iface_routes = move || -> Result<()> {
+ let setup_iface_routes = || -> Result<()> {
#[cfg(target_os = "windows")]
if !crate::winnet::add_device_ip_addresses(&iface_name, &config.tunnel.addresses) {
return Err(Error::SetIpAddressesError);
}
- #[cfg(target_os = "linux")]
- route_handle
- .create_routing_rules(config.enable_ipv6)
- .map_err(Error::SetupRoutingError)?;
+ runtime.block_on(async move {
+ #[cfg(target_os = "linux")]
+ route_handle
+ .create_routing_rules(config.enable_ipv6)
+ .await
+ .map_err(Error::SetupRoutingError)?;
- route_handle
- .add_routes(Self::get_routes(&iface_name, &config))
- .map_err(Error::SetupRoutingError)
+ route_handle
+ .add_routes(Self::get_routes(&iface_name, &config))
+ .await
+ .map_err(Error::SetupRoutingError)
+ })
};
if let Err(error) = setup_iface_routes() {
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index ea8ce8e992..72af2b6f38 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -139,7 +139,10 @@ impl ConnectedState {
log::error!("{}", error.display_chain_with_msg("Failed to clear routes"));
}
#[cfg(target_os = "linux")]
- if let Err(error) = shared_values.route_manager.clear_routing_rules() {
+ if let Err(error) = shared_values
+ .runtime
+ .block_on(shared_values.route_manager.clear_routing_rules())
+ {
log::error!(
"{}",
error.display_chain_with_msg("Failed to clear routing rules")
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 8486ae0a3e..b0c87acdb4 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -204,7 +204,10 @@ impl ConnectingState {
log::error!("{}", error.display_chain_with_msg("Failed to clear routes"));
}
#[cfg(target_os = "linux")]
- if let Err(error) = shared_values.route_manager.clear_routing_rules() {
+ if let Err(error) = shared_values
+ .runtime
+ .block_on(shared_values.route_manager.clear_routing_rules())
+ {
log::error!(
"{}",
error.display_chain_with_msg("Failed to clear routing rules")
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 1bcc6961d9..38496865a4 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -89,18 +89,10 @@ pub async fn spawn(
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
- let mut offline_monitor = offline::spawn_monitor(
- Arc::downgrade(&command_tx),
- #[cfg(target_os = "android")]
- android_context.clone(),
- )
- .await
- .map_err(Error::OfflineMonitorError)?;
- let is_offline = offline_monitor.is_offline().await;
let tun_provider = TunProvider::new(
#[cfg(target_os = "android")]
- android_context,
+ android_context.clone(),
#[cfg(target_os = "android")]
allow_lan,
#[cfg(target_os = "android")]
@@ -112,12 +104,13 @@ pub async fn spawn(
let runtime = tokio::runtime::Handle::current();
let (startup_result_tx, startup_result_rx) = sync_mpsc::channel();
+ let weak_command_tx = Arc::downgrade(&command_tx);
std::thread::spawn(move || {
- let state_machine = TunnelStateMachine::new(
+ let state_machine = runtime.block_on(TunnelStateMachine::new(
runtime.clone(),
+ weak_command_tx,
allow_lan,
block_when_disconnected,
- is_offline,
dns_servers,
allowed_endpoint,
tunnel_parameters_generator,
@@ -127,7 +120,9 @@ pub async fn spawn(
cache_dir,
command_rx,
reset_firewall,
- );
+ #[cfg(target_os = "android")]
+ android_context,
+ ));
let state_machine = match state_machine {
Ok(state_machine) => {
startup_result_tx.send(Ok(())).unwrap();
@@ -144,8 +139,6 @@ pub async fn spawn(
if shutdown_tx.send(()).is_err() {
log::error!("Can't send shutdown completion to daemon");
}
-
- std::mem::drop(offline_monitor);
});
startup_result_rx
@@ -199,11 +192,11 @@ struct TunnelStateMachine {
}
impl TunnelStateMachine {
- fn new(
+ async fn new(
runtime: tokio::runtime::Handle,
+ command_tx: std::sync::Weak<mpsc::UnboundedSender<TunnelCommand>>,
allow_lan: bool,
block_when_disconnected: bool,
- is_offline: bool,
dns_servers: Option<Vec<IpAddr>>,
allowed_endpoint: Endpoint,
tunnel_parameters_generator: impl TunnelParametersGenerator,
@@ -213,6 +206,7 @@ impl TunnelStateMachine {
cache_dir: impl AsRef<Path>,
commands: mpsc::UnboundedReceiver<TunnelCommand>,
reset_firewall: bool,
+ #[cfg(target_os = "android")] android_context: AndroidContext,
) -> Result<Self, Error> {
let args = FirewallArguments {
initialize_blocked: block_when_disconnected || !reset_firewall,
@@ -221,15 +215,36 @@ impl TunnelStateMachine {
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
- let dns_monitor =
- DnsMonitor::new(runtime.clone(), cache_dir).map_err(Error::InitDnsMonitorError)?;
let route_manager = RouteManager::new(runtime.clone(), HashSet::new())
+ .await
.map_err(Error::InitRouteManagerError)?;
+ let dns_monitor = DnsMonitor::new(
+ runtime.clone(),
+ cache_dir,
+ #[cfg(target_os = "linux")]
+ route_manager
+ .handle()
+ .map_err(Error::InitRouteManagerError)?,
+ )
+ .map_err(Error::InitDnsMonitorError)?;
+ let mut offline_monitor = offline::spawn_monitor(
+ command_tx,
+ #[cfg(target_os = "linux")]
+ route_manager
+ .handle()
+ .map_err(Error::InitRouteManagerError)?,
+ #[cfg(target_os = "android")]
+ android_context,
+ )
+ .await
+ .map_err(Error::OfflineMonitorError)?;
+ let is_offline = offline_monitor.is_offline().await;
let mut shared_values = SharedTunnelStateValues {
runtime,
firewall,
dns_monitor,
route_manager,
+ _offline_monitor: offline_monitor,
allow_lan,
block_when_disconnected,
is_offline,
@@ -299,6 +314,7 @@ struct SharedTunnelStateValues {
firewall: Firewall,
dns_monitor: DnsMonitor,
route_manager: RouteManager,
+ _offline_monitor: offline::MonitorHandle,
/// Should LAN access be allowed outside the tunnel.
allow_lan: bool,
/// Should network access be allowed when in the disconnected state.