diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-06-01 19:43:38 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-06-14 11:42:26 +0200 |
| commit | 54dee1a08229d27329c1ec45a00e9e8b9a2e414c (patch) | |
| tree | 36f1152558af37f1717fddd69e63e28bddfee4d4 | |
| parent | 0e74ce4143661138a76d313df43546dc4174ba1a (diff) | |
| download | mullvadvpn-54dee1a08229d27329c1ec45a00e9e8b9a2e414c.tar.xz mullvadvpn-54dee1a08229d27329c1ec45a00e9e8b9a2e414c.zip | |
Simplify DNS management on Windows to set servers on the tunnel
interface only
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 239 | ||||
| -rw-r--r-- | talpid-core/src/firewall/mod.rs | 12 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/wintun.rs | 22 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/windows/mod.rs | 89 |
5 files changed, 217 insertions, 147 deletions
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index 8986c1d7ab..3e2b5ceca5 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -1,15 +1,10 @@ -use crate::{ - logging::windows::{log_sink, LogSink}, - windows::luid_from_alias, -}; - +use crate::windows::{get_system_dir, guid_from_luid, luid_from_alias, string_from_guid}; use lazy_static::lazy_static; -use std::{env, io, net::IpAddr, path::Path}; +use std::{env, io, net::IpAddr, path::Path, process::Command}; use talpid_types::ErrorExt; -use widestring::WideCString; -use winapi::shared::ifdef::NET_LUID; +use winapi::shared::guiddef::GUID; use winreg::{ - enums::{HKEY_LOCAL_MACHINE, REG_MULTI_SZ}, + enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE, REG_MULTI_SZ}, transaction::Transaction, RegKey, RegValue, }; @@ -20,89 +15,66 @@ lazy_static! { /// Specifies whether to override per-interface DNS resolvers with a global DNS policy. static ref GLOBAL_DNS_CACHE_POLICY: bool = env::var("TALPID_DNS_CACHE_POLICY") .map(|v| v != "0") - .unwrap_or(true); + .unwrap_or(false); } /// Errors that can happen when configuring DNS on Windows. #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { - /// Failure to initialize WinDns. - #[error(display = "Failed to initialize WinDns")] - Initialization, - - /// Failure to deinitialize WinDns. - #[error(display = "Failed to deinitialize WinDns")] - Deinitialization, - - /// Failure to set new DNS servers on the interface. - #[error(display = "Failed to set new DNS servers on interface")] - Setting, - /// Failure to obtain an interface LUID given an alias. #[error(display = "Failed to obtain LUID for the interface alias")] InterfaceLuidError(#[error(source)] io::Error), + /// Failure to obtain an interface GUID. + #[error(display = "Failed to obtain GUID for the interface")] + InterfaceGuidError(#[error(source)] io::Error), + /// Failure to set new DNS servers. #[error(display = "Failed to update dnscache policy config")] UpdateDnsCachePolicy(#[error(source)] io::Error), + + /// Failure to flush DNS cache. + #[error(display = "Failed to execute ipconfig")] + ExecuteIpconfigError(#[error(source)] io::Error), + + /// Failure to flush DNS cache. + #[error(display = "Failed to flush DNS resolver cache")] + FlushResolverCacheError, + + /// Failed to update DNS servers for interface. + #[error(display = "Failed to update interface DNS servers")] + SetResolversError(#[error(source)] io::Error), + + /// Failed to locate system dir. + #[error(display = "Failed to locate the system directory")] + SystemDirError(#[error(source)] io::Error), } -pub struct DnsMonitor {} +pub struct DnsMonitor { + current_guid: Option<GUID>, +} impl super::DnsMonitorT for DnsMonitor { type Error = Error; fn new() -> Result<Self, Error> { - unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? }; - - let mut monitor = DnsMonitor {}; + let mut monitor = DnsMonitor { current_guid: None }; monitor.reset()?; Ok(monitor) } fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> { - let ipv4 = servers - .iter() - .filter(|ip| ip.is_ipv4()) - .map(ip_to_widestring) - .collect::<Vec<_>>(); - let ipv6 = servers - .iter() - .filter(|ip| ip.is_ipv6()) - .map(ip_to_widestring) - .collect::<Vec<_>>(); - - let mut ipv4_address_ptrs = ipv4 - .iter() - .map(|ip_cstr| ip_cstr.as_ptr()) - .collect::<Vec<_>>(); - let mut ipv6_address_ptrs = ipv6 - .iter() - .map(|ip_cstr| ip_cstr.as_ptr()) - .collect::<Vec<_>>(); - - log::trace!("ipv4 ips: {:?} ({})", ipv4, ipv4.len()); - log::trace!("ipv6 ips: {:?} ({})", ipv6, ipv6.len()); - - let luid = luid_from_alias(interface).map_err(Error::InterfaceLuidError)?; - - unsafe { - WinDns_Set( - &luid, - ipv4_address_ptrs.as_mut_ptr(), - ipv4_address_ptrs.len() as u32, - ipv6_address_ptrs.as_mut_ptr(), - ipv6_address_ptrs.len() as u32, - ) - .into_result() - }?; + let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?) + .map_err(Error::InterfaceGuidError)?; + set_dns(&guid, servers)?; + self.current_guid = Some(guid); + flush_dns_cache()?; if *GLOBAL_DNS_CACHE_POLICY { if let Err(error) = set_dns_cache_policy(servers) { log::error!("{}", error.display_chain()); - log::warn!("DNS resolution may be slowed down"); } } @@ -110,16 +82,18 @@ impl super::DnsMonitorT for DnsMonitor { } fn reset(&mut self) -> Result<(), Error> { + let mut result = Ok(()); + + if let Some(guid) = self.current_guid.take() { + result = result.and(set_dns(&guid, &[])).and(flush_dns_cache()); + } + if *GLOBAL_DNS_CACHE_POLICY { - reset_dns_cache_policy() - } else { - Ok(()) + result = result.and(reset_dns_cache_policy()); } - } -} -fn ip_to_widestring(ip: &IpAddr) -> WideCString { - WideCString::from_str_truncate(ip.to_string()) + result + } } impl Drop for DnsMonitor { @@ -132,13 +106,103 @@ impl Drop for DnsMonitor { ); } } + } +} - if unsafe { WinDns_Deinitialize().into_result().is_ok() } { - log::trace!("Successfully deinitialized WinDns"); - } else { - log::error!("Failed to deinitialize WinDns"); +fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> { + let transaction = Transaction::new().map_err(Error::SetResolversError)?; + let result = match set_dns_inner(&transaction, interface, servers) { + Ok(()) => transaction.commit(), + Err(error) => transaction.rollback().and(Err(error)), + }; + result.map_err(Error::SetResolversError) +} + +fn set_dns_inner( + transaction: &Transaction, + interface: &GUID, + servers: &[IpAddr], +) -> io::Result<()> { + let guid_str = string_from_guid(interface); + + config_interface( + transaction, + &guid_str, + "Tcpip", + servers.iter().filter(|addr| addr.is_ipv4()), + )?; + + config_interface( + transaction, + &guid_str, + "Tcpip6", + servers.iter().filter(|addr| addr.is_ipv6()), + )?; + + Ok(()) +} + +fn config_interface<'a>( + transaction: &Transaction, + guid: &str, + service: &str, + nameservers: impl Iterator<Item = &'a IpAddr>, +) -> io::Result<()> { + let nameservers = nameservers + .map(|addr| addr.to_string()) + .collect::<Vec<String>>(); + + let reg_path = + format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,); + let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags( + reg_path, + transaction, + KEY_SET_VALUE, + ) { + Ok(adapter_key) => Ok(adapter_key), + Err(error) => { + if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound { + return Ok(()); + } + Err(error) } + }?; + + if !nameservers.is_empty() { + adapter_key.set_value("NameServer", &nameservers.join(","))?; + } else { + adapter_key.delete_value("NameServer").or_else(|error| { + if error.kind() == io::ErrorKind::NotFound { + Ok(()) + } else { + Err(error) + } + })?; + } + + // Try to disable LLMNR on the interface + if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) { + log::error!( + "{}\nService: {service}", + error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface") + ); } + + Ok(()) +} + +fn flush_dns_cache() -> Result<(), Error> { + let sysdir = get_system_dir().map_err(Error::SystemDirError)?; + let mut ipconfig = Command::new(sysdir.join("ipconfig.exe")); + ipconfig.arg("/flushdns"); + let output = ipconfig.output().map_err(Error::ExecuteIpconfigError)?; + let output = String::from_utf8_lossy(&output.stdout); + // The exit code cannot be trusted + if !output.contains("Successfully flushed") { + log::error!("Failed to flush DNS cache: {}", output); + return Err(Error::FlushResolverCacheError); + } + Ok(()) } fn set_dns_cache_policy(servers: &[IpAddr]) -> Result<(), Error> { @@ -207,32 +271,3 @@ fn reset_dns_cache_policy() -> Result<(), Error> { } } } - -ffi_error!(InitializationResult, Error::Initialization); -ffi_error!(DeinitializationResult, Error::Deinitialization); -ffi_error!(SettingResult, Error::Setting); - -#[allow(non_snake_case)] -extern "stdcall" { - #[link_name = "WinDns_Initialize"] - pub fn WinDns_Initialize( - sink: Option<LogSink>, - sink_context: *const u8, - ) -> InitializationResult; - - // WinDns_Deinitialize: - // - // Call this function once before unloading WINDNS or exiting the process. - #[link_name = "WinDns_Deinitialize"] - pub fn WinDns_Deinitialize() -> DeinitializationResult; - - // Configure which DNS servers should be used and start enforcing these settings. - #[link_name = "WinDns_Set"] - pub fn WinDns_Set( - interface_luid: *const NET_LUID, - v4_ips: *mut *const u16, - v4_n_ips: u32, - v6_ips: *mut *const u16, - v6_n_ips: u32, - ) -> SettingResult; -} diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs index de117d39ad..b23e16b017 100644 --- a/talpid-core/src/firewall/mod.rs +++ b/talpid-core/src/firewall/mod.rs @@ -1,14 +1,13 @@ -#[cfg(unix)] use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; -#[cfg(unix)] use lazy_static::lazy_static; -use std::fmt; #[cfg(not(target_os = "android"))] use std::net::IpAddr; -#[cfg(unix)] -use std::net::{Ipv4Addr, Ipv6Addr}; #[cfg(windows)] use std::path::PathBuf; +use std::{ + fmt, + net::{Ipv4Addr, Ipv6Addr}, +}; use talpid_types::net::{AllowedEndpoint, Endpoint}; #[cfg(target_os = "macos")] @@ -29,7 +28,6 @@ mod imp; pub use self::imp::Error; -#[cfg(unix)] lazy_static! { /// When "allow local network" is enabled the app will allow traffic to and from these networks. pub(crate) static ref ALLOWED_LAN_NETS: [IpNetwork; 6] = [ @@ -83,7 +81,7 @@ const DHCPV6_CLIENT_PORT: u16 = 546; #[cfg(all(unix, not(target_os = "android")))] const ROOT_UID: u32 = 0; -#[cfg(all(unix, not(target_os = "android")))] +#[cfg(any(all(unix, not(target_os = "android")), target_os = "windows"))] /// Returns whether an address belongs to a private subnet. pub fn is_local_address(address: &IpAddr) -> bool { let address = address.clone(); diff --git a/talpid-core/src/tunnel/openvpn/wintun.rs b/talpid-core/src/tunnel/openvpn/wintun.rs index a61b83643d..1746756db4 100644 --- a/talpid-core/src/tunnel/openvpn/wintun.rs +++ b/talpid-core/src/tunnel/openvpn/wintun.rs @@ -1,4 +1,6 @@ -use crate::windows::{get_ip_interface_entry, set_ip_interface_entry, AddressFamily}; +use crate::windows::{ + get_ip_interface_entry, set_ip_interface_entry, string_from_guid, AddressFamily, +}; use lazy_static::lazy_static; use std::{ ffi::CStr, @@ -483,24 +485,6 @@ impl Drop for WintunLoggerHandle { } } -/// Obtain a string representation for a GUID object. -fn string_from_guid(guid: &GUID) -> String { - use std::{ffi::OsString, os::windows::ffi::OsStringExt}; - use winapi::um::combaseapi::StringFromGUID2; - - let mut buffer = [0u16; 40]; - let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) } - as usize; - if length > 0 { - let length = length - 1; - OsString::from_wide(&buffer[0..length]) - .to_string_lossy() - .to_string() - } else { - "".to_string() - } -} - /// Returns the registry key for a network device identified by its GUID. fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result<RegKey> { let net_devs = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_with_flags( diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 80b6de8772..80e5e28957 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -123,7 +123,7 @@ impl ConnectedState { fn set_dns(&self, shared_values: &mut SharedTunnelStateValues) -> Result<(), BoxedError> { let dns_ips = self.get_dns_servers(shared_values); - #[cfg(target_os = "linux")] + #[cfg(any(target_os = "linux", target_os = "windows"))] let dns_ips = &dns_ips .into_iter() .filter(|ip| { diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs index 4ec1fe19f3..115d8f5397 100644 --- a/talpid-core/src/windows/mod.rs +++ b/talpid-core/src/windows/mod.rs @@ -1,34 +1,48 @@ use socket2::SockAddr; use std::{ ffi::{OsStr, OsString}, - fmt, io, mem, + fmt, io, + mem::{self, MaybeUninit}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::windows::{ ffi::{OsStrExt, OsStringExt}, io::RawHandle, }, + path::PathBuf, + ptr, sync::Mutex, time::{Duration, Instant}, }; -use winapi::shared::{ - ifdef::NET_LUID, - in6addr::IN6_ADDR, - inaddr::IN_ADDR, - netioapi::{ - CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, - FreeMibTable, GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, - MibAddInstance, NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, - MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE, +use widestring::WideCStr; +use winapi::{ + shared::{ + guiddef::GUID, + ifdef::NET_LUID, + in6addr::IN6_ADDR, + inaddr::IN_ADDR, + netioapi::{ + CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, + ConvertInterfaceLuidToGuid, FreeMibTable, GetIpInterfaceEntry, + GetUnicastIpAddressEntry, GetUnicastIpAddressTable, MibAddInstance, + NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, + MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE, + }, + nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE}, + ntddndis::NDIS_IF_MAX_STRING_SIZE, + ntdef::FALSE, + winerror::{ERROR_NOT_FOUND, NO_ERROR, S_OK}, + ws2def::{ + AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in, + SOCKADDR_STORAGE as sockaddr_storage, + }, + ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET}, }, - nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE}, - ntddndis::NDIS_IF_MAX_STRING_SIZE, - ntdef::FALSE, - winerror::{ERROR_NOT_FOUND, NO_ERROR}, - ws2def::{ - AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in, - SOCKADDR_STORAGE as sockaddr_storage, + um::{ + combaseapi::{CoTaskMemFree, StringFromGUID2}, + knownfolders::FOLDERID_System, + shlobj::SHGetKnownFolderPath, + winnt::PWSTR, }, - ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET}, }; pub mod window; @@ -350,6 +364,27 @@ pub fn get_unicast_table( Ok(unicast_rows) } +/// Obtain a string representation for a GUID object. +pub fn string_from_guid(guid: &GUID) -> String { + let mut buffer = [0u16; 40]; + let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) } + as usize; + // cannot fail because `buffer` is large enough + assert!(length > 0); + let length = length - 1; + String::from_utf16(&buffer[0..length]).unwrap() +} + +/// Returns the GUID of a network interface given its LUID. +pub fn guid_from_luid(luid: &NET_LUID) -> io::Result<GUID> { + let mut guid = MaybeUninit::zeroed(); + let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) }; + if status != NO_ERROR { + return Err(io::Error::from_raw_os_error(status as i32)); + } + Ok(unsafe { guid.assume_init() }) +} + /// Returns the LUID of an interface given its alias. pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID> { let alias_wide: Vec<u16> = alias @@ -435,6 +470,24 @@ pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAd .ok_or(Error::UnknownAddressFamily(family)) } +/// Returns the system directory, i.e. `%windir%\system32`. +pub fn get_system_dir() -> io::Result<PathBuf> { + let mut folder_path: PWSTR = ptr::null_mut(); + let status = + unsafe { SHGetKnownFolderPath(&FOLDERID_System, 0, ptr::null_mut(), &mut folder_path) }; + let result = if status == S_OK { + let path = unsafe { WideCStr::from_ptr_str(folder_path) }; + Ok(path.to_ustring().to_os_string().into()) + } else { + Err(io::Error::new( + io::ErrorKind::NotFound, + "Cannot find the system directory", + )) + }; + unsafe { CoTaskMemFree(folder_path as *mut _) }; + result +} + /// Casts a struct to a slice of possibly uninitialized bytes. #[cfg(target_os = "windows")] pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8>] { |
