diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-11-19 10:25:44 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-11-22 13:38:16 +0100 |
| commit | f4db85b3a552f60d2454bfa69912c7ced51b41b1 (patch) | |
| tree | c932bc8d75ea3ca6d95dfdd0c3925a171cea9d07 /mullvad-api/src | |
| parent | 8ababf0f77b23f7245a1aed3d8c8c4a5e3c06192 (diff) | |
| download | mullvadvpn-f4db85b3a552f60d2454bfa69912c7ced51b41b1.tar.xz mullvadvpn-f4db85b3a552f60d2454bfa69912c7ced51b41b1.zip | |
Add non-blocking DNS resolver for Android API requests
Diffstat (limited to 'mullvad-api/src')
| -rw-r--r-- | mullvad-api/src/bin/relay_list.rs | 6 | ||||
| -rw-r--r-- | mullvad-api/src/https_client_with_sni.rs | 36 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 52 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 3 |
4 files changed, 76 insertions, 21 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index def32303ea..22190abd63 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,13 +2,15 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; +use mullvad_api::{ + proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy, +}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver) .expect("Failed to load runtime"); let relay_list_request = diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 898927513f..09e198ca3b 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -2,17 +2,14 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, tls_stream::TlsStream, - AddressCache, + AddressCache, DnsResolver, }; use futures::{channel::mpsc, future, pin_mut, StreamExt}; #[cfg(target_os = "android")] use futures::{channel::oneshot, sink::SinkExt}; use http::uri::Scheme; use hyper::Uri; -use hyper_util::{ - client::legacy::connect::dns::{GaiResolver, Name}, - rt::TokioIo, -}; +use hyper_util::rt::TokioIo; use mullvad_encrypted_dns_proxy::{ config::ProxyConfig as EncryptedDNSConfig, Forwarder as EncryptedDNSForwarder, }; @@ -291,6 +288,7 @@ pub struct HttpsConnectorWithSni { sni_hostname: Option<String>, address_cache: AddressCache, abort_notify: Arc<tokio::sync::Notify>, + dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } @@ -307,6 +305,7 @@ impl HttpsConnectorWithSni { pub fn new( sni_hostname: Option<String>, address_cache: AddressCache, + dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> (Self, HttpsConnectorWithSniHandle) { let (tx, mut rx) = mpsc::unbounded(); @@ -355,6 +354,7 @@ impl HttpsConnectorWithSni { sni_hostname, address_cache, abort_notify, + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, }, @@ -388,7 +388,14 @@ impl HttpsConnectorWithSni { .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? } - async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> { + /// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used. + /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback. + /// If the URI contains a port, then that port will be used. + async fn resolve_address( + address_cache: AddressCache, + dns_resolver: &dyn DnsResolver, + uri: Uri, + ) -> io::Result<SocketAddr> { const DEFAULT_PORT: u16 = 443; let hostname = uri.host().ok_or_else(|| { @@ -408,19 +415,13 @@ impl HttpsConnectorWithSni { )); } - // Use getaddrinfo as a fallback + // Use DNS resolution as fallback // - let mut addrs = GaiResolver::new() - .call( - Name::from_str(hostname) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?, - ) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let addrs = dns_resolver.resolve(hostname.to_owned()).await?; let addr = addrs - .next() + .first() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; - Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT))) + Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT))) } } @@ -455,6 +456,7 @@ impl Service<Uri> for HttpsConnectorWithSni { #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); let address_cache = self.address_cache.clone(); + let dns_resolver = self.dns_resolver.clone(); let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { @@ -465,7 +467,7 @@ impl Service<Uri> for HttpsConnectorWithSni { } let hostname = sni_hostname?; - let addr = Self::resolve_address(address_cache, uri).await?; + let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?; // Loop until we have established a connection. This starts over if a new endpoint // is selected while connecting. diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 6b3ac3c951..1f47d600b3 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -1,4 +1,5 @@ #![allow(rustdoc::private_intra_doc_links)] +use async_trait::async_trait; #[cfg(target_os = "android")] use futures::channel::mpsc; #[cfg(target_os = "android")] @@ -12,10 +13,11 @@ use std::{ cell::Cell, collections::BTreeMap, future::Future, + io, net::{IpAddr, Ipv4Addr, SocketAddr}, ops::Deref, path::Path, - sync::OnceLock, + sync::{Arc, OnceLock}, }; use talpid_types::ErrorExt; @@ -304,11 +306,43 @@ impl ApiEndpoint { } } +#[async_trait] +pub trait DnsResolver: 'static + Send + Sync { + async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>; +} + +/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`). +pub struct DefaultDnsResolver; + +#[async_trait] +impl DnsResolver for DefaultDnsResolver { + async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> { + use std::net::ToSocketAddrs; + // Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which + // blocks and either has no timeout or a very long one. + let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs()) + .await + .expect("DNS task panicked")?; + Ok(addrs.map(|addr| addr.ip()).collect()) + } +} + +/// DNS resolver that always returns no results +pub struct NullDnsResolver; + +#[async_trait] +impl DnsResolver for NullDnsResolver { + async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> { + Ok(vec![]) + } +} + /// A type that helps with the creation of API connections. pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, + dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } @@ -323,13 +357,20 @@ pub enum Error { #[error("API availability check failed")] ApiCheckError(#[from] availability::Error), + + #[error("DNS resolution error")] + ResolutionFailed(#[from] std::io::Error), } impl Runtime { /// Create a new `Runtime`. - pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { + pub fn new( + handle: tokio::runtime::Handle, + dns_resolver: impl DnsResolver, + ) -> Result<Self, Error> { Self::new_inner( handle, + dns_resolver, #[cfg(target_os = "android")] None, ) @@ -346,12 +387,14 @@ impl Runtime { fn new_inner( handle: tokio::runtime::Handle, + dns_resolver: impl DnsResolver, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { Ok(Runtime { handle, address_cache: AddressCache::new(None)?, api_availability: ApiAvailability::default(), + dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -360,15 +403,18 @@ impl Runtime { /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. pub async fn with_cache( + dns_resolver: impl DnsResolver, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { let handle = tokio::runtime::Handle::current(); + #[cfg(feature = "api-override")] if API.disable_address_cache { return Self::new_inner( handle, + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, ); @@ -402,6 +448,7 @@ impl Runtime { handle, address_cache, api_availability, + dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -419,6 +466,7 @@ impl Runtime { self.api_availability.clone(), self.address_cache.clone(), connection_mode_provider, + self.dns_resolver.clone(), #[cfg(target_os = "android")] socket_bypass_tx, ) diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index f6098c3b49..54a32f63f9 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -6,6 +6,7 @@ use crate::{ availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, + DnsResolver, }; use futures::{ channel::{mpsc, oneshot}, @@ -154,11 +155,13 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { api_availability: ApiAvailability, address_cache: AddressCache, connection_mode_provider: T, + dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( sni_hostname, address_cache.clone(), + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), ); |
