diff options
| author | David Lönnhager <david.l@mullvad.net> | 2024-11-27 15:09:05 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2024-12-02 16:00:38 +0100 |
| commit | 94ce8fb753f7441243d9281416632ce7ed4b6cd6 (patch) | |
| tree | 8fa86e98b4b50ecd70ebf6468d222cad8e6cf552 /mullvad-api/src | |
| parent | e07c12a5f14a11051fa086c97cc22413e431a1c8 (diff) | |
| download | mullvadvpn-94ce8fb753f7441243d9281416632ce7ed4b6cd6.tar.xz mullvadvpn-94ce8fb753f7441243d9281416632ce7ed4b6cd6.zip | |
Remove DNS fallback except for conncheck
Diffstat (limited to 'mullvad-api/src')
| -rw-r--r-- | mullvad-api/src/address_cache.rs | 20 | ||||
| -rw-r--r-- | mullvad-api/src/bin/relay_list.rs | 6 | ||||
| -rw-r--r-- | mullvad-api/src/https_client_with_sni.rs | 34 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 66 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 3 |
5 files changed, 59 insertions, 70 deletions
diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index a48db1a0e2..0898f8da1f 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -1,6 +1,8 @@ //! This module keeps track of the last known good API IP address and reads and stores it on disk. use super::API; +use crate::DnsResolver; +use async_trait::async_trait; use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ fs, @@ -23,6 +25,17 @@ pub enum Error { Write(#[source] io::Error), } +/// A DNS resolver which resolves using `AddressCache`. +#[async_trait] +impl DnsResolver for AddressCache { + async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> { + self.resolve_hostname(&host) + .await + .map(|addr| vec![addr]) + .ok_or(io::Error::other("host does not match API host")) + } +} + #[derive(Clone)] pub struct AddressCache { inner: Arc<Mutex<AddressCacheInner>>, @@ -42,7 +55,10 @@ impl AddressCache { /// Initialize cache using `read_path`, and write changes to `write_path`. pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> { log::debug!("Loading API addresses from {}", read_path.display()); - Ok(Self::new_inner(read_address_file(read_path).await?, write_path)) + Ok(Self::new_inner( + read_address_file(read_path).await?, + write_path, + )) } fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self { @@ -56,7 +72,7 @@ impl AddressCache { } /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. - pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> { + async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> { if hostname.eq_ignore_ascii_case(API.host()) { Some(self.get_address().await) } else { diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index 22190abd63..def32303ea 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,15 +2,13 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{ - proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy, -}; +use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver) + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("Failed to load runtime"); let relay_list_request = diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 09ce493431..59d1f7fac7 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -2,7 +2,7 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, tls_stream::TlsStream, - AddressCache, DnsResolver, + DnsResolver, }; use futures::{channel::mpsc, future, pin_mut, StreamExt}; #[cfg(target_os = "android")] @@ -287,7 +287,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode { pub struct HttpsConnectorWithSni { inner: Arc<Mutex<HttpsConnectorWithSniInner>>, sni_hostname: Option<String>, - address_cache: AddressCache, abort_notify: Arc<tokio::sync::Notify>, dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] @@ -305,7 +304,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>); impl HttpsConnectorWithSni { pub fn new( sni_hostname: Option<String>, - address_cache: AddressCache, dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> (Self, HttpsConnectorWithSniHandle) { @@ -353,7 +351,6 @@ impl HttpsConnectorWithSni { HttpsConnectorWithSni { inner, sni_hostname, - address_cache, abort_notify, dns_resolver, #[cfg(target_os = "android")] @@ -390,13 +387,9 @@ impl HttpsConnectorWithSni { } /// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used. - /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback. + /// Otherwise `dns_resolver` will be used as a fallback. /// If the URI contains a port, then that port will be used. - async fn resolve_address( - address_cache: AddressCache, - dns_resolver: &dyn DnsResolver, - uri: Uri, - ) -> io::Result<SocketAddr> { + async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> { const DEFAULT_PORT: u16 = 443; let hostname = uri.host().ok_or_else(|| { @@ -407,22 +400,16 @@ impl HttpsConnectorWithSni { return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT))); } - // Preferentially, use cached address. - // - if let Some(addr) = address_cache.resolve_hostname(hostname).await { - return Ok(SocketAddr::new( - addr.ip(), - port.unwrap_or_else(|| addr.port()), - )); - } - - // Use DNS resolution as fallback - // let addrs = dns_resolver.resolve(hostname.to_owned()).await?; let addr = addrs .first() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; - Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT))) + let port = match (addr.port(), port) { + (_, Some(port)) => port, + (0, None) => DEFAULT_PORT, + (addr_port, None) => addr_port, + }; + Ok(SocketAddr::new(addr.ip(), port)) } } @@ -456,7 +443,6 @@ impl Service<Uri> for HttpsConnectorWithSni { let abort_notify = self.abort_notify.clone(); #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); - let address_cache = self.address_cache.clone(); let dns_resolver = self.dns_resolver.clone(); let fut = async move { @@ -468,7 +454,7 @@ impl Service<Uri> for HttpsConnectorWithSni { } let hostname = sni_hostname?; - let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?; + let addr = Self::resolve_address(&*dns_resolver, uri).await?; // Loop until we have established a connection. This starts over if a new endpoint // is selected while connecting. diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 3f67709242..5d46534c44 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -308,7 +308,7 @@ impl ApiEndpoint { #[async_trait] pub trait DnsResolver: 'static + Send + Sync { - async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>; + async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>; } /// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`). @@ -316,14 +316,14 @@ pub struct DefaultDnsResolver; #[async_trait] impl DnsResolver for DefaultDnsResolver { - async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> { + async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> { use std::net::ToSocketAddrs; // Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which // blocks and either has no timeout or a very long one. let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs()) .await .expect("DNS task panicked")?; - Ok(addrs.map(|addr| addr.ip()).collect()) + Ok(addrs.collect()) } } @@ -332,7 +332,7 @@ pub struct NullDnsResolver; #[async_trait] impl DnsResolver for NullDnsResolver { - async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> { + async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> { Ok(vec![]) } } @@ -342,7 +342,6 @@ pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, - dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } @@ -364,13 +363,9 @@ pub enum Error { impl Runtime { /// Create a new `Runtime`. - pub fn new( - handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, - ) -> Result<Self, Error> { + pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] None, ) @@ -381,21 +376,18 @@ impl Runtime { Runtime { handle, address_cache: AddressCache::with_static_addr(address), - dns_resolver: Arc::new(NullDnsResolver), api_availability: ApiAvailability::default(), } } fn new_inner( handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { Ok(Runtime { handle, - address_cache: AddressCache::new(None)?, + address_cache: AddressCache::new(None), api_availability: ApiAvailability::default(), - dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -404,7 +396,6 @@ impl Runtime { /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. pub async fn with_cache( - dns_resolver: impl DnsResolver, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, @@ -415,7 +406,6 @@ impl Runtime { if API.disable_address_cache { return Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, ); @@ -439,7 +429,7 @@ impl Runtime { ) ); } - AddressCache::new(write_file)? + AddressCache::new(write_file) } }; @@ -449,30 +439,11 @@ impl Runtime { handle, address_cache, api_availability, - dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) } - /// Creates a new request service and returns a handle to it. - fn new_request_service<T: ConnectionModeProvider + 'static>( - &self, - sni_hostname: Option<String>, - connection_mode_provider: T, - #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - ) -> rest::RequestServiceHandle { - rest::RequestService::spawn( - sni_hostname, - self.api_availability.clone(), - self.address_cache.clone(), - connection_mode_provider, - self.dns_resolver.clone(), - #[cfg(target_os = "android")] - socket_bypass_tx, - ) - } - /// Returns a request factory initialized to create requests for the master API pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>( &self, @@ -481,6 +452,7 @@ impl Runtime { let service = self.new_request_service( Some(API.host().to_string()), connection_mode_provider, + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -495,6 +467,7 @@ impl Runtime { let service = self.new_request_service( Some(hostname.clone()), ApiConnectionMode::Direct.into_provider(), + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -505,15 +478,34 @@ impl Runtime { } /// Returns a new request service handle - pub fn rest_handle(&self) -> rest::RequestServiceHandle { + pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle { self.new_request_service( None, ApiConnectionMode::Direct.into_provider(), + Arc::new(dns_resolver), #[cfg(target_os = "android")] None, ) } + /// Creates a new request service and returns a handle to it. + fn new_request_service<T: ConnectionModeProvider + 'static>( + &self, + sni_hostname: Option<String>, + connection_mode_provider: T, + dns_resolver: Arc<dyn DnsResolver>, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> rest::RequestServiceHandle { + rest::RequestService::spawn( + sni_hostname, + self.api_availability.clone(), + connection_mode_provider, + dns_resolver, + #[cfg(target_os = "android")] + socket_bypass_tx, + ) + } + pub fn handle(&mut self) -> &mut tokio::runtime::Handle { &mut self.handle } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 54a32f63f9..5c28705158 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -2,7 +2,6 @@ pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ access::AccessTokenStore, - address_cache::AddressCache, availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, @@ -153,14 +152,12 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { pub fn spawn( sni_hostname: Option<String>, api_availability: ApiAvailability, - address_cache: AddressCache, connection_mode_provider: T, dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( sni_hostname, - address_cache.clone(), dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), |
