summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-api/src/address_cache.rs20
-rw-r--r--mullvad-api/src/bin/relay_list.rs6
-rw-r--r--mullvad-api/src/https_client_with_sni.rs34
-rw-r--r--mullvad-api/src/lib.rs66
-rw-r--r--mullvad-api/src/rest.rs3
-rw-r--r--mullvad-daemon/src/android_dns.rs6
-rw-r--r--mullvad-daemon/src/lib.rs15
-rw-r--r--mullvad-problem-report/src/lib.rs3
-rw-r--r--mullvad-setup/src/main.rs4
-rw-r--r--test/test-manager/src/tests/account.rs7
10 files changed, 74 insertions, 90 deletions
diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs
index a48db1a0e2..0898f8da1f 100644
--- a/mullvad-api/src/address_cache.rs
+++ b/mullvad-api/src/address_cache.rs
@@ -1,6 +1,8 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use super::API;
+use crate::DnsResolver;
+use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
@@ -23,6 +25,17 @@ pub enum Error {
Write(#[source] io::Error),
}
+/// A DNS resolver which resolves using `AddressCache`.
+#[async_trait]
+impl DnsResolver for AddressCache {
+ async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
+ self.resolve_hostname(&host)
+ .await
+ .map(|addr| vec![addr])
+ .ok_or(io::Error::other("host does not match API host"))
+ }
+}
+
#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
@@ -42,7 +55,10 @@ impl AddressCache {
/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
- Ok(Self::new_inner(read_address_file(read_path).await?, write_path))
+ Ok(Self::new_inner(
+ read_address_file(read_path).await?,
+ write_path,
+ ))
}
fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
@@ -56,7 +72,7 @@ impl AddressCache {
}
/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
- pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
+ async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index 22190abd63..def32303ea 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
-use mullvad_api::{
- proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
-};
+use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;
#[tokio::main]
async fn main() {
- let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
+ let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
let relay_list_request =
diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs
index 09ce493431..59d1f7fac7 100644
--- a/mullvad-api/src/https_client_with_sni.rs
+++ b/mullvad-api/src/https_client_with_sni.rs
@@ -2,7 +2,7 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
- AddressCache, DnsResolver,
+ DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
@@ -287,7 +287,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
- address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
@@ -305,7 +304,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
- address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
@@ -353,7 +351,6 @@ impl HttpsConnectorWithSni {
HttpsConnectorWithSni {
inner,
sni_hostname,
- address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
@@ -390,13 +387,9 @@ impl HttpsConnectorWithSni {
}
/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
- /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
+ /// Otherwise `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
- async fn resolve_address(
- address_cache: AddressCache,
- dns_resolver: &dyn DnsResolver,
- uri: Uri,
- ) -> io::Result<SocketAddr> {
+ async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;
let hostname = uri.host().ok_or_else(|| {
@@ -407,22 +400,16 @@ impl HttpsConnectorWithSni {
return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT)));
}
- // Preferentially, use cached address.
- //
- if let Some(addr) = address_cache.resolve_hostname(hostname).await {
- return Ok(SocketAddr::new(
- addr.ip(),
- port.unwrap_or_else(|| addr.port()),
- ));
- }
-
- // Use DNS resolution as fallback
- //
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
- Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
+ let port = match (addr.port(), port) {
+ (_, Some(port)) => port,
+ (0, None) => DEFAULT_PORT,
+ (addr_port, None) => addr_port,
+ };
+ Ok(SocketAddr::new(addr.ip(), port))
}
}
@@ -456,7 +443,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
- let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();
let fut = async move {
@@ -468,7 +454,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
let hostname = sni_hostname?;
- let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
+ let addr = Self::resolve_address(&*dns_resolver, uri).await?;
// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 3f67709242..5d46534c44 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -308,7 +308,7 @@ impl ApiEndpoint {
#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
- async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
+ async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>;
}
/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
@@ -316,14 +316,14 @@ pub struct DefaultDnsResolver;
#[async_trait]
impl DnsResolver for DefaultDnsResolver {
- async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
+ async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
- Ok(addrs.map(|addr| addr.ip()).collect())
+ Ok(addrs.collect())
}
}
@@ -332,7 +332,7 @@ pub struct NullDnsResolver;
#[async_trait]
impl DnsResolver for NullDnsResolver {
- async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
+ async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> {
Ok(vec![])
}
}
@@ -342,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
- dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -364,13 +363,9 @@ pub enum Error {
impl Runtime {
/// Create a new `Runtime`.
- pub fn new(
- handle: tokio::runtime::Handle,
- dns_resolver: impl DnsResolver,
- ) -> Result<Self, Error> {
+ pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
- dns_resolver,
#[cfg(target_os = "android")]
None,
)
@@ -381,21 +376,18 @@ impl Runtime {
Runtime {
handle,
address_cache: AddressCache::with_static_addr(address),
- dns_resolver: Arc::new(NullDnsResolver),
api_availability: ApiAvailability::default(),
}
}
fn new_inner(
handle: tokio::runtime::Handle,
- dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
- address_cache: AddressCache::new(None)?,
+ address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
- dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
@@ -404,7 +396,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
- dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
@@ -415,7 +406,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
- dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
@@ -439,7 +429,7 @@ impl Runtime {
)
);
}
- AddressCache::new(write_file)?
+ AddressCache::new(write_file)
}
};
@@ -449,30 +439,11 @@ impl Runtime {
handle,
address_cache,
api_availability,
- dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}
- /// Creates a new request service and returns a handle to it.
- fn new_request_service<T: ConnectionModeProvider + 'static>(
- &self,
- sni_hostname: Option<String>,
- connection_mode_provider: T,
- #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> rest::RequestServiceHandle {
- rest::RequestService::spawn(
- sni_hostname,
- self.api_availability.clone(),
- self.address_cache.clone(),
- connection_mode_provider,
- self.dns_resolver.clone(),
- #[cfg(target_os = "android")]
- socket_bypass_tx,
- )
- }
-
/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
@@ -481,6 +452,7 @@ impl Runtime {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
+ Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -495,6 +467,7 @@ impl Runtime {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
+ Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -505,15 +478,34 @@ impl Runtime {
}
/// Returns a new request service handle
- pub fn rest_handle(&self) -> rest::RequestServiceHandle {
+ pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
+ Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}
+ /// Creates a new request service and returns a handle to it.
+ fn new_request_service<T: ConnectionModeProvider + 'static>(
+ &self,
+ sni_hostname: Option<String>,
+ connection_mode_provider: T,
+ dns_resolver: Arc<dyn DnsResolver>,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> rest::RequestServiceHandle {
+ rest::RequestService::spawn(
+ sni_hostname,
+ self.api_availability.clone(),
+ connection_mode_provider,
+ dns_resolver,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ )
+ }
+
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 54a32f63f9..5c28705158 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -2,7 +2,6 @@
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
access::AccessTokenStore,
- address_cache::AddressCache,
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
@@ -153,14 +152,12 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailability,
- address_cache: AddressCache,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
- address_cache.clone(),
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
diff --git a/mullvad-daemon/src/android_dns.rs b/mullvad-daemon/src/android_dns.rs
index ed44f5dc8c..5cbc9c271a 100644
--- a/mullvad-daemon/src/android_dns.rs
+++ b/mullvad-daemon/src/android_dns.rs
@@ -7,7 +7,7 @@ use hickory_resolver::{
TokioAsyncResolver,
};
use mullvad_api::DnsResolver;
-use std::{io, net::IpAddr};
+use std::{io, net::SocketAddr};
use talpid_core::connectivity_listener::ConnectivityListener;
/// A non-blocking DNS resolver. The default resolver uses `getaddrinfo`, which often prevents the
@@ -27,7 +27,7 @@ impl AndroidDnsResolver {
#[async_trait]
impl DnsResolver for AndroidDnsResolver {
- async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
+ async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
let ips = self
.connectivity_listener
.current_dns_servers()
@@ -44,6 +44,6 @@ impl DnsResolver for AndroidDnsResolver {
.await
.map_err(|err| io::Error::other(format!("lookup_ip failed: {err}")))?;
- Ok(lookup.into_iter().collect())
+ Ok(lookup.into_iter().map(|ip| (ip, 0).into()).collect())
}
}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 79f08e8a0f..4f98c73d01 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -39,8 +39,6 @@ use futures::{
};
use geoip::GeoIpHandler;
use management_interface::ManagementInterfaceServer;
-#[cfg(not(target_os = "android"))]
-use mullvad_api::DefaultDnsResolver;
use mullvad_relay_selector::{RelaySelector, SelectorConfig};
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
@@ -622,10 +620,6 @@ impl Daemon {
mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await;
let api_runtime = mullvad_api::Runtime::with_cache(
- #[cfg(target_os = "android")]
- android_dns::AndroidDnsResolver::new(connectivity_listener.clone()),
- #[cfg(not(target_os = "android"))]
- DefaultDnsResolver,
&cache_dir,
true,
#[cfg(target_os = "android")]
@@ -798,7 +792,7 @@ impl Daemon {
#[cfg(target_os = "android")]
android_context,
#[cfg(target_os = "android")]
- connectivity_listener,
+ connectivity_listener.clone(),
#[cfg(target_os = "linux")]
tunnel_state_machine::LinuxNetworkingIdentifiers {
fwmark: mullvad_types::TUNNEL_FWMARK,
@@ -835,7 +829,12 @@ impl Daemon {
relay_list_updater.update().await;
let location_handler = GeoIpHandler::new(
- api_runtime.rest_handle(),
+ api_runtime.rest_handle(
+ #[cfg(not(target_os = "android"))]
+ mullvad_api::DefaultDnsResolver,
+ #[cfg(target_os = "android")]
+ android_dns::AndroidDnsResolver::new(connectivity_listener),
+ ),
internal_event_tx.clone().to_specialized_sender(),
);
diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs
index 91b790e5f6..270de55f95 100644
--- a/mullvad-problem-report/src/lib.rs
+++ b/mullvad-problem-report/src/lib.rs
@@ -1,4 +1,4 @@
-use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver};
+use mullvad_api::proxy::ApiConnectionMode;
use regex::Regex;
use std::{
borrow::Cow,
@@ -292,7 +292,6 @@ async fn send_problem_report_inner(
) -> Result<(), Error> {
let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect);
let api_runtime = mullvad_api::Runtime::with_cache(
- NullDnsResolver,
cache_dir,
false,
#[cfg(target_os = "android")]
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index 4a444aa63c..d3dfd6de8a 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -1,7 +1,7 @@
use clap::Parser;
use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration};
-use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver, DEVICE_NOT_FOUND};
+use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND};
use mullvad_management_interface::MullvadProxyClient;
use mullvad_types::version::ParsedAppVersion;
use talpid_core::firewall::{self, Firewall};
@@ -152,7 +152,7 @@ async fn remove_device() -> Result<(), Error> {
.await
.map_err(Error::ReadDeviceCacheError)?;
if let Some(device) = state.into_device() {
- let api_runtime = mullvad_api::Runtime::with_cache(NullDnsResolver, &cache_path, false)
+ let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false)
.await
.map_err(Error::RpcInitializationError)?;
diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs
index 45151070a9..7fe14ae58e 100644
--- a/test/test-manager/src/tests/account.rs
+++ b/test/test-manager/src/tests/account.rs
@@ -295,11 +295,8 @@ pub async fn new_device_client() -> anyhow::Result<DevicesProxy> {
..api_endpoint
});
- let api = mullvad_api::Runtime::new(
- tokio::runtime::Handle::current(),
- mullvad_api::DefaultDnsResolver,
- )
- .expect("failed to create api runtime");
+ let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
+ .expect("failed to create api runtime");
let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider());
Ok(DevicesProxy::new(rest_handle))
}