diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-10-22 14:54:18 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-10-22 14:54:18 +0200 |
| commit | 545d41ee38f8fd4c49622b9a51558144f62d6d89 (patch) | |
| tree | 6a94d0d5e555ee018852ab20aa937f3b68f56f83 | |
| parent | 99abef8ff4a66c11b2b3abd00c8d2a6c07df7a3f (diff) | |
| parent | 5e96ccd96d615ffe0d2d173e0cb0daafa58f4709 (diff) | |
| download | mullvadvpn-545d41ee38f8fd4c49622b9a51558144f62d6d89.tar.xz mullvadvpn-545d41ee38f8fd4c49622b9a51558144f62d6d89.zip | |
Merge branch 'fix-local-dns-resolution'
| -rw-r--r-- | talpid-core/Cargo.toml | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 154 |
2 files changed, 150 insertions, 6 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 66826bacdc..eafceb2854 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -71,7 +71,7 @@ tun = "0.5" [target.'cfg(windows)'.dependencies] widestring = "0.4" -winreg = "0.7" +winreg = { version = "0.7", features = ["transactions"] } winapi = { version = "0.3.6", features = ["handleapi", "ifdef", "libloaderapi", "netioapi", "synchapi", "winbase", "winuser"] } socket2 = "0.3" pnet_packet = "0.26" diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index 3cb3fdf47e..bf3807c9b6 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -1,14 +1,25 @@ use crate::logging::windows::{log_sink, LogSink}; -use log::{error, trace}; -use std::{net::IpAddr, path::Path}; +use log::{error, trace, warn}; +use std::{ffi::OsString, io, iter, mem, net::IpAddr, os::windows::ffi::OsStrExt, path::Path, ptr}; +use talpid_types::ErrorExt; use widestring::WideCString; +use winapi::um::{ + libloaderapi::{GetModuleHandleW, GetProcAddress}, + winnt::RTL_OSVERSIONINFOW, +}; +use winreg::{ + enums::{HKEY_LOCAL_MACHINE, REG_MULTI_SZ}, + transaction::Transaction, + RegKey, RegValue, +}; mod system_state; use self::system_state::SystemStateWriter; const DNS_STATE_FILENAME: &'static str = "dns-state-backup"; +const DNS_CACHE_POLICY_GUID: &str = "{d57d2750-f971-408e-8e55-cfddb37e60ae}"; /// Errors that can happen when configuring DNS on Windows. #[derive(err_derive::Error, Debug)] @@ -21,9 +32,13 @@ pub enum Error { #[error(display = "Failed to deinitialize WinDns")] Deinitialization, - /// Failure to set new DNS servers. - #[error(display = "Failed to set new DNS servers")] + /// Failure to set new DNS servers on the interface. + #[error(display = "Failed to set new DNS servers on interface")] Setting, + + /// Failure to set new DNS servers. + #[error(display = "Failed to update dnscache policy config")] + UpdateDnsCachePolicy(#[error(source)] io::Error), } pub struct DnsMonitor {} @@ -34,6 +49,15 @@ impl super::DnsMonitorT for DnsMonitor { fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> { unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? }; + if is_minimum_windows10() { + if let Err(error) = reset_dns_cache_policy() { + error!( + "{}", + error.display_chain_with_msg("Failed to reset DNS cache policy") + ); + } + } + let backup_writer = SystemStateWriter::new( cache_dir .as_ref() @@ -77,11 +101,24 @@ impl super::DnsMonitorT for DnsMonitor { ipv6_address_ptrs.len() as u32, ) .into_result() + }?; + + if is_minimum_windows10() { + if let Err(error) = set_dns_cache_policy(servers) { + error!("{}", error.display_chain()); + warn!("DNS resolution may be slowed down"); + } } + + Ok(()) } fn reset(&mut self) -> Result<(), Error> { - Ok(()) + if is_minimum_windows10() { + reset_dns_cache_policy() + } else { + Ok(()) + } } } @@ -91,6 +128,15 @@ fn ip_to_widestring(ip: &IpAddr) -> WideCString { impl Drop for DnsMonitor { fn drop(&mut self) { + if is_minimum_windows10() { + if let Err(error) = reset_dns_cache_policy() { + warn!( + "{}", + error.display_chain_with_msg("Failed to reset DNS cache policy") + ); + } + } + if unsafe { WinDns_Deinitialize().into_result().is_ok() } { trace!("Successfully deinitialized WinDns"); } else { @@ -99,6 +145,104 @@ impl Drop for DnsMonitor { } } +fn set_dns_cache_policy(servers: &[IpAddr]) -> Result<(), Error> { + let transaction = Transaction::new()?; + match set_dns_cache_policy_inner(&transaction, servers) { + Ok(()) => { + transaction.commit()?; + Ok(()) + } + Err(error) => { + transaction.rollback()?; + Err(error) + } + } +} + +fn set_dns_cache_policy_inner(transaction: &Transaction, servers: &[IpAddr]) -> Result<(), Error> { + let dns_cache_parameters = RegKey::predef(HKEY_LOCAL_MACHINE) + .open_subkey(r#"SYSTEM\CurrentControlSet\Services\DnsCache\Parameters"#)?; + + let policy_path = Path::new("DnsPolicyConfig").join(DNS_CACHE_POLICY_GUID); + let (policy_config, _) = + dns_cache_parameters.create_subkey_transacted(policy_path, transaction)?; + + // Enable only the "Generic DNS server" option + policy_config.set_value("ConfigOptions", &0x08u32)?; + let server_list: Vec<String> = servers.iter().map(|server| server.to_string()).collect(); + policy_config.set_value("GenericDNSServers", &server_list.join(";"))?; + policy_config.set_value("IPSECCARestriction", &"")?; + policy_config.set_raw_value( + "Name", + &RegValue { + // utf16 string: ".\0\0" + bytes: [0x2e, 0, 0, 0, 0, 0].to_vec(), + vtype: REG_MULTI_SZ, + }, + )?; + policy_config.set_value("Version", &2u32)?; + + Ok(()) +} + +fn reset_dns_cache_policy() -> Result<(), Error> { + let dns_cache_parameters = RegKey::predef(HKEY_LOCAL_MACHINE) + .open_subkey(r#"SYSTEM\CurrentControlSet\Services\DnsCache\Parameters"#)?; + let policy_path = Path::new("DnsPolicyConfig").join(DNS_CACHE_POLICY_GUID); + match dns_cache_parameters.delete_subkey_all(policy_path) { + Ok(()) => Ok(()), + Err(error) => { + if error.kind() == io::ErrorKind::NotFound { + Ok(()) + } else { + Err(Error::UpdateDnsCachePolicy(error)) + } + } + } +} + +fn is_minimum_windows10() -> bool { + match is_minimum_windows10_inner() { + Ok(result) => result, + Err(error) => { + error!( + "{}", + error.display_chain_with_msg("OS version check failed") + ); + false + } + } +} + +fn is_minimum_windows10_inner() -> Result<bool, io::Error> { + let rtl_get_version: extern "stdcall" fn(*mut RTL_OSVERSIONINFOW); + + let module_name: Vec<u16> = OsString::from("ntdll") + .as_os_str() + .encode_wide() + .chain(iter::once(0u16)) + .collect(); + + let ntdll = unsafe { GetModuleHandleW(module_name.as_ptr()) }; + if ntdll == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + + let function_address = + unsafe { GetProcAddress(ntdll, b"RtlGetVersion\0" as *const _ as *const i8) }; + if function_address == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + + rtl_get_version = unsafe { mem::transmute(function_address) }; + + let mut version_info: RTL_OSVERSIONINFOW = unsafe { std::mem::zeroed() }; + version_info.dwOSVersionInfoSize = mem::size_of_val(&version_info) as u32; + rtl_get_version(&mut version_info); + + Ok(version_info.dwMajorVersion >= 10) +} + ffi_error!(InitializationResult, Error::Initialization); ffi_error!(DeinitializationResult, Error::Deinitialization); |
