diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-05-21 10:53:18 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-06-09 15:55:07 +0200 |
| commit | f8c4a9879afb26482108ae02f3e65682d7d7ab83 (patch) | |
| tree | 654461afe3da1014b547c0cfe2a7f93f1aff6ccc | |
| parent | 6d9a1ea3a2502b0d4b3145231b7922984d77632a (diff) | |
| download | mullvadvpn-f8c4a9879afb26482108ae02f3e65682d7d7ab83.tar.xz mullvadvpn-f8c4a9879afb26482108ae02f3e65682d7d7ab83.zip | |
Infer and monitor interfaces for DNS config
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-core/src/dns/android.rs | 5 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/mod.rs | 28 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/routing.rs | 259 | ||||
| -rw-r--r-- | talpid-core/src/dns/linux/systemd_resolved.rs | 271 | ||||
| -rw-r--r-- | talpid-core/src/dns/macos.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/mod.rs | 9 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/firewall/mod.rs | 3 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 3 | ||||
| -rw-r--r-- | talpid-dbus/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-dbus/src/systemd_resolved.rs | 156 |
12 files changed, 663 insertions, 77 deletions
diff --git a/Cargo.lock b/Cargo.lock index c1548a9905..1a24bf4a7f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2592,6 +2592,7 @@ dependencies = [ "libc", "log", "talpid-types", + "tokio", ] [[package]] diff --git a/talpid-core/src/dns/android.rs b/talpid-core/src/dns/android.rs index 032960ae1e..517c311113 100644 --- a/talpid-core/src/dns/android.rs +++ b/talpid-core/src/dns/android.rs @@ -10,7 +10,10 @@ pub struct DnsMonitor; impl super::DnsMonitorT for DnsMonitor { type Error = Error; - fn new(_cache_dir: impl AsRef<Path>) -> Result<Self, Self::Error> { + fn new( + _handle: tokio::runtime::Handle, + _cache_dir: impl AsRef<Path>, + ) -> Result<Self, Self::Error> { Ok(DnsMonitor) } diff --git a/talpid-core/src/dns/linux/mod.rs b/talpid-core/src/dns/linux/mod.rs index 96f8402da7..62827211bd 100644 --- a/talpid-core/src/dns/linux/mod.rs +++ b/talpid-core/src/dns/linux/mod.rs @@ -1,5 +1,6 @@ mod network_manager; mod resolvconf; +mod routing; mod static_resolv_conf; pub(self) mod systemd_resolved; @@ -39,28 +40,32 @@ pub enum Error { } pub struct DnsMonitor { + handle: tokio::runtime::Handle, inner: Option<DnsMonitorHolder>, } impl super::DnsMonitorT for DnsMonitor { type Error = Error; - fn new(_cache_dir: impl AsRef<Path>) -> Result<Self> { - Ok(DnsMonitor { inner: None }) + fn new(handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> { + Ok(DnsMonitor { + handle, + inner: None, + }) } fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<()> { self.reset()?; // Creating a new DNS monitor for each set, in case the system changed how it manages DNS. let mut inner = DnsMonitorHolder::new()?; - inner.set(interface, servers)?; + inner.set(&self.handle, interface, servers)?; self.inner = Some(inner); Ok(()) } fn reset(&mut self) -> Result<()> { if let Some(mut inner) = self.inner.take() { - inner.reset()?; + inner.reset(&self.handle)?; } Ok(()) } @@ -120,7 +125,12 @@ impl DnsMonitorHolder { .map_err(|_| Error::NoDnsMonitor) } - fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<()> { + fn set( + &mut self, + handle: &tokio::runtime::Handle, + interface: &str, + servers: &[IpAddr], + ) -> Result<()> { use self::DnsMonitorHolder::*; match self { Resolvconf(ref mut resolvconf) => resolvconf.set_dns(interface, servers)?, @@ -128,7 +138,7 @@ impl DnsMonitorHolder { static_resolv_conf.set_dns(servers.to_vec())? } SystemdResolved(ref mut systemd_resolved) => { - systemd_resolved.set_dns(interface, &servers)? + handle.block_on(systemd_resolved.set_dns(interface, &servers))? } NetworkManager(ref mut network_manager) => { network_manager.set_dns(interface, servers)? @@ -137,12 +147,14 @@ impl DnsMonitorHolder { Ok(()) } - fn reset(&mut self) -> Result<()> { + fn reset(&mut self, handle: &tokio::runtime::Handle) -> Result<()> { use self::DnsMonitorHolder::*; match self { Resolvconf(ref mut resolvconf) => resolvconf.reset()?, StaticResolvConf(ref mut static_resolv_conf) => static_resolv_conf.reset()?, - SystemdResolved(ref mut systemd_resolved) => systemd_resolved.reset()?, + SystemdResolved(ref mut systemd_resolved) => { + handle.block_on(systemd_resolved.reset())? + } NetworkManager(ref mut network_manager) => network_manager.reset()?, } Ok(()) diff --git a/talpid-core/src/dns/linux/routing.rs b/talpid-core/src/dns/linux/routing.rs new file mode 100644 index 0000000000..5c90b57a33 --- /dev/null +++ b/talpid-core/src/dns/linux/routing.rs @@ -0,0 +1,259 @@ +use futures::{ + channel::mpsc::UnboundedSender, future::abortable, FutureExt, StreamExt, TryStream, + TryStreamExt, +}; +use netlink_packet_core::{NetlinkPayload, NLM_F_REQUEST}; +use netlink_packet_route::{ + rtnl::route::nlas::Nla as RouteNla, NetlinkMessage, RouteFlags, RouteMessage, RtnlMessage, +}; +use rtnetlink::{ + constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE, RTMGRP_NOTIFY}, + sys::SocketAddr, + Handle, IpVersion, +}; +use std::{ + collections::BTreeMap, + fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; +use talpid_types::ErrorExt; + +pub type Result<T> = std::result::Result<T, Error>; + +const PUBLIC_INTERNET_ADDRESS_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)); +const PUBLIC_INTERNET_ADDRESS_V6: IpAddr = + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)); + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + #[error(display = "Failed to get a route for an arbitrary IP address")] + GetRouteError(#[error(source)] failure::Compat<rtnetlink::Error>), + + #[error(display = "Failed to connect to bind to netlink socket")] + BindError(#[error(source)] io::Error), + + #[error(display = "No netlink response for route query")] + NoRouteError, + + #[error(display = "Route is missing an output interface")] + RouteNoInterfaceError, +} + +pub struct DnsRouteMonitor { + _handle: rtnetlink::Handle, + stop_tx: Option<futures::channel::oneshot::Sender<()>>, +} + +impl Drop for DnsRouteMonitor { + fn drop(&mut self) { + if let Some(stop_tx) = self.stop_tx.take() { + let _ = stop_tx.send(()); + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct DnsConfig { + pub interface: u32, + pub resolvers: Vec<IpAddr>, +} + +impl fmt::Display for DnsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "interface index {}, resolvers:", self.interface)?; + for server in &self.resolvers { + write!(f, " {}", server)?; + } + Ok(()) + } +} + +pub async fn spawn_monitor( + destinations: Vec<IpAddr>, + update_tx: UnboundedSender<BTreeMap<u32, DnsConfig>>, +) -> Result<(DnsRouteMonitor, BTreeMap<u32, DnsConfig>)> { + let (mut connection, handle, messages) = + rtnetlink::new_connection().expect("Failed to create a netlink connection"); + + let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE | RTMGRP_NOTIFY; + let addr = SocketAddr::new(0, mgroup_flags); + + connection + .socket_mut() + .bind(&addr) + .map_err(Error::BindError)?; + + let (abortable_connection, abort_connection) = abortable(connection); + tokio::spawn(abortable_connection); + + let (stop_tx, stop_rx) = futures::channel::oneshot::channel(); + + let monitor = DnsRouteMonitor { + _handle: handle.clone(), + stop_tx: Some(stop_tx), + }; + + let mut last_config = setup_configurations(&handle, &destinations).await?; + let initial_config = last_config.clone(); + + tokio::spawn(async move { + let mut messages = messages.fuse(); + let mut stop_rx = stop_rx.fuse(); + loop { + futures::select! { + _new_message = messages.next() => { + match setup_configurations(&handle, &destinations).await { + Ok(new_config) => { + if last_config != new_config { + last_config = new_config.clone(); + if update_tx.unbounded_send(new_config).is_err() { + log::trace!("Stopping DNS monitor: channel is closed"); + break; + } + } + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to determine new DNS interface settings" + ) + ); + } + } + }, + _ = stop_rx => break, + } + } + abort_connection.abort(); + }); + + Ok((monitor, initial_config)) +} + +async fn setup_configurations( + handle: &Handle, + destinations: &[IpAddr], +) -> Result<BTreeMap<u32, DnsConfig>> { + let mut interface_to_destinations = BTreeMap::<u32, DnsConfig>::new(); + for destination in destinations { + let interface = if destination.is_loopback() { + get_default_route_interface(handle, get_ip_version(destination), true).await? + } else { + if crate::firewall::is_local_address(&destination) { + get_destination_interface(handle, destination, true).await? + } else { + get_default_route_interface(handle, get_ip_version(destination), false).await? + } + }; + match interface { + Some(iface) => { + if let Some(config) = interface_to_destinations.get_mut(&iface) { + config.resolvers.push(*destination); + } else { + interface_to_destinations.insert( + iface, + DnsConfig { + interface: iface, + resolvers: vec![*destination], + }, + ); + } + } + None => { + log::trace!( + "Ignoring DNS server that did not match to any interface: {}", + destination + ); + } + } + } + + Ok(interface_to_destinations) +} + +async fn get_default_route_interface( + handle: &Handle, + ip_version: IpVersion, + set_mark: bool, +) -> Result<Option<u32>> { + match ip_version { + IpVersion::V4 => { + get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V4, set_mark).await + } + IpVersion::V6 => { + get_destination_interface(handle, &PUBLIC_INTERNET_ADDRESS_V6, set_mark).await + } + } +} + +async fn get_destination_interface( + handle: &Handle, + destination: &IpAddr, + set_mark: bool, +) -> Result<Option<u32>> { + let mut request = handle.route().get(get_ip_version(destination)); + let octets = match destination { + IpAddr::V4(address) => address.octets().to_vec(), + IpAddr::V6(address) => address.octets().to_vec(), + }; + let message = request.message_mut(); + if set_mark { + message + .nlas + .push(RouteNla::Mark(crate::linux::TUNNEL_FW_MARK)); + } + message.header.destination_prefix_length = 8u8 * (octets.len() as u8); + message.header.flags = RouteFlags::RTM_F_FIB_MATCH; + message.nlas.push(RouteNla::Destination(octets)); + let mut stream = execute_route_get_request(handle.clone(), message.clone()); + match stream.try_next().await { + Ok(Some(route_msg)) => { + for nla in &route_msg.nlas { + if let RouteNla::Oif(interface) = nla { + return Ok(Some(*interface)); + } + } + Err(Error::RouteNoInterfaceError) + } + Ok(None) => Err(Error::NoRouteError), + Err(rtnetlink::Error::NetlinkError(nl_err)) if nl_err.code == -libc::ENETUNREACH => { + Ok(None) + } + Err(err) => Err(Error::GetRouteError(failure::Fail::compat(err))), + } +} + +pub fn execute_route_get_request( + mut handle: Handle, + message: RouteMessage, +) -> impl TryStream<Ok = RouteMessage, Error = rtnetlink::Error> { + use futures::future::{self, Either}; + use rtnetlink::Error; + + let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); + req.header.flags = NLM_F_REQUEST; + + match handle.request(req) { + Ok(response) => Either::Left(response.map(move |msg| { + let (header, payload) = msg.into_parts(); + match payload { + NetlinkPayload::InnerMessage(RtnlMessage::NewRoute(msg)) => Ok(msg), + NetlinkPayload::Error(err) => Err(Error::NetlinkError(err)), + _ => Err(Error::UnexpectedMessage(NetlinkMessage::new( + header, payload, + ))), + } + })), + Err(e) => Either::Right(future::err::<RouteMessage, Error>(e).into_stream()), + } +} + +fn get_ip_version(addr: &IpAddr) -> IpVersion { + if addr.is_ipv4() { + IpVersion::V4 + } else { + IpVersion::V6 + } +} diff --git a/talpid-core/src/dns/linux/systemd_resolved.rs b/talpid-core/src/dns/linux/systemd_resolved.rs index 29a8feaf85..66141b5177 100644 --- a/talpid-core/src/dns/linux/systemd_resolved.rs +++ b/talpid-core/src/dns/linux/systemd_resolved.rs @@ -1,13 +1,16 @@ use crate::linux::{iface_index, IfaceIndexLookupError}; +use futures::{channel::mpsc, StreamExt}; use std::{ + collections::BTreeMap, net::IpAddr, sync::{ atomic::{AtomicBool, Ordering}, - Arc, + Arc, Mutex, }, thread, }; -use talpid_dbus::systemd_resolved::{DnsState, SystemdResolved as DbusInterface}; +use talpid_dbus::systemd_resolved::{AsyncHandle, DnsState, SystemdResolved as DbusInterface}; +use talpid_types::ErrorExt; pub(crate) use talpid_dbus::systemd_resolved::Error as SystemdDbusError; @@ -20,55 +23,199 @@ pub enum Error { #[error(display = "Failed to resolve interface index with error {}", _0)] InterfaceNameError(#[error(source)] IfaceIndexLookupError), -} -pub struct SystemdResolved { - pub dbus_interface: DbusInterface, - state: Option<SetConfigState>, + #[error(display = "Failed to spawn DNS interface monitor")] + SpawnInterfaceMonitor(#[error(source)] super::routing::Error), } -struct SetConfigState { - dns_config: Arc<DnsState>, - watcher_thread: thread::JoinHandle<()>, - watcher_should_shutdown: Arc<AtomicBool>, +use super::routing::{DnsConfig, DnsRouteMonitor}; + +pub struct SystemdResolved { + pub dbus_interface: AsyncHandle, + current_config: Arc<Mutex<BTreeMap<u32, DnsConfig>>>, + initial_states: Arc<Mutex<BTreeMap<u32, DnsState>>>, + tunnel_index: u32, + route_monitor: Option<(DnsRouteMonitor, tokio::task::JoinHandle<()>)>, + watcher: Option<(thread::JoinHandle<()>, Arc<AtomicBool>)>, } impl SystemdResolved { pub fn new() -> Result<Self> { - let dbus_interface = DbusInterface::new()?; + let dbus_interface = DbusInterface::new()?.async_handle(); let systemd_resolved = SystemdResolved { dbus_interface, - state: None, + current_config: Arc::new(Mutex::new(BTreeMap::new())), + initial_states: Arc::new(Mutex::new(BTreeMap::new())), + tunnel_index: 0, + route_monitor: None, + watcher: None, }; Ok(systemd_resolved) } - pub fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> { - let iface_index = iface_index(interface_name)?; - let dns_state = self.dbus_interface.set_dns(iface_index, servers)?; - let dns_config = Arc::new(dns_state); + pub async fn set_dns(&mut self, interface_name: &str, servers: &[IpAddr]) -> Result<()> { + let (update_tx, mut update_rx) = mpsc::unbounded(); + let (monitor, initial_config) = super::routing::spawn_monitor(servers.to_vec(), update_tx) + .await + .map_err(Error::SpawnInterfaceMonitor)?; + let tunnel_index = iface_index(interface_name)?; + self.tunnel_index = tunnel_index; + let mut last_result = Ok(()); - let (watcher_thread, watcher_should_shutdown) = - self.spawn_watcher_thread(dns_config.clone()); - self.state = Some(SetConfigState { - dns_config, - watcher_thread, - watcher_should_shutdown, - }); + { + let mut initial_states = self.initial_states.lock().unwrap(); + for (iface_index, iface_config) in &initial_config { + let initial_state = match self.dbus_interface.get_dns(*iface_index).await { + Ok(state) => state, + Err(error) => { + last_result = Err(Error::SystemdResolvedError(error)); + break; + } + }; + if let Err(error) = self + .dbus_interface + .set_dns(*iface_index, iface_config.resolvers.clone()) + .await + { + last_result = Err(Error::SystemdResolvedError(error)); + break; + } + initial_states.insert(*iface_index, initial_state); + } + } + + if last_result.is_ok() { + if has_only_tunnel_config(&initial_config, tunnel_index) { + if let Err(error) = self + .dbus_interface + .set_domains(tunnel_index, &[(".", true)]) + .await + { + last_result = Err(Error::SystemdResolvedError(error)); + } + } else { + if let Err(error) = self.dbus_interface.set_domains(tunnel_index, &[]).await { + last_result = Err(Error::SystemdResolvedError(error)); + } + } + } + + if let Err(error) = last_result { + let _ = self.reset(); + return Err(error); + } + + { + *self.current_config.lock().unwrap() = initial_config; + } + + let ignore_config_changes = Arc::new(AtomicBool::new(false)); + + self.watcher = Some(self.spawn_watcher_thread( + tunnel_index, + self.current_config.clone(), + ignore_config_changes.clone(), + )); + + let dbus_interface = self.dbus_interface.clone(); + let initial_states = self.initial_states.clone(); + let current_config = self.current_config.clone(); + let join_handle = tokio::spawn(async move { + while let Some(new_config) = update_rx.next().await { + let mut new_initial_states = { initial_states.lock().unwrap().clone() }; + + let disable_watcher = ignore_config_changes.clone(); + disable_watcher.store(true, Ordering::Release); + + // Revert interfaces no longer in use + let keys = new_initial_states.keys().cloned().collect::<Vec<u32>>(); + for iface in keys { + if !new_config.contains_key(&iface) { + log::debug!("Reverting DNS config on interface {}", iface); + if let Err(err) = dbus_interface + .set_dns_state(new_initial_states[&iface].clone()) + .await + { + log::error!("Failed to revert interface config: {}", err); + } + new_initial_states.remove(&iface); + } + } + + for (iface, config) in &new_config { + if tunnel_index == *iface { + // All public addresses (plus the gateway) will be assigned + // to the tunnel: we can assume nothing has changed. + continue; + } + + // Store new interfaces + if !new_initial_states.contains_key(iface) { + let initial_state = match dbus_interface.get_dns(*iface).await { + Ok(state) => state, + Err(error) => { + log::error!( + "Failed to get resolvers: {}\n{}", + config, + error.display_chain() + ); + continue; + } + }; + new_initial_states.insert(*iface, initial_state); + } + if let Err(error) = dbus_interface + .set_dns(*iface, config.resolvers.clone()) + .await + { + log::error!( + "Failed to set resolvers: {}\n{}", + config, + error.display_chain() + ); + } + } + + let tunnel_domains = if has_only_tunnel_config(&new_config, tunnel_index) { + &[(".", true)][..] + } else { + &[][..] + }; + if let Err(error) = dbus_interface + .set_domains(tunnel_index, tunnel_domains) + .await + { + log::error!( + "Failed to set DNS domains on tunnel interface\n{}", + error.display_chain() + ); + } + + { + *current_config.lock().unwrap() = new_config.clone(); + *initial_states.lock().unwrap() = new_initial_states; + } + + disable_watcher.store(false, Ordering::Release); + } + }); + self.route_monitor = Some((monitor, join_handle)); Ok(()) } fn spawn_watcher_thread( &mut self, - dns_state: Arc<DnsState>, + tunnel_index: u32, + current_config: Arc<Mutex<BTreeMap<u32, DnsConfig>>>, + disable_watcher: Arc<AtomicBool>, ) -> (thread::JoinHandle<()>, Arc<AtomicBool>) { - let dbus_interface = self.dbus_interface.clone(); + let dbus_interface = self.dbus_interface.handle().clone(); let should_shutdown = Arc::new(AtomicBool::new(false)); let watch_shutdown = should_shutdown.clone(); let callback_shutdown = should_shutdown.clone(); @@ -78,16 +225,34 @@ impl SystemdResolved { if callback_shutdown.clone().load(Ordering::Acquire) { return; } - let mut current_servers: Vec<IpAddr> = new_servers - .into_iter() - .filter(|server| server.iface_index == dns_state.interface_index as i32) + if disable_watcher.clone().load(Ordering::Acquire) { + return; + } + let configs = current_config.lock().unwrap(); + let mut anything_changed = false; + for (iface, config) in &*configs { + let current_servers: Vec<IpAddr> = new_servers + .iter() + .filter(|server| server.iface_index == *iface as i32) .map(|server| server.address) .collect(); - current_servers.sort(); - if current_servers != *dns_state.set_servers { - log::debug!("DNS config for tunnel interface changed, currently applied servers - {:?}", current_servers); - if let Err(err) = dbus_interface.set_dns(dns_state.interface_index, &dns_state.set_servers) { - log::error!("Failed to re-apply DNS config - {}", err); + if current_servers != config.resolvers { + log::trace!("DNS config for interface {} changed, currently applied servers - {:?}", iface, current_servers); + if let Err(err) = dbus_interface.set_dns(*iface, config.resolvers.clone()) + { + log::error!("Failed to re-apply DNS config - {}", err); + } + anything_changed = true; + } + } + if anything_changed { + let result = if has_only_tunnel_config(&configs, tunnel_index) { + dbus_interface.set_domains(tunnel_index, &[(".", true)]) + } else { + dbus_interface.set_domains(tunnel_index, &[]) + }; + if let Err(err) = result { + log::error!("Failed to re-apply DNS domains - {}", err); } } }, @@ -100,23 +265,41 @@ impl SystemdResolved { (watcher_thread, should_shutdown) } - pub fn reset(&mut self) -> Result<()> { - if let Some(SetConfigState { - dns_config, - watcher_thread, - watcher_should_shutdown, - }) = self.state.take() - { + pub async fn reset(&mut self) -> Result<()> { + if let Some((watcher_thread, watcher_should_shutdown)) = self.watcher.take() { watcher_should_shutdown.store(true, Ordering::Release); - if let Err(err) = self.dbus_interface.revert_link(&dns_config) { - log::error!("Failed to revert interface config: {}", err); - } - if watcher_thread.join().is_err() { log::error!("DNS watcher thread panicked!"); } } + if let Some((monitor, join_handle)) = self.route_monitor.take() { + std::mem::drop(monitor); + let _ = join_handle.await; + } + + let mut initial_states = self.initial_states.lock().unwrap(); + for (iface, state) in &*initial_states { + let result = if *iface == self.tunnel_index { + self.dbus_interface.revert_link(state.clone()).await + } else { + self.dbus_interface.set_dns_state(state.clone()).await + }; + if let Err(err) = result { + log::error!( + "{}", + err.display_chain_with_msg("Failed to revert interface config") + ); + } + } + initial_states.clear(); + + self.current_config.lock().unwrap().clear(); + Ok(()) } } + +fn has_only_tunnel_config(configs: &BTreeMap<u32, DnsConfig>, tunnel_index: u32) -> bool { + configs.len() == 1 && configs.contains_key(&tunnel_index) +} diff --git a/talpid-core/src/dns/macos.rs b/talpid-core/src/dns/macos.rs index 4c735e9e44..44ba949f59 100644 --- a/talpid-core/src/dns/macos.rs +++ b/talpid-core/src/dns/macos.rs @@ -140,7 +140,7 @@ impl super::DnsMonitorT for DnsMonitor { /// DNS settings for all network interfaces. If any changes occur it will instantly reset /// the DNS settings for that interface back to the last server list set to this instance /// with `set_dns`. - fn new(_cache_dir: impl AsRef<Path>) -> Result<Self> { + fn new(_handle: tokio::runtime::Handle, _cache_dir: impl AsRef<Path>) -> Result<Self> { let state = Arc::new(Mutex::new(None)); Self::spawn(state.clone())?; Ok(DnsMonitor { diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs index 987f3f1092..b07ce13eac 100644 --- a/talpid-core/src/dns/mod.rs +++ b/talpid-core/src/dns/mod.rs @@ -28,9 +28,9 @@ pub struct DnsMonitor { impl DnsMonitor { /// Returns a new `DnsMonitor` that can set and monitor the system DNS. - pub fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> { + pub fn new(handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> { Ok(DnsMonitor { - inner: imp::DnsMonitor::new(cache_dir)?, + inner: imp::DnsMonitor::new(handle, cache_dir)?, }) } @@ -58,7 +58,10 @@ impl DnsMonitor { trait DnsMonitorT: Sized { type Error: std::error::Error; - fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Self::Error>; + fn new( + handle: tokio::runtime::Handle, + cache_dir: impl AsRef<Path>, + ) -> Result<Self, Self::Error>; fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Self::Error>; diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index 59f6ee349d..90ef7552b7 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -50,7 +50,7 @@ pub struct DnsMonitor {} impl super::DnsMonitorT for DnsMonitor { type Error = Error; - fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> { + fn new(_handle: tokio::runtime::Handle, cache_dir: impl AsRef<Path>) -> Result<Self, Error> { unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? }; let backup_writer = SystemStateWriter::new( diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index a3d4333d6e..557ae01412 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -83,7 +83,8 @@ const DHCPV6_CLIENT_PORT: u16 = 546; #[cfg(all(unix, not(target_os = "android")))] -fn is_local_address(address: &IpAddr) -> bool { +/// Returns whether an address belongs to a private subnet. +pub fn is_local_address(address: &IpAddr) -> bool { let address = address.clone(); (&*ALLOWED_LAN_NETS) .iter() diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 1f26f2fa04..1bcc6961d9 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -221,7 +221,8 @@ impl TunnelStateMachine { }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; - let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?; + let dns_monitor = + DnsMonitor::new(runtime.clone(), cache_dir).map_err(Error::InitDnsMonitorError)?; let route_manager = RouteManager::new(runtime.clone(), HashSet::new()) .map_err(Error::InitRouteManagerError)?; let mut shared_values = SharedTunnelStateValues { diff --git a/talpid-dbus/Cargo.toml b/talpid-dbus/Cargo.toml index 5685daf922..7ae6403254 100644 --- a/talpid-dbus/Cargo.toml +++ b/talpid-dbus/Cargo.toml @@ -12,3 +12,4 @@ lazy_static = "1.0" log = "0.4" libc = "0.2" talpid-types = { path = "../talpid-types" } +tokio = { version = "0.2", features = [ "blocking" ] } diff --git a/talpid-dbus/src/systemd_resolved.rs b/talpid-dbus/src/systemd_resolved.rs index 2fd2f14956..2a5b7b2b02 100644 --- a/talpid-dbus/src/systemd_resolved.rs +++ b/talpid-dbus/src/systemd_resolved.rs @@ -52,6 +52,9 @@ pub enum Error { #[error(display = "Failed to remove a match for DNS config updates")] DnsUpdateRemoveMatchError(#[error(source)] dbus::Error), + + #[error(display = "Async D-Bus task failed")] + AsyncTaskError(#[error(source)] tokio::task::JoinError), } lazy_static! { @@ -73,6 +76,7 @@ const RPC_TIMEOUT: Duration = Duration::from_secs(1); const LINK_INTERFACE: &str = "org.freedesktop.resolve1.Link"; const MANAGER_INTERFACE: &str = "org.freedesktop.resolve1.Manager"; +const DNS_DOMAINS: &str = "Domains"; const DNS_SERVERS: &str = "DNS"; const GET_LINK_METHOD: &str = "GetLink"; const SET_DNS_METHOD: &str = "SetDNS"; @@ -84,12 +88,18 @@ pub struct SystemdResolved { pub dbus_connection: Arc<SyncConnection>, } +#[derive(Clone)] pub struct DnsState { pub interface_path: dbus::Path<'static>, pub interface_index: u32, pub set_servers: Vec<IpAddr>, } +#[derive(Clone)] +pub struct AsyncHandle { + dbus_interface: SystemdResolved, +} + impl SystemdResolved { pub fn new() -> Result<Self> { let dbus_connection = crate::get_connection().map_err(Error::ConnectDBus)?; @@ -218,14 +228,29 @@ impl SystemdResolved { ) } - pub fn set_dns(&self, interface_index: u32, servers: &[IpAddr]) -> Result<DnsState> { + pub fn get_dns(&self, interface_index: u32) -> Result<DnsState> { let link_object_path = self .fetch_link(interface_index) .map_err(|e| Error::GetLinkError(Box::new(e)))?; + let set_servers = self.get_link_dns(&link_object_path)?; + + Ok(DnsState { + interface_path: link_object_path, + interface_index, + set_servers, + }) + } - let mut set_servers = servers.to_vec(); - set_servers.sort(); - self.set_link_dns(&link_object_path, servers)?; + pub fn set_dns_state(&self, state: DnsState) -> Result<()> { + self.set_link_dns(&state.interface_path, &state.set_servers) + } + + pub fn set_dns(&self, interface_index: u32, servers: Vec<IpAddr>) -> Result<DnsState> { + let set_servers = servers.to_vec(); + let link_object_path = self + .fetch_link(interface_index) + .map_err(|e| Error::GetLinkError(Box::new(e)))?; + self.set_link_dns(&link_object_path, &servers)?; Ok(DnsState { interface_path: link_object_path, interface_index, @@ -233,6 +258,20 @@ impl SystemdResolved { }) } + pub fn get_domains(&self, interface_index: u32) -> Result<Vec<(String, bool)>> { + let link_object_path = self + .fetch_link(interface_index) + .map_err(|e| Error::GetLinkError(Box::new(e)))?; + self.get_link_dns_domains(&link_object_path) + } + + pub fn set_domains(&self, interface_index: u32, domains: &[(&str, bool)]) -> Result<()> { + let link_object_path = self + .fetch_link(interface_index) + .map_err(|e| Error::GetLinkError(Box::new(e)))?; + self.set_link_dns_domains(&link_object_path, domains) + } + fn fetch_link(&self, interface_index: u32) -> Result<dbus::Path<'static>> { self.as_manager_object() .method_call( @@ -244,6 +283,21 @@ impl SystemdResolved { .map(|result: (dbus::Path<'static>,)| result.0) } + fn get_link_dns<'a, 'b: 'a>( + &'a self, + link_object_path: &'b dbus::Path<'static>, + ) -> Result<Vec<IpAddr>> { + let servers: Vec<(i32, Vec<u8>)> = self + .as_link_object(link_object_path.clone()) + .get(LINK_INTERFACE, DNS_SERVERS) + .map_err(Error::DBusRpcError)?; + + Ok(servers + .into_iter() + .filter_map(|(_family, addr)| ip_from_bytes(&addr)) + .collect()) + } + fn set_link_dns<'a, 'b: 'a>( &'a self, link_object_path: &'b dbus::Path<'static>, @@ -255,20 +309,18 @@ impl SystemdResolved { .collect::<Vec<_>>(); self.as_link_object(link_object_path.clone()) .method_call(LINK_INTERFACE, SET_DNS_METHOD, (servers,)) - .map_err(Error::DBusRpcError)?; - - // set the search domain to catch all DNS requests, forces the link to be the prefered - // resolver, otherwise systemd-resolved will use other interfaces to do DNS lookups - let dns_domains: &[_] = &[(&".", true)]; + .map_err(Error::DBusRpcError) + } - Proxy::new( - RESOLVED_BUS, - link_object_path, - RPC_TIMEOUT, - &*self.dbus_connection, - ) - .method_call(LINK_INTERFACE, SET_DOMAINS_METHOD, (dns_domains,)) - .map_err(Error::SetDomainsError) + fn get_link_dns_domains<'a, 'b: 'a>( + &'a self, + link_object_path: &'b dbus::Path<'static>, + ) -> Result<Vec<(String, bool)>> { + let domains: Vec<(String, bool)> = self + .as_link_object(link_object_path.clone()) + .get(LINK_INTERFACE, DNS_DOMAINS) + .map_err(Error::DBusRpcError)?; + Ok(domains) } pub fn revert_link(&mut self, dns_state: &DnsState) -> std::result::Result<(), dbus::Error> { @@ -289,6 +341,21 @@ impl SystemdResolved { } } + fn set_link_dns_domains<'a, 'b: 'a>( + &'a self, + link_object_path: &'b dbus::Path<'static>, + domains: &[(&str, bool)], + ) -> Result<()> { + Proxy::new( + RESOLVED_BUS, + link_object_path, + RPC_TIMEOUT, + &*self.dbus_connection, + ) + .method_call(LINK_INTERFACE, SET_DOMAINS_METHOD, (domains,)) + .map_err(Error::SetDomainsError) + } + pub fn watch_dns_changes< F: FnMut(Vec<DnsServer>) + Send + Sync + 'static, S: Fn() -> bool + Clone + Send + Sync + 'static, @@ -338,6 +405,10 @@ impl SystemdResolved { .remove_match(dns_matcher) .map_err(Error::DnsUpdateRemoveMatchError) } + + pub fn async_handle(&self) -> AsyncHandle { + AsyncHandle::new(self.clone()) + } } #[derive(Debug)] @@ -398,3 +469,54 @@ fn ip_from_bytes(bytes: &[u8]) -> Option<IpAddr> { _ => None, } } + +impl AsyncHandle { + fn new(dbus_interface: SystemdResolved) -> Self { + Self { dbus_interface } + } + + pub async fn get_dns(&self, interface_index: u32) -> Result<DnsState> { + let interface = self.dbus_interface.clone(); + tokio::task::spawn_blocking(move || interface.get_dns(interface_index)) + .await + .map_err(Error::AsyncTaskError)? + } + + pub async fn set_dns_state(&self, state: DnsState) -> Result<()> { + let interface = self.dbus_interface.clone(); + tokio::task::spawn_blocking(move || interface.set_dns_state(state)) + .await + .map_err(Error::AsyncTaskError)? + } + + pub async fn set_dns(&self, interface_index: u32, servers: Vec<IpAddr>) -> Result<DnsState> { + let interface = self.dbus_interface.clone(); + tokio::task::spawn_blocking(move || interface.set_dns(interface_index, servers)) + .await + .map_err(Error::AsyncTaskError)? + } + + pub async fn set_domains( + &self, + interface_index: u32, + domains: &[(&'static str, bool)], + ) -> Result<()> { + let interface = self.dbus_interface.clone(); + let domains = domains.to_vec(); + tokio::task::spawn_blocking(move || interface.set_domains(interface_index, &domains)) + .await + .map_err(Error::AsyncTaskError)? + } + + pub async fn revert_link(&self, state: DnsState) -> Result<()> { + let mut interface = self.dbus_interface.clone(); + tokio::task::spawn_blocking(move || interface.revert_link(&state)) + .await + .map_err(Error::AsyncTaskError)? + .map_err(Error::DBusRpcError) + } + + pub fn handle(&self) -> &SystemdResolved { + &self.dbus_interface + } +} |
