diff options
| author | David Lönnhager <david.l@mullvad.net> | 2023-08-07 15:45:27 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2023-08-08 11:10:06 +0200 |
| commit | 147dc31583c514bbc1d0bcc2b128e2a9b788e6a7 (patch) | |
| tree | 8e76096ca52be7e9eb86606e88e351d51d801d51 | |
| parent | 79bf21804b0ddf7783910903fafb12d9c4f5da4a (diff) | |
| download | mullvadvpn-147dc31583c514bbc1d0bcc2b128e2a9b788e6a7.tar.xz mullvadvpn-147dc31583c514bbc1d0bcc2b128e2a9b788e6a7.zip | |
Simplify dnsapi module
| -rw-r--r-- | talpid-core/src/dns/windows/dnsapi.rs | 61 |
1 files changed, 14 insertions, 47 deletions
diff --git a/talpid-core/src/dns/windows/dnsapi.rs b/talpid-core/src/dns/windows/dnsapi.rs index 9554c874db..b9458a70b9 100644 --- a/talpid-core/src/dns/windows/dnsapi.rs +++ b/talpid-core/src/dns/windows/dnsapi.rs @@ -1,26 +1,14 @@ -use once_cell::sync::OnceCell; use std::{ - io, sync::{ atomic::{AtomicUsize, Ordering}, - mpsc, Arc, + mpsc, Arc, OnceLock, }, time::{Duration, Instant}, }; -use windows_sys::{ - w, - Win32::{ - Foundation::BOOL, - System::LibraryLoader::{ - FreeLibrary, GetProcAddress, LoadLibraryExW, LOAD_LIBRARY_SEARCH_SYSTEM32, - }, - }, -}; +use windows_sys::Win32::Foundation::BOOL; -type FlushResolverCacheFn = unsafe extern "stdcall" fn() -> BOOL; - -static DNSAPI_HANDLE: OnceCell<DnsApi> = OnceCell::new(); static FLUSH_TIMEOUT: Duration = Duration::from_secs(5); +static DNSAPI_HANDLE: OnceLock<DnsApi> = OnceLock::new(); const MAX_CONCURRENT_FLUSHES: usize = 5; @@ -28,14 +16,6 @@ const MAX_CONCURRENT_FLUSHES: usize = 5; #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { - /// Failed to load dnsapi.dll. - #[error(display = "Failed to load dnsapi.dll")] - LoadDll(#[error(source)] io::Error), - - /// Failed to obtain exported function. - #[error(display = "Failed to obtain flush function")] - GetFunction(#[error(source)] io::Error), - /// Failed to flush the DNS cache. #[error(display = "Call to flush DNS cache failed")] FlushCache, @@ -50,35 +30,18 @@ pub enum Error { } pub fn flush_resolver_cache() -> Result<(), Error> { - DNSAPI_HANDLE.get_or_try_init(DnsApi::new)?.flush_cache() + DNSAPI_HANDLE.get_or_init(DnsApi::new).flush_cache() } struct DnsApi { in_flight_flush_count: Arc<AtomicUsize>, - flush_fn: FlushResolverCacheFn, } -unsafe impl Send for DnsApi {} -unsafe impl Sync for DnsApi {} - impl DnsApi { - fn new() -> Result<Self, Error> { - let handle = unsafe { LoadLibraryExW(w!("dnsapi.dll"), 0, LOAD_LIBRARY_SEARCH_SYSTEM32) }; - if handle == 0 { - return Err(Error::LoadDll(io::Error::last_os_error())); - } - - let flush_fn = unsafe { GetProcAddress(handle, b"DnsFlushResolverCache\0" as *const u8) }; - let flush_fn = flush_fn.ok_or_else(|| { - let error = io::Error::last_os_error(); - unsafe { FreeLibrary(handle) }; - Error::GetFunction(error) - })?; - - Ok(DnsApi { + fn new() -> Self { + DnsApi { in_flight_flush_count: Arc::new(AtomicUsize::new(0)), - flush_fn: unsafe { *(&flush_fn as *const _ as *const _) }, - }) + } } fn flush_cache(&self) -> Result<(), Error> { @@ -97,12 +60,10 @@ impl DnsApi { let (tx, rx) = mpsc::channel(); let flush_count = self.in_flight_flush_count.clone(); - let flush_fn = self.flush_fn; - std::thread::spawn(move || { let begin = Instant::now(); - let result = if unsafe { (flush_fn)() } != 0 { + let result = if unsafe { (DnsFlushResolverCache)() } != 0 { let elapsed = begin.elapsed(); if elapsed >= FLUSH_TIMEOUT { log::warn!( @@ -127,3 +88,9 @@ impl DnsApi { } } } + +#[link(name = "dnsapi")] +extern "system" { + // Flushes the DNS resolver cache + pub fn DnsFlushResolverCache() -> BOOL; +} |
