summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-12-12 15:05:24 +0100
committerDavid Lönnhager <david.l@mullvad.net>2023-01-10 11:58:51 +0100
commitc72be74adbde77f3e6f33a038e5b870108a52196 (patch)
tree1f3a13c0a6998b36973a4b7c4c3ba78b950cf579
parentc0b55ed2b99f9dbfd437899be4a99d1c96f55ca4 (diff)
downloadmullvadvpn-c72be74adbde77f3e6f33a038e5b870108a52196.tar.xz
mullvadvpn-c72be74adbde77f3e6f33a038e5b870108a52196.zip
Revert to using 'netsh' instead of registry
-rw-r--r--talpid-core/src/dns/mod.rs2
-rw-r--r--talpid-core/src/dns/windows.rs211
-rw-r--r--talpid-core/src/dns/windows/dnsapi.rs132
-rw-r--r--talpid-core/src/dns/windows/mod.rs157
-rw-r--r--talpid-windows-net/src/net.rs15
5 files changed, 225 insertions, 292 deletions
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs
index ca06251f55..ae902da09e 100644
--- a/talpid-core/src/dns/mod.rs
+++ b/talpid-core/src/dns/mod.rs
@@ -20,7 +20,7 @@ mod imp;
pub use imp::will_use_nm;
#[cfg(windows)]
-#[path = "windows/mod.rs"]
+#[path = "windows.rs"]
mod imp;
#[cfg(target_os = "android")]
diff --git a/talpid-core/src/dns/windows.rs b/talpid-core/src/dns/windows.rs
new file mode 100644
index 0000000000..adfe3e01e9
--- /dev/null
+++ b/talpid-core/src/dns/windows.rs
@@ -0,0 +1,211 @@
+use std::{
+ ffi::OsString,
+ io::{self, Write},
+ net::IpAddr,
+ os::windows::prelude::{AsRawHandle, OsStringExt},
+ path::PathBuf,
+ process::{Child, Command, ExitStatus, Stdio},
+ time::Duration,
+};
+use talpid_types::{net::IpVersion, ErrorExt};
+use talpid_windows_net::{index_from_luid, luid_from_alias};
+use windows_sys::Win32::{
+ Foundation::{MAX_PATH, WAIT_OBJECT_0, WAIT_TIMEOUT},
+ System::{
+ SystemInformation::GetSystemDirectoryW, Threading::WaitForSingleObject,
+ WindowsProgramming::INFINITE,
+ },
+};
+
+const NETSH_TIMEOUT: Duration = Duration::from_secs(10);
+
+/// Errors that can happen when configuring DNS on Windows.
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failure to obtain an interface LUID given an alias.
+ #[error(display = "Failed to obtain LUID for the interface alias")]
+ InterfaceLuidError(#[error(source)] io::Error),
+
+ /// Failure to obtain an interface index.
+ #[error(display = "Failed to obtain index of the interface")]
+ InterfaceIndexError(#[error(source)] io::Error),
+
+ /// Failure to spawn netsh subprocess.
+ #[error(display = "Failed to spawn 'netsh'")]
+ SpawnNetsh(#[error(source)] io::Error),
+
+ /// Failure to spawn netsh subprocess.
+ #[error(display = "Failed to obtain system directory")]
+ GetSystemDir(#[error(source)] io::Error),
+
+ /// Failure to write to stdin.
+ #[error(display = "Failed to write to stdin for 'netsh'")]
+ NetshInput(#[error(source)] io::Error),
+
+ /// Failure to wait for netsh result.
+ #[error(display = "Failed to wait for 'netsh'")]
+ WaitNetsh(#[error(source)] io::Error),
+
+ /// netsh returned a non-zero status.
+ #[error(display = "'netsh' returned an error: {:?}", _0)]
+ NetshError(Option<i32>),
+
+ /// netsh did not return in a timely manner.
+ #[error(display = "'netsh' took too long to complete")]
+ NetshTimeout,
+}
+
+pub struct DnsMonitor {
+ current_index: Option<u32>,
+}
+
+impl super::DnsMonitorT for DnsMonitor {
+ type Error = Error;
+
+ fn new() -> Result<Self, Error> {
+ Ok(DnsMonitor {
+ current_index: None,
+ })
+ }
+
+ fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> {
+ let interface_luid = luid_from_alias(interface).map_err(Error::InterfaceLuidError)?;
+ let interface_index =
+ index_from_luid(&interface_luid).map_err(Error::InterfaceIndexError)?;
+
+ self.current_index = Some(interface_index);
+
+ let mut added_ipv4_server = false;
+ let mut added_ipv6_server = false;
+
+ let mut netsh_input = String::new();
+
+ for server in servers {
+ let is_additional_server;
+
+ if server.is_ipv4() {
+ is_additional_server = added_ipv4_server;
+ added_ipv4_server = true;
+ } else {
+ is_additional_server = added_ipv6_server;
+ added_ipv6_server = true;
+ };
+
+ if is_additional_server {
+ netsh_input.push_str(&create_netsh_add_command(interface_index, server));
+ } else {
+ netsh_input.push_str(&create_netsh_set_command(interface_index, server));
+ }
+ }
+
+ if !added_ipv4_server {
+ netsh_input.push_str(&create_netsh_flush_command(interface_index, IpVersion::V4));
+ }
+ if !added_ipv6_server {
+ netsh_input.push_str(&create_netsh_flush_command(interface_index, IpVersion::V6));
+ }
+
+ run_netsh_with_timeout(netsh_input, NETSH_TIMEOUT)?;
+
+ Ok(())
+ }
+
+ fn reset(&mut self) -> Result<(), Error> {
+ if let Some(index) = self.current_index.take() {
+ let mut netsh_input = String::new();
+ netsh_input.push_str(&create_netsh_flush_command(index, IpVersion::V4));
+ netsh_input.push_str(&create_netsh_flush_command(index, IpVersion::V6));
+
+ if let Err(error) = run_netsh_with_timeout(netsh_input, NETSH_TIMEOUT) {
+ log::error!("{}", error.display_chain_with_msg("Failed to reset DNS"));
+ }
+ }
+ Ok(())
+ }
+}
+
+fn run_netsh_with_timeout(netsh_input: String, timeout: Duration) -> Result<(), Error> {
+ log::debug!("running netsh:\n{}", netsh_input);
+
+ let sysdir = get_system_dir().map_err(Error::GetSystemDir)?;
+ let mut netsh = Command::new(sysdir.join(r"netsh.exe"));
+
+ let mut subproc = netsh
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .spawn()
+ .map_err(Error::SpawnNetsh)?;
+
+ let mut stdin = subproc.stdin.take().unwrap();
+ stdin
+ .write_all(netsh_input.as_bytes())
+ .map_err(Error::NetshInput)?;
+ drop(stdin);
+
+ match wait_for_child(&mut subproc, timeout) {
+ Ok(Some(status)) => {
+ if !status.success() {
+ return Err(Error::NetshError(status.code()));
+ }
+ Ok(())
+ }
+ Ok(None) => {
+ let _ = subproc.kill();
+ Err(Error::NetshTimeout)
+ }
+ Err(error) => Err(Error::WaitNetsh(error)),
+ }
+}
+
+fn wait_for_child(subproc: &mut Child, timeout: Duration) -> io::Result<Option<ExitStatus>> {
+ let dur_millis = u32::try_from(timeout.as_millis()).unwrap_or(INFINITE);
+
+ let subproc_handle = subproc.as_raw_handle();
+ match unsafe { WaitForSingleObject(subproc_handle as isize, dur_millis) } {
+ WAIT_OBJECT_0 => subproc.try_wait(),
+ WAIT_TIMEOUT => Ok(None),
+ _error => Err(io::Error::last_os_error()),
+ }
+}
+
+fn create_netsh_set_command(interface_index: u32, server: &IpAddr) -> String {
+ // Set primary DNS server:
+ // netsh interface ipv4 set dnsservers name="Mullvad" source=static address=10.64.0.1
+ // validate=no
+
+ let interface_type = if server.is_ipv4() { "ipv4" } else { "ipv6" };
+ format!("interface {interface_type} set dnsservers name={interface_index} source=static address={server} validate=no\r\n")
+}
+
+fn create_netsh_add_command(interface_index: u32, server: &IpAddr) -> String {
+ // Add DNS server:
+ // netsh interface ipv4 add dnsservers name="Mullvad" address=10.64.0.2 validate=no
+
+ let interface_type = if server.is_ipv4() { "ipv4" } else { "ipv6" };
+ format!("interface {interface_type} add dnsservers name={interface_index} address={server} validate=no\r\n")
+}
+
+fn create_netsh_flush_command(interface_index: u32, ip_version: IpVersion) -> String {
+ // Flush DNS settings:
+ // netsh interface ipv4 set dnsservers name="Mullvad" source=static address=none validate=no
+
+ let interface_type = match ip_version {
+ IpVersion::V4 => "ipv4",
+ IpVersion::V6 => "ipv6",
+ };
+
+ format!("interface {interface_type} set dnsservers name={interface_index} source=static address=none validate=no\r\n")
+}
+
+fn get_system_dir() -> io::Result<PathBuf> {
+ let mut sysdir = [0u16; MAX_PATH as usize + 1];
+ let len = unsafe { GetSystemDirectoryW(sysdir.as_mut_ptr(), (sysdir.len() - 1) as u32) };
+ if len == 0 {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(PathBuf::from(OsString::from_wide(
+ &sysdir[0..(len as usize)],
+ )))
+}
diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs
deleted file mode 100644
index e48f37258e..0000000000
--- a/talpid-core/src/dns/windows/dnsapi.rs
+++ /dev/null
@@ -1,132 +0,0 @@
-use once_cell::sync::OnceCell;
-use std::{
- io,
- sync::{
- atomic::{AtomicUsize, Ordering},
- mpsc, Arc,
- },
- time::{Duration, Instant},
-};
-use windows_sys::{
- w,
- Win32::{
- Foundation::BOOL,
- System::LibraryLoader::{
- FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32,
- },
- },
-};
-
-type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL;
-
-static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new();
-static FLUSH_TIMEOUT: Duration = Duration::from_secs(5);
-
-const MAX_CONCURRENT_FLUSHES: usize = 5;
-
-/// Errors that can happen when configuring DNS on Windows.
-#[derive(err_derive::Error, Debug)]
-#[error(no_from)]
-pub enum Error {
- /// Failed to load dnsapi.dll.
- #[error(display = "Failed to load dnsapi.dll")]
- LoadDll(#[error(source)] io::Error),
-
- /// Failed to obtain exported function.
- #[error(display = "Failed to obtain flush function")]
- GetFunction(#[error(source)] io::Error),
-
- /// Failed to flush the DNS cache.
- #[error(display = "Call to flush DNS cache failed")]
- FlushCache,
-
- /// Too many flush attempts in progress.
- #[error(display = "Too many flush attempts in progress")]
- TooManyFlushAttempts,
-
- /// Flushing the DNS cache timed out.
- #[error(display = "Timeout while flushing DNS cache")]
- Timeout,
-}
-
-pub fn flush_resolver_cache() -> Result<(), Error> {
- DNSAPI_HANDLE
- .get_or_try_init(|| DnsApi::new())?
- .flush_cache()
-}
-
-struct DnsApi {
- in_flight_flush_count: Arc<AtomicUsize>,
- flush_fn: FlushResolverCacheFn,
-}
-
-unsafe impl Send for DnsApi {}
-unsafe impl Sync for DnsApi {}
-
-impl DnsApi {
- fn new() -> Result<Self, Error> {
- let handle = unsafe { LoadLibraryExW(w!("dnsapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) };
- if handle == 0 {
- return Err(Error::LoadDll(io::Error::last_os_error()));
- }
-
- let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const u8) };
- let flush_fn = flush_fn.ok_or_else(|| {
- let error = io::Error::last_os_error();
- unsafe { FreeLibrary(handle) };
- Error::GetFunction(error)
- })?;
-
- Ok(DnsApi {
- in_flight_flush_count: Arc::new(AtomicUsize::new(0)),
- flush_fn: unsafe { *(&flush_fn as *const _ as *const _) },
- })
- }
-
- fn flush_cache(&self) -> Result<(), Error> {
- if self
- .in_flight_flush_count
- .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
- if val >= MAX_CONCURRENT_FLUSHES {
- return None;
- }
- Some(val + 1)
- })
- .is_err()
- {
- return Err(Error::TooManyFlushAttempts);
- }
-
- let (tx, rx) = mpsc::channel();
- let flush_count = self.in_flight_flush_count.clone();
-
- let flush_fn = self.flush_fn;
-
- std::thread::spawn(move || {
- let begin = Instant::now();
-
- let result = if unsafe { (flush_fn)() } != 0 {
- let elapsed = begin.elapsed();
- if elapsed >= FLUSH_TIMEOUT {
- log::warn!(
- "Flushing system DNS cache took {} seconds",
- elapsed.as_secs()
- );
- } else {
- log::debug!("Flushed system DNS cache");
- }
- Ok(())
- } else {
- Err(Error::FlushCache)
- };
- let _ = tx.send(result);
-
- flush_count.fetch_sub(1, Ordering::SeqCst);
- });
-
- match rx.recv_timeout(FLUSH_TIMEOUT) {
- Ok(result) => result,
- Err(_timeout_err) => Err(Error::Timeout),
- }
- }
-}
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
deleted file mode 100644
index 9b780b1f54..0000000000
--- a/talpid-core/src/dns/windows/mod.rs
+++ /dev/null
@@ -1,157 +0,0 @@
-use std::{io, net::IpAddr};
-use talpid_types::ErrorExt;
-use talpid_windows_net::{guid_from_luid, luid_from_alias};
-use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2};
-use winreg::{
- enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE},
- transaction::Transaction,
- RegKey,
-};
-
-mod dnsapi;
-
-/// Errors that can happen when configuring DNS on Windows.
-#[derive(err_derive::Error, Debug)]
-#[error(no_from)]
-pub enum Error {
- /// Failure to obtain an interface LUID given an alias.
- #[error(display = "Failed to obtain LUID for the interface alias")]
- InterfaceLuidError(#[error(source)] io::Error),
-
- /// Failure to obtain an interface GUID.
- #[error(display = "Failed to obtain GUID for the interface")]
- InterfaceGuidError(#[error(source)] io::Error),
-
- /// Failure to flush DNS cache.
- #[error(display = "Failed to flush DNS resolver cache")]
- FlushResolverCacheError(#[error(source)] dnsapi::Error),
-
- /// Failed to update DNS servers for interface.
- #[error(display = "Failed to update interface DNS servers")]
- SetResolversError(#[error(source)] io::Error),
-}
-
-pub struct DnsMonitor {
- current_guid: Option<GUID>,
-}
-
-impl super::DnsMonitorT for DnsMonitor {
- type Error = Error;
-
- fn new() -> Result<Self, Error> {
- Ok(DnsMonitor { current_guid: None })
- }
-
- fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> {
- let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?)
- .map_err(Error::InterfaceGuidError)?;
- set_dns(&guid, servers)?;
- self.current_guid = Some(guid);
- flush_dns_cache()?;
- Ok(())
- }
-
- fn reset(&mut self) -> Result<(), Error> {
- if let Some(guid) = self.current_guid.take() {
- return set_dns(&guid, &[]).and(flush_dns_cache());
- }
- Ok(())
- }
-}
-
-fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> {
- let transaction = Transaction::new().map_err(Error::SetResolversError)?;
- let result = match set_dns_inner(&transaction, interface, servers) {
- Ok(()) => transaction.commit(),
- Err(error) => transaction.rollback().and(Err(error)),
- };
- result.map_err(Error::SetResolversError)
-}
-
-fn set_dns_inner(
- transaction: &Transaction,
- interface: &GUID,
- servers: &[IpAddr],
-) -> io::Result<()> {
- let guid_str = string_from_guid(interface);
-
- config_interface(
- transaction,
- &guid_str,
- "Tcpip",
- servers.iter().filter(|addr| addr.is_ipv4()),
- )?;
-
- config_interface(
- transaction,
- &guid_str,
- "Tcpip6",
- servers.iter().filter(|addr| addr.is_ipv6()),
- )?;
-
- Ok(())
-}
-
-fn config_interface<'a>(
- transaction: &Transaction,
- guid: &str,
- service: &str,
- nameservers: impl Iterator<Item = &'a IpAddr>,
-) -> io::Result<()> {
- let nameservers = nameservers
- .map(|addr| addr.to_string())
- .collect::<Vec<String>>();
-
- let reg_path =
- format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,);
- let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags(
- reg_path,
- transaction,
- KEY_SET_VALUE,
- ) {
- Ok(adapter_key) => Ok(adapter_key),
- Err(error) => {
- if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound {
- return Ok(());
- }
- Err(error)
- }
- }?;
-
- if !nameservers.is_empty() {
- adapter_key.set_value("NameServer", &nameservers.join(","))?;
- } else {
- adapter_key.delete_value("NameServer").or_else(|error| {
- if error.kind() == io::ErrorKind::NotFound {
- Ok(())
- } else {
- Err(error)
- }
- })?;
- }
-
- // Try to disable LLMNR on the interface
- if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) {
- log::error!(
- "{}\nService: {service}",
- error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface")
- );
- }
-
- Ok(())
-}
-
-fn flush_dns_cache() -> Result<(), Error> {
- dnsapi::flush_resolver_cache().map_err(Error::FlushResolverCacheError)
-}
-
-/// Obtain a string representation for a GUID object.
-fn string_from_guid(guid: &GUID) -> String {
- let mut buffer = [0u16; 40];
- let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) }
- as usize;
- // cannot fail because `buffer` is large enough
- assert!(length > 0);
- let length = length - 1;
- String::from_utf16(&buffer[0..length]).unwrap()
-}
diff --git a/talpid-windows-net/src/net.rs b/talpid-windows-net/src/net.rs
index 6a11cf7018..18a3f7748c 100644
--- a/talpid-windows-net/src/net.rs
+++ b/talpid-windows-net/src/net.rs
@@ -17,8 +17,9 @@ use windows_sys::{
NetworkManagement::{
IpHelper::{
CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias,
- ConvertInterfaceLuidToGuid, CreateUnicastIpAddressEntry, FreeMibTable,
- GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable,
+ ConvertInterfaceLuidToGuid, ConvertInterfaceLuidToIndex,
+ CreateUnicastIpAddressEntry, FreeMibTable, GetIpInterfaceEntry,
+ GetUnicastIpAddressEntry, GetUnicastIpAddressTable,
InitializeUnicastIpAddressEntry, MibAddInstance, NotifyIpInterfaceChange,
SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, MIB_UNICASTIPADDRESS_ROW,
MIB_UNICASTIPADDRESS_TABLE,
@@ -377,6 +378,16 @@ pub fn get_unicast_table(
Ok(unicast_rows)
}
+/// Returns the index of a network interface given its LUID.
+pub fn index_from_luid(luid: &NET_LUID_LH) -> io::Result<u32> {
+ let mut index = 0u32;
+ let status = unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) };
+ if status != NO_ERROR as i32 {
+ return Err(io::Error::from_raw_os_error(status as i32));
+ }
+ Ok(index)
+}
+
/// Returns the GUID of a network interface given its LUID.
pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> {
let mut guid = MaybeUninit::zeroed();