summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-07-21 11:42:00 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-08-10 14:19:24 +0200
commitfaa314e90bb9b3333fdd32c510123a3e0b774882 (patch)
treef1d419f530427e9a2cbbe49daa1d00e7f2b89151
parentd7223f3ae2bfefb267317a8fe7e7c51524f63f94 (diff)
downloadmullvadvpn-faa314e90bb9b3333fdd32c510123a3e0b774882.tar.xz
mullvadvpn-faa314e90bb9b3333fdd32c510123a3e0b774882.zip
Limit number of concurrent flush attempts
-rw-r--r--talpid-core/src/dns/windows/dnsapi.rs94
-rw-r--r--talpid-core/src/dns/windows/mod.rs6
2 files changed, 72 insertions, 28 deletions
diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs
index 4a0cf636e5..2d428468ff 100644
--- a/talpid-core/src/dns/windows/dnsapi.rs
+++ b/talpid-core/src/dns/windows/dnsapi.rs
@@ -1,5 +1,12 @@
use once_cell::sync::OnceCell;
-use std::{io, ptr, sync::mpsc, time::Duration};
+use std::{
+ io, ptr,
+ sync::{
+ atomic::{AtomicUsize, Ordering},
+ mpsc, Arc,
+ },
+ time::Duration,
+};
use winapi::{
shared::minwindef::{BOOL, FALSE},
um::libloaderapi::{FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32},
@@ -7,9 +14,11 @@ use winapi::{
type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL;
-static FLUSH_RESOLVER_CACHE: OnceCell<FlushResolverCacheFn> = OnceCell::new();
+static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new();
static FLUSH_TIMEOUT: Duration = Duration::from_secs(5);
+const MAX_CONCURRENT_FLUSHES: usize = 5;
+
/// Errors that can happen when configuring DNS on Windows.
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
@@ -26,29 +35,31 @@ pub enum Error {
#[error(display = "Call to flush DNS cache failed")]
FlushCache,
+ /// Too many flush attempts in progress.
+ #[error(display = "Too many flush attempts in progress")]
+ TooManyFlushAttempts,
+
/// Flushing the DNS cache timed out.
#[error(display = "Timeout while flushing DNS cache")]
Timeout,
}
pub fn flush_resolver_cache() -> Result<(), Error> {
- let (tx, rx) = mpsc::channel();
-
- std::thread::spawn(move || {
- if tx.send(flush_resolver_cache_inner()).is_err() {
- log::warn!("Flushing DNS cache completed (delayed)");
- }
- });
+ DNSAPI_HANDLE
+ .get_or_try_init(|| DnsApi::new())?
+ .flush_cache()
+}
- match rx.recv_timeout(FLUSH_TIMEOUT) {
- Ok(result) => result,
- // TODO: Can this be a cancelled safely?
- Err(_timeout_err) => Err(Error::Timeout),
- }
+struct DnsApi {
+ in_flight_flush_count: Arc<AtomicUsize>,
+ flush_fn: FlushResolverCacheFn,
}
-fn flush_resolver_cache_inner() -> Result<(), Error> {
- let flush_cache = FLUSH_RESOLVER_CACHE.get_or_try_init(|| {
+unsafe impl Send for DnsApi {}
+unsafe impl Sync for DnsApi {}
+
+impl DnsApi {
+ fn new() -> Result<Self, Error> {
let handle = unsafe {
LoadLibraryExW(
b"d\0n\0s\0a\0p\0i\0.\0d\0l\0l\0\0\0" as *const u8 as *const u16,
@@ -59,18 +70,55 @@ fn flush_resolver_cache_inner() -> Result<(), Error> {
if handle.is_null() {
return Err(Error::LoadDll(io::Error::last_os_error()));
}
- let function_addr =
+
+ let flush_fn =
unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const _ as *const i8) };
- if function_addr.is_null() {
+ if flush_fn.is_null() {
let error = io::Error::last_os_error();
unsafe { FreeLibrary(handle) };
return Err(Error::GetFunction(error));
}
- Ok(unsafe { *(&function_addr as *const _ as *const _) })
- })?;
- if unsafe { flush_cache() } == FALSE {
- return Err(Error::FlushCache);
+ Ok(DnsApi {
+ in_flight_flush_count: Arc::new(AtomicUsize::new(0)),
+ flush_fn: unsafe { *(&flush_fn as *const _ as *const _) },
+ })
+ }
+
+ fn flush_cache(&self) -> Result<(), Error> {
+ if self
+ .in_flight_flush_count
+ .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
+ if val >= MAX_CONCURRENT_FLUSHES {
+ return None;
+ }
+ Some(val + 1)
+ })
+ .is_err()
+ {
+ return Err(Error::TooManyFlushAttempts);
+ }
+
+ let (tx, rx) = mpsc::channel();
+ let flush_count = self.in_flight_flush_count.clone();
+
+ let flush_fn = self.flush_fn;
+
+ std::thread::spawn(move || {
+ let result = if unsafe { (flush_fn)() } == FALSE {
+ Err(Error::FlushCache)
+ } else {
+ log::debug!("Flushed DNS resolver cache");
+ Ok(())
+ };
+ let _ = tx.send(result);
+
+ flush_count.fetch_sub(1, Ordering::SeqCst);
+ });
+
+ match rx.recv_timeout(FLUSH_TIMEOUT) {
+ Ok(result) => result,
+ Err(_timeout_err) => Err(Error::Timeout),
+ }
}
- Ok(())
}
diff --git a/talpid-core/src/dns/windows/mod.rs b/talpid-core/src/dns/windows/mod.rs
index 2cb9b74f0b..85d964aca1 100644
--- a/talpid-core/src/dns/windows/mod.rs
+++ b/talpid-core/src/dns/windows/mod.rs
@@ -23,12 +23,8 @@ pub enum Error {
InterfaceGuidError(#[error(source)] io::Error),
/// Failure to flush DNS cache.
- #[error(display = "Failed to execute ipconfig")]
- ExecuteIpconfigError(#[error(source)] io::Error),
-
- /// Failure to flush DNS cache.
#[error(display = "Failed to flush DNS resolver cache")]
- FlushResolverCacheError(dnsapi::Error),
+ FlushResolverCacheError(#[error(source)] dnsapi::Error),
/// Failed to update DNS servers for interface.
#[error(display = "Failed to update interface DNS servers")]