summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-06-01 19:43:38 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-06-14 11:42:26 +0200
commit54dee1a08229d27329c1ec45a00e9e8b9a2e414c (patch)
tree36f1152558af37f1717fddd69e63e28bddfee4d4
parent0e74ce4143661138a76d313df43546dc4174ba1a (diff)
downloadmullvadvpn-54dee1a08229d27329c1ec45a00e9e8b9a2e414c.tar.xz
mullvadvpn-54dee1a08229d27329c1ec45a00e9e8b9a2e414c.zip
Simplify DNS management on Windows to set servers on the tunnel
interface only
-rw-r--r--talpid-core/src/dns/windows/mod.rs239
-rw-r--r--talpid-core/src/firewall/mod.rs12
-rw-r--r--talpid-core/src/tunnel/openvpn/wintun.rs22
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs2
-rw-r--r--talpid-core/src/windows/mod.rs89
5 files changed, 217 insertions, 147 deletions
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index 8986c1d7ab..3e2b5ceca5 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -1,15 +1,10 @@
-use crate::{
- logging::windows::{log_sink, LogSink},
- windows::luid_from_alias,
-};
-
+use crate::windows::{get_system_dir, guid_from_luid, luid_from_alias, string_from_guid};
use lazy_static::lazy_static;
-use std::{env, io, net::IpAddr, path::Path};
+use std::{env, io, net::IpAddr, path::Path, process::Command};
use talpid_types::ErrorExt;
-use widestring::WideCString;
-use winapi::shared::ifdef::NET_LUID;
+use winapi::shared::guiddef::GUID;
use winreg::{
- enums::{HKEY_LOCAL_MACHINE, REG_MULTI_SZ},
+ enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE, REG_MULTI_SZ},
transaction::Transaction,
RegKey, RegValue,
};
@@ -20,89 +15,66 @@ lazy_static! {
/// Specifies whether to override per-interface DNS resolvers with a global DNS policy.
static ref GLOBAL_DNS_CACHE_POLICY: bool = env::var("TALPID_DNS_CACHE_POLICY")
.map(|v| v != "0")
- .unwrap_or(true);
+ .unwrap_or(false);
}
/// Errors that can happen when configuring DNS on Windows.
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
- /// Failure to initialize WinDns.
- #[error(display = "Failed to initialize WinDns")]
- Initialization,
-
- /// Failure to deinitialize WinDns.
- #[error(display = "Failed to deinitialize WinDns")]
- Deinitialization,
-
- /// Failure to set new DNS servers on the interface.
- #[error(display = "Failed to set new DNS servers on interface")]
- Setting,
-
/// Failure to obtain an interface LUID given an alias.
#[error(display = "Failed to obtain LUID for the interface alias")]
InterfaceLuidError(#[error(source)] io::Error),
+ /// Failure to obtain an interface GUID.
+ #[error(display = "Failed to obtain GUID for the interface")]
+ InterfaceGuidError(#[error(source)] io::Error),
+
/// Failure to set new DNS servers.
#[error(display = "Failed to update dnscache policy config")]
UpdateDnsCachePolicy(#[error(source)] io::Error),
+
+ /// Failure to flush DNS cache.
+ #[error(display = "Failed to execute ipconfig")]
+ ExecuteIpconfigError(#[error(source)] io::Error),
+
+ /// Failure to flush DNS cache.
+ #[error(display = "Failed to flush DNS resolver cache")]
+ FlushResolverCacheError,
+
+ /// Failed to update DNS servers for interface.
+ #[error(display = "Failed to update interface DNS servers")]
+ SetResolversError(#[error(source)] io::Error),
+
+ /// Failed to locate system dir.
+ #[error(display = "Failed to locate the system directory")]
+ SystemDirError(#[error(source)] io::Error),
}
-pub struct DnsMonitor {}
+pub struct DnsMonitor {
+ current_guid: Option<GUID>,
+}
impl super::DnsMonitorT for DnsMonitor {
type Error = Error;
fn new() -> Result<Self, Error> {
- unsafe { WinDns_Initialize(Some(log_sink), b"WinDns\0".as_ptr()).into_result()? };
-
- let mut monitor = DnsMonitor {};
+ let mut monitor = DnsMonitor { current_guid: None };
monitor.reset()?;
Ok(monitor)
}
fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> {
- let ipv4 = servers
- .iter()
- .filter(|ip| ip.is_ipv4())
- .map(ip_to_widestring)
- .collect::<Vec<_>>();
- let ipv6 = servers
- .iter()
- .filter(|ip| ip.is_ipv6())
- .map(ip_to_widestring)
- .collect::<Vec<_>>();
-
- let mut ipv4_address_ptrs = ipv4
- .iter()
- .map(|ip_cstr| ip_cstr.as_ptr())
- .collect::<Vec<_>>();
- let mut ipv6_address_ptrs = ipv6
- .iter()
- .map(|ip_cstr| ip_cstr.as_ptr())
- .collect::<Vec<_>>();
-
- log::trace!("ipv4 ips: {:?} ({})", ipv4, ipv4.len());
- log::trace!("ipv6 ips: {:?} ({})", ipv6, ipv6.len());
-
- let luid = luid_from_alias(interface).map_err(Error::InterfaceLuidError)?;
-
- unsafe {
- WinDns_Set(
- &luid,
- ipv4_address_ptrs.as_mut_ptr(),
- ipv4_address_ptrs.len() as u32,
- ipv6_address_ptrs.as_mut_ptr(),
- ipv6_address_ptrs.len() as u32,
- )
- .into_result()
- }?;
+ let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?)
+ .map_err(Error::InterfaceGuidError)?;
+ set_dns(&guid, servers)?;
+ self.current_guid = Some(guid);
+ flush_dns_cache()?;
if *GLOBAL_DNS_CACHE_POLICY {
if let Err(error) = set_dns_cache_policy(servers) {
log::error!("{}", error.display_chain());
- log::warn!("DNS resolution may be slowed down");
}
}
@@ -110,16 +82,18 @@ impl super::DnsMonitorT for DnsMonitor {
}
fn reset(&mut self) -> Result<(), Error> {
+ let mut result = Ok(());
+
+ if let Some(guid) = self.current_guid.take() {
+ result = result.and(set_dns(&guid, &[])).and(flush_dns_cache());
+ }
+
if *GLOBAL_DNS_CACHE_POLICY {
- reset_dns_cache_policy()
- } else {
- Ok(())
+ result = result.and(reset_dns_cache_policy());
}
- }
-}
-fn ip_to_widestring(ip: &IpAddr) -> WideCString {
- WideCString::from_str_truncate(ip.to_string())
+ result
+ }
}
impl Drop for DnsMonitor {
@@ -132,13 +106,103 @@ impl Drop for DnsMonitor {
);
}
}
+ }
+}
- if unsafe { WinDns_Deinitialize().into_result().is_ok() } {
- log::trace!("Successfully deinitialized WinDns");
- } else {
- log::error!("Failed to deinitialize WinDns");
+fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> {
+ let transaction = Transaction::new().map_err(Error::SetResolversError)?;
+ let result = match set_dns_inner(&transaction, interface, servers) {
+ Ok(()) => transaction.commit(),
+ Err(error) => transaction.rollback().and(Err(error)),
+ };
+ result.map_err(Error::SetResolversError)
+}
+
+fn set_dns_inner(
+ transaction: &Transaction,
+ interface: &GUID,
+ servers: &[IpAddr],
+) -> io::Result<()> {
+ let guid_str = string_from_guid(interface);
+
+ config_interface(
+ transaction,
+ &guid_str,
+ "Tcpip",
+ servers.iter().filter(|addr| addr.is_ipv4()),
+ )?;
+
+ config_interface(
+ transaction,
+ &guid_str,
+ "Tcpip6",
+ servers.iter().filter(|addr| addr.is_ipv6()),
+ )?;
+
+ Ok(())
+}
+
+fn config_interface<'a>(
+ transaction: &Transaction,
+ guid: &str,
+ service: &str,
+ nameservers: impl Iterator<Item = &'a IpAddr>,
+) -> io::Result<()> {
+ let nameservers = nameservers
+ .map(|addr| addr.to_string())
+ .collect::<Vec<String>>();
+
+ let reg_path =
+ format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,);
+ let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags(
+ reg_path,
+ transaction,
+ KEY_SET_VALUE,
+ ) {
+ Ok(adapter_key) => Ok(adapter_key),
+ Err(error) => {
+ if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound {
+ return Ok(());
+ }
+ Err(error)
}
+ }?;
+
+ if !nameservers.is_empty() {
+ adapter_key.set_value("NameServer", &nameservers.join(","))?;
+ } else {
+ adapter_key.delete_value("NameServer").or_else(|error| {
+ if error.kind() == io::ErrorKind::NotFound {
+ Ok(())
+ } else {
+ Err(error)
+ }
+ })?;
+ }
+
+ // Try to disable LLMNR on the interface
+ if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) {
+ log::error!(
+ "{}\nService: {service}",
+ error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface")
+ );
}
+
+ Ok(())
+}
+
+fn flush_dns_cache() -> Result<(), Error> {
+ let sysdir = get_system_dir().map_err(Error::SystemDirError)?;
+ let mut ipconfig = Command::new(sysdir.join("ipconfig.exe"));
+ ipconfig.arg("/flushdns");
+ let output = ipconfig.output().map_err(Error::ExecuteIpconfigError)?;
+ let output = String::from_utf8_lossy(&output.stdout);
+ // The exit code cannot be trusted
+ if !output.contains("Successfully flushed") {
+ log::error!("Failed to flush DNS cache: {}", output);
+ return Err(Error::FlushResolverCacheError);
+ }
+ Ok(())
}
fn set_dns_cache_policy(servers: &[IpAddr]) -> Result<(), Error> {
@@ -207,32 +271,3 @@ fn reset_dns_cache_policy() -> Result<(), Error> {
}
}
}
-
-ffi_error!(InitializationResult, Error::Initialization);
-ffi_error!(DeinitializationResult, Error::Deinitialization);
-ffi_error!(SettingResult, Error::Setting);
-
-#[allow(non_snake_case)]
-extern "stdcall" {
- #[link_name = "WinDns_Initialize"]
- pub fn WinDns_Initialize(
- sink: Option<LogSink>,
- sink_context: *const u8,
- ) -> InitializationResult;
-
- // WinDns_Deinitialize:
- //
- // Call this function once before unloading WINDNS or exiting the process.
- #[link_name = "WinDns_Deinitialize"]
- pub fn WinDns_Deinitialize() -> DeinitializationResult;
-
- // Configure which DNS servers should be used and start enforcing these settings.
- #[link_name = "WinDns_Set"]
- pub fn WinDns_Set(
- interface_luid: *const NET_LUID,
- v4_ips: *mut *const u16,
- v4_n_ips: u32,
- v6_ips: *mut *const u16,
- v6_n_ips: u32,
- ) -> SettingResult;
-}
diff --git a/talpid-core/src/firewall/mod.rs b/talpid-core/src/firewall/mod.rs
index de117d39ad..b23e16b017 100644
--- a/talpid-core/src/firewall/mod.rs
+++ b/talpid-core/src/firewall/mod.rs
@@ -1,14 +1,13 @@
-#[cfg(unix)]
use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
-#[cfg(unix)]
use lazy_static::lazy_static;
-use std::fmt;
#[cfg(not(target_os = "android"))]
use std::net::IpAddr;
-#[cfg(unix)]
-use std::net::{Ipv4Addr, Ipv6Addr};
#[cfg(windows)]
use std::path::PathBuf;
+use std::{
+ fmt,
+ net::{Ipv4Addr, Ipv6Addr},
+};
use talpid_types::net::{AllowedEndpoint, Endpoint};
#[cfg(target_os = "macos")]
@@ -29,7 +28,6 @@ mod imp;
pub use self::imp::Error;
-#[cfg(unix)]
lazy_static! {
/// When "allow local network" is enabled the app will allow traffic to and from these networks.
pub(crate) static ref ALLOWED_LAN_NETS: [IpNetwork; 6] = [
@@ -83,7 +81,7 @@ const DHCPV6_CLIENT_PORT: u16 = 546;
#[cfg(all(unix, not(target_os = "android")))]
const ROOT_UID: u32 = 0;
-#[cfg(all(unix, not(target_os = "android")))]
+#[cfg(any(all(unix, not(target_os = "android")), target_os = "windows"))]
/// Returns whether an address belongs to a private subnet.
pub fn is_local_address(address: &IpAddr) -> bool {
let address = address.clone();
diff --git a/talpid-core/src/tunnel/openvpn/wintun.rs b/talpid-core/src/tunnel/openvpn/wintun.rs
index a61b83643d..1746756db4 100644
--- a/talpid-core/src/tunnel/openvpn/wintun.rs
+++ b/talpid-core/src/tunnel/openvpn/wintun.rs
@@ -1,4 +1,6 @@
-use crate::windows::{get_ip_interface_entry, set_ip_interface_entry, AddressFamily};
+use crate::windows::{
+ get_ip_interface_entry, set_ip_interface_entry, string_from_guid, AddressFamily,
+};
use lazy_static::lazy_static;
use std::{
ffi::CStr,
@@ -483,24 +485,6 @@ impl Drop for WintunLoggerHandle {
}
}
-/// Obtain a string representation for a GUID object.
-fn string_from_guid(guid: &GUID) -> String {
- use std::{ffi::OsString, os::windows::ffi::OsStringExt};
- use winapi::um::combaseapi::StringFromGUID2;
-
- let mut buffer = [0u16; 40];
- let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) }
- as usize;
- if length > 0 {
- let length = length - 1;
- OsString::from_wide(&buffer[0..length])
- .to_string_lossy()
- .to_string()
- } else {
- "".to_string()
- }
-}
-
/// Returns the registry key for a network device identified by its GUID.
fn find_adapter_registry_key(find_guid: &str, permissions: REGSAM) -> io::Result<RegKey> {
let net_devs = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_with_flags(
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 80b6de8772..80e5e28957 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -123,7 +123,7 @@ impl ConnectedState {
fn set_dns(&self, shared_values: &mut SharedTunnelStateValues) -> Result<(), BoxedError> {
let dns_ips = self.get_dns_servers(shared_values);
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", target_os = "windows"))]
let dns_ips = &dns_ips
.into_iter()
.filter(|ip| {
diff --git a/talpid-core/src/windows/mod.rs b/talpid-core/src/windows/mod.rs
index 4ec1fe19f3..115d8f5397 100644
--- a/talpid-core/src/windows/mod.rs
+++ b/talpid-core/src/windows/mod.rs
@@ -1,34 +1,48 @@
use socket2::SockAddr;
use std::{
ffi::{OsStr, OsString},
- fmt, io, mem,
+ fmt, io,
+ mem::{self, MaybeUninit},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::windows::{
ffi::{OsStrExt, OsStringExt},
io::RawHandle,
},
+ path::PathBuf,
+ ptr,
sync::Mutex,
time::{Duration, Instant},
};
-use winapi::shared::{
- ifdef::NET_LUID,
- in6addr::IN6_ADDR,
- inaddr::IN_ADDR,
- netioapi::{
- CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias,
- FreeMibTable, GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable,
- MibAddInstance, NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW,
- MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE,
+use widestring::WideCStr;
+use winapi::{
+ shared::{
+ guiddef::GUID,
+ ifdef::NET_LUID,
+ in6addr::IN6_ADDR,
+ inaddr::IN_ADDR,
+ netioapi::{
+ CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias,
+ ConvertInterfaceLuidToGuid, FreeMibTable, GetIpInterfaceEntry,
+ GetUnicastIpAddressEntry, GetUnicastIpAddressTable, MibAddInstance,
+ NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW,
+ MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE,
+ },
+ nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE},
+ ntddndis::NDIS_IF_MAX_STRING_SIZE,
+ ntdef::FALSE,
+ winerror::{ERROR_NOT_FOUND, NO_ERROR, S_OK},
+ ws2def::{
+ AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in,
+ SOCKADDR_STORAGE as sockaddr_storage,
+ },
+ ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET},
},
- nldef::{IpDadStatePreferred, IpDadStateTentative, NL_DAD_STATE},
- ntddndis::NDIS_IF_MAX_STRING_SIZE,
- ntdef::FALSE,
- winerror::{ERROR_NOT_FOUND, NO_ERROR},
- ws2def::{
- AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN as sockaddr_in,
- SOCKADDR_STORAGE as sockaddr_storage,
+ um::{
+ combaseapi::{CoTaskMemFree, StringFromGUID2},
+ knownfolders::FOLDERID_System,
+ shlobj::SHGetKnownFolderPath,
+ winnt::PWSTR,
},
- ws2ipdef::{SOCKADDR_IN6_LH as sockaddr_in6, SOCKADDR_INET},
};
pub mod window;
@@ -350,6 +364,27 @@ pub fn get_unicast_table(
Ok(unicast_rows)
}
+/// Obtain a string representation for a GUID object.
+pub fn string_from_guid(guid: &GUID) -> String {
+ let mut buffer = [0u16; 40];
+ let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) }
+ as usize;
+ // cannot fail because `buffer` is large enough
+ assert!(length > 0);
+ let length = length - 1;
+ String::from_utf16(&buffer[0..length]).unwrap()
+}
+
+/// Returns the GUID of a network interface given its LUID.
+pub fn guid_from_luid(luid: &NET_LUID) -> io::Result<GUID> {
+ let mut guid = MaybeUninit::zeroed();
+ let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) };
+ if status != NO_ERROR {
+ return Err(io::Error::from_raw_os_error(status as i32));
+ }
+ Ok(unsafe { guid.assume_init() })
+}
+
/// Returns the LUID of an interface given its alias.
pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID> {
let alias_wide: Vec<u16> = alias
@@ -435,6 +470,24 @@ pub fn try_socketaddr_from_inet_sockaddr(addr: SOCKADDR_INET) -> Result<SocketAd
.ok_or(Error::UnknownAddressFamily(family))
}
+/// Returns the system directory, i.e. `%windir%\system32`.
+pub fn get_system_dir() -> io::Result<PathBuf> {
+ let mut folder_path: PWSTR = ptr::null_mut();
+ let status =
+ unsafe { SHGetKnownFolderPath(&FOLDERID_System, 0, ptr::null_mut(), &mut folder_path) };
+ let result = if status == S_OK {
+ let path = unsafe { WideCStr::from_ptr_str(folder_path) };
+ Ok(path.to_ustring().to_os_string().into())
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::NotFound,
+ "Cannot find the system directory",
+ ))
+ };
+ unsafe { CoTaskMemFree(folder_path as *mut _) };
+ result
+}
+
/// Casts a struct to a slice of possibly uninitialized bytes.
#[cfg(target_os = "windows")]
pub fn as_uninit_byte_slice<T: Copy + Sized>(value: &T) -> &[mem::MaybeUninit<u8>] {