summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-05-21 10:53:18 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-06-09 15:55:07 +0200
commitf8c4a9879afb26482108ae02f3e65682d7d7ab83 (patch)
tree654461afe3da1014b547c0cfe2a7f93f1aff6ccc
parent6d9a1ea3a2502b0d4b3145231b7922984d77632a (diff)
downloadmullvadvpn-f8c4a9879afb26482108ae02f3e65682d7d7ab83.tar.xz
mullvadvpn-f8c4a9879afb26482108ae02f3e65682d7d7ab83.zip
Infer and monitor interfaces for DNS config
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-core/src/dns/android.rs5
-rw-r--r--talpid-core/src/dns/linux/mod.rs28
-rw-r--r--talpid-core/src/dns/linux/routing.rs259
-rw-r--r--talpid-core/src/dns/linux/systemd_resolved.rs271
-rw-r--r--talpid-core/src/dns/macos.rs2
-rw-r--r--talpid-core/src/dns/mod.rs9
-rw-r--r--talpid-core/src/dns/windows/mod.rs2
-rw-r--r--talpid-core/src/firewall/mod.rs3
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs3
-rw-r--r--talpid-dbus/Cargo.toml1
-rw-r--r--talpid-dbus/src/systemd_resolved.rs156
12 files changed, 663 insertions, 77 deletions
diff --git a/Cargo.lock b/Cargo.lock
index c1548a9905..1a24bf4a7f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2592,6 +2592,7 @@ dependencies = [
"libc",
"log",
"talpid-types",
+ "tokio",
]
[[package]]
diff --git a/talpid-core/src/dns/android.rs b/talpid-core/src/dns/android.rs
index 032960ae1e..517c311113 100644
--- a/talpid-core/src/dns/android.rs
+++ b/talpid-core/src/dns/android.rs
@@ -10,7 +10,10 @@ pub struct DnsMonitor;
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
- fn new(_cache_dir: impl AsRef<Path>) -> Result<Self, Self::Error> {
+ fn new(
+ _handle: tokio::runtime::Handle,
+ _cache_dir: impl AsRef<Path>,
+ ) -> Result<Self, Self::Error> {
Ok(DnsMonitor)
}
diff --git a/talpid-core/src/dns/linux/mod.rs b/talpid-core/src/dns/linux/mod.rs
index 96f8402da7..62827211bd 100644
--- a/talpid-core/src/dns/linux/mod.rs
+++ b/talpid-core/src/dns/linux/mod.rs
@@ -1,5 +1,6 @@
mod network_manager;
mod resolvconf;
+mod routing;
mod static_resolv_conf;
pub(self) mod systemd_resolved;
@@ -39,28 +40,32 @@ pub enum Error {
}
pub struct DnsMonitor {
+ handle: tokio::runtime::Handle,
inner: Option<DnsMonitorHolder>,
}
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
- fn new(_cache_dir: impl AsRef<Path>) -> Result<Self> {
- Ok(DnsMonitor { inner: None })
+ fn new(handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> {
+ Ok(DnsMonitor {
+ handle,
+ inner: None,
+ })
}
fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<()> {
self.reset()?;
// Creating a new DNS monitor for each set, in case the system changed how it manages DNS.
let mut inner = DnsMonitorHolder::new()?;
- inner.set(interface, servers)?;
+ inner.set(&self.handle, interface, servers)?;
self.inner = Some(inner);
Ok(())
}
fn reset(&mut self) -> Result<()> {
if let Some(mut inner) = self.inner.take() {
- inner.reset()?;
+ inner.reset(&self.handle)?;
}
Ok(())
}
@@ -120,7 +125,12 @@ impl DnsMonitorHolder {
.map_err(|_| Error::NoDnsMonitor)
}
- fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<()> {
+ fn set(
+ &mut self,
+ handle: &tokio::runtime::Handle,
+ interface: &str,
+ servers: &[IpAddr],
+ ) -> Result<()> {
use self::DnsMonitorHolder::*;
match self {
Resolvconf(ref mut resolvconf) => resolvconf.set_dns(interface, servers)?,
@@ -128,7 +138,7 @@ impl DnsMonitorHolder {
static_resolv_conf.set_dns(servers.to_vec())?
}
SystemdResolved(ref mut systemd_resolved) => {
- systemd_resolved.set_dns(interface, &servers)?
+ handle.block_on(systemd_resolved.set_dns(interface, &servers))?
}
NetworkManager(ref mut network_manager) => {
network_manager.set_dns(interface, servers)?
@@ -137,12 +147,14 @@ impl DnsMonitorHolder {
Ok(())
}
- fn reset(&mut self) -> Result<()> {
+ fn reset(&mut self, handle: &tokio::runtime::Handle) -> Result<()> {
use self::DnsMonitorHolder::*;
match self {
Resolvconf(ref mut resolvconf) => resolvconf.reset()?,
StaticResolvConf(ref mut static_resolv_conf) => static_resolv_conf.reset()?,
- SystemdResolved(ref mut systemd_resolved) => systemd_resolved.reset()?,
+ SystemdResolved(ref mut systemd_resolved) => {
+ handle.block_on(systemd_resolved.reset())?
+ }
NetworkManager(ref mut network_manager) => network_manager.reset()?,
}
Ok(())
diff --git a/talpid-core/src/dns/linux/routing.rs b/talpid-core/src/dns/linux/routing.rs
new file mode 100644
index 0000000000..5c90b57a33
--- /dev/null
+++ b/talpid-core/src/dns/linux/routing.rs
@@ -0,0 +1,259 @@
+use futures::{
+ channel::mpsc::UnboundedSender, future::abortable, FutureExt, StreamExt, TryStream,
+ TryStreamExt,
+};
+use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST};
+use netlink_packet_route::{
+ rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage,
+};
+use rtnetlink::{
+ constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_NOTIFY},
+ sys::SocketAddr,
+ Handle, IpVersion,
+};
+use std::{
+ collections::BTreeMap,
+ fmt, io,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
+};
+use talpid_types::ErrorExt;
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+const PUBLIC_INTERNET_ADDRESS_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6));
+const PUBLIC_INTERNET_ADDRESS_V6: IpAddr =
+ IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ #[error(display = "Failed to get a route for an arbitrary IP address")]
+ GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>),
+
+ #[error(display = "Failed to connect to bind to netlink socket")]
+ BindError(#[error(source)] io::Error),
+
+ #[error(display = "No netlink response for route query")]
+ NoRouteError,
+
+ #[error(display = "Route is missing an output interface")]
+ RouteNoInterfaceError,
+}
+
+pub struct DnsRouteMonitor {
+ _handle: rtnetlink::Handle,
+ stop_tx: Option<futures::channel::oneshot::Sender<()>>,
+}
+
+impl Drop for DnsRouteMonitor {
+ fn drop(&mut self) {
+ if let Some(stop_tx) = self.stop_tx.take() {
+ let _ = stop_tx.send(());
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+pub struct DnsConfig {
+ pub interface: u32,
+ pub resolvers: Vec<IpAddr>,
+}
+
+impl fmt::Display for DnsConfig {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "interface index {}, resolvers:", self.interface)?;
+ for server in &self.resolvers {
+ write!(f, " {}", server)?;
+ }
+ Ok(())
+ }
+}
+
+pub async fn spawn_monitor(
+ destinations: Vec<IpAddr>,
+ update_tx: UnboundedSender<BTreeMap<u32, DnsConfig>>,
+) -> Result<(DnsRouteMonitor, BTreeMap<u32, DnsConfig>)> {
+ let (mut connection, handle, messages) =
+ rtnetlink::new_connection().expect("Failed to create a netlink connection");
+
+ let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_NOTIFY;
+ let addr = SocketAddr::new(0, mgroup_flags);
+
+ connection
+ .socket_mut()
+ .bind(&addr)
+ .map_err(Error::BindError)?;
+
+ let (abortable_connection, abort_connection) = abortable(connection);
+ tokio::spawn(abortable_connection);
+
+ let (stop_tx, stop_rx) = futures::channel::oneshot::channel();
+
+ let monitor = DnsRouteMonitor {
+ _handle: handle.clone(),
+ stop_tx: Some(stop_tx),
+ };
+
+ let mut last_config = setup_configurations(&handle, &destinations).await?;
+ let initial_config = last_config.clone();
+
+ tokio::spawn(async move {
+ let mut messages = messages.fuse();
+ let mut stop_rx = stop_rx.fuse();
+ loop {
+ futures::select! {
+ _new_message = messages.next() => {
+ match setup_configurations(&handle, &destinations).await {
+ Ok(new_config) => {
+ if last_config != new_config {
+ last_config = new_config.clone();
+ if update_tx.unbounded_send(new_config).is_err() {
+ log::trace!("Stopping DNS monitor: channel is closed");
+ break;
+ }
+ }
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to determine new DNS interface settings"
+ )
+ );
+ }
+ }
+ },
+ _ = stop_rx => break,
+ }
+ }
+ abort_connection.abort();
+ });
+
+ Ok((monitor, initial_config))
+}
+
+async fn setup_configurations(
+ handle: &Handle,
+ destinations: &[IpAddr],
+) -> Result<BTreeMap<u32, DnsConfig>> {
+ let mut interface_to_destinations = BTreeMap::<u32, DnsConfig>::new();
+ for destination in destinations {
+ let interface = if destination.is_loopback() {
+ get_default_route_interface(handle, get_ip_version(destination), true).await?
+ } else {
+ if crate::firewall::is_local_address(&destination) {
+ get_destination_interface(handle, destination, true).await?
+ } else {
+ get_default_route_interface(handle, get_ip_version(destination), false).await?
+ }
+ };
+ match interface {
+ Some(iface) => {
+ if let Some(config) = interface_to_destinations.get_mut(&iface) {
+ config.resolvers.push(*destination);
+ } else {
+ interface_to_destinations.insert(
+ iface,
+ DnsConfig {
+ interface: iface,
+ resolvers: vec![*destination],
+ },
+ );
+ }
+ }
+ None => {
+ log::trace!(
+ "Ignoring DNS server that did not match to any interface: {}",
+ destination
+ );
+ }
+ }
+ }
+
+ Ok(interface_to_destinations)
+}
+
+async fn get_default_route_interface(
+ handle: &Handle,
+ ip_version: IpVersion,
+ set_mark: bool,
+) -> Result<Option<u32>> {
+ match ip_version {
+ IpVersion::V4 => {
+ get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V4, set_mark).await
+ }
+ IpVersion::V6 => {
+ get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V6, set_mark).await
+ }
+ }
+}
+
+async fn get_destination_interface(
+ handle: &Handle,
+ destination: &IpAddr,
+ set_mark: bool,
+) -> Result<Option<u32>> {
+ let mut request = handle.route().get(get_ip_version(destination));
+ let octets = match destination {
+ IpAddr::V4(address) => address.octets().to_vec(),
+ IpAddr::V6(address) => address.octets().to_vec(),
+ };
+ let message = request.message_mut();
+ if set_mark {
+ message
+ .nlas
+ .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK));
+ }
+ message.header.destination_prefix_length = 8u8 * (octets.len() as u8);
+ message.header.flags = RouteFlags::RTM_F_FIB_MATCH;
+ message.nlas.push(RouteNla::Destination(octets));
+ let mut stream = execute_route_get_request(handle.clone(), message.clone());
+ match stream.try_next().await {
+ Ok(Some(route_msg)) => {
+ for nla in &route_msg.nlas {
+ if let RouteNla::Oif(interface) = nla {
+ return Ok(Some(*interface));
+ }
+ }
+ Err(Error::RouteNoInterfaceError)
+ }
+ Ok(None) => Err(Error::NoRouteError),
+ Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => {
+ Ok(None)
+ }
+ Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))),
+ }
+}
+
+pub fn execute_route_get_request(
+ mut handle: Handle,
+ message: RouteMessage,
+) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> {
+ use futures::future::{self, Either};
+ use rtnetlink::Error;
+
+ let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message));
+ req.header.flags = NLM_F_REQUEST;
+
+ match handle.request(req) {
+ Ok(response) => Either::Left(response.map(move |msg| {
+ let (header, payload) = msg.into_parts();
+ match payload {
+ NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg),
+ NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)),
+ _ => Err(Error::UnexpectedMessage(NetlinkMessage::new(
+ header, payload,
+ ))),
+ }
+ })),
+ Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()),
+ }
+}
+
+fn get_ip_version(addr: &IpAddr) -> IpVersion {
+ if addr.is_ipv4() {
+ IpVersion::V4
+ } else {
+ IpVersion::V6
+ }
+}
diff --git a/talpid-core/src/dns/linux/systemd_resolved.rs b/talpid-core/src/dns/linux/systemd_resolved.rs
index 29a8feaf85..66141b5177 100644
--- a/talpid-core/src/dns/linux/systemd_resolved.rs
+++ b/talpid-core/src/dns/linux/systemd_resolved.rs
@@ -1,13 +1,16 @@
use crate::linux::{iface_index, IfaceIndexLookupError};
+use futures::{channel::mpsc, StreamExt};
use std::{
+ collections::BTreeMap,
net::IpAddr,
sync::{
atomic::{AtomicBool, Ordering},
- Arc,
+ Arc, Mutex,
},
thread,
};
-use talpid_dbus::systemd_resolved::{DnsState, SystemdResolved as DbusInterface};
+use talpid_dbus::systemd_resolved::{AsyncHandle, DnsState, SystemdResolved as DbusInterface};
+use talpid_types::ErrorExt;
pub(crate) use talpid_dbus::systemd_resolved::Error as SystemdDbusError;
@@ -20,55 +23,199 @@ pub enum Error {
#[error(display = "Failed to resolve interface index with error {}", _0)]
InterfaceNameError(#[error(source)] IfaceIndexLookupError),
-}
-pub struct SystemdResolved {
- pub dbus_interface: DbusInterface,
- state: Option<SetConfigState>,
+ #[error(display = "Failed to spawn DNS interface monitor")]
+ SpawnInterfaceMonitor(#[error(source)] super::routing::Error),
}
-struct SetConfigState {
- dns_config: Arc<DnsState>,
- watcher_thread: thread::JoinHandle<()>,
- watcher_should_shutdown: Arc<AtomicBool>,
+use super::routing::{DnsConfig, DnsRouteMonitor};
+
+pub struct SystemdResolved {
+ pub dbus_interface: AsyncHandle,
+ current_config: Arc<Mutex<BTreeMap<u32, DnsConfig>>>,
+ initial_states: Arc<Mutex<BTreeMap<u32, DnsState>>>,
+ tunnel_index: u32,
+ route_monitor: Option<(DnsRouteMonitor, tokio::task::JoinHandle<()>)>,
+ watcher: Option<(thread::JoinHandle<()>, Arc<AtomicBool>)>,
}
impl SystemdResolved {
pub fn new() -> Result<Self> {
- let dbus_interface = DbusInterface::new()?;
+ let dbus_interface = DbusInterface::new()?.async_handle();
let systemd_resolved = SystemdResolved {
dbus_interface,
- state: None,
+ current_config: Arc::new(Mutex::new(BTreeMap::new())),
+ initial_states: Arc::new(Mutex::new(BTreeMap::new())),
+ tunnel_index: 0,
+ route_monitor: None,
+ watcher: None,
};
Ok(systemd_resolved)
}
- pub fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> {
- let iface_index = iface_index(interface_name)?;
- let dns_state = self.dbus_interface.set_dns(iface_index, servers)?;
- let dns_config = Arc::new(dns_state);
+ pub async fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> {
+ let (update_tx, mut update_rx) = mpsc::unbounded();
+ let (monitor, initial_config) = super::routing::spawn_monitor(servers.to_vec(), update_tx)
+ .await
+ .map_err(Error::SpawnInterfaceMonitor)?;
+ let tunnel_index = iface_index(interface_name)?;
+ self.tunnel_index = tunnel_index;
+ let mut last_result = Ok(());
- let (watcher_thread, watcher_should_shutdown) =
- self.spawn_watcher_thread(dns_config.clone());
- self.state = Some(SetConfigState {
- dns_config,
- watcher_thread,
- watcher_should_shutdown,
- });
+ {
+ let mut initial_states = self.initial_states.lock().unwrap();
+ for (iface_index, iface_config) in &initial_config {
+ let initial_state = match self.dbus_interface.get_dns(*iface_index).await {
+ Ok(state) => state,
+ Err(error) => {
+ last_result = Err(Error::SystemdResolvedError(error));
+ break;
+ }
+ };
+ if let Err(error) = self
+ .dbus_interface
+ .set_dns(*iface_index, iface_config.resolvers.clone())
+ .await
+ {
+ last_result = Err(Error::SystemdResolvedError(error));
+ break;
+ }
+ initial_states.insert(*iface_index, initial_state);
+ }
+ }
+
+ if last_result.is_ok() {
+ if has_only_tunnel_config(&initial_config, tunnel_index) {
+ if let Err(error) = self
+ .dbus_interface
+ .set_domains(tunnel_index, &[(".", true)])
+ .await
+ {
+ last_result = Err(Error::SystemdResolvedError(error));
+ }
+ } else {
+ if let Err(error) = self.dbus_interface.set_domains(tunnel_index, &[]).await {
+ last_result = Err(Error::SystemdResolvedError(error));
+ }
+ }
+ }
+
+ if let Err(error) = last_result {
+ let _ = self.reset();
+ return Err(error);
+ }
+
+ {
+ *self.current_config.lock().unwrap() = initial_config;
+ }
+
+ let ignore_config_changes = Arc::new(AtomicBool::new(false));
+
+ self.watcher = Some(self.spawn_watcher_thread(
+ tunnel_index,
+ self.current_config.clone(),
+ ignore_config_changes.clone(),
+ ));
+
+ let dbus_interface = self.dbus_interface.clone();
+ let initial_states = self.initial_states.clone();
+ let current_config = self.current_config.clone();
+ let join_handle = tokio::spawn(async move {
+ while let Some(new_config) = update_rx.next().await {
+ let mut new_initial_states = { initial_states.lock().unwrap().clone() };
+
+ let disable_watcher = ignore_config_changes.clone();
+ disable_watcher.store(true, Ordering::Release);
+
+ // Revert interfaces no longer in use
+ let keys = new_initial_states.keys().cloned().collect::<Vec<u32>>();
+ for iface in keys {
+ if !new_config.contains_key(&iface) {
+ log::debug!("Reverting DNS config on interface {}", iface);
+ if let Err(err) = dbus_interface
+ .set_dns_state(new_initial_states[&iface].clone())
+ .await
+ {
+ log::error!("Failed to revert interface config: {}", err);
+ }
+ new_initial_states.remove(&iface);
+ }
+ }
+
+ for (iface, config) in &new_config {
+ if tunnel_index == *iface {
+ // All public addresses (plus the gateway) will be assigned
+ // to the tunnel: we can assume nothing has changed.
+ continue;
+ }
+
+ // Store new interfaces
+ if !new_initial_states.contains_key(iface) {
+ let initial_state = match dbus_interface.get_dns(*iface).await {
+ Ok(state) => state,
+ Err(error) => {
+ log::error!(
+ "Failed to get resolvers: {}\n{}",
+ config,
+ error.display_chain()
+ );
+ continue;
+ }
+ };
+ new_initial_states.insert(*iface, initial_state);
+ }
+ if let Err(error) = dbus_interface
+ .set_dns(*iface, config.resolvers.clone())
+ .await
+ {
+ log::error!(
+ "Failed to set resolvers: {}\n{}",
+ config,
+ error.display_chain()
+ );
+ }
+ }
+
+ let tunnel_domains = if has_only_tunnel_config(&new_config, tunnel_index) {
+ &[(".", true)][..]
+ } else {
+ &[][..]
+ };
+ if let Err(error) = dbus_interface
+ .set_domains(tunnel_index, tunnel_domains)
+ .await
+ {
+ log::error!(
+ "Failed to set DNS domains on tunnel interface\n{}",
+ error.display_chain()
+ );
+ }
+
+ {
+ *current_config.lock().unwrap() = new_config.clone();
+ *initial_states.lock().unwrap() = new_initial_states;
+ }
+
+ disable_watcher.store(false, Ordering::Release);
+ }
+ });
+ self.route_monitor = Some((monitor, join_handle));
Ok(())
}
fn spawn_watcher_thread(
&mut self,
- dns_state: Arc<DnsState>,
+ tunnel_index: u32,
+ current_config: Arc<Mutex<BTreeMap<u32, DnsConfig>>>,
+ disable_watcher: Arc<AtomicBool>,
) -> (thread::JoinHandle<()>, Arc<AtomicBool>) {
- let dbus_interface = self.dbus_interface.clone();
+ let dbus_interface = self.dbus_interface.handle().clone();
let should_shutdown = Arc::new(AtomicBool::new(false));
let watch_shutdown = should_shutdown.clone();
let callback_shutdown = should_shutdown.clone();
@@ -78,16 +225,34 @@ impl SystemdResolved {
if callback_shutdown.clone().load(Ordering::Acquire) {
return;
}
- let mut current_servers: Vec<IpAddr> = new_servers
- .into_iter()
- .filter(|server| server.iface_index == dns_state.interface_index as i32)
+ if disable_watcher.clone().load(Ordering::Acquire) {
+ return;
+ }
+ let configs = current_config.lock().unwrap();
+ let mut anything_changed = false;
+ for (iface, config) in &*configs {
+ let current_servers: Vec<IpAddr> = new_servers
+ .iter()
+ .filter(|server| server.iface_index == *iface as i32)
.map(|server| server.address)
.collect();
- current_servers.sort();
- if current_servers != *dns_state.set_servers {
- log::debug!("DNS config for tunnel interface changed, currently applied servers - {:?}", current_servers);
- if let Err(err) = dbus_interface.set_dns(dns_state.interface_index, &dns_state.set_servers) {
- log::error!("Failed to re-apply DNS config - {}", err);
+ if current_servers != config.resolvers {
+ log::trace!("DNS config for interface {} changed, currently applied servers - {:?}", iface, current_servers);
+ if let Err(err) = dbus_interface.set_dns(*iface, config.resolvers.clone())
+ {
+ log::error!("Failed to re-apply DNS config - {}", err);
+ }
+ anything_changed = true;
+ }
+ }
+ if anything_changed {
+ let result = if has_only_tunnel_config(&configs, tunnel_index) {
+ dbus_interface.set_domains(tunnel_index, &[(".", true)])
+ } else {
+ dbus_interface.set_domains(tunnel_index, &[])
+ };
+ if let Err(err) = result {
+ log::error!("Failed to re-apply DNS domains - {}", err);
}
}
},
@@ -100,23 +265,41 @@ impl SystemdResolved {
(watcher_thread, should_shutdown)
}
- pub fn reset(&mut self) -> Result<()> {
- if let Some(SetConfigState {
- dns_config,
- watcher_thread,
- watcher_should_shutdown,
- }) = self.state.take()
- {
+ pub async fn reset(&mut self) -> Result<()> {
+ if let Some((watcher_thread, watcher_should_shutdown)) = self.watcher.take() {
watcher_should_shutdown.store(true, Ordering::Release);
- if let Err(err) = self.dbus_interface.revert_link(&dns_config) {
- log::error!("Failed to revert interface config: {}", err);
- }
-
if watcher_thread.join().is_err() {
log::error!("DNS watcher thread panicked!");
}
}
+ if let Some((monitor, join_handle)) = self.route_monitor.take() {
+ std::mem::drop(monitor);
+ let _ = join_handle.await;
+ }
+
+ let mut initial_states = self.initial_states.lock().unwrap();
+ for (iface, state) in &*initial_states {
+ let result = if *iface == self.tunnel_index {
+ self.dbus_interface.revert_link(state.clone()).await
+ } else {
+ self.dbus_interface.set_dns_state(state.clone()).await
+ };
+ if let Err(err) = result {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("Failed to revert interface config")
+ );
+ }
+ }
+ initial_states.clear();
+
+ self.current_config.lock().unwrap().clear();
+
Ok(())
}
}
+
+fn has_only_tunnel_config(configs: &BTreeMap<u32, DnsConfig>, tunnel_index: u32) -> bool {
+ configs.len() == 1 && configs.contains_key(&tunnel_index)
+}
diff --git a/talpid-core/src/dns/macos.rs b/talpid-core/src/dns/macos.rs
index 4c735e9e44..44ba949f59 100644
--- a/talpid-core/src/dns/macos.rs
+++ b/talpid-core/src/dns/macos.rs
@@ -140,7 +140,7 @@ impl super::DnsMonitorT for DnsMonitor {
/// DNS settings for all network interfaces. If any changes occur it will instantly reset
/// the DNS settings for that interface back to the last server list set to this instance
/// with `set_dns`.
- fn new(_cache_dir: impl AsRef<Path>) -> Result<Self> {
+ fn new(_handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> {
let state = Arc::new(Mutex::new(None));
Self::spawn(state.clone())?;
Ok(DnsMonitor {
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs
index 987f3f1092..b07ce13eac 100644
--- a/talpid-core/src/dns/mod.rs
+++ b/talpid-core/src/dns/mod.rs
@@ -28,9 +28,9 @@ pub struct DnsMonitor {
impl DnsMonitor {
/// Returns a new `DnsMonitor` that can set and monitor the system DNS.
- pub fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
+ pub fn new(handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
Ok(DnsMonitor {
- inner: imp::DnsMonitor::new(cache_dir)?,
+ inner: imp::DnsMonitor::new(handle, cache_dir)?,
})
}
@@ -58,7 +58,10 @@ impl DnsMonitor {
trait DnsMonitorT: Sized {
type Error: std::error::Error;
- fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Self::Error>;
+ fn new(
+ handle: tokio::runtime::Handle,
+ cache_dir: impl AsRef<Path>,
+ ) -> Result<Self, Self::Error>;
fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Self::Error>;
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index 59f6ee349d..90ef7552b7 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -50,7 +50,7 @@ pub struct DnsMonitor {}
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
- fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
+ fn new(_handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? };
let backup_writer = SystemStateWriter::new(
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index a3d4333d6e..557ae01412 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -83,7 +83,8 @@ const DHCPV6_CLIENT_PORT: u16 = 546;
#[cfg(all(unix, not(target_os = "android")))]
-fn is_local_address(address: &IpAddr) -> bool {
+/// Returns whether an address belongs to a private subnet.
+pub fn is_local_address(address: &IpAddr) -> bool {
let address = address.clone();
(&*ALLOWED_LAN_NETS)
.iter()
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 1f26f2fa04..1bcc6961d9 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -221,7 +221,8 @@ impl TunnelStateMachine {
};
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
- let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?;
+ let dns_monitor =
+ DnsMonitor::new(runtime.clone(), cache_dir).map_err(Error::InitDnsMonitorError)?;
let route_manager = RouteManager::new(runtime.clone(), HashSet::new())
.map_err(Error::InitRouteManagerError)?;
let mut shared_values = SharedTunnelStateValues {
diff --git a/talpid-dbus/Cargo.toml b/talpid-dbus/Cargo.toml
index 5685daf922..7ae6403254 100644
--- a/talpid-dbus/Cargo.toml
+++ b/talpid-dbus/Cargo.toml
@@ -12,3 +12,4 @@ lazy_static = "1.0"
log = "0.4"
libc = "0.2"
talpid-types = { path = "../talpid-types" }
+tokio = { version = "0.2", features = [ "blocking" ] }
diff --git a/talpid-dbus/src/systemd_resolved.rs b/talpid-dbus/src/systemd_resolved.rs
index 2fd2f14956..2a5b7b2b02 100644
--- a/talpid-dbus/src/systemd_resolved.rs
+++ b/talpid-dbus/src/systemd_resolved.rs
@@ -52,6 +52,9 @@ pub enum Error {
#[error(display = "Failed to remove a match for DNS config updates")]
DnsUpdateRemoveMatchError(#[error(source)] dbus::Error),
+
+ #[error(display = "Async D-Bus task failed")]
+ AsyncTaskError(#[error(source)] tokio::task::JoinError),
}
lazy_static! {
@@ -73,6 +76,7 @@ const RPC_TIMEOUT: Duration = Duration::from_secs(1);
const LINK_INTERFACE: &str = "org.freedesktop.resolve1.Link";
const MANAGER_INTERFACE: &str = "org.freedesktop.resolve1.Manager";
+const DNS_DOMAINS: &str = "Domains";
const DNS_SERVERS: &str = "DNS";
const GET_LINK_METHOD: &str = "GetLink";
const SET_DNS_METHOD: &str = "SetDNS";
@@ -84,12 +88,18 @@ pub struct SystemdResolved {
pub dbus_connection: Arc<SyncConnection>,
}
+#[derive(Clone)]
pub struct DnsState {
pub interface_path: dbus::Path<'static>,
pub interface_index: u32,
pub set_servers: Vec<IpAddr>,
}
+#[derive(Clone)]
+pub struct AsyncHandle {
+ dbus_interface: SystemdResolved,
+}
+
impl SystemdResolved {
pub fn new() -> Result<Self> {
let dbus_connection = crate::get_connection().map_err(Error::ConnectDBus)?;
@@ -218,14 +228,29 @@ impl SystemdResolved {
)
}
- pub fn set_dns(&self, interface_index: u32, servers: &[IpAddr]) -> Result<DnsState> {
+ pub fn get_dns(&self, interface_index: u32) -> Result<DnsState> {
let link_object_path = self
.fetch_link(interface_index)
.map_err(|e| Error::GetLinkError(Box::new(e)))?;
+ let set_servers = self.get_link_dns(&link_object_path)?;
+
+ Ok(DnsState {
+ interface_path: link_object_path,
+ interface_index,
+ set_servers,
+ })
+ }
- let mut set_servers = servers.to_vec();
- set_servers.sort();
- self.set_link_dns(&link_object_path, servers)?;
+ pub fn set_dns_state(&self, state: DnsState) -> Result<()> {
+ self.set_link_dns(&state.interface_path, &state.set_servers)
+ }
+
+ pub fn set_dns(&self, interface_index: u32, servers: Vec<IpAddr>) -> Result<DnsState> {
+ let set_servers = servers.to_vec();
+ let link_object_path = self
+ .fetch_link(interface_index)
+ .map_err(|e| Error::GetLinkError(Box::new(e)))?;
+ self.set_link_dns(&link_object_path, &servers)?;
Ok(DnsState {
interface_path: link_object_path,
interface_index,
@@ -233,6 +258,20 @@ impl SystemdResolved {
})
}
+ pub fn get_domains(&self, interface_index: u32) -> Result<Vec<(String, bool)>> {
+ let link_object_path = self
+ .fetch_link(interface_index)
+ .map_err(|e| Error::GetLinkError(Box::new(e)))?;
+ self.get_link_dns_domains(&link_object_path)
+ }
+
+ pub fn set_domains(&self, interface_index: u32, domains: &[(&str, bool)]) -> Result<()> {
+ let link_object_path = self
+ .fetch_link(interface_index)
+ .map_err(|e| Error::GetLinkError(Box::new(e)))?;
+ self.set_link_dns_domains(&link_object_path, domains)
+ }
+
fn fetch_link(&self, interface_index: u32) -> Result<dbus::Path<'static>> {
self.as_manager_object()
.method_call(
@@ -244,6 +283,21 @@ impl SystemdResolved {
.map(|result: (dbus::Path<'static>,)| result.0)
}
+ fn get_link_dns<'a, 'b: 'a>(
+ &'a self,
+ link_object_path: &'b dbus::Path<'static>,
+ ) -> Result<Vec<IpAddr>> {
+ let servers: Vec<(i32, Vec<u8>)> = self
+ .as_link_object(link_object_path.clone())
+ .get(LINK_INTERFACE, DNS_SERVERS)
+ .map_err(Error::DBusRpcError)?;
+
+ Ok(servers
+ .into_iter()
+ .filter_map(|(_family, addr)| ip_from_bytes(&addr))
+ .collect())
+ }
+
fn set_link_dns<'a, 'b: 'a>(
&'a self,
link_object_path: &'b dbus::Path<'static>,
@@ -255,20 +309,18 @@ impl SystemdResolved {
.collect::<Vec<_>>();
self.as_link_object(link_object_path.clone())
.method_call(LINK_INTERFACE, SET_DNS_METHOD, (servers,))
- .map_err(Error::DBusRpcError)?;
-
- // set the search domain to catch all DNS requests, forces the link to be the prefered
- // resolver, otherwise systemd-resolved will use other interfaces to do DNS lookups
- let dns_domains: &[_] = &[(&".", true)];
+ .map_err(Error::DBusRpcError)
+ }
- Proxy::new(
- RESOLVED_BUS,
- link_object_path,
- RPC_TIMEOUT,
- &*self.dbus_connection,
- )
- .method_call(LINK_INTERFACE, SET_DOMAINS_METHOD, (dns_domains,))
- .map_err(Error::SetDomainsError)
+ fn get_link_dns_domains<'a, 'b: 'a>(
+ &'a self,
+ link_object_path: &'b dbus::Path<'static>,
+ ) -> Result<Vec<(String, bool)>> {
+ let domains: Vec<(String, bool)> = self
+ .as_link_object(link_object_path.clone())
+ .get(LINK_INTERFACE, DNS_DOMAINS)
+ .map_err(Error::DBusRpcError)?;
+ Ok(domains)
}
pub fn revert_link(&mut self, dns_state: &DnsState) -> std::result::Result<(), dbus::Error> {
@@ -289,6 +341,21 @@ impl SystemdResolved {
}
}
+ fn set_link_dns_domains<'a, 'b: 'a>(
+ &'a self,
+ link_object_path: &'b dbus::Path<'static>,
+ domains: &[(&str, bool)],
+ ) -> Result<()> {
+ Proxy::new(
+ RESOLVED_BUS,
+ link_object_path,
+ RPC_TIMEOUT,
+ &*self.dbus_connection,
+ )
+ .method_call(LINK_INTERFACE, SET_DOMAINS_METHOD, (domains,))
+ .map_err(Error::SetDomainsError)
+ }
+
pub fn watch_dns_changes<
F: FnMut(Vec<DnsServer>) + Send + Sync + 'static,
S: Fn() -> bool + Clone + Send + Sync + 'static,
@@ -338,6 +405,10 @@ impl SystemdResolved {
.remove_match(dns_matcher)
.map_err(Error::DnsUpdateRemoveMatchError)
}
+
+ pub fn async_handle(&self) -> AsyncHandle {
+ AsyncHandle::new(self.clone())
+ }
}
#[derive(Debug)]
@@ -398,3 +469,54 @@ fn ip_from_bytes(bytes: &[u8]) -> Option<IpAddr> {
_ => None,
}
}
+
+impl AsyncHandle {
+ fn new(dbus_interface: SystemdResolved) -> Self {
+ Self { dbus_interface }
+ }
+
+ pub async fn get_dns(&self, interface_index: u32) -> Result<DnsState> {
+ let interface = self.dbus_interface.clone();
+ tokio::task::spawn_blocking(move || interface.get_dns(interface_index))
+ .await
+ .map_err(Error::AsyncTaskError)?
+ }
+
+ pub async fn set_dns_state(&self, state: DnsState) -> Result<()> {
+ let interface = self.dbus_interface.clone();
+ tokio::task::spawn_blocking(move || interface.set_dns_state(state))
+ .await
+ .map_err(Error::AsyncTaskError)?
+ }
+
+ pub async fn set_dns(&self, interface_index: u32, servers: Vec<IpAddr>) -> Result<DnsState> {
+ let interface = self.dbus_interface.clone();
+ tokio::task::spawn_blocking(move || interface.set_dns(interface_index, servers))
+ .await
+ .map_err(Error::AsyncTaskError)?
+ }
+
+ pub async fn set_domains(
+ &self,
+ interface_index: u32,
+ domains: &[(&'static str, bool)],
+ ) -> Result<()> {
+ let interface = self.dbus_interface.clone();
+ let domains = domains.to_vec();
+ tokio::task::spawn_blocking(move || interface.set_domains(interface_index, &domains))
+ .await
+ .map_err(Error::AsyncTaskError)?
+ }
+
+ pub async fn revert_link(&self, state: DnsState) -> Result<()> {
+ let mut interface = self.dbus_interface.clone();
+ tokio::task::spawn_blocking(move || interface.revert_link(&state))
+ .await
+ .map_err(Error::AsyncTaskError)?
+ .map_err(Error::DBusRpcError)
+ }
+
+ pub fn handle(&self) -> &SystemdResolved {
+ &self.dbus_interface
+ }
+}