diff options
| -rw-r--r-- | CHANGELOG.md | 1 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-rpc/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-rpc/src/address_cache.rs | 152 | ||||
| -rw-r--r-- | mullvad-rpc/src/cached_dns_resolver.rs | 467 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 60 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 142 |
7 files changed, 325 insertions, 500 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 014320e2ad..7a72f205df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th ### Changed - Open and focus app when opened from context menu instead of toggling the window. +- Use the API to fetch API IP addresses instead of DNS. #### Android - Removed the Quit button. diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 430ace18f0..737ffab185 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -490,6 +490,7 @@ where tokio::runtime::Handle::current(), &cache_dir, ) + .await .map_err(Error::InitRpcFactory)?; let rpc_handle = rpc_runtime.mullvad_rest_handle(); diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml index 2c98bc5017..a5acd75be4 100644 --- a/mullvad-rpc/Cargo.toml +++ b/mullvad-rpc/Cargo.toml @@ -19,7 +19,7 @@ regex = "1" serde = "1" serde_json = "1.0" hyper-rustls = "0.21" -tokio = { version = "0.2", features = [ "macros", "time", "rt-threaded", "net", "io-std", "io-driver" ] } +tokio = { version = "0.2", features = [ "macros", "time", "rt-threaded", "net", "io-std", "io-driver", "fs" ] } tokio-rustls = "0.14" urlencoding = "1" webpki = { version = "0.21", features = [] } diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs new file mode 100644 index 0000000000..6c08927bc5 --- /dev/null +++ b/mullvad-rpc/src/address_cache.rs @@ -0,0 +1,152 @@ +use std::{ + io, + net::{IpAddr, SocketAddr}, + path::Path, + sync::{Arc, Mutex}, +}; +use tokio::{ + fs, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, +}; + +const FALLBACK_API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443); + +#[derive(Clone)] +pub struct AddressCache { + inner: Arc<Mutex<AddressCacheInner>>, + cache_path: Option<Arc<Path>>, +} + +impl AddressCache { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(Default::default())), + cache_path: None, + } + } + + pub async fn with_cache(cache_path: Box<Path>) -> Self { + let cache = AddressCacheInner::from_cache_file(&cache_path) + .await + .unwrap_or_default(); + let inner = Arc::new(Mutex::new(cache)); + let cache_path = Some(cache_path.into()); + Self { inner, cache_path } + } + + pub fn get_address(&self) -> SocketAddr { + let mut inner = self.inner.lock().unwrap(); + inner.last_try = Some(inner.choice); + + Self::get_address_inner(&inner) + } + + fn get_address_inner(inner: &AddressCacheInner) -> SocketAddr { + if inner.addresses.is_empty() { + return FALLBACK_API_ADDRESS.into(); + } + *inner + .addresses + .get(inner.choice % inner.addresses.len()) + .unwrap_or(&FALLBACK_API_ADDRESS.into()) + } + + pub fn register_failure(&self, failed_addr: SocketAddr, err: &dyn std::error::Error) { + let mut inner = self.inner.lock().unwrap(); + + let current_address = Self::get_address_inner(&inner); + // Only choose the next server if the current one has been tried before and it failed + if failed_addr == current_address + && inner + .last_try + .map(|last_try| last_try == inner.choice) + .unwrap_or(false) + { + log::error!("HTTP request failed: {}, will try next API address", err); + inner.choice = inner.choice.wrapping_add(1); + } + } + + pub async fn set_addresses(&self, addresses: Vec<SocketAddr>) -> io::Result<()> { + let should_update = { + let mut inner = self.inner.lock().unwrap(); + if addresses != inner.addresses { + inner.addresses = addresses.clone(); + inner.choice = 0; + true + } else { + false + } + }; + if should_update { + self.save_to_disk(addresses).await?; + } + Ok(()) + } + + async fn save_to_disk(&self, addresses: Vec<SocketAddr>) -> io::Result<()> { + if let Some(cache_path) = self.cache_path.as_ref() { + let mut file = fs::File::create(cache_path).await?; + let mut contents = addresses + .iter() + .map(ToString::to_string) + .collect::<Vec<String>>() + .join("\n"); + contents += "\n"; + + file.write_all(contents.as_bytes()).await?; + file.sync_data().await?; + } + + Ok(()) + } +} + +impl crate::rest::AddressProvider for AddressCache { + fn get_address(&self) -> String { + self.get_address().to_string() + } + + fn clone_box(&self) -> Box<dyn crate::rest::AddressProvider> { + Box::new(self.clone()) + } +} + + +struct AddressCacheInner { + addresses: Vec<SocketAddr>, + choice: usize, + last_try: Option<usize>, +} + +impl AddressCacheInner { + async fn from_cache_file(path: &Path) -> io::Result<Self> { + let file = fs::File::open(path).await?; + let mut lines = BufReader::new(file).lines(); + let mut addresses = vec![]; + while let Some(line) = lines.next_line().await? { + // for line in lines.next_line() { + match line.trim().parse() { + Ok(address) => addresses.push(address), + Err(err) => { + log::error!("Failed to parse cached address line: {}", err); + } + } + } + + Ok(Self { + addresses, + ..Default::default() + }) + } +} + +impl Default for AddressCacheInner { + fn default() -> Self { + Self { + addresses: vec![FALLBACK_API_ADDRESS.into()], + choice: 0, + last_try: None, + } + } +} diff --git a/mullvad-rpc/src/cached_dns_resolver.rs b/mullvad-rpc/src/cached_dns_resolver.rs deleted file mode 100644 index 2c5790ef5f..0000000000 --- a/mullvad-rpc/src/cached_dns_resolver.rs +++ /dev/null @@ -1,467 +0,0 @@ -use log::{debug, info, warn}; -use std::{ - fs::{self, File}, - io::{self, Write}, - net::{IpAddr, ToSocketAddrs}, - path::{Path, PathBuf}, - sync::mpsc, - thread, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; -use talpid_types::ErrorExt; - - -static DNS_TIMEOUT: Duration = Duration::from_secs(2); -static MAX_CACHE_AGE: Duration = Duration::from_secs(3600); -static EXPIRED_CACHE_TIMESTAMP: SystemTime = UNIX_EPOCH; - -pub type Result<T> = std::result::Result<T, Error>; - -#[derive(err_derive::Error, Debug)] -pub enum Error { - /// DNS resolution for a host took too long - #[error(display = "DNS resolution for \"{}\" timed out", _0)] - DnsTimeout(String, #[error(source)] mpsc::RecvTimeoutError), - - /// DNS resolution for a host didn't return any IP addresses - #[error(display = "DNS resolution for \"{}\" did not return any IPs", _0)] - HostNotFound(String), - - /// Failed to resolve IP address for host - #[error(display = "Failed to resolve IP address for \"{}\"", _0)] - ResolveFailure(String, #[error(source)] io::Error), - - /// Unable to read IP cache file - #[error(display = "Failed to read DNS IP cache file")] - ReadCacheError(#[error(source)] io::Error), - - /// Address loaded from file is invalid - #[error(display = "Address loaded from file is invalid")] - ParseCacheError(#[error(source)] std::net::AddrParseError), -} - - -pub trait DnsResolver { - fn resolve(&mut self, host: &str) -> Result<IpAddr>; -} - -pub struct SystemDnsResolver; - -impl SystemDnsResolver { - fn resolve_in_background_thread(host: &str) -> mpsc::Receiver<Result<IpAddr>> { - let host = host.to_owned(); - let (tx, rx) = mpsc::channel(); - - thread::spawn(move || { - let _ = tx.send(Self::resolve_hostname(&host)); - }); - - rx - } - - fn resolve_hostname(host: &str) -> Result<IpAddr> { - (host, 0) - .to_socket_addrs() - .map_err(|e| Error::ResolveFailure(host.to_owned(), e))? - .next() - .map(|socket_address| socket_address.ip()) - .ok_or_else(|| Error::HostNotFound(host.to_owned())) - } -} - -impl DnsResolver for SystemDnsResolver { - fn resolve(&mut self, host: &str) -> Result<IpAddr> { - Self::resolve_in_background_thread(host) - .recv_timeout(DNS_TIMEOUT) - .map_err(|e| Error::DnsTimeout(host.to_owned(), e)) - .and_then(|result| result) - } -} - -pub struct CachedDnsResolver<R: DnsResolver = SystemDnsResolver> { - hostname: String, - dns_resolver: R, - cache_file: Option<PathBuf>, - cached_address: IpAddr, - last_updated: SystemTime, -} - -impl CachedDnsResolver<SystemDnsResolver> { - pub fn new(hostname: String, cache_file: Option<PathBuf>, fallback_address: IpAddr) -> Self { - Self::with_dns_resolver(SystemDnsResolver, hostname, cache_file, fallback_address) - } -} - -impl<R: DnsResolver> CachedDnsResolver<R> { - pub fn with_dns_resolver( - dns_resolver: R, - hostname: String, - cache_file: Option<PathBuf>, - fallback_address: IpAddr, - ) -> Self { - let (cached_address, last_updated) = match &cache_file { - Some(cache_file) => Self::load_initial_cached_address(&cache_file, fallback_address), - None => (fallback_address, EXPIRED_CACHE_TIMESTAMP), - }; - - CachedDnsResolver { - hostname, - dns_resolver, - cache_file, - cached_address, - last_updated, - } - } - - pub fn resolve(&mut self) -> IpAddr { - if let Ok(cache_age) = self.last_updated.elapsed() { - if cache_age > MAX_CACHE_AGE { - self.resolve_into_cache(); - } - } else { - warn!("System time changed, assuming cached IP address has expired"); - self.resolve_into_cache(); - } - - self.cached_address - } - - fn load_initial_cached_address( - cache_file: &Path, - fallback_address: IpAddr, - ) -> (IpAddr, SystemTime) { - match Self::load_from_file(cache_file) { - Ok(previously_cached_address) => match Self::read_file_modification_time(cache_file) { - Ok(last_updated) => (previously_cached_address, last_updated), - Err(error) => { - warn!("Failed to read modification time of file: {}", error); - (previously_cached_address, EXPIRED_CACHE_TIMESTAMP) - } - }, - Err(error) => { - info!( - "{}", - error.display_chain_with_msg( - "Failed to load previously cached IP address, using fallback" - ) - ); - (fallback_address, EXPIRED_CACHE_TIMESTAMP) - } - } - } - - fn load_from_file(file_path: &Path) -> Result<IpAddr> { - let address = fs::read_to_string(file_path).map_err(Error::ReadCacheError)?; - address.trim().parse().map_err(Error::ParseCacheError) - } - - fn read_file_modification_time(cache_file: &Path) -> io::Result<SystemTime> { - cache_file - .metadata() - .and_then(|metadata| metadata.modified()) - } - - fn resolve_into_cache(&mut self) { - debug!("Resolving IP for {}", self.hostname); - match self.dns_resolver.resolve(&self.hostname) { - Ok(address) => { - if Self::is_bogus_address(address) { - warn!( - "DNS lookup for {} returned bogus address {}, ignoring", - self.hostname, address - ); - return; - } - - debug!("Updating DNS cache for {} with {}", self.hostname, address); - self.cached_address = address; - self.last_updated = SystemTime::now(); - - if let Err(error) = self.update_cache_file() { - warn!("Failed to update cache file with new IP address: {}", error); - } - } - Err(e) => { - warn!( - "{}", - e.display_chain_with_msg(&format!("Unable to resolve {}", self.hostname)) - ); - } - } - } - - /// Checks if an IP seems to be a reasonable and routable IP. Used to try to filter out and - /// ignore invalid IPs returned by poisoned DNS etc. - fn is_bogus_address(address: IpAddr) -> bool { - let is_private = match address { - IpAddr::V4(address) => address.is_private(), - _ => false, - }; - address.is_unspecified() || address.is_loopback() || is_private - } - - fn update_cache_file(&mut self) -> io::Result<()> { - if let Some(cache_file_path) = &self.cache_file { - let mut cache_file = File::create(cache_file_path)?; - writeln!(cache_file, "{}", self.cached_address) - } else { - Ok(()) - } - } -} - -#[cfg(test)] -mod tests { - use std::{ - fs::{self, File}, - io::{Read, Write}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - }; - - use super::*; - use filetime::FileTime; - use tempfile::TempDir; - - #[test] - fn uses_previously_cached_address() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let mock_resolver = MockDnsResolver::with_address("192.168.1.206".parse().unwrap()); - let mock_resolver_was_called = mock_resolver.was_called_handle(); - let cached_address = "127.0.0.1".parse().unwrap(); - - write_address(&cache_dir, cached_address); - - let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None); - let address = cache.resolve(); - - assert!(!mock_resolver_was_called.load(Ordering::Acquire)); - assert_eq!(address, cached_address); - } - - #[test] - fn old_cache_file_is_updated() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let cached_address = "80.10.20.30".parse().unwrap(); - let mock_address = "90.168.1.206".parse().unwrap(); - let mock_resolver = MockDnsResolver::with_address(mock_address); - - let cache_file_path = write_address(&cache_dir, cached_address); - - make_file_old(&cache_file_path); - - let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None); - let address = cache.resolve(); - - assert_eq!(get_cached_address(&cache_dir), address.to_string()); - assert_eq!(address, mock_address); - } - - #[test] - fn old_cache_file_is_used_if_resolution_fails() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let mock_resolver = MockDnsResolver::that_fails(); - let cached_address = "127.0.0.1".parse().unwrap(); - - let cache_file_path = write_address(&cache_dir, cached_address); - - make_file_old(&cache_file_path); - - let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None); - let address = cache.resolve(); - - assert_eq!(address, cached_address); - } - - #[test] - fn caches_resolved_ip() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let mock_address = "80.10.1.206".parse().unwrap(); - let mock_resolver = MockDnsResolver::with_address(mock_address); - - let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None); - let address = cache.resolve(); - - assert_eq!(address, mock_address); - assert_eq!(get_cached_address(&cache_dir), address.to_string()); - } - - #[test] - fn resolves_even_if_impossible_to_store_in_cache() { - let (temp_dir, cache_dir) = create_test_dirs(); - let mock_address = "201.0.1.206".parse().unwrap(); - let mock_resolver = MockDnsResolver::with_address(mock_address); - - let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None); - - std::mem::drop(temp_dir); - - assert_eq!(cache.resolve(), mock_address); - } - - #[test] - fn uses_fallback_address() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let fallback_address = "192.168.1.31".parse().unwrap(); - let mock_resolver = MockDnsResolver::that_fails(); - let mock_resolver_was_called = mock_resolver.was_called_handle(); - - let mut cache = - create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address)); - let address = cache.resolve(); - - assert!(mock_resolver_was_called.load(Ordering::Acquire)); - assert_eq!(address, fallback_address); - } - - #[test] - fn ignores_fallback_address_if_resolution_succeeds() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let fallback_address = "200.10.1.31".parse().unwrap(); - let mock_address = "150.10.1.206".parse().unwrap(); - let mock_resolver = MockDnsResolver::with_address(mock_address); - - let mut cache = - create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address)); - let address = cache.resolve(); - - assert_eq!(address, mock_address); - } - - #[test] - fn invalid_cache_file_leads_to_fallback_address_usage() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let fallback_address = "160.20.1.31".parse().unwrap(); - let mock_resolver = MockDnsResolver::that_fails(); - let mock_resolver_was_called = mock_resolver.was_called_handle(); - - write_invalid_address(&cache_dir); - - let mut cache = - create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address)); - let address = cache.resolve(); - - assert!(mock_resolver_was_called.load(Ordering::Acquire)); - assert_eq!(address, fallback_address); - } - - #[test] - fn ignores_private_ip() { - let (_temp_dir, cache_dir) = create_test_dirs(); - let fallback_address = "160.20.1.31".parse().unwrap(); - let mock_address = "10.100.200.1".parse().unwrap(); - let mock_resolver = MockDnsResolver::with_address(mock_address); - - let mut cache = - create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address)); - let address = cache.resolve(); - - assert_eq!(address, fallback_address); - let cache_file_path = cache_dir.join(crate::API_IP_CACHE_FILENAME); - assert!(!cache_file_path.exists()); - } - - fn create_test_dirs() -> (TempDir, PathBuf) { - let temp_dir = TempDir::new().expect("Failed to create a temporary cache directory"); - let cache_dir = temp_dir.path().join("cache"); - - fs::create_dir(&cache_dir).unwrap(); - - (temp_dir, cache_dir) - } - - fn write_invalid_address(dir: &Path) -> PathBuf { - let file_path = dir.join(crate::API_IP_CACHE_FILENAME); - let mut file = File::create(&file_path).unwrap(); - - writeln!(file, "400.30.12.9").unwrap(); - - file_path - } - - fn write_address(dir: &Path, address: IpAddr) -> PathBuf { - let file_path = dir.join(crate::API_IP_CACHE_FILENAME); - let mut file = File::create(&file_path).unwrap(); - - writeln!(file, "{}", address).unwrap(); - - file_path - } - - fn make_file_old(file: &Path) { - let file_metadata = file.metadata().unwrap(); - let last_access_time = FileTime::from_last_access_time(&file_metadata); - let fake_modification_time = FileTime::from_unix_time(100_000, 0); - - filetime::set_file_times(&file, last_access_time, fake_modification_time).unwrap(); - } - - fn get_cached_address(cache_dir: &Path) -> String { - let cache_file_path = cache_dir.join(crate::API_IP_CACHE_FILENAME); - - assert!(cache_file_path.exists()); - - let mut cache_file = File::open(cache_file_path).unwrap(); - let mut cached_address = String::new(); - - cache_file.read_to_string(&mut cached_address).unwrap(); - - cached_address.trim().to_string() - } - - fn create_cached_dns_resolver( - mock_resolver: MockDnsResolver, - cache_dir: &Path, - fallback_address: Option<IpAddr>, - ) -> CachedDnsResolver<MockDnsResolver> { - let hostname = String::from("dummy.host"); - let cache_file = cache_dir.join(crate::API_IP_CACHE_FILENAME); - let fallback_address = fallback_address.unwrap_or(IpAddr::from([10, 0, 109, 91])); - - CachedDnsResolver::with_dns_resolver( - mock_resolver, - hostname, - Some(cache_file), - fallback_address, - ) - } - - struct MockDnsResolver { - address: Option<IpAddr>, - called: Arc<AtomicBool>, - } - - impl MockDnsResolver { - pub fn with_address(address: IpAddr) -> Self { - MockDnsResolver { - address: Some(address), - called: Arc::new(AtomicBool::new(false)), - } - } - - pub fn that_fails() -> Self { - MockDnsResolver { - address: None, - called: Arc::new(AtomicBool::new(false)), - } - } - - pub fn was_called_handle(&self) -> Arc<AtomicBool> { - self.called.clone() - } - } - - impl DnsResolver for MockDnsResolver { - fn resolve(&mut self, host: &str) -> Result<IpAddr> { - self.called.store(true, Ordering::Release); - self.address.ok_or_else(|| { - Error::ResolveFailure( - host.to_owned(), - io::Error::new(io::ErrorKind::Other, "FAILED"), - ) - }) - } - } -} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 60e39f571e..69f39b6349 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -9,7 +9,7 @@ use mullvad_types::{ use std::{ collections::BTreeMap, future::Future, - net::{IpAddr, Ipv4Addr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, }; use talpid_types::net::wireguard; @@ -17,13 +17,12 @@ use talpid_types::net::wireguard; pub mod rest; -mod cached_dns_resolver; -use crate::cached_dns_resolver::CachedDnsResolver; - mod https_client_with_sni; use crate::https_client_with_sni::HttpsConnectorWithSni; +mod address_cache; mod relay_list; +use address_cache::AddressCache; pub use hyper::StatusCode; pub use relay_list::RelayListProxy; @@ -40,9 +39,9 @@ const API_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78)); /// A type that helps with the creation of RPC connections. pub struct MullvadRpcRuntime { - cached_dns_resolver: CachedDnsResolver, https_connector: HttpsConnectorWithSni, handle: tokio::runtime::Handle, + address_cache: AddressCache, } #[derive(err_derive::Error, Debug)] @@ -55,24 +54,26 @@ impl MullvadRpcRuntime { /// Create a new `MullvadRpcRuntime`. pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { Ok(MullvadRpcRuntime { - cached_dns_resolver: CachedDnsResolver::new(API_HOST.to_owned(), None, API_IP), https_connector: HttpsConnectorWithSni::new(), handle, + address_cache: AddressCache::new(), }) } /// Create a new `MullvadRpcRuntime` using the specified cache directory. - pub fn with_cache_dir(handle: tokio::runtime::Handle, cache_dir: &Path) -> Result<Self, Error> { + pub async fn with_cache_dir( + handle: tokio::runtime::Handle, + cache_dir: &Path, + ) -> Result<Self, Error> { let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); - let cached_dns_resolver = - CachedDnsResolver::new(API_HOST.to_owned(), Some(cache_file), API_IP); + let address_cache = AddressCache::with_cache(cache_file.into_boxed_path()).await; let https_connector = HttpsConnectorWithSni::new(); Ok(MullvadRpcRuntime { - cached_dns_resolver, https_connector, handle, + address_cache, }) } @@ -81,7 +82,11 @@ impl MullvadRpcRuntime { let mut https_connector = self.https_connector.clone(); https_connector.set_sni_hostname(sni_hostname); - let service = rest::RequestService::new(https_connector, self.handle.clone()); + let service = rest::RequestService::new( + https_connector, + self.handle.clone(), + self.address_cache.clone(), + ); let handle = service.handle(); self.handle.spawn(service.into_future()); handle @@ -90,11 +95,13 @@ impl MullvadRpcRuntime { /// Returns a request factory initialized to create requests for the master API pub fn mullvad_rest_handle(&mut self) -> rest::MullvadRestHandle { let service = self.new_request_service(Some(API_HOST.to_owned())); - let ip = self.cached_dns_resolver.resolve(); - let factory = - rest::RequestFactory::new(API_HOST.to_owned(), Some(ip), Some("app".to_owned())); + let factory = rest::RequestFactory::new( + API_HOST.to_owned(), + Box::new(self.address_cache.clone()), + Some("app".to_owned()), + ); - rest::MullvadRestHandle { service, factory } + rest::MullvadRestHandle::new(service, factory, self.address_cache.clone()) } /// Returns a new request service handle @@ -412,3 +419,26 @@ impl WireguardKeyProxy { Ok(()) } } + +#[derive(Clone)] +pub struct ApiProxy { + handle: rest::MullvadRestHandle, +} + +impl ApiProxy { + pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> { + let service = self.handle.service.clone(); + + let response = rest::send_request( + &self.handle.factory, + service, + "/v1/api-addrs", + Method::GET, + None, + StatusCode::OK, + ) + .await?; + + rest::deserialize_body(response).await + } +} diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 025c22eaa8..2bf8511d24 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,3 +1,4 @@ +use crate::address_cache::AddressCache; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Aborted}, @@ -10,7 +11,14 @@ use hyper::{ header::{self, HeaderValue}, Method, Uri, }; -use std::{collections::BTreeMap, future::Future, mem, net::IpAddr, str::FromStr, time::Duration}; +use std::{ + collections::BTreeMap, + future::Future, + mem, + net::{IpAddr, SocketAddr}, + str::FromStr, + time::{Duration, Instant}, +}; use tokio::runtime::Handle; pub use hyper::StatusCode; @@ -18,6 +26,11 @@ pub use hyper::StatusCode; pub type Request = hyper::Request<hyper::Body>; pub type Response = hyper::Response<hyper::Body>; +const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60); +const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60); +const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); +const API_IP_CHECK_ERROR_INTERVAL: Duration = Duration::from_secs(15 * 60); + pub type Result<T> = std::result::Result<T, Error>; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); @@ -68,11 +81,12 @@ pub(crate) struct RequestService<C> { handle: Handle, next_id: u64, in_flight_requests: BTreeMap<u64, AbortHandle>, + address_cache: AddressCache, } impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { /// Constructs a new request service. - pub fn new(connector: C, handle: Handle) -> RequestService<C> { + pub fn new(connector: C, handle: Handle, address_cache: AddressCache) -> RequestService<C> { let client = Self::new_client(connector.clone()); let (command_tx, command_rx) = mpsc::channel(1); @@ -84,6 +98,7 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { next_id: 0, connector, handle, + address_cache, } } @@ -106,11 +121,12 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { let mut tx = self.command_tx.clone(); let timeout = request.timeout(); - let (request_future, abort_handle) = abortable( - self.client - .request(request.into_request()) - .map_err(Error::from), - ); + let hyper_request = request.into_request(); + let host_addr = get_request_socket_addr(&hyper_request); + + let (request_future, abort_handle) = + abortable(self.client.request(hyper_request).map_err(Error::from)); + let address_cache = self.address_cache.clone(); let future = async move { let response = @@ -119,6 +135,17 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { .map_err(Error::TimeoutError); let response = flatten_result(flatten_result(response)); + if let Some(host_addr) = host_addr { + if let Err(err) = &response { + match err { + Error::HyperError(_) | Error::TimeoutError(_) => { + address_cache.register_failure(host_addr, err); + } + _ => (), + } + } + } + if completion_tx.send(response).is_err() { log::trace!( @@ -165,6 +192,18 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> { } } +fn get_request_socket_addr(request: &Request) -> Option<SocketAddr> { + let uri = request.uri(); + let port = uri + .port_u16() + // Assuming HTTPS always + .unwrap_or(443); + + let host_addr = uri.host().and_then(|host| host.parse::<IpAddr>().ok())?; + + Some(SocketAddr::new(host_addr, port)) +} + #[derive(Clone)] /// A handle to interact with a spawned `RequestService`. @@ -301,18 +340,22 @@ pub struct ErrorResponse { #[derive(Clone)] pub struct RequestFactory { - host: String, - address: Option<IpAddr>, + hostname: String, + address_provider: Box<dyn AddressProvider>, path_prefix: Option<String>, pub timeout: Duration, } impl RequestFactory { - pub fn new(host: String, address: Option<IpAddr>, path_prefix: Option<String>) -> Self { + pub fn new( + hostname: String, + address_provider: Box<dyn AddressProvider>, + path_prefix: Option<String>, + ) -> Self { Self { - host, - address, + hostname, + address_provider, path_prefix, timeout: DEFAULT_TIMEOUT, } @@ -367,16 +410,13 @@ impl RequestFactory { .method(method) .uri(uri) .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .header(header::HOST, self.host.clone()); + .header(header::HOST, self.hostname.clone()); request.body(hyper::Body::empty()).map_err(Error::HttpError) } fn get_uri(&self, path: &str) -> Result<Uri> { - let host: &dyn std::fmt::Display = &self - .address - .map(|addr| addr.to_string()) - .unwrap_or_else(|| self.host.clone()); + let host = self.address_provider.get_address(); let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or(""); let uri = format!("https://{}/{}{}", host, prefix, path); hyper::Uri::from_str(&uri).map_err(Error::UriError) @@ -388,6 +428,29 @@ impl RequestFactory { } } +pub trait AddressProvider: Send + Sync { + /// Must return a string that represents either a host or a host with port + fn get_address(&self) -> String; + fn clone_box(&self) -> Box<dyn AddressProvider>; +} + +impl Clone for Box<dyn AddressProvider> { + fn clone(&self) -> Self { + self.clone_box() + } +} + +impl AddressProvider for IpAddr { + /// Must return a string that represents either a host or a host with port + fn get_address(&self) -> String { + self.to_string() + } + + fn clone_box(&self) -> Box<dyn AddressProvider> { + Box::new(*self) + } +} + pub fn get_request<T: serde::de::DeserializeOwned>( factory: &RequestFactory, @@ -490,6 +553,51 @@ pub struct MullvadRestHandle { } impl MullvadRestHandle { + pub(crate) fn new( + service: RequestServiceHandle, + factory: RequestFactory, + address_cache: AddressCache, + ) -> Self { + let handle = Self { service, factory }; + handle.spawn_api_address_fetcher(address_cache); + + handle + } + + fn spawn_api_address_fetcher(&self, address_cache: AddressCache) { + let handle = self.clone(); + + self.service.spawn(async move { + // always start the fetch after 15 minutes + let api_proxy = crate::ApiProxy { handle }; + let mut next_check = Instant::now() + API_IP_CHECK_DELAY; + + let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL; + let next_regular_check = || Instant::now() + API_IP_CHECK_INTERVAL; + + let mut interval = tokio::time::interval_at(next_check.into(), TIMER_CHECK_INTERVAL); + + loop { + interval.tick().await; + if next_check < Instant::now() { + match api_proxy.clone().get_api_addrs().await { + Ok(new_addrs) => { + log::debug!("Fetched new API addresses {:?}, will fetch again in {} hours", new_addrs, API_IP_CHECK_INTERVAL.as_secs() / ( 60 * 60 )); + if let Err(err) = address_cache.set_addresses(new_addrs).await { + log::error!("Failed to save newly updated API addresses: {}", err); + } + next_check = next_regular_check(); + } + Err(err) => { + log::error!("Failed to fetch new API addresses: {}, will retry again in {} seconds", err, API_IP_CHECK_ERROR_INTERVAL.as_secs()); + next_check = next_error_check(); + } + } + } + } + }); + } + pub fn service(&self) -> RequestServiceHandle { self.service.clone() } |
