summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-10-22 14:54:18 +0200
committerDavid Lönnhager <david.l@mullvad.net>2020-10-22 14:54:18 +0200
commit545d41ee38f8fd4c49622b9a51558144f62d6d89 (patch)
tree6a94d0d5e555ee018852ab20aa937f3b68f56f83
parent99abef8ff4a66c11b2b3abd00c8d2a6c07df7a3f (diff)
parent5e96ccd96d615ffe0d2d173e0cb0daafa58f4709 (diff)
downloadmullvadvpn-545d41ee38f8fd4c49622b9a51558144f62d6d89.tar.xz
mullvadvpn-545d41ee38f8fd4c49622b9a51558144f62d6d89.zip
Merge branch 'fix-local-dns-resolution'
-rw-r--r--talpid-core/Cargo.toml2
-rw-r--r--talpid-core/src/dns/windows/mod.rs154
2 files changed, 150 insertions, 6 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 66826bacdc..eafceb2854 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -71,7 +71,7 @@ tun = "0.5"
[target.'cfg(windows)'.dependencies]
widestring = "0.4"
-winreg = "0.7"
+winreg = { version = "0.7", features = ["transactions"] }
winapi = { version = "0.3.6", features = ["handleapi", "ifdef", "libloaderapi", "netioapi", "synchapi", "winbase", "winuser"] }
socket2 = "0.3"
pnet_packet = "0.26"
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index 3cb3fdf47e..bf3807c9b6 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -1,14 +1,25 @@
use crate::logging::windows::{log_sink, LogSink};
-use log::{error, trace};
-use std::{net::IpAddr, path::Path};
+use log::{error, trace, warn};
+use std::{ffi::OsString, io, iter, mem, net::IpAddr, os::windows::ffi::OsStrExt, path::Path, ptr};
+use talpid_types::ErrorExt;
use widestring::WideCString;
+use winapi::um::{
+ libloaderapi::{GetModuleHandleW, GetProcAddress},
+ winnt::RTL_OSVERSIONINFOW,
+};
+use winreg::{
+ enums::{HKEY_LOCAL_MACHINE, REG_MULTI_SZ},
+ transaction::Transaction,
+ RegKey, RegValue,
+};
mod system_state;
use self::system_state::SystemStateWriter;
const DNS_STATE_FILENAME: &'static str = "dns-state-backup";
+const DNS_CACHE_POLICY_GUID: &str = "{d57d2750-f971-408e-8e55-cfddb37e60ae}";
/// Errors that can happen when configuring DNS on Windows.
#[derive(err_derive::Error, Debug)]
@@ -21,9 +32,13 @@ pub enum Error {
#[error(display = "Failed to deinitialize WinDns")]
Deinitialization,
- /// Failure to set new DNS servers.
- #[error(display = "Failed to set new DNS servers")]
+ /// Failure to set new DNS servers on the interface.
+ #[error(display = "Failed to set new DNS servers on interface")]
Setting,
+
+ /// Failure to set new DNS servers.
+ #[error(display = "Failed to update dnscache policy config")]
+ UpdateDnsCachePolicy(#[error(source)] io::Error),
}
pub struct DnsMonitor {}
@@ -34,6 +49,15 @@ impl super::DnsMonitorT for DnsMonitor {
fn new(cache_dir: impl AsRef<Path>) -> Result<Self, Error> {
unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? };
+ if is_minimum_windows10() {
+ if let Err(error) = reset_dns_cache_policy() {
+ error!(
+ "{}",
+ error.display_chain_with_msg("Failed to reset DNS cache policy")
+ );
+ }
+ }
+
let backup_writer = SystemStateWriter::new(
cache_dir
.as_ref()
@@ -77,11 +101,24 @@ impl super::DnsMonitorT for DnsMonitor {
ipv6_address_ptrs.len() as u32,
)
.into_result()
+ }?;
+
+ if is_minimum_windows10() {
+ if let Err(error) = set_dns_cache_policy(servers) {
+ error!("{}", error.display_chain());
+ warn!("DNS resolution may be slowed down");
+ }
}
+
+ Ok(())
}
fn reset(&mut self) -> Result<(), Error> {
- Ok(())
+ if is_minimum_windows10() {
+ reset_dns_cache_policy()
+ } else {
+ Ok(())
+ }
}
}
@@ -91,6 +128,15 @@ fn ip_to_widestring(ip: &IpAddr) -> WideCString {
impl Drop for DnsMonitor {
fn drop(&mut self) {
+ if is_minimum_windows10() {
+ if let Err(error) = reset_dns_cache_policy() {
+ warn!(
+ "{}",
+ error.display_chain_with_msg("Failed to reset DNS cache policy")
+ );
+ }
+ }
+
if unsafe { WinDns_Deinitialize().into_result().is_ok() } {
trace!("Successfully deinitialized WinDns");
} else {
@@ -99,6 +145,104 @@ impl Drop for DnsMonitor {
}
}
+fn set_dns_cache_policy(servers: &[IpAddr]) -> Result<(), Error> {
+ let transaction = Transaction::new()?;
+ match set_dns_cache_policy_inner(&transaction, servers) {
+ Ok(()) => {
+ transaction.commit()?;
+ Ok(())
+ }
+ Err(error) => {
+ transaction.rollback()?;
+ Err(error)
+ }
+ }
+}
+
+fn set_dns_cache_policy_inner(transaction: &Transaction, servers: &[IpAddr]) -> Result<(), Error> {
+ let dns_cache_parameters = RegKey::predef(HKEY_LOCAL_MACHINE)
+ .open_subkey(r#"SYSTEM\CurrentControlSet\Services\DnsCache\Parameters"#)?;
+
+ let policy_path = Path::new("DnsPolicyConfig").join(DNS_CACHE_POLICY_GUID);
+ let (policy_config, _) =
+ dns_cache_parameters.create_subkey_transacted(policy_path, transaction)?;
+
+ // Enable only the "Generic DNS server" option
+ policy_config.set_value("ConfigOptions", &0x08u32)?;
+ let server_list: Vec<String> = servers.iter().map(|server| server.to_string()).collect();
+ policy_config.set_value("GenericDNSServers", &server_list.join(";"))?;
+ policy_config.set_value("IPSECCARestriction", &"")?;
+ policy_config.set_raw_value(
+ "Name",
+ &RegValue {
+ // utf16 string: ".\0\0"
+ bytes: [0x2e, 0, 0, 0, 0, 0].to_vec(),
+ vtype: REG_MULTI_SZ,
+ },
+ )?;
+ policy_config.set_value("Version", &2u32)?;
+
+ Ok(())
+}
+
+fn reset_dns_cache_policy() -> Result<(), Error> {
+ let dns_cache_parameters = RegKey::predef(HKEY_LOCAL_MACHINE)
+ .open_subkey(r#"SYSTEM\CurrentControlSet\Services\DnsCache\Parameters"#)?;
+ let policy_path = Path::new("DnsPolicyConfig").join(DNS_CACHE_POLICY_GUID);
+ match dns_cache_parameters.delete_subkey_all(policy_path) {
+ Ok(()) => Ok(()),
+ Err(error) => {
+ if error.kind() == io::ErrorKind::NotFound {
+ Ok(())
+ } else {
+ Err(Error::UpdateDnsCachePolicy(error))
+ }
+ }
+ }
+}
+
+fn is_minimum_windows10() -> bool {
+ match is_minimum_windows10_inner() {
+ Ok(result) => result,
+ Err(error) => {
+ error!(
+ "{}",
+ error.display_chain_with_msg("OS version check failed")
+ );
+ false
+ }
+ }
+}
+
+fn is_minimum_windows10_inner() -> Result<bool, io::Error> {
+ let rtl_get_version: extern "stdcall" fn(*mut RTL_OSVERSIONINFOW);
+
+ let module_name: Vec<u16> = OsString::from("ntdll")
+ .as_os_str()
+ .encode_wide()
+ .chain(iter::once(0u16))
+ .collect();
+
+ let ntdll = unsafe { GetModuleHandleW(module_name.as_ptr()) };
+ if ntdll == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+
+ let function_address =
+ unsafe { GetProcAddress(ntdll, b"RtlGetVersion\0" as *const _ as *const i8) };
+ if function_address == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+
+ rtl_get_version = unsafe { mem::transmute(function_address) };
+
+ let mut version_info: RTL_OSVERSIONINFOW = unsafe { std::mem::zeroed() };
+ version_info.dwOSVersionInfoSize = mem::size_of_val(&version_info) as u32;
+ rtl_get_version(&mut version_info);
+
+ Ok(version_info.dwMajorVersion >= 10)
+}
+
ffi_error!(InitializationResult, Error::Initialization);
ffi_error!(DeinitializationResult, Error::Deinitialization);