diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-12-12 15:05:24 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-01-10 11:58:51 +0100 |
| commit | c72be74adbde77f3e6f33a038e5b870108a52196 (patch) | |
| tree | 1f3a13c0a6998b36973a4b7c4c3ba78b950cf579 | |
| parent | c0b55ed2b99f9dbfd437899be4a99d1c96f55ca4 (diff) | |
| download | mullvadvpn-c72be74adbde77f3e6f33a038e5b870108a52196.tar.xz mullvadvpn-c72be74adbde77f3e6f33a038e5b870108a52196.zip | |
Revert to using 'netsh' instead of registry
| -rw-r--r-- | talpid-core/src/dns/mod.rs | 2 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows.rs | 211 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/dnsapi.rs | 132 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 157 | ||||
| -rw-r--r-- | talpid-windows-net/src/net.rs | 15 |
5 files changed, 225 insertions, 292 deletions
diff --git a/talpid-core/src/dns/mod.rs b/talpid-core/src/dns/mod.rs index ca06251f55..ae902da09e 100644 --- a/talpid-core/src/dns/mod.rs +++ b/talpid-core/src/dns/mod.rs @@ -20,7 +20,7 @@ mod imp; pub use imp::will_use_nm; #[cfg(windows)] -#[path = "windows/mod.rs"] +#[path = "windows.rs"] mod imp; #[cfg(target_os = "android")] diff --git a/talpid-core/src/dns/windows.rs b/talpid-core/src/dns/windows.rs new file mode 100644 index 0000000000..adfe3e01e9 --- /dev/null +++ b/talpid-core/src/dns/windows.rs @@ -0,0 +1,211 @@ +use std::{ + ffi::OsString, + io::{self, Write}, + net::IpAddr, + os::windows::prelude::{AsRawHandle, OsStringExt}, + path::PathBuf, + process::{Child, Command, ExitStatus, Stdio}, + time::Duration, +}; +use talpid_types::{net::IpVersion, ErrorExt}; +use talpid_windows_net::{index_from_luid, luid_from_alias}; +use windows_sys::Win32::{ + Foundation::{MAX_PATH, WAIT_OBJECT_0, WAIT_TIMEOUT}, + System::{ + SystemInformation::GetSystemDirectoryW, Threading::WaitForSingleObject, + WindowsProgramming::INFINITE, + }, +}; + +const NETSH_TIMEOUT: Duration = Duration::from_secs(10); + +/// Errors that can happen when configuring DNS on Windows. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failure to obtain an interface LUID given an alias. + #[error(display = "Failed to obtain LUID for the interface alias")] + InterfaceLuidError(#[error(source)] io::Error), + + /// Failure to obtain an interface index. + #[error(display = "Failed to obtain index of the interface")] + InterfaceIndexError(#[error(source)] io::Error), + + /// Failure to spawn netsh subprocess. + #[error(display = "Failed to spawn 'netsh'")] + SpawnNetsh(#[error(source)] io::Error), + + /// Failure to spawn netsh subprocess. + #[error(display = "Failed to obtain system directory")] + GetSystemDir(#[error(source)] io::Error), + + /// Failure to write to stdin. + #[error(display = "Failed to write to stdin for 'netsh'")] + NetshInput(#[error(source)] io::Error), + + /// Failure to wait for netsh result. + #[error(display = "Failed to wait for 'netsh'")] + WaitNetsh(#[error(source)] io::Error), + + /// netsh returned a non-zero status. + #[error(display = "'netsh' returned an error: {:?}", _0)] + NetshError(Option<i32>), + + /// netsh did not return in a timely manner. + #[error(display = "'netsh' took too long to complete")] + NetshTimeout, +} + +pub struct DnsMonitor { + current_index: Option<u32>, +} + +impl super::DnsMonitorT for DnsMonitor { + type Error = Error; + + fn new() -> Result<Self, Error> { + Ok(DnsMonitor { + current_index: None, + }) + } + + fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> { + let interface_luid = luid_from_alias(interface).map_err(Error::InterfaceLuidError)?; + let interface_index = + index_from_luid(&interface_luid).map_err(Error::InterfaceIndexError)?; + + self.current_index = Some(interface_index); + + let mut added_ipv4_server = false; + let mut added_ipv6_server = false; + + let mut netsh_input = String::new(); + + for server in servers { + let is_additional_server; + + if server.is_ipv4() { + is_additional_server = added_ipv4_server; + added_ipv4_server = true; + } else { + is_additional_server = added_ipv6_server; + added_ipv6_server = true; + }; + + if is_additional_server { + netsh_input.push_str(&create_netsh_add_command(interface_index, server)); + } else { + netsh_input.push_str(&create_netsh_set_command(interface_index, server)); + } + } + + if !added_ipv4_server { + netsh_input.push_str(&create_netsh_flush_command(interface_index, IpVersion::V4)); + } + if !added_ipv6_server { + netsh_input.push_str(&create_netsh_flush_command(interface_index, IpVersion::V6)); + } + + run_netsh_with_timeout(netsh_input, NETSH_TIMEOUT)?; + + Ok(()) + } + + fn reset(&mut self) -> Result<(), Error> { + if let Some(index) = self.current_index.take() { + let mut netsh_input = String::new(); + netsh_input.push_str(&create_netsh_flush_command(index, IpVersion::V4)); + netsh_input.push_str(&create_netsh_flush_command(index, IpVersion::V6)); + + if let Err(error) = run_netsh_with_timeout(netsh_input, NETSH_TIMEOUT) { + log::error!("{}", error.display_chain_with_msg("Failed to reset DNS")); + } + } + Ok(()) + } +} + +fn run_netsh_with_timeout(netsh_input: String, timeout: Duration) -> Result<(), Error> { + log::debug!("running netsh:\n{}", netsh_input); + + let sysdir = get_system_dir().map_err(Error::GetSystemDir)?; + let mut netsh = Command::new(sysdir.join(r"netsh.exe")); + + let mut subproc = netsh + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(Error::SpawnNetsh)?; + + let mut stdin = subproc.stdin.take().unwrap(); + stdin + .write_all(netsh_input.as_bytes()) + .map_err(Error::NetshInput)?; + drop(stdin); + + match wait_for_child(&mut subproc, timeout) { + Ok(Some(status)) => { + if !status.success() { + return Err(Error::NetshError(status.code())); + } + Ok(()) + } + Ok(None) => { + let _ = subproc.kill(); + Err(Error::NetshTimeout) + } + Err(error) => Err(Error::WaitNetsh(error)), + } +} + +fn wait_for_child(subproc: &mut Child, timeout: Duration) -> io::Result<Option<ExitStatus>> { + let dur_millis = u32::try_from(timeout.as_millis()).unwrap_or(INFINITE); + + let subproc_handle = subproc.as_raw_handle(); + match unsafe { WaitForSingleObject(subproc_handle as isize, dur_millis) } { + WAIT_OBJECT_0 => subproc.try_wait(), + WAIT_TIMEOUT => Ok(None), + _error => Err(io::Error::last_os_error()), + } +} + +fn create_netsh_set_command(interface_index: u32, server: &IpAddr) -> String { + // Set primary DNS server: + // netsh interface ipv4 set dnsservers name="Mullvad" source=static address=10.64.0.1 + // validate=no + + let interface_type = if server.is_ipv4() { "ipv4" } else { "ipv6" }; + format!("interface {interface_type} set dnsservers name={interface_index} source=static address={server} validate=no\r\n") +} + +fn create_netsh_add_command(interface_index: u32, server: &IpAddr) -> String { + // Add DNS server: + // netsh interface ipv4 add dnsservers name="Mullvad" address=10.64.0.2 validate=no + + let interface_type = if server.is_ipv4() { "ipv4" } else { "ipv6" }; + format!("interface {interface_type} add dnsservers name={interface_index} address={server} validate=no\r\n") +} + +fn create_netsh_flush_command(interface_index: u32, ip_version: IpVersion) -> String { + // Flush DNS settings: + // netsh interface ipv4 set dnsservers name="Mullvad" source=static address=none validate=no + + let interface_type = match ip_version { + IpVersion::V4 => "ipv4", + IpVersion::V6 => "ipv6", + }; + + format!("interface {interface_type} set dnsservers name={interface_index} source=static address=none validate=no\r\n") +} + +fn get_system_dir() -> io::Result<PathBuf> { + let mut sysdir = [0u16; MAX_PATH as usize + 1]; + let len = unsafe { GetSystemDirectoryW(sysdir.as_mut_ptr(), (sysdir.len() - 1) as u32) }; + if len == 0 { + return Err(io::Error::last_os_error()); + } + Ok(PathBuf::from(OsString::from_wide( + &sysdir[0..(len as usize)], + ))) +} diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs deleted file mode 100644 index e48f37258e..0000000000 --- a/talpid-core/src/dns/windows/dnsapi.rs +++ /dev/null @@ -1,132 +0,0 @@ -use once_cell::sync::OnceCell; -use std::{ - io, - sync::{ - atomic::{AtomicUsize, Ordering}, - mpsc, Arc, - }, - time::{Duration, Instant}, -}; -use windows_sys::{ - w, - Win32::{ - Foundation::BOOL, - System::LibraryLoader::{ - FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32, - }, - }, -}; - -type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL; - -static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new(); -static FLUSH_TIMEOUT: Duration = Duration::from_secs(5); - -const MAX_CONCURRENT_FLUSHES: usize = 5; - -/// Errors that can happen when configuring DNS on Windows. -#[derive(err_derive::Error, Debug)] -#[error(no_from)] -pub enum Error { - /// Failed to load dnsapi.dll. - #[error(display = "Failed to load dnsapi.dll")] - LoadDll(#[error(source)] io::Error), - - /// Failed to obtain exported function. - #[error(display = "Failed to obtain flush function")] - GetFunction(#[error(source)] io::Error), - - /// Failed to flush the DNS cache. - #[error(display = "Call to flush DNS cache failed")] - FlushCache, - - /// Too many flush attempts in progress. - #[error(display = "Too many flush attempts in progress")] - TooManyFlushAttempts, - - /// Flushing the DNS cache timed out. - #[error(display = "Timeout while flushing DNS cache")] - Timeout, -} - -pub fn flush_resolver_cache() -> Result<(), Error> { - DNSAPI_HANDLE - .get_or_try_init(|| DnsApi::new())? - .flush_cache() -} - -struct DnsApi { - in_flight_flush_count: Arc<AtomicUsize>, - flush_fn: FlushResolverCacheFn, -} - -unsafe impl Send for DnsApi {} -unsafe impl Sync for DnsApi {} - -impl DnsApi { - fn new() -> Result<Self, Error> { - let handle = unsafe { LoadLibraryExW(w!("dnsapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) }; - if handle == 0 { - return Err(Error::LoadDll(io::Error::last_os_error())); - } - - let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const u8) }; - let flush_fn = flush_fn.ok_or_else(|| { - let error = io::Error::last_os_error(); - unsafe { FreeLibrary(handle) }; - Error::GetFunction(error) - })?; - - Ok(DnsApi { - in_flight_flush_count: Arc::new(AtomicUsize::new(0)), - flush_fn: unsafe { *(&flush_fn as *const _ as *const _) }, - }) - } - - fn flush_cache(&self) -> Result<(), Error> { - if self - .in_flight_flush_count - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| { - if val >= MAX_CONCURRENT_FLUSHES { - return None; - } - Some(val + 1) - }) - .is_err() - { - return Err(Error::TooManyFlushAttempts); - } - - let (tx, rx) = mpsc::channel(); - let flush_count = self.in_flight_flush_count.clone(); - - let flush_fn = self.flush_fn; - - std::thread::spawn(move || { - let begin = Instant::now(); - - let result = if unsafe { (flush_fn)() } != 0 { - let elapsed = begin.elapsed(); - if elapsed >= FLUSH_TIMEOUT { - log::warn!( - "Flushing system DNS cache took {} seconds", - elapsed.as_secs() - ); - } else { - log::debug!("Flushed system DNS cache"); - } - Ok(()) - } else { - Err(Error::FlushCache) - }; - let _ = tx.send(result); - - flush_count.fetch_sub(1, Ordering::SeqCst); - }); - - match rx.recv_timeout(FLUSH_TIMEOUT) { - Ok(result) => result, - Err(_timeout_err) => Err(Error::Timeout), - } - } -} diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs deleted file mode 100644 index 9b780b1f54..0000000000 --- a/talpid-core/src/dns/windows/mod.rs +++ /dev/null @@ -1,157 +0,0 @@ -use std::{io, net::IpAddr}; -use talpid_types::ErrorExt; -use talpid_windows_net::{guid_from_luid, luid_from_alias}; -use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2}; -use winreg::{ - enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE}, - transaction::Transaction, - RegKey, -}; - -mod dnsapi; - -/// Errors that can happen when configuring DNS on Windows. -#[derive(err_derive::Error, Debug)] -#[error(no_from)] -pub enum Error { - /// Failure to obtain an interface LUID given an alias. - #[error(display = "Failed to obtain LUID for the interface alias")] - InterfaceLuidError(#[error(source)] io::Error), - - /// Failure to obtain an interface GUID. - #[error(display = "Failed to obtain GUID for the interface")] - InterfaceGuidError(#[error(source)] io::Error), - - /// Failure to flush DNS cache. - #[error(display = "Failed to flush DNS resolver cache")] - FlushResolverCacheError(#[error(source)] dnsapi::Error), - - /// Failed to update DNS servers for interface. - #[error(display = "Failed to update interface DNS servers")] - SetResolversError(#[error(source)] io::Error), -} - -pub struct DnsMonitor { - current_guid: Option<GUID>, -} - -impl super::DnsMonitorT for DnsMonitor { - type Error = Error; - - fn new() -> Result<Self, Error> { - Ok(DnsMonitor { current_guid: None }) - } - - fn set(&mut self, interface: &str, servers: &[IpAddr]) -> Result<(), Error> { - let guid = guid_from_luid(&luid_from_alias(interface).map_err(Error::InterfaceLuidError)?) - .map_err(Error::InterfaceGuidError)?; - set_dns(&guid, servers)?; - self.current_guid = Some(guid); - flush_dns_cache()?; - Ok(()) - } - - fn reset(&mut self) -> Result<(), Error> { - if let Some(guid) = self.current_guid.take() { - return set_dns(&guid, &[]).and(flush_dns_cache()); - } - Ok(()) - } -} - -fn set_dns(interface: &GUID, servers: &[IpAddr]) -> Result<(), Error> { - let transaction = Transaction::new().map_err(Error::SetResolversError)?; - let result = match set_dns_inner(&transaction, interface, servers) { - Ok(()) => transaction.commit(), - Err(error) => transaction.rollback().and(Err(error)), - }; - result.map_err(Error::SetResolversError) -} - -fn set_dns_inner( - transaction: &Transaction, - interface: &GUID, - servers: &[IpAddr], -) -> io::Result<()> { - let guid_str = string_from_guid(interface); - - config_interface( - transaction, - &guid_str, - "Tcpip", - servers.iter().filter(|addr| addr.is_ipv4()), - )?; - - config_interface( - transaction, - &guid_str, - "Tcpip6", - servers.iter().filter(|addr| addr.is_ipv6()), - )?; - - Ok(()) -} - -fn config_interface<'a>( - transaction: &Transaction, - guid: &str, - service: &str, - nameservers: impl Iterator<Item = &'a IpAddr>, -) -> io::Result<()> { - let nameservers = nameservers - .map(|addr| addr.to_string()) - .collect::<Vec<String>>(); - - let reg_path = - format!(r#"SYSTEM\CurrentControlSet\Services\{service}\Parameters\Interfaces\{guid}"#,); - let adapter_key = match RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey_transacted_with_flags( - reg_path, - transaction, - KEY_SET_VALUE, - ) { - Ok(adapter_key) => Ok(adapter_key), - Err(error) => { - if nameservers.is_empty() && error.kind() == io::ErrorKind::NotFound { - return Ok(()); - } - Err(error) - } - }?; - - if !nameservers.is_empty() { - adapter_key.set_value("NameServer", &nameservers.join(","))?; - } else { - adapter_key.delete_value("NameServer").or_else(|error| { - if error.kind() == io::ErrorKind::NotFound { - Ok(()) - } else { - Err(error) - } - })?; - } - - // Try to disable LLMNR on the interface - if let Err(error) = adapter_key.set_value("EnableMulticast", &0u32) { - log::error!( - "{}\nService: {service}", - error.display_chain_with_msg("Failed to disable LLMNR on the tunnel interface") - ); - } - - Ok(()) -} - -fn flush_dns_cache() -> Result<(), Error> { - dnsapi::flush_resolver_cache().map_err(Error::FlushResolverCacheError) -} - -/// Obtain a string representation for a GUID object. -fn string_from_guid(guid: &GUID) -> String { - let mut buffer = [0u16; 40]; - let length = unsafe { StringFromGUID2(guid, &mut buffer[0] as *mut _, buffer.len() as i32 - 1) } - as usize; - // cannot fail because `buffer` is large enough - assert!(length > 0); - let length = length - 1; - String::from_utf16(&buffer[0..length]).unwrap() -} diff --git a/talpid-windows-net/src/net.rs b/talpid-windows-net/src/net.rs index 6a11cf7018..18a3f7748c 100644 --- a/talpid-windows-net/src/net.rs +++ b/talpid-windows-net/src/net.rs @@ -17,8 +17,9 @@ use windows_sys::{ NetworkManagement::{ IpHelper::{ CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias, - ConvertInterfaceLuidToGuid, CreateUnicastIpAddressEntry, FreeMibTable, - GetIpInterfaceEntry, GetUnicastIpAddressEntry, GetUnicastIpAddressTable, + ConvertInterfaceLuidToGuid, ConvertInterfaceLuidToIndex, + CreateUnicastIpAddressEntry, FreeMibTable, GetIpInterfaceEntry, + GetUnicastIpAddressEntry, GetUnicastIpAddressTable, InitializeUnicastIpAddressEntry, MibAddInstance, NotifyIpInterfaceChange, SetIpInterfaceEntry, MIB_IPINTERFACE_ROW, MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE, @@ -377,6 +378,16 @@ pub fn get_unicast_table( Ok(unicast_rows) } +/// Returns the index of a network interface given its LUID. +pub fn index_from_luid(luid: &NET_LUID_LH) -> io::Result<u32> { + let mut index = 0u32; + let status = unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) }; + if status != NO_ERROR as i32 { + return Err(io::Error::from_raw_os_error(status as i32)); + } + Ok(index) +} + /// Returns the GUID of a network interface given its LUID. pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> { let mut guid = MaybeUninit::zeroed(); |
