summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-11-19 10:25:44 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-11-22 13:38:16 +0100
commitf4db85b3a552f60d2454bfa69912c7ced51b41b1 (patch)
treec932bc8d75ea3ca6d95dfdd0c3925a171cea9d07 /mullvad-api/src
parent8ababf0f77b23f7245a1aed3d8c8c4a5e3c06192 (diff)
downloadmullvadvpn-f4db85b3a552f60d2454bfa69912c7ced51b41b1.tar.xz
mullvadvpn-f4db85b3a552f60d2454bfa69912c7ced51b41b1.zip
Add non-blocking DNS resolver for Android API requests
Diffstat (limited to 'mullvad-api/src')
-rw-r--r--mullvad-api/src/bin/relay_list.rs6
-rw-r--r--mullvad-api/src/https_client_with_sni.rs36
-rw-r--r--mullvad-api/src/lib.rs52
-rw-r--r--mullvad-api/src/rest.rs3
4 files changed, 76 insertions, 21 deletions
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index def32303ea..22190abd63 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -2,13 +2,15 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
-use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
+use mullvad_api::{
+ proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
+};
use std::process;
use talpid_types::ErrorExt;
#[tokio::main]
async fn main() {
- let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
+ let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
.expect("Failed to load runtime");
let relay_list_request =
diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs
index 898927513f..09e198ca3b 100644
--- a/mullvad-api/src/https_client_with_sni.rs
+++ b/mullvad-api/src/https_client_with_sni.rs
@@ -2,17 +2,14 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
- AddressCache,
+ AddressCache, DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::Uri;
-use hyper_util::{
- client::legacy::connect::dns::{GaiResolver, Name},
- rt::TokioIo,
-};
+use hyper_util::rt::TokioIo;
use mullvad_encrypted_dns_proxy::{
config::ProxyConfig as EncryptedDNSConfig, Forwarder as EncryptedDNSForwarder,
};
@@ -291,6 +288,7 @@ pub struct HttpsConnectorWithSni {
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
+ dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -307,6 +305,7 @@ impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
+ dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
let (tx, mut rx) = mpsc::unbounded();
@@ -355,6 +354,7 @@ impl HttpsConnectorWithSni {
sni_hostname,
address_cache,
abort_notify,
+ dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
},
@@ -388,7 +388,14 @@ impl HttpsConnectorWithSni {
.map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
}
- async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> {
+ /// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
+ /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
+ /// If the URI contains a port, then that port will be used.
+ async fn resolve_address(
+ address_cache: AddressCache,
+ dns_resolver: &dyn DnsResolver,
+ uri: Uri,
+ ) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;
let hostname = uri.host().ok_or_else(|| {
@@ -408,19 +415,13 @@ impl HttpsConnectorWithSni {
));
}
- // Use getaddrinfo as a fallback
+ // Use DNS resolution as fallback
//
- let mut addrs = GaiResolver::new()
- .call(
- Name::from_str(hostname)
- .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
- )
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+ let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
- .next()
+ .first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
- Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT)))
+ Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
}
}
@@ -455,6 +456,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
+ let dns_resolver = self.dns_resolver.clone();
let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
@@ -465,7 +467,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
let hostname = sni_hostname?;
- let addr = Self::resolve_address(address_cache, uri).await?;
+ let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 6b3ac3c951..1f47d600b3 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -1,4 +1,5 @@
#![allow(rustdoc::private_intra_doc_links)]
+use async_trait::async_trait;
#[cfg(target_os = "android")]
use futures::channel::mpsc;
#[cfg(target_os = "android")]
@@ -12,10 +13,11 @@ use std::{
cell::Cell,
collections::BTreeMap,
future::Future,
+ io,
net::{IpAddr, Ipv4Addr, SocketAddr},
ops::Deref,
path::Path,
- sync::OnceLock,
+ sync::{Arc, OnceLock},
};
use talpid_types::ErrorExt;
@@ -304,11 +306,43 @@ impl ApiEndpoint {
}
}
+#[async_trait]
+pub trait DnsResolver: 'static + Send + Sync {
+ async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
+}
+
+/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
+pub struct DefaultDnsResolver;
+
+#[async_trait]
+impl DnsResolver for DefaultDnsResolver {
+ async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
+ use std::net::ToSocketAddrs;
+ // Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
+ // blocks and either has no timeout or a very long one.
+ let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
+ .await
+ .expect("DNS task panicked")?;
+ Ok(addrs.map(|addr| addr.ip()).collect())
+ }
+}
+
+/// DNS resolver that always returns no results
+pub struct NullDnsResolver;
+
+#[async_trait]
+impl DnsResolver for NullDnsResolver {
+ async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
+ Ok(vec![])
+ }
+}
+
/// A type that helps with the creation of API connections.
pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
+ dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -323,13 +357,20 @@ pub enum Error {
#[error("API availability check failed")]
ApiCheckError(#[from] availability::Error),
+
+ #[error("DNS resolution error")]
+ ResolutionFailed(#[from] std::io::Error),
}
impl Runtime {
/// Create a new `Runtime`.
- pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
+ pub fn new(
+ handle: tokio::runtime::Handle,
+ dns_resolver: impl DnsResolver,
+ ) -> Result<Self, Error> {
Self::new_inner(
handle,
+ dns_resolver,
#[cfg(target_os = "android")]
None,
)
@@ -346,12 +387,14 @@ impl Runtime {
fn new_inner(
handle: tokio::runtime::Handle,
+ dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
api_availability: ApiAvailability::default(),
+ dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
@@ -360,15 +403,18 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
+ dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
let handle = tokio::runtime::Handle::current();
+
#[cfg(feature = "api-override")]
if API.disable_address_cache {
return Self::new_inner(
handle,
+ dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
@@ -402,6 +448,7 @@ impl Runtime {
handle,
address_cache,
api_availability,
+ dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
@@ -419,6 +466,7 @@ impl Runtime {
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
+ self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index f6098c3b49..54a32f63f9 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -6,6 +6,7 @@ use crate::{
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
+ DnsResolver,
};
use futures::{
channel::{mpsc, oneshot},
@@ -154,11 +155,13 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
+ dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
address_cache.clone(),
+ dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
);