diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-01-09 12:08:46 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-01-10 11:58:51 +0100 |
| commit | f4fd14dea00cf53b7e0bfc32008cee03ed827bff (patch) | |
| tree | 52b11acaf33ffebf5f382be76baa5b593c093ce2 | |
| parent | f7ad580f5e72d12a63014d0ca295fa1b6cff10bd (diff) | |
| download | mullvadvpn-f4fd14dea00cf53b7e0bfc32008cee03ed827bff.tar.xz mullvadvpn-f4fd14dea00cf53b7e0bfc32008cee03ed827bff.zip | |
Keep registry DNS manager implementation as a TALPID_DNS_MODULE option
| -rw-r--r-- | talpid-core/src/dns/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/dnsapi.rs | 132 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 76 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/netsh.rs (renamed from talpid-core/src/dns/windows.rs) | 3 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/tcpip.rs | 156 |
5 files changed, 367 insertions, 2 deletions
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs index 1a08ec265c..39660078c3 100644 --- a/talpid-core/src/dns/mod.rs +++ b/talpid-core/src/dns/mod.rs @@ -20,7 +20,7 @@ mod imp; pub use imp::will_use_nm; #[cfg(windows)] -#[path = "windows.rs"] +#[path = "windows/mod.rs"] mod imp; #[cfg(target_os = "android")] diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs new file mode 100644 index 0000000000..e48f37258e --- /dev/null +++ b/talpid-core/src/dns/windows/dnsapi.rs @@ -0,0 +1,132 @@ +use once_cell::sync::OnceCell; +use std::{ + io, + sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc, Arc, + }, + time::{Duration, Instant}, +}; +use windows_sys::{ + w, + Win32::{ + Foundation::BOOL, + System::LibraryLoader::{ + FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32, + }, + }, +}; + +type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL; + +static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new(); +static FLUSH_TIMEOUT: Duration = Duration::from_secs(5); + +const MAX_CONCURRENT_FLUSHES: usize = 5; + +/// Errors that can happen when configuring DNS on Windows. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to load dnsapi.dll. + #[error(display = "Failed to load dnsapi.dll")] + LoadDll(#[error(source)] io::Error), + + /// Failed to obtain exported function. + #[error(display = "Failed to obtain flush function")] + GetFunction(#[error(source)] io::Error), + + /// Failed to flush the DNS cache. + #[error(display = "Call to flush DNS cache failed")] + FlushCache, + + /// Too many flush attempts in progress. + #[error(display = "Too many flush attempts in progress")] + TooManyFlushAttempts, + + /// Flushing the DNS cache timed out. + #[error(display = "Timeout while flushing DNS cache")] + Timeout, +} + +pub fn flush_resolver_cache() -> Result<(), Error> { + DNSAPI_HANDLE + .get_or_try_init(|| DnsApi::new())? + .flush_cache() +} + +struct DnsApi { + in_flight_flush_count: Arc<AtomicUsize>, + flush_fn: FlushResolverCacheFn, +} + +unsafe impl Send for DnsApi {} +unsafe impl Sync for DnsApi {} + +impl DnsApi { + fn new() -> Result<Self, Error> { + let handle = unsafe { LoadLibraryExW(w!("dnsapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) }; + if handle == 0 { + return Err(Error::LoadDll(io::Error::last_os_error())); + } + + let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const u8) }; + let flush_fn = flush_fn.ok_or_else(|| { + let error = io::Error::last_os_error(); + unsafe { FreeLibrary(handle) }; + Error::GetFunction(error) + })?; + + Ok(DnsApi { + in_flight_flush_count: Arc::new(AtomicUsize::new(0)), + flush_fn: unsafe { *(&flush_fn as *const _ as *const _) }, + }) + } + + fn flush_cache(&self) -> Result<(), Error> { + if self + .in_flight_flush_count + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| { + if val >= MAX_CONCURRENT_FLUSHES { + return None; + } + Some(val + 1) + }) + .is_err() + { + return Err(Error::TooManyFlushAttempts); + } + + let (tx, rx) = mpsc::channel(); + let flush_count = self.in_flight_flush_count.clone(); + + let flush_fn = self.flush_fn; + + std::thread::spawn(move || { + let begin = Instant::now(); + + let result = if unsafe { (flush_fn)() } != 0 { + let elapsed = begin.elapsed(); + if elapsed >= FLUSH_TIMEOUT { + log::warn!( + "Flushing system DNS cache took {} seconds", + elapsed.as_secs() + ); + } else { + log::debug!("Flushed system DNS cache"); + } + Ok(()) + } else { + Err(Error::FlushCache) + }; + let _ = tx.send(result); + + flush_count.fetch_sub(1, Ordering::SeqCst); + }); + + match rx.recv_timeout(FLUSH_TIMEOUT) { + Ok(result) => result, + Err(_timeout_err) => Err(Error::Timeout), + } + } +} diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs new file mode 100644 index 0000000000..10ce85c055 --- /dev/null +++ b/talpid-core/src/dns/windows/mod.rs @@ -0,0 +1,76 @@ +use std::{env, fmt, net::IpAddr}; + +mod dnsapi; +mod netsh; +mod tcpip; + +/// Errors that can happen when configuring DNS on Windows. +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// Failed to set DNS config using the netsh module. + #[error(display = "Error in netsh module")] + Netsh(#[error(source)] netsh::Error), + + /// Failed to set DNS config using the tcpip module. + #[error(display = "Error in tcpip module")] + Tcpip(#[error(source)] tcpip::Error), +} + +pub struct DnsMonitor { + inner: DnsMonitorHolder, +} + +impl super::DnsMonitorT for DnsMonitor { + type Error = Error; + + fn new() -> Result<Self, Error> { + let dns_module = env::var_os("TALPID_DNS_MODULE"); + + let inner = match dns_module.as_ref().and_then(|value| value.to_str()) { + Some("tcpip") => DnsMonitorHolder::Tcpip(tcpip::DnsMonitor::new()?), + Some(_) | None => DnsMonitorHolder::Netsh(netsh::DnsMonitor::new()?), + }; + + log::debug!("DNS monitor: {}", inner); + + Ok(DnsMonitor { inner }) + } + + fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> { + match self.inner { + DnsMonitorHolder::Netsh(ref mut inner) => inner.set(interface, servers)?, + DnsMonitorHolder::Tcpip(ref mut inner) => inner.set(interface, servers)?, + } + Ok(()) + } + + fn reset(&mut self) -> Result<(), Error> { + match self.inner { + DnsMonitorHolder::Netsh(ref mut inner) => inner.reset()?, + DnsMonitorHolder::Tcpip(ref mut inner) => inner.reset()?, + } + Ok(()) + } + + fn reset_before_interface_removal(&mut self) -> Result<(), Error> { + match self.inner { + DnsMonitorHolder::Netsh(ref mut inner) => inner.reset_before_interface_removal()?, + DnsMonitorHolder::Tcpip(ref mut inner) => inner.reset_before_interface_removal()?, + } + Ok(()) + } +} + +enum DnsMonitorHolder { + Netsh(netsh::DnsMonitor), + Tcpip(tcpip::DnsMonitor), +} + +impl fmt::Display for DnsMonitorHolder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DnsMonitorHolder::Netsh(_) => f.write_str("netsh"), + DnsMonitorHolder::Tcpip(_) => f.write_str("TCP/IP registry parameter"), + } + } +} diff --git a/talpid-core/src/dns/windows.rs b/talpid-core/src/dns/windows/netsh.rs index 68742b9266..7c0450f855 100644 --- a/talpid-core/src/dns/windows.rs +++ b/talpid-core/src/dns/windows/netsh.rs @@ -1,3 +1,4 @@ +use crate::dns::DnsMonitorT; use std::{ ffi::OsString, io::{self, Write}, @@ -60,7 +61,7 @@ pub struct DnsMonitor { current_index: Option<u32>, } -impl super::DnsMonitorT for DnsMonitor { +impl DnsMonitorT for DnsMonitor { type Error = Error; fn new() -> Result<Self, Error> { diff --git a/talpid-core/src/dns/windows/tcpip.rs b/talpid-core/src/dns/windows/tcpip.rs new file mode 100644 index 0000000000..f7536eaed1 --- /dev/null +++ b/talpid-core/src/dns/windows/tcpip.rs @@ -0,0 +1,156 @@ +use crate::dns::DnsMonitorT; +use std::{io, net::IpAddr}; +use talpid_types::ErrorExt; +use talpid_windows_net::{guid_from_luid, luid_from_alias}; +use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2}; +use winreg::{ + enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE}, + transaction::Transaction, + RegKey, +}; + +/// Errors that can happen when configuring DNS on Windows. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failure to obtain an interface LUID given an alias. + #[error(display = "Failed to obtain LUID for the interface alias")] + InterfaceLuidError(#[error(source)] io::Error), + + /// Failure to obtain an interface GUID. + #[error(display = "Failed to obtain GUID for the interface")] + InterfaceGuidError(#[error(source)] io::Error), + + /// Failure to flush DNS cache. + #[error(display = "Failed to flush DNS resolver cache")] + FlushResolverCacheError(#[error(source)] super::dnsapi::Error), + + /// Failed to update DNS servers for interface. + #[error(display = "Failed to update interface DNS servers")] + SetResolversError(#[error(source)] io::Error), +} + +pub struct DnsMonitor { + current_guid: Option<GUID>, +} + +impl DnsMonitorT for DnsMonitor { + type Error = Error; + + fn new() -> Result<Self, Error> { + Ok(DnsMonitor { current_guid: None }) + } + + fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> { + let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?) + .map_err(Error::InterfaceGuidError)?; + set_dns(&guid, servers)?; + self.current_guid = Some(guid); + flush_dns_cache()?; + Ok(()) + } + + fn reset(&mut self) -> Result<(), Error> { + if let Some(guid) = self.current_guid.take() { + return set_dns(&guid, &[]).and(flush_dns_cache()); + } + Ok(()) + } +} + +fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> { + let transaction = Transaction::new().map_err(Error::SetResolversError)?; + let result = match set_dns_inner(&transaction, interface, servers) { + Ok(()) => transaction.commit(), + Err(error) => transaction.rollback().and(Err(error)), + }; + result.map_err(Error::SetResolversError) +} + +fn set_dns_inner( + transaction: &Transaction, + interface: &GUID, + servers: &[IpAddr], +) -> io::Result<()> { + let guid_str = string_from_guid(interface); + + config_interface( + transaction, + &guid_str, + "Tcpip", + servers.iter().filter(|addr| addr.is_ipv4()), + )?; + + config_interface( + transaction, + &guid_str, + "Tcpip6", + servers.iter().filter(|addr| addr.is_ipv6()), + )?; + + Ok(()) +} + +fn config_interface<'a>( + transaction: &Transaction, + guid: &str, + service: &str, + nameservers: impl Iterator<Item = &'a IpAddr>, +) -> io::Result<()> { + let nameservers = nameservers + .map(|addr| addr.to_string()) + .collect::<Vec<String>>(); + + let reg_path = + format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,); + let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags( + reg_path, + transaction, + KEY_SET_VALUE, + ) { + Ok(adapter_key) => Ok(adapter_key), + Err(error) => { + if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound { + return Ok(()); + } + Err(error) + } + }?; + + if !nameservers.is_empty() { + adapter_key.set_value("NameServer", &nameservers.join(","))?; + } else { + adapter_key.delete_value("NameServer").or_else(|error| { + if error.kind() == io::ErrorKind::NotFound { + Ok(()) + } else { + Err(error) + } + })?; + } + + // Try to disable LLMNR on the interface + if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) { + log::error!( + "{}\nService: {service}", + error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface") + ); + } + + Ok(()) +} + +fn flush_dns_cache() -> Result<(), Error> { + super::dnsapi::flush_resolver_cache().map_err(Error::FlushResolverCacheError) +} + +/// Obtain a string representation for a GUID object. +fn string_from_guid(guid: &GUID) -> String { + let mut buffer = [0u16; 40]; + let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) } + as usize; + // cannot fail because `buffer` is large enough + assert!(length > 0); + let length = length - 1; + String::from_utf16(&buffer[0..length]).unwrap() +} |
