diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-07-21 11:42:00 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-08-10 14:19:24 +0200 |
| commit | faa314e90bb9b3333fdd32c510123a3e0b774882 (patch) | |
| tree | f1d419f530427e9a2cbbe49daa1d00e7f2b89151 | |
| parent | d7223f3ae2bfefb267317a8fe7e7c51524f63f94 (diff) | |
| download | mullvadvpn-faa314e90bb9b3333fdd32c510123a3e0b774882.tar.xz mullvadvpn-faa314e90bb9b3333fdd32c510123a3e0b774882.zip | |
Limit number of concurrent flush attempts
| -rw-r--r-- | talpid-core/src/dns/windows/dnsapi.rs | 94 | ||||
| -rw-r--r-- | talpid-core/src/dns/windows/mod.rs | 6 |
2 files changed, 72 insertions, 28 deletions
diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs index 4a0cf636e5..2d428468ff 100644 --- a/talpid-core/src/dns/windows/dnsapi.rs +++ b/talpid-core/src/dns/windows/dnsapi.rs @@ -1,5 +1,12 @@ use once_cell::sync::OnceCell; -use std::{io, ptr, sync::mpsc, time::Duration}; +use std::{ + io, ptr, + sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc, Arc, + }, + time::Duration, +}; use winapi::{ shared::minwindef::{BOOL, FALSE}, um::libloaderapi::{FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32}, @@ -7,9 +14,11 @@ use winapi::{ type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL; -static FLUSH_RESOLVER_CACHE: OnceCell<FlushResolverCacheFn> = OnceCell::new(); +static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new(); static FLUSH_TIMEOUT: Duration = Duration::from_secs(5); +const MAX_CONCURRENT_FLUSHES: usize = 5; + /// Errors that can happen when configuring DNS on Windows. #[derive(err_derive::Error, Debug)] #[error(no_from)] @@ -26,29 +35,31 @@ pub enum Error { #[error(display = "Call to flush DNS cache failed")] FlushCache, + /// Too many flush attempts in progress. + #[error(display = "Too many flush attempts in progress")] + TooManyFlushAttempts, + /// Flushing the DNS cache timed out. #[error(display = "Timeout while flushing DNS cache")] Timeout, } pub fn flush_resolver_cache() -> Result<(), Error> { - let (tx, rx) = mpsc::channel(); - - std::thread::spawn(move || { - if tx.send(flush_resolver_cache_inner()).is_err() { - log::warn!("Flushing DNS cache completed (delayed)"); - } - }); + DNSAPI_HANDLE + .get_or_try_init(|| DnsApi::new())? + .flush_cache() +} - match rx.recv_timeout(FLUSH_TIMEOUT) { - Ok(result) => result, - // TODO: Can this be a cancelled safely? - Err(_timeout_err) => Err(Error::Timeout), - } +struct DnsApi { + in_flight_flush_count: Arc<AtomicUsize>, + flush_fn: FlushResolverCacheFn, } -fn flush_resolver_cache_inner() -> Result<(), Error> { - let flush_cache = FLUSH_RESOLVER_CACHE.get_or_try_init(|| { +unsafe impl Send for DnsApi {} +unsafe impl Sync for DnsApi {} + +impl DnsApi { + fn new() -> Result<Self, Error> { let handle = unsafe { LoadLibraryExW( b"d\0n\0s\0a\0p\0i\0.\0d\0l\0l\0\0\0" as *const u8 as *const u16, @@ -59,18 +70,55 @@ fn flush_resolver_cache_inner() -> Result<(), Error> { if handle.is_null() { return Err(Error::LoadDll(io::Error::last_os_error())); } - let function_addr = + + let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const _ as *const i8) }; - if function_addr.is_null() { + if flush_fn.is_null() { let error = io::Error::last_os_error(); unsafe { FreeLibrary(handle) }; return Err(Error::GetFunction(error)); } - Ok(unsafe { *(&function_addr as *const _ as *const _) }) - })?; - if unsafe { flush_cache() } == FALSE { - return Err(Error::FlushCache); + Ok(DnsApi { + in_flight_flush_count: Arc::new(AtomicUsize::new(0)), + flush_fn: unsafe { *(&flush_fn as *const _ as *const _) }, + }) + } + + fn flush_cache(&self) -> Result<(), Error> { + if self + .in_flight_flush_count + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| { + if val >= MAX_CONCURRENT_FLUSHES { + return None; + } + Some(val + 1) + }) + .is_err() + { + return Err(Error::TooManyFlushAttempts); + } + + let (tx, rx) = mpsc::channel(); + let flush_count = self.in_flight_flush_count.clone(); + + let flush_fn = self.flush_fn; + + std::thread::spawn(move || { + let result = if unsafe { (flush_fn)() } == FALSE { + Err(Error::FlushCache) + } else { + log::debug!("Flushed DNS resolver cache"); + Ok(()) + }; + let _ = tx.send(result); + + flush_count.fetch_sub(1, Ordering::SeqCst); + }); + + match rx.recv_timeout(FLUSH_TIMEOUT) { + Ok(result) => result, + Err(_timeout_err) => Err(Error::Timeout), + } } - Ok(()) } diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs index 2cb9b74f0b..85d964aca1 100644 --- a/talpid-core/src/dns/windows/mod.rs +++ b/talpid-core/src/dns/windows/mod.rs @@ -23,12 +23,8 @@ pub enum Error { InterfaceGuidError(#[error(source)] io::Error), /// Failure to flush DNS cache. - #[error(display = "Failed to execute ipconfig")] - ExecuteIpconfigError(#[error(source)] io::Error), - - /// Failure to flush DNS cache. #[error(display = "Failed to flush DNS resolver cache")] - FlushResolverCacheError(dnsapi::Error), + FlushResolverCacheError(#[error(source)] dnsapi::Error), /// Failed to update DNS servers for interface. #[error(display = "Failed to update interface DNS servers")] |
