summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2024-11-27 15:09:05 +0100
committerDavid Lönnhager <david.l@mullvad.net>2024-12-02 16:00:38 +0100
commit94ce8fb753f7441243d9281416632ce7ed4b6cd6 (patch)
tree8fa86e98b4b50ecd70ebf6468d222cad8e6cf552 /mullvad-api/src
parente07c12a5f14a11051fa086c97cc22413e431a1c8 (diff)
downloadmullvadvpn-94ce8fb753f7441243d9281416632ce7ed4b6cd6.tar.xz
mullvadvpn-94ce8fb753f7441243d9281416632ce7ed4b6cd6.zip
Remove DNS fallback except for conncheck
Diffstat (limited to 'mullvad-api/src')
-rw-r--r--mullvad-api/src/address_cache.rs20
-rw-r--r--mullvad-api/src/bin/relay_list.rs6
-rw-r--r--mullvad-api/src/https_client_with_sni.rs34
-rw-r--r--mullvad-api/src/lib.rs66
-rw-r--r--mullvad-api/src/rest.rs3
5 files changed, 59 insertions, 70 deletions
diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs
index a48db1a0e2..0898f8da1f 100644
--- a/mullvad-api/src/address_cache.rs
+++ b/mullvad-api/src/address_cache.rs
@@ -1,6 +1,8 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use super::API;
+use crate::DnsResolver;
+use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
@@ -23,6 +25,17 @@ pub enum Error {
Write(#[source] io::Error),
}
+/// A DNS resolver which resolves using `AddressCache`.
+#[async_trait]
+impl DnsResolver for AddressCache {
+ async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
+ self.resolve_hostname(&host)
+ .await
+ .map(|addr| vec![addr])
+ .ok_or(io::Error::other("host does not match API host"))
+ }
+}
+
#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
@@ -42,7 +55,10 @@ impl AddressCache {
/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
- Ok(Self::new_inner(read_address_file(read_path).await?, write_path))
+ Ok(Self::new_inner(
+ read_address_file(read_path).await?,
+ write_path,
+ ))
}
fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
@@ -56,7 +72,7 @@ impl AddressCache {
}
/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
- pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
+ async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index 22190abd63..def32303ea 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
-use mullvad_api::{
- proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
-};
+use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;
#[tokio::main]
async fn main() {
- let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
+ let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
let relay_list_request =
diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs
index 09ce493431..59d1f7fac7 100644
--- a/mullvad-api/src/https_client_with_sni.rs
+++ b/mullvad-api/src/https_client_with_sni.rs
@@ -2,7 +2,7 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
- AddressCache, DnsResolver,
+ DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
@@ -287,7 +287,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
- address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
@@ -305,7 +304,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
- address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
@@ -353,7 +351,6 @@ impl HttpsConnectorWithSni {
HttpsConnectorWithSni {
inner,
sni_hostname,
- address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
@@ -390,13 +387,9 @@ impl HttpsConnectorWithSni {
}
/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
- /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
+ /// Otherwise `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
- async fn resolve_address(
- address_cache: AddressCache,
- dns_resolver: &dyn DnsResolver,
- uri: Uri,
- ) -> io::Result<SocketAddr> {
+ async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;
let hostname = uri.host().ok_or_else(|| {
@@ -407,22 +400,16 @@ impl HttpsConnectorWithSni {
return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT)));
}
- // Preferentially, use cached address.
- //
- if let Some(addr) = address_cache.resolve_hostname(hostname).await {
- return Ok(SocketAddr::new(
- addr.ip(),
- port.unwrap_or_else(|| addr.port()),
- ));
- }
-
- // Use DNS resolution as fallback
- //
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
- Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
+ let port = match (addr.port(), port) {
+ (_, Some(port)) => port,
+ (0, None) => DEFAULT_PORT,
+ (addr_port, None) => addr_port,
+ };
+ Ok(SocketAddr::new(addr.ip(), port))
}
}
@@ -456,7 +443,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
- let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();
let fut = async move {
@@ -468,7 +454,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}
let hostname = sni_hostname?;
- let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
+ let addr = Self::resolve_address(&*dns_resolver, uri).await?;
// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 3f67709242..5d46534c44 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -308,7 +308,7 @@ impl ApiEndpoint {
#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
- async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
+ async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>;
}
/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
@@ -316,14 +316,14 @@ pub struct DefaultDnsResolver;
#[async_trait]
impl DnsResolver for DefaultDnsResolver {
- async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
+ async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
- Ok(addrs.map(|addr| addr.ip()).collect())
+ Ok(addrs.collect())
}
}
@@ -332,7 +332,7 @@ pub struct NullDnsResolver;
#[async_trait]
impl DnsResolver for NullDnsResolver {
- async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
+ async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> {
Ok(vec![])
}
}
@@ -342,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
- dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -364,13 +363,9 @@ pub enum Error {
impl Runtime {
/// Create a new `Runtime`.
- pub fn new(
- handle: tokio::runtime::Handle,
- dns_resolver: impl DnsResolver,
- ) -> Result<Self, Error> {
+ pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
- dns_resolver,
#[cfg(target_os = "android")]
None,
)
@@ -381,21 +376,18 @@ impl Runtime {
Runtime {
handle,
address_cache: AddressCache::with_static_addr(address),
- dns_resolver: Arc::new(NullDnsResolver),
api_availability: ApiAvailability::default(),
}
}
fn new_inner(
handle: tokio::runtime::Handle,
- dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
- address_cache: AddressCache::new(None)?,
+ address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
- dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
@@ -404,7 +396,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
- dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
@@ -415,7 +406,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
- dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
@@ -439,7 +429,7 @@ impl Runtime {
)
);
}
- AddressCache::new(write_file)?
+ AddressCache::new(write_file)
}
};
@@ -449,30 +439,11 @@ impl Runtime {
handle,
address_cache,
api_availability,
- dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}
- /// Creates a new request service and returns a handle to it.
- fn new_request_service<T: ConnectionModeProvider + 'static>(
- &self,
- sni_hostname: Option<String>,
- connection_mode_provider: T,
- #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> rest::RequestServiceHandle {
- rest::RequestService::spawn(
- sni_hostname,
- self.api_availability.clone(),
- self.address_cache.clone(),
- connection_mode_provider,
- self.dns_resolver.clone(),
- #[cfg(target_os = "android")]
- socket_bypass_tx,
- )
- }
-
/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
@@ -481,6 +452,7 @@ impl Runtime {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
+ Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -495,6 +467,7 @@ impl Runtime {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
+ Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
@@ -505,15 +478,34 @@ impl Runtime {
}
/// Returns a new request service handle
- pub fn rest_handle(&self) -> rest::RequestServiceHandle {
+ pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
+ Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}
+ /// Creates a new request service and returns a handle to it.
+ fn new_request_service<T: ConnectionModeProvider + 'static>(
+ &self,
+ sni_hostname: Option<String>,
+ connection_mode_provider: T,
+ dns_resolver: Arc<dyn DnsResolver>,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> rest::RequestServiceHandle {
+ rest::RequestService::spawn(
+ sni_hostname,
+ self.api_availability.clone(),
+ connection_mode_provider,
+ dns_resolver,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ )
+ }
+
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 54a32f63f9..5c28705158 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -2,7 +2,6 @@
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
access::AccessTokenStore,
- address_cache::AddressCache,
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
@@ -153,14 +152,12 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailability,
- address_cache: AddressCache,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
- address_cache.clone(),
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),