summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2023-01-09 12:08:46 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-01-10 11:58:51 +0100
commitf4fd14dea00cf53b7e0bfc32008cee03ed827bff (patch)
tree52b11acaf33ffebf5f382be76baa5b593c093ce2
parentf7ad580f5e72d12a63014d0ca295fa1b6cff10bd (diff)
downloadmullvadvpn-f4fd14dea00cf53b7e0bfc32008cee03ed827bff.tar.xz
mullvadvpn-f4fd14dea00cf53b7e0bfc32008cee03ed827bff.zip
Keep registry DNS manager implementation as a TALPID_DNS_MODULE option
-rw-r--r--talpid-core/src/dns/mod.rs2
-rw-r--r--talpid-core/src/dns/windows/dnsapi.rs132
-rw-r--r--talpid-core/src/dns/windows/mod.rs76
-rw-r--r--talpid-core/src/dns/windows/netsh.rs (renamed from talpid-core/src/dns/windows.rs)3
-rw-r--r--talpid-core/src/dns/windows/tcpip.rs156
5 files changed, 367 insertions, 2 deletions
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs
index 1a08ec265c..39660078c3 100644
--- a/talpid-core/src/dns/mod.rs
+++ b/talpid-core/src/dns/mod.rs
@@ -20,7 +20,7 @@ mod imp;
pub use imp::will_use_nm;
#[cfg(windows)]
-#[path = "windows.rs"]
+#[path = "windows/mod.rs"]
mod imp;
#[cfg(target_os = "android")]
diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs
new file mode 100644
index 0000000000..e48f37258e
--- /dev/null
+++ b/talpid-core/src/dns/windows/dnsapi.rs
@@ -0,0 +1,132 @@
+use once_cell::sync::OnceCell;
+use std::{
+ io,
+ sync::{
+ atomic::{AtomicUsize, Ordering},
+ mpsc, Arc,
+ },
+ time::{Duration, Instant},
+};
+use windows_sys::{
+ w,
+ Win32::{
+ Foundation::BOOL,
+ System::LibraryLoader::{
+ FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32,
+ },
+ },
+};
+
+type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL;
+
+static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new();
+static FLUSH_TIMEOUT: Duration = Duration::from_secs(5);
+
+const MAX_CONCURRENT_FLUSHES: usize = 5;
+
+/// Errors that can happen when configuring DNS on Windows.
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to load dnsapi.dll.
+ #[error(display = "Failed to load dnsapi.dll")]
+ LoadDll(#[error(source)] io::Error),
+
+ /// Failed to obtain exported function.
+ #[error(display = "Failed to obtain flush function")]
+ GetFunction(#[error(source)] io::Error),
+
+ /// Failed to flush the DNS cache.
+ #[error(display = "Call to flush DNS cache failed")]
+ FlushCache,
+
+ /// Too many flush attempts in progress.
+ #[error(display = "Too many flush attempts in progress")]
+ TooManyFlushAttempts,
+
+ /// Flushing the DNS cache timed out.
+ #[error(display = "Timeout while flushing DNS cache")]
+ Timeout,
+}
+
+pub fn flush_resolver_cache() -> Result<(), Error> {
+ DNSAPI_HANDLE
+ .get_or_try_init(|| DnsApi::new())?
+ .flush_cache()
+}
+
+struct DnsApi {
+ in_flight_flush_count: Arc<AtomicUsize>,
+ flush_fn: FlushResolverCacheFn,
+}
+
+unsafe impl Send for DnsApi {}
+unsafe impl Sync for DnsApi {}
+
+impl DnsApi {
+ fn new() -> Result<Self, Error> {
+ let handle = unsafe { LoadLibraryExW(w!("dnsapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) };
+ if handle == 0 {
+ return Err(Error::LoadDll(io::Error::last_os_error()));
+ }
+
+ let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const u8) };
+ let flush_fn = flush_fn.ok_or_else(|| {
+ let error = io::Error::last_os_error();
+ unsafe { FreeLibrary(handle) };
+ Error::GetFunction(error)
+ })?;
+
+ Ok(DnsApi {
+ in_flight_flush_count: Arc::new(AtomicUsize::new(0)),
+ flush_fn: unsafe { *(&flush_fn as *const _ as *const _) },
+ })
+ }
+
+ fn flush_cache(&self) -> Result<(), Error> {
+ if self
+ .in_flight_flush_count
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
+ if val >= MAX_CONCURRENT_FLUSHES {
+ return None;
+ }
+ Some(val + 1)
+ })
+ .is_err()
+ {
+ return Err(Error::TooManyFlushAttempts);
+ }
+
+ let (tx, rx) = mpsc::channel();
+ let flush_count = self.in_flight_flush_count.clone();
+
+ let flush_fn = self.flush_fn;
+
+ std::thread::spawn(move || {
+ let begin = Instant::now();
+
+ let result = if unsafe { (flush_fn)() } != 0 {
+ let elapsed = begin.elapsed();
+ if elapsed >= FLUSH_TIMEOUT {
+ log::warn!(
+ "Flushing system DNS cache took {} seconds",
+ elapsed.as_secs()
+ );
+ } else {
+ log::debug!("Flushed system DNS cache");
+ }
+ Ok(())
+ } else {
+ Err(Error::FlushCache)
+ };
+ let _ = tx.send(result);
+
+ flush_count.fetch_sub(1, Ordering::SeqCst);
+ });
+
+ match rx.recv_timeout(FLUSH_TIMEOUT) {
+ Ok(result) => result,
+ Err(_timeout_err) => Err(Error::Timeout),
+ }
+ }
+}
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
new file mode 100644
index 0000000000..10ce85c055
--- /dev/null
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -0,0 +1,76 @@
+use std::{env, fmt, net::IpAddr};
+
+mod dnsapi;
+mod netsh;
+mod tcpip;
+
+/// Errors that can happen when configuring DNS on Windows.
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// Failed to set DNS config using the netsh module.
+ #[error(display = "Error in netsh module")]
+ Netsh(#[error(source)] netsh::Error),
+
+ /// Failed to set DNS config using the tcpip module.
+ #[error(display = "Error in tcpip module")]
+ Tcpip(#[error(source)] tcpip::Error),
+}
+
+pub struct DnsMonitor {
+ inner: DnsMonitorHolder,
+}
+
+impl super::DnsMonitorT for DnsMonitor {
+ type Error = Error;
+
+ fn new() -> Result<Self, Error> {
+ let dns_module = env::var_os("TALPID_DNS_MODULE");
+
+ let inner = match dns_module.as_ref().and_then(|value| value.to_str()) {
+ Some("tcpip") => DnsMonitorHolder::Tcpip(tcpip::DnsMonitor::new()?),
+ Some(_) | None => DnsMonitorHolder::Netsh(netsh::DnsMonitor::new()?),
+ };
+
+ log::debug!("DNS monitor: {}", inner);
+
+ Ok(DnsMonitor { inner })
+ }
+
+ fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> {
+ match self.inner {
+ DnsMonitorHolder::Netsh(ref mut inner) => inner.set(interface, servers)?,
+ DnsMonitorHolder::Tcpip(ref mut inner) => inner.set(interface, servers)?,
+ }
+ Ok(())
+ }
+
+ fn reset(&mut self) -> Result<(), Error> {
+ match self.inner {
+ DnsMonitorHolder::Netsh(ref mut inner) => inner.reset()?,
+ DnsMonitorHolder::Tcpip(ref mut inner) => inner.reset()?,
+ }
+ Ok(())
+ }
+
+ fn reset_before_interface_removal(&mut self) -> Result<(), Error> {
+ match self.inner {
+ DnsMonitorHolder::Netsh(ref mut inner) => inner.reset_before_interface_removal()?,
+ DnsMonitorHolder::Tcpip(ref mut inner) => inner.reset_before_interface_removal()?,
+ }
+ Ok(())
+ }
+}
+
+enum DnsMonitorHolder {
+ Netsh(netsh::DnsMonitor),
+ Tcpip(tcpip::DnsMonitor),
+}
+
+impl fmt::Display for DnsMonitorHolder {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ DnsMonitorHolder::Netsh(_) => f.write_str("netsh"),
+ DnsMonitorHolder::Tcpip(_) => f.write_str("TCP/IP registry parameter"),
+ }
+ }
+}
diff --git a/talpid-core/src/dns/windows.rs b/talpid-core/src/dns/windows/netsh.rs
index 68742b9266..7c0450f855 100644
--- a/talpid-core/src/dns/windows.rs
+++ b/talpid-core/src/dns/windows/netsh.rs
@@ -1,3 +1,4 @@
+use crate::dns::DnsMonitorT;
use std::{
ffi::OsString,
io::{self, Write},
@@ -60,7 +61,7 @@ pub struct DnsMonitor {
current_index: Option<u32>,
}
-impl super::DnsMonitorT for DnsMonitor {
+impl DnsMonitorT for DnsMonitor {
type Error = Error;
fn new() -> Result<Self, Error> {
diff --git a/talpid-core/src/dns/windows/tcpip.rs b/talpid-core/src/dns/windows/tcpip.rs
new file mode 100644
index 0000000000..f7536eaed1
--- /dev/null
+++ b/talpid-core/src/dns/windows/tcpip.rs
@@ -0,0 +1,156 @@
+use crate::dns::DnsMonitorT;
+use std::{io, net::IpAddr};
+use talpid_types::ErrorExt;
+use talpid_windows_net::{guid_from_luid, luid_from_alias};
+use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2};
+use winreg::{
+ enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE},
+ transaction::Transaction,
+ RegKey,
+};
+
+/// Errors that can happen when configuring DNS on Windows.
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failure to obtain an interface LUID given an alias.
+ #[error(display = "Failed to obtain LUID for the interface alias")]
+ InterfaceLuidError(#[error(source)] io::Error),
+
+ /// Failure to obtain an interface GUID.
+ #[error(display = "Failed to obtain GUID for the interface")]
+ InterfaceGuidError(#[error(source)] io::Error),
+
+ /// Failure to flush DNS cache.
+ #[error(display = "Failed to flush DNS resolver cache")]
+ FlushResolverCacheError(#[error(source)] super::dnsapi::Error),
+
+ /// Failed to update DNS servers for interface.
+ #[error(display = "Failed to update interface DNS servers")]
+ SetResolversError(#[error(source)] io::Error),
+}
+
+pub struct DnsMonitor {
+ current_guid: Option<GUID>,
+}
+
+impl DnsMonitorT for DnsMonitor {
+ type Error = Error;
+
+ fn new() -> Result<Self, Error> {
+ Ok(DnsMonitor { current_guid: None })
+ }
+
+ fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> {
+ let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?)
+ .map_err(Error::InterfaceGuidError)?;
+ set_dns(&guid, servers)?;
+ self.current_guid = Some(guid);
+ flush_dns_cache()?;
+ Ok(())
+ }
+
+ fn reset(&mut self) -> Result<(), Error> {
+ if let Some(guid) = self.current_guid.take() {
+ return set_dns(&guid, &[]).and(flush_dns_cache());
+ }
+ Ok(())
+ }
+}
+
+fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> {
+ let transaction = Transaction::new().map_err(Error::SetResolversError)?;
+ let result = match set_dns_inner(&transaction, interface, servers) {
+ Ok(()) => transaction.commit(),
+ Err(error) => transaction.rollback().and(Err(error)),
+ };
+ result.map_err(Error::SetResolversError)
+}
+
+fn set_dns_inner(
+ transaction: &Transaction,
+ interface: &GUID,
+ servers: &[IpAddr],
+) -> io::Result<()> {
+ let guid_str = string_from_guid(interface);
+
+ config_interface(
+ transaction,
+ &guid_str,
+ "Tcpip",
+ servers.iter().filter(|addr| addr.is_ipv4()),
+ )?;
+
+ config_interface(
+ transaction,
+ &guid_str,
+ "Tcpip6",
+ servers.iter().filter(|addr| addr.is_ipv6()),
+ )?;
+
+ Ok(())
+}
+
+fn config_interface<'a>(
+ transaction: &Transaction,
+ guid: &str,
+ service: &str,
+ nameservers: impl Iterator<Item = &'a IpAddr>,
+) -> io::Result<()> {
+ let nameservers = nameservers
+ .map(|addr| addr.to_string())
+ .collect::<Vec<String>>();
+
+ let reg_path =
+ format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,);
+ let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags(
+ reg_path,
+ transaction,
+ KEY_SET_VALUE,
+ ) {
+ Ok(adapter_key) => Ok(adapter_key),
+ Err(error) => {
+ if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound {
+ return Ok(());
+ }
+ Err(error)
+ }
+ }?;
+
+ if !nameservers.is_empty() {
+ adapter_key.set_value("NameServer", &nameservers.join(","))?;
+ } else {
+ adapter_key.delete_value("NameServer").or_else(|error| {
+ if error.kind() == io::ErrorKind::NotFound {
+ Ok(())
+ } else {
+ Err(error)
+ }
+ })?;
+ }
+
+ // Try to disable LLMNR on the interface
+ if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) {
+ log::error!(
+ "{}\nService: {service}",
+ error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface")
+ );
+ }
+
+ Ok(())
+}
+
+fn flush_dns_cache() -> Result<(), Error> {
+ super::dnsapi::flush_resolver_cache().map_err(Error::FlushResolverCacheError)
+}
+
+/// Obtain a string representation for a GUID object.
+fn string_from_guid(guid: &GUID) -> String {
+ let mut buffer = [0u16; 40];
+ let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) }
+ as usize;
+ // cannot fail because `buffer` is large enough
+ assert!(length > 0);
+ let length = length - 1;
+ String::from_utf16(&buffer[0..length]).unwrap()
+}