summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2020-09-24 00:03:46 +0100
committerEmīls <emils@mullvad.net>2020-09-25 15:28:25 +0100
commit7ed2c4ab89c1c4676f4a99aeecbd3bdc8ce129ec (patch)
treeb8ac6e850b682ea614f439927d73c6155152464c
parent22ba73b2c753824181cf6f0d378deb49ddfbaf23 (diff)
downloadmullvadvpn-7ed2c4ab89c1c4676f4a99aeecbd3bdc8ce129ec.tar.xz
mullvadvpn-7ed2c4ab89c1c4676f4a99aeecbd3bdc8ce129ec.zip
Use an address cache for reaching the API
-rw-r--r--CHANGELOG.md1
-rw-r--r--mullvad-daemon/src/lib.rs1
-rw-r--r--mullvad-rpc/Cargo.toml2
-rw-r--r--mullvad-rpc/src/address_cache.rs152
-rw-r--r--mullvad-rpc/src/cached_dns_resolver.rs467
-rw-r--r--mullvad-rpc/src/lib.rs60
-rw-r--r--mullvad-rpc/src/rest.rs142
7 files changed, 325 insertions, 500 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 014320e2ad..7a72f205df 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th
### Changed
- Open and focus app when opened from context menu instead of toggling the window.
+- Use the API to fetch API IP addresses instead of DNS.
#### Android
- Removed the Quit button.
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 430ace18f0..737ffab185 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -490,6 +490,7 @@ where
tokio::runtime::Handle::current(),
&cache_dir,
)
+ .await
.map_err(Error::InitRpcFactory)?;
let rpc_handle = rpc_runtime.mullvad_rest_handle();
diff --git a/mullvad-rpc/Cargo.toml b/mullvad-rpc/Cargo.toml
index 2c98bc5017..a5acd75be4 100644
--- a/mullvad-rpc/Cargo.toml
+++ b/mullvad-rpc/Cargo.toml
@@ -19,7 +19,7 @@ regex = "1"
serde = "1"
serde_json = "1.0"
hyper-rustls = "0.21"
-tokio = { version = "0.2", features = [ "macros", "time", "rt-threaded", "net", "io-std", "io-driver" ] }
+tokio = { version = "0.2", features = [ "macros", "time", "rt-threaded", "net", "io-std", "io-driver", "fs" ] }
tokio-rustls = "0.14"
urlencoding = "1"
webpki = { version = "0.21", features = [] }
diff --git a/mullvad-rpc/src/address_cache.rs b/mullvad-rpc/src/address_cache.rs
new file mode 100644
index 0000000000..6c08927bc5
--- /dev/null
+++ b/mullvad-rpc/src/address_cache.rs
@@ -0,0 +1,152 @@
+use std::{
+ io,
+ net::{IpAddr, SocketAddr},
+ path::Path,
+ sync::{Arc, Mutex},
+};
+use tokio::{
+ fs,
+ io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
+};
+
+const FALLBACK_API_ADDRESS: (IpAddr, u16) = (crate::API_IP, 443);
+
+#[derive(Clone)]
+pub struct AddressCache {
+ inner: Arc<Mutex<AddressCacheInner>>,
+ cache_path: Option<Arc<Path>>,
+}
+
+impl AddressCache {
+ pub fn new() -> Self {
+ Self {
+ inner: Arc::new(Mutex::new(Default::default())),
+ cache_path: None,
+ }
+ }
+
+ pub async fn with_cache(cache_path: Box<Path>) -> Self {
+ let cache = AddressCacheInner::from_cache_file(&cache_path)
+ .await
+ .unwrap_or_default();
+ let inner = Arc::new(Mutex::new(cache));
+ let cache_path = Some(cache_path.into());
+ Self { inner, cache_path }
+ }
+
+ pub fn get_address(&self) -> SocketAddr {
+ let mut inner = self.inner.lock().unwrap();
+ inner.last_try = Some(inner.choice);
+
+ Self::get_address_inner(&inner)
+ }
+
+ fn get_address_inner(inner: &AddressCacheInner) -> SocketAddr {
+ if inner.addresses.is_empty() {
+ return FALLBACK_API_ADDRESS.into();
+ }
+ *inner
+ .addresses
+ .get(inner.choice % inner.addresses.len())
+ .unwrap_or(&FALLBACK_API_ADDRESS.into())
+ }
+
+ pub fn register_failure(&self, failed_addr: SocketAddr, err: &dyn std::error::Error) {
+ let mut inner = self.inner.lock().unwrap();
+
+ let current_address = Self::get_address_inner(&inner);
+ // Only choose the next server if the current one has been tried before and it failed
+ if failed_addr == current_address
+ && inner
+ .last_try
+ .map(|last_try| last_try == inner.choice)
+ .unwrap_or(false)
+ {
+ log::error!("HTTP request failed: {}, will try next API address", err);
+ inner.choice = inner.choice.wrapping_add(1);
+ }
+ }
+
+ pub async fn set_addresses(&self, addresses: Vec<SocketAddr>) -> io::Result<()> {
+ let should_update = {
+ let mut inner = self.inner.lock().unwrap();
+ if addresses != inner.addresses {
+ inner.addresses = addresses.clone();
+ inner.choice = 0;
+ true
+ } else {
+ false
+ }
+ };
+ if should_update {
+ self.save_to_disk(addresses).await?;
+ }
+ Ok(())
+ }
+
+ async fn save_to_disk(&self, addresses: Vec<SocketAddr>) -> io::Result<()> {
+ if let Some(cache_path) = self.cache_path.as_ref() {
+ let mut file = fs::File::create(cache_path).await?;
+ let mut contents = addresses
+ .iter()
+ .map(ToString::to_string)
+ .collect::<Vec<String>>()
+ .join("\n");
+ contents += "\n";
+
+ file.write_all(contents.as_bytes()).await?;
+ file.sync_data().await?;
+ }
+
+ Ok(())
+ }
+}
+
+impl crate::rest::AddressProvider for AddressCache {
+ fn get_address(&self) -> String {
+ self.get_address().to_string()
+ }
+
+ fn clone_box(&self) -> Box<dyn crate::rest::AddressProvider> {
+ Box::new(self.clone())
+ }
+}
+
+
+struct AddressCacheInner {
+ addresses: Vec<SocketAddr>,
+ choice: usize,
+ last_try: Option<usize>,
+}
+
+impl AddressCacheInner {
+ async fn from_cache_file(path: &Path) -> io::Result<Self> {
+ let file = fs::File::open(path).await?;
+ let mut lines = BufReader::new(file).lines();
+ let mut addresses = vec![];
+ while let Some(line) = lines.next_line().await? {
+ // for line in lines.next_line() {
+ match line.trim().parse() {
+ Ok(address) => addresses.push(address),
+ Err(err) => {
+ log::error!("Failed to parse cached address line: {}", err);
+ }
+ }
+ }
+
+ Ok(Self {
+ addresses,
+ ..Default::default()
+ })
+ }
+}
+
+impl Default for AddressCacheInner {
+ fn default() -> Self {
+ Self {
+ addresses: vec![FALLBACK_API_ADDRESS.into()],
+ choice: 0,
+ last_try: None,
+ }
+ }
+}
diff --git a/mullvad-rpc/src/cached_dns_resolver.rs b/mullvad-rpc/src/cached_dns_resolver.rs
deleted file mode 100644
index 2c5790ef5f..0000000000
--- a/mullvad-rpc/src/cached_dns_resolver.rs
+++ /dev/null
@@ -1,467 +0,0 @@
-use log::{debug, info, warn};
-use std::{
- fs::{self, File},
- io::{self, Write},
- net::{IpAddr, ToSocketAddrs},
- path::{Path, PathBuf},
- sync::mpsc,
- thread,
- time::{Duration, SystemTime, UNIX_EPOCH},
-};
-use talpid_types::ErrorExt;
-
-
-static DNS_TIMEOUT: Duration = Duration::from_secs(2);
-static MAX_CACHE_AGE: Duration = Duration::from_secs(3600);
-static EXPIRED_CACHE_TIMESTAMP: SystemTime = UNIX_EPOCH;
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- /// DNS resolution for a host took too long
- #[error(display = "DNS resolution for \"{}\" timed out", _0)]
- DnsTimeout(String, #[error(source)] mpsc::RecvTimeoutError),
-
- /// DNS resolution for a host didn't return any IP addresses
- #[error(display = "DNS resolution for \"{}\" did not return any IPs", _0)]
- HostNotFound(String),
-
- /// Failed to resolve IP address for host
- #[error(display = "Failed to resolve IP address for \"{}\"", _0)]
- ResolveFailure(String, #[error(source)] io::Error),
-
- /// Unable to read IP cache file
- #[error(display = "Failed to read DNS IP cache file")]
- ReadCacheError(#[error(source)] io::Error),
-
- /// Address loaded from file is invalid
- #[error(display = "Address loaded from file is invalid")]
- ParseCacheError(#[error(source)] std::net::AddrParseError),
-}
-
-
-pub trait DnsResolver {
- fn resolve(&mut self, host: &str) -> Result<IpAddr>;
-}
-
-pub struct SystemDnsResolver;
-
-impl SystemDnsResolver {
- fn resolve_in_background_thread(host: &str) -> mpsc::Receiver<Result<IpAddr>> {
- let host = host.to_owned();
- let (tx, rx) = mpsc::channel();
-
- thread::spawn(move || {
- let _ = tx.send(Self::resolve_hostname(&host));
- });
-
- rx
- }
-
- fn resolve_hostname(host: &str) -> Result<IpAddr> {
- (host, 0)
- .to_socket_addrs()
- .map_err(|e| Error::ResolveFailure(host.to_owned(), e))?
- .next()
- .map(|socket_address| socket_address.ip())
- .ok_or_else(|| Error::HostNotFound(host.to_owned()))
- }
-}
-
-impl DnsResolver for SystemDnsResolver {
- fn resolve(&mut self, host: &str) -> Result<IpAddr> {
- Self::resolve_in_background_thread(host)
- .recv_timeout(DNS_TIMEOUT)
- .map_err(|e| Error::DnsTimeout(host.to_owned(), e))
- .and_then(|result| result)
- }
-}
-
-pub struct CachedDnsResolver<R: DnsResolver = SystemDnsResolver> {
- hostname: String,
- dns_resolver: R,
- cache_file: Option<PathBuf>,
- cached_address: IpAddr,
- last_updated: SystemTime,
-}
-
-impl CachedDnsResolver<SystemDnsResolver> {
- pub fn new(hostname: String, cache_file: Option<PathBuf>, fallback_address: IpAddr) -> Self {
- Self::with_dns_resolver(SystemDnsResolver, hostname, cache_file, fallback_address)
- }
-}
-
-impl<R: DnsResolver> CachedDnsResolver<R> {
- pub fn with_dns_resolver(
- dns_resolver: R,
- hostname: String,
- cache_file: Option<PathBuf>,
- fallback_address: IpAddr,
- ) -> Self {
- let (cached_address, last_updated) = match &cache_file {
- Some(cache_file) => Self::load_initial_cached_address(&cache_file, fallback_address),
- None => (fallback_address, EXPIRED_CACHE_TIMESTAMP),
- };
-
- CachedDnsResolver {
- hostname,
- dns_resolver,
- cache_file,
- cached_address,
- last_updated,
- }
- }
-
- pub fn resolve(&mut self) -> IpAddr {
- if let Ok(cache_age) = self.last_updated.elapsed() {
- if cache_age > MAX_CACHE_AGE {
- self.resolve_into_cache();
- }
- } else {
- warn!("System time changed, assuming cached IP address has expired");
- self.resolve_into_cache();
- }
-
- self.cached_address
- }
-
- fn load_initial_cached_address(
- cache_file: &Path,
- fallback_address: IpAddr,
- ) -> (IpAddr, SystemTime) {
- match Self::load_from_file(cache_file) {
- Ok(previously_cached_address) => match Self::read_file_modification_time(cache_file) {
- Ok(last_updated) => (previously_cached_address, last_updated),
- Err(error) => {
- warn!("Failed to read modification time of file: {}", error);
- (previously_cached_address, EXPIRED_CACHE_TIMESTAMP)
- }
- },
- Err(error) => {
- info!(
- "{}",
- error.display_chain_with_msg(
- "Failed to load previously cached IP address, using fallback"
- )
- );
- (fallback_address, EXPIRED_CACHE_TIMESTAMP)
- }
- }
- }
-
- fn load_from_file(file_path: &Path) -> Result<IpAddr> {
- let address = fs::read_to_string(file_path).map_err(Error::ReadCacheError)?;
- address.trim().parse().map_err(Error::ParseCacheError)
- }
-
- fn read_file_modification_time(cache_file: &Path) -> io::Result<SystemTime> {
- cache_file
- .metadata()
- .and_then(|metadata| metadata.modified())
- }
-
- fn resolve_into_cache(&mut self) {
- debug!("Resolving IP for {}", self.hostname);
- match self.dns_resolver.resolve(&self.hostname) {
- Ok(address) => {
- if Self::is_bogus_address(address) {
- warn!(
- "DNS lookup for {} returned bogus address {}, ignoring",
- self.hostname, address
- );
- return;
- }
-
- debug!("Updating DNS cache for {} with {}", self.hostname, address);
- self.cached_address = address;
- self.last_updated = SystemTime::now();
-
- if let Err(error) = self.update_cache_file() {
- warn!("Failed to update cache file with new IP address: {}", error);
- }
- }
- Err(e) => {
- warn!(
- "{}",
- e.display_chain_with_msg(&format!("Unable to resolve {}", self.hostname))
- );
- }
- }
- }
-
- /// Checks if an IP seems to be a reasonable and routable IP. Used to try to filter out and
- /// ignore invalid IPs returned by poisoned DNS etc.
- fn is_bogus_address(address: IpAddr) -> bool {
- let is_private = match address {
- IpAddr::V4(address) => address.is_private(),
- _ => false,
- };
- address.is_unspecified() || address.is_loopback() || is_private
- }
-
- fn update_cache_file(&mut self) -> io::Result<()> {
- if let Some(cache_file_path) = &self.cache_file {
- let mut cache_file = File::create(cache_file_path)?;
- writeln!(cache_file, "{}", self.cached_address)
- } else {
- Ok(())
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::{
- fs::{self, File},
- io::{Read, Write},
- sync::{
- atomic::{AtomicBool, Ordering},
- Arc,
- },
- };
-
- use super::*;
- use filetime::FileTime;
- use tempfile::TempDir;
-
- #[test]
- fn uses_previously_cached_address() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let mock_resolver = MockDnsResolver::with_address("192.168.1.206".parse().unwrap());
- let mock_resolver_was_called = mock_resolver.was_called_handle();
- let cached_address = "127.0.0.1".parse().unwrap();
-
- write_address(&cache_dir, cached_address);
-
- let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None);
- let address = cache.resolve();
-
- assert!(!mock_resolver_was_called.load(Ordering::Acquire));
- assert_eq!(address, cached_address);
- }
-
- #[test]
- fn old_cache_file_is_updated() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let cached_address = "80.10.20.30".parse().unwrap();
- let mock_address = "90.168.1.206".parse().unwrap();
- let mock_resolver = MockDnsResolver::with_address(mock_address);
-
- let cache_file_path = write_address(&cache_dir, cached_address);
-
- make_file_old(&cache_file_path);
-
- let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None);
- let address = cache.resolve();
-
- assert_eq!(get_cached_address(&cache_dir), address.to_string());
- assert_eq!(address, mock_address);
- }
-
- #[test]
- fn old_cache_file_is_used_if_resolution_fails() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let mock_resolver = MockDnsResolver::that_fails();
- let cached_address = "127.0.0.1".parse().unwrap();
-
- let cache_file_path = write_address(&cache_dir, cached_address);
-
- make_file_old(&cache_file_path);
-
- let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None);
- let address = cache.resolve();
-
- assert_eq!(address, cached_address);
- }
-
- #[test]
- fn caches_resolved_ip() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let mock_address = "80.10.1.206".parse().unwrap();
- let mock_resolver = MockDnsResolver::with_address(mock_address);
-
- let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None);
- let address = cache.resolve();
-
- assert_eq!(address, mock_address);
- assert_eq!(get_cached_address(&cache_dir), address.to_string());
- }
-
- #[test]
- fn resolves_even_if_impossible_to_store_in_cache() {
- let (temp_dir, cache_dir) = create_test_dirs();
- let mock_address = "201.0.1.206".parse().unwrap();
- let mock_resolver = MockDnsResolver::with_address(mock_address);
-
- let mut cache = create_cached_dns_resolver(mock_resolver, &cache_dir, None);
-
- std::mem::drop(temp_dir);
-
- assert_eq!(cache.resolve(), mock_address);
- }
-
- #[test]
- fn uses_fallback_address() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let fallback_address = "192.168.1.31".parse().unwrap();
- let mock_resolver = MockDnsResolver::that_fails();
- let mock_resolver_was_called = mock_resolver.was_called_handle();
-
- let mut cache =
- create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address));
- let address = cache.resolve();
-
- assert!(mock_resolver_was_called.load(Ordering::Acquire));
- assert_eq!(address, fallback_address);
- }
-
- #[test]
- fn ignores_fallback_address_if_resolution_succeeds() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let fallback_address = "200.10.1.31".parse().unwrap();
- let mock_address = "150.10.1.206".parse().unwrap();
- let mock_resolver = MockDnsResolver::with_address(mock_address);
-
- let mut cache =
- create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address));
- let address = cache.resolve();
-
- assert_eq!(address, mock_address);
- }
-
- #[test]
- fn invalid_cache_file_leads_to_fallback_address_usage() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let fallback_address = "160.20.1.31".parse().unwrap();
- let mock_resolver = MockDnsResolver::that_fails();
- let mock_resolver_was_called = mock_resolver.was_called_handle();
-
- write_invalid_address(&cache_dir);
-
- let mut cache =
- create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address));
- let address = cache.resolve();
-
- assert!(mock_resolver_was_called.load(Ordering::Acquire));
- assert_eq!(address, fallback_address);
- }
-
- #[test]
- fn ignores_private_ip() {
- let (_temp_dir, cache_dir) = create_test_dirs();
- let fallback_address = "160.20.1.31".parse().unwrap();
- let mock_address = "10.100.200.1".parse().unwrap();
- let mock_resolver = MockDnsResolver::with_address(mock_address);
-
- let mut cache =
- create_cached_dns_resolver(mock_resolver, &cache_dir, Some(fallback_address));
- let address = cache.resolve();
-
- assert_eq!(address, fallback_address);
- let cache_file_path = cache_dir.join(crate::API_IP_CACHE_FILENAME);
- assert!(!cache_file_path.exists());
- }
-
- fn create_test_dirs() -> (TempDir, PathBuf) {
- let temp_dir = TempDir::new().expect("Failed to create a temporary cache directory");
- let cache_dir = temp_dir.path().join("cache");
-
- fs::create_dir(&cache_dir).unwrap();
-
- (temp_dir, cache_dir)
- }
-
- fn write_invalid_address(dir: &Path) -> PathBuf {
- let file_path = dir.join(crate::API_IP_CACHE_FILENAME);
- let mut file = File::create(&file_path).unwrap();
-
- writeln!(file, "400.30.12.9").unwrap();
-
- file_path
- }
-
- fn write_address(dir: &Path, address: IpAddr) -> PathBuf {
- let file_path = dir.join(crate::API_IP_CACHE_FILENAME);
- let mut file = File::create(&file_path).unwrap();
-
- writeln!(file, "{}", address).unwrap();
-
- file_path
- }
-
- fn make_file_old(file: &Path) {
- let file_metadata = file.metadata().unwrap();
- let last_access_time = FileTime::from_last_access_time(&file_metadata);
- let fake_modification_time = FileTime::from_unix_time(100_000, 0);
-
- filetime::set_file_times(&file, last_access_time, fake_modification_time).unwrap();
- }
-
- fn get_cached_address(cache_dir: &Path) -> String {
- let cache_file_path = cache_dir.join(crate::API_IP_CACHE_FILENAME);
-
- assert!(cache_file_path.exists());
-
- let mut cache_file = File::open(cache_file_path).unwrap();
- let mut cached_address = String::new();
-
- cache_file.read_to_string(&mut cached_address).unwrap();
-
- cached_address.trim().to_string()
- }
-
- fn create_cached_dns_resolver(
- mock_resolver: MockDnsResolver,
- cache_dir: &Path,
- fallback_address: Option<IpAddr>,
- ) -> CachedDnsResolver<MockDnsResolver> {
- let hostname = String::from("dummy.host");
- let cache_file = cache_dir.join(crate::API_IP_CACHE_FILENAME);
- let fallback_address = fallback_address.unwrap_or(IpAddr::from([10, 0, 109, 91]));
-
- CachedDnsResolver::with_dns_resolver(
- mock_resolver,
- hostname,
- Some(cache_file),
- fallback_address,
- )
- }
-
- struct MockDnsResolver {
- address: Option<IpAddr>,
- called: Arc<AtomicBool>,
- }
-
- impl MockDnsResolver {
- pub fn with_address(address: IpAddr) -> Self {
- MockDnsResolver {
- address: Some(address),
- called: Arc::new(AtomicBool::new(false)),
- }
- }
-
- pub fn that_fails() -> Self {
- MockDnsResolver {
- address: None,
- called: Arc::new(AtomicBool::new(false)),
- }
- }
-
- pub fn was_called_handle(&self) -> Arc<AtomicBool> {
- self.called.clone()
- }
- }
-
- impl DnsResolver for MockDnsResolver {
- fn resolve(&mut self, host: &str) -> Result<IpAddr> {
- self.called.store(true, Ordering::Release);
- self.address.ok_or_else(|| {
- Error::ResolveFailure(
- host.to_owned(),
- io::Error::new(io::ErrorKind::Other, "FAILED"),
- )
- })
- }
- }
-}
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 60e39f571e..69f39b6349 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -9,7 +9,7 @@ use mullvad_types::{
use std::{
collections::BTreeMap,
future::Future,
- net::{IpAddr, Ipv4Addr},
+ net::{IpAddr, Ipv4Addr, SocketAddr},
path::Path,
};
use talpid_types::net::wireguard;
@@ -17,13 +17,12 @@ use talpid_types::net::wireguard;
pub mod rest;
-mod cached_dns_resolver;
-use crate::cached_dns_resolver::CachedDnsResolver;
-
mod https_client_with_sni;
use crate::https_client_with_sni::HttpsConnectorWithSni;
+mod address_cache;
mod relay_list;
+use address_cache::AddressCache;
pub use hyper::StatusCode;
pub use relay_list::RelayListProxy;
@@ -40,9 +39,9 @@ const API_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78));
/// A type that helps with the creation of RPC connections.
pub struct MullvadRpcRuntime {
- cached_dns_resolver: CachedDnsResolver,
https_connector: HttpsConnectorWithSni,
handle: tokio::runtime::Handle,
+ address_cache: AddressCache,
}
#[derive(err_derive::Error, Debug)]
@@ -55,24 +54,26 @@ impl MullvadRpcRuntime {
/// Create a new `MullvadRpcRuntime`.
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Ok(MullvadRpcRuntime {
- cached_dns_resolver: CachedDnsResolver::new(API_HOST.to_owned(), None, API_IP),
https_connector: HttpsConnectorWithSni::new(),
handle,
+ address_cache: AddressCache::new(),
})
}
/// Create a new `MullvadRpcRuntime` using the specified cache directory.
- pub fn with_cache_dir(handle: tokio::runtime::Handle, cache_dir: &Path) -> Result<Self, Error> {
+ pub async fn with_cache_dir(
+ handle: tokio::runtime::Handle,
+ cache_dir: &Path,
+ ) -> Result<Self, Error> {
let cache_file = cache_dir.join(API_IP_CACHE_FILENAME);
- let cached_dns_resolver =
- CachedDnsResolver::new(API_HOST.to_owned(), Some(cache_file), API_IP);
+ let address_cache = AddressCache::with_cache(cache_file.into_boxed_path()).await;
let https_connector = HttpsConnectorWithSni::new();
Ok(MullvadRpcRuntime {
- cached_dns_resolver,
https_connector,
handle,
+ address_cache,
})
}
@@ -81,7 +82,11 @@ impl MullvadRpcRuntime {
let mut https_connector = self.https_connector.clone();
https_connector.set_sni_hostname(sni_hostname);
- let service = rest::RequestService::new(https_connector, self.handle.clone());
+ let service = rest::RequestService::new(
+ https_connector,
+ self.handle.clone(),
+ self.address_cache.clone(),
+ );
let handle = service.handle();
self.handle.spawn(service.into_future());
handle
@@ -90,11 +95,13 @@ impl MullvadRpcRuntime {
/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle(&mut self) -> rest::MullvadRestHandle {
let service = self.new_request_service(Some(API_HOST.to_owned()));
- let ip = self.cached_dns_resolver.resolve();
- let factory =
- rest::RequestFactory::new(API_HOST.to_owned(), Some(ip), Some("app".to_owned()));
+ let factory = rest::RequestFactory::new(
+ API_HOST.to_owned(),
+ Box::new(self.address_cache.clone()),
+ Some("app".to_owned()),
+ );
- rest::MullvadRestHandle { service, factory }
+ rest::MullvadRestHandle::new(service, factory, self.address_cache.clone())
}
/// Returns a new request service handle
@@ -412,3 +419,26 @@ impl WireguardKeyProxy {
Ok(())
}
}
+
+#[derive(Clone)]
+pub struct ApiProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+impl ApiProxy {
+ pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> {
+ let service = self.handle.service.clone();
+
+ let response = rest::send_request(
+ &self.handle.factory,
+ service,
+ "/v1/api-addrs",
+ Method::GET,
+ None,
+ StatusCode::OK,
+ )
+ .await?;
+
+ rest::deserialize_body(response).await
+ }
+}
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 025c22eaa8..2bf8511d24 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,3 +1,4 @@
+use crate::address_cache::AddressCache;
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Aborted},
@@ -10,7 +11,14 @@ use hyper::{
header::{self, HeaderValue},
Method, Uri,
};
-use std::{collections::BTreeMap, future::Future, mem, net::IpAddr, str::FromStr, time::Duration};
+use std::{
+ collections::BTreeMap,
+ future::Future,
+ mem,
+ net::{IpAddr, SocketAddr},
+ str::FromStr,
+ time::{Duration, Instant},
+};
use tokio::runtime::Handle;
pub use hyper::StatusCode;
@@ -18,6 +26,11 @@ pub use hyper::StatusCode;
pub type Request = hyper::Request<hyper::Body>;
pub type Response = hyper::Response<hyper::Body>;
+const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60);
+const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60);
+const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+const API_IP_CHECK_ERROR_INTERVAL: Duration = Duration::from_secs(15 * 60);
+
pub type Result<T> = std::result::Result<T, Error>;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
@@ -68,11 +81,12 @@ pub(crate) struct RequestService<C> {
handle: Handle,
next_id: u64,
in_flight_requests: BTreeMap<u64, AbortHandle>,
+ address_cache: AddressCache,
}
impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
/// Constructs a new request service.
- pub fn new(connector: C, handle: Handle) -> RequestService<C> {
+ pub fn new(connector: C, handle: Handle, address_cache: AddressCache) -> RequestService<C> {
let client = Self::new_client(connector.clone());
let (command_tx, command_rx) = mpsc::channel(1);
@@ -84,6 +98,7 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
next_id: 0,
connector,
handle,
+ address_cache,
}
}
@@ -106,11 +121,12 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
let mut tx = self.command_tx.clone();
let timeout = request.timeout();
- let (request_future, abort_handle) = abortable(
- self.client
- .request(request.into_request())
- .map_err(Error::from),
- );
+ let hyper_request = request.into_request();
+ let host_addr = get_request_socket_addr(&hyper_request);
+
+ let (request_future, abort_handle) =
+ abortable(self.client.request(hyper_request).map_err(Error::from));
+ let address_cache = self.address_cache.clone();
let future = async move {
let response =
@@ -119,6 +135,17 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
.map_err(Error::TimeoutError);
let response = flatten_result(flatten_result(response));
+ if let Some(host_addr) = host_addr {
+ if let Err(err) = &response {
+ match err {
+ Error::HyperError(_) | Error::TimeoutError(_) => {
+ address_cache.register_failure(host_addr, err);
+ }
+ _ => (),
+ }
+ }
+ }
+
if completion_tx.send(response).is_err() {
log::trace!(
@@ -165,6 +192,18 @@ impl<C: Connect + Clone + Send + Sync + 'static> RequestService<C> {
}
}
+fn get_request_socket_addr(request: &Request) -> Option<SocketAddr> {
+ let uri = request.uri();
+ let port = uri
+ .port_u16()
+ // Assuming HTTPS always
+ .unwrap_or(443);
+
+ let host_addr = uri.host().and_then(|host| host.parse::<IpAddr>().ok())?;
+
+ Some(SocketAddr::new(host_addr, port))
+}
+
#[derive(Clone)]
/// A handle to interact with a spawned `RequestService`.
@@ -301,18 +340,22 @@ pub struct ErrorResponse {
#[derive(Clone)]
pub struct RequestFactory {
- host: String,
- address: Option<IpAddr>,
+ hostname: String,
+ address_provider: Box<dyn AddressProvider>,
path_prefix: Option<String>,
pub timeout: Duration,
}
impl RequestFactory {
- pub fn new(host: String, address: Option<IpAddr>, path_prefix: Option<String>) -> Self {
+ pub fn new(
+ hostname: String,
+ address_provider: Box<dyn AddressProvider>,
+ path_prefix: Option<String>,
+ ) -> Self {
Self {
- host,
- address,
+ hostname,
+ address_provider,
path_prefix,
timeout: DEFAULT_TIMEOUT,
}
@@ -367,16 +410,13 @@ impl RequestFactory {
.method(method)
.uri(uri)
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
- .header(header::HOST, self.host.clone());
+ .header(header::HOST, self.hostname.clone());
request.body(hyper::Body::empty()).map_err(Error::HttpError)
}
fn get_uri(&self, path: &str) -> Result<Uri> {
- let host: &dyn std::fmt::Display = &self
- .address
- .map(|addr| addr.to_string())
- .unwrap_or_else(|| self.host.clone());
+ let host = self.address_provider.get_address();
let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
let uri = format!("https://{}/{}{}", host, prefix, path);
hyper::Uri::from_str(&uri).map_err(Error::UriError)
@@ -388,6 +428,29 @@ impl RequestFactory {
}
}
+pub trait AddressProvider: Send + Sync {
+ /// Must return a string that represents either a host or a host with port
+ fn get_address(&self) -> String;
+ fn clone_box(&self) -> Box<dyn AddressProvider>;
+}
+
+impl Clone for Box<dyn AddressProvider> {
+ fn clone(&self) -> Self {
+ self.clone_box()
+ }
+}
+
+impl AddressProvider for IpAddr {
+ /// Must return a string that represents either a host or a host with port
+ fn get_address(&self) -> String {
+ self.to_string()
+ }
+
+ fn clone_box(&self) -> Box<dyn AddressProvider> {
+ Box::new(*self)
+ }
+}
+
pub fn get_request<T: serde::de::DeserializeOwned>(
factory: &RequestFactory,
@@ -490,6 +553,51 @@ pub struct MullvadRestHandle {
}
impl MullvadRestHandle {
+ pub(crate) fn new(
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ address_cache: AddressCache,
+ ) -> Self {
+ let handle = Self { service, factory };
+ handle.spawn_api_address_fetcher(address_cache);
+
+ handle
+ }
+
+ fn spawn_api_address_fetcher(&self, address_cache: AddressCache) {
+ let handle = self.clone();
+
+ self.service.spawn(async move {
+ // always start the fetch after 15 minutes
+ let api_proxy = crate::ApiProxy { handle };
+ let mut next_check = Instant::now() + API_IP_CHECK_DELAY;
+
+ let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL;
+ let next_regular_check = || Instant::now() + API_IP_CHECK_INTERVAL;
+
+ let mut interval = tokio::time::interval_at(next_check.into(), TIMER_CHECK_INTERVAL);
+
+ loop {
+ interval.tick().await;
+ if next_check < Instant::now() {
+ match api_proxy.clone().get_api_addrs().await {
+ Ok(new_addrs) => {
+ log::debug!("Fetched new API addresses {:?}, will fetch again in {} hours", new_addrs, API_IP_CHECK_INTERVAL.as_secs() / ( 60 * 60 ));
+ if let Err(err) = address_cache.set_addresses(new_addrs).await {
+ log::error!("Failed to save newly updated API addresses: {}", err);
+ }
+ next_check = next_regular_check();
+ }
+ Err(err) => {
+ log::error!("Failed to fetch new API addresses: {}, will retry again in {} seconds", err, API_IP_CHECK_ERROR_INTERVAL.as_secs());
+ next_check = next_error_check();
+ }
+ }
+ }
+ }
+ });
+ }
+
pub fn service(&self) -> RequestServiceHandle {
self.service.clone()
}