summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api
diff options
context:
space:
mode:
authorEmīls <emils@mullvad.net>2024-12-09 10:59:51 +0100
committerEmīls <emils@mullvad.net>2025-01-02 10:29:05 +0100
commit3093408a057020fcd912976b892fbb6bc26e6293 (patch)
tree56ac01255fdc7f0c29aa6a8c5a8cdd863ff4f496 /mullvad-api
parent58efb1004f8ca762fe6aa95f541c6dc329ed790e (diff)
downloadmullvadvpn-3093408a057020fcd912976b892fbb6bc26e6293.tar.xz
mullvadvpn-3093408a057020fcd912976b892fbb6bc26e6293.zip
Remove global API endpoint
Diffstat (limited to 'mullvad-api')
-rw-r--r--mullvad-api/Cargo.toml9
-rw-r--r--mullvad-api/include/mullvad-api.h11
-rw-r--r--mullvad-api/src/address_cache.rs29
-rw-r--r--mullvad-api/src/bin/relay_list.rs10
-rw-r--r--mullvad-api/src/ffi/error.rs8
-rw-r--r--mullvad-api/src/ffi/mod.rs106
-rw-r--r--mullvad-api/src/https_client_with_sni.rs28
-rw-r--r--mullvad-api/src/lib.rs152
-rw-r--r--mullvad-api/src/rest.rs4
9 files changed, 215 insertions, 142 deletions
diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml
index fc9d7d899b..005d1a9505 100644
--- a/mullvad-api/Cargo.toml
+++ b/mullvad-api/Cargo.toml
@@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["macros", "time", "rt-multi-thread", "ne
tokio-rustls = { version = "0.26.0", features = ["logging", "tls12", "ring"], default-features = false}
tokio-socks = "0.5.1"
rustls-pemfile = "2.1.3"
+uuid = { version = "1.4.1", features = ["v4"] }
mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" }
mullvad-fs = { path = "../mullvad-fs" }
@@ -50,14 +51,6 @@ mockito = "1.6.1"
[build-dependencies]
cbindgen = { version = "0.24.3", default-features = false }
-[target.'cfg(target_os = "ios")'.dependencies]
-uuid = { version = "1.4.1", features = ["v4"] }
-
[lib]
crate-type = [ "rlib", "staticlib" ]
bench = false
-
-[[test]]
-name = "ffi"
-# required-features = [ "api-override" ]
-features = [ "api-override" ]
diff --git a/mullvad-api/include/mullvad-api.h b/mullvad-api/include/mullvad-api.h
index e0295b20aa..4e5f78aef2 100644
--- a/mullvad-api/include/mullvad-api.h
+++ b/mullvad-api/include/mullvad-api.h
@@ -49,15 +49,16 @@ typedef struct MullvadApiDevice {
* struct.
*
* * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address
- * representation
- * ("143.32.4.32:9090"), the port is mandatory.
+ * representation ("143.32.4.32:9090"), the port is mandatory.
*
* * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be
* used for TLS validation.
+ * * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift.
*/
struct MullvadApiError mullvad_api_client_initialize(struct MullvadApiClient *client_ptr,
const char *api_address_ptr,
- const char *hostname);
+ const char *hostname,
+ bool disable_tls);
/**
* Removes all devices from a given account
@@ -98,8 +99,8 @@ struct MullvadApiError mullvad_api_get_expiry(struct MullvadApiClient client_ptr
* * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the
* account that will have all of it's devices removed.
*
- * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function
- * doesn't return an error, the pointer will be initialized with a valid instance of
+ * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't
+ * return an error, the pointer will be initialized with a valid instance of
* `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices.
*/
struct MullvadApiError mullvad_api_list_devices(struct MullvadApiClient client_ptr,
diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs
index 0898f8da1f..a6a60146b4 100644
--- a/mullvad-api/src/address_cache.rs
+++ b/mullvad-api/src/address_cache.rs
@@ -1,7 +1,6 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
-use super::API;
-use crate::DnsResolver;
+use crate::{ApiEndpoint, DnsResolver};
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
@@ -38,42 +37,42 @@ impl DnsResolver for AddressCache {
#[derive(Clone)]
pub struct AddressCache {
+ hostname: String,
inner: Arc<Mutex<AddressCacheInner>>,
write_path: Option<Arc<Path>>,
}
impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
- pub fn new(write_path: Option<Box<Path>>) -> Self {
- Self::new_inner(API.address(), write_path)
- }
-
- pub fn with_static_addr(address: SocketAddr) -> Self {
- Self::new_inner(address, None)
+ pub fn new(endpoint: &ApiEndpoint, write_path: Option<Box<Path>>) -> Self {
+ Self::new_inner(endpoint.address(), endpoint.host().to_owned(), write_path)
}
/// Initialize cache using `read_path`, and write changes to `write_path`.
- pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ pub async fn from_file(
+ read_path: &Path,
+ write_path: Option<Box<Path>>,
+ hostname: String,
+ ) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
- Ok(Self::new_inner(
- read_address_file(read_path).await?,
- write_path,
- ))
+ let address = read_address_file(read_path).await?;
+ Ok(Self::new_inner(address, hostname, write_path))
}
- fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
+ fn new_inner(address: SocketAddr, hostname: String, write_path: Option<Box<Path>>) -> Self {
let cache = AddressCacheInner::from_address(address);
log::debug!("Using API address: {}", cache.address);
Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(Arc::from),
+ hostname,
}
}
/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
- if hostname.eq_ignore_ascii_case(API.host()) {
+ if hostname.eq_ignore_ascii_case(&self.hostname) {
Some(self.get_address().await)
} else {
None
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
index def32303ea..3ea771cc81 100644
--- a/mullvad-api/src/bin/relay_list.rs
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -2,14 +2,18 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
-use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
+use mullvad_api::{
+ proxy::ApiConnectionMode, rest::Error as RestError, ApiEndpoint, RelayListProxy,
+};
use std::process;
use talpid_types::ErrorExt;
#[tokio::main]
async fn main() {
- let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
- .expect("Failed to load runtime");
+ let runtime = mullvad_api::Runtime::new(
+ tokio::runtime::Handle::current(),
+ &ApiEndpoint::from_env_vars(),
+ );
let relay_list_request =
RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()))
diff --git a/mullvad-api/src/ffi/error.rs b/mullvad-api/src/ffi/error.rs
index 539a6c23a0..66ffc01220 100644
--- a/mullvad-api/src/ffi/error.rs
+++ b/mullvad-api/src/ffi/error.rs
@@ -13,6 +13,7 @@ pub enum MullvadApiErrorKind {
/// MullvadApiErrorKind contains a description and an error kind. If the error kind is
/// `MullvadApiErrorKind` is NoError, the pointer will be nil.
+#[derive(Debug)]
#[repr(C)]
pub struct MullvadApiError {
description: *mut libc::c_char,
@@ -47,6 +48,13 @@ impl MullvadApiError {
}
}
+ pub fn unwrap(&self) {
+ if !matches!(self.kind, MullvadApiErrorKind::NoError) {
+ let desc = unsafe { std::ffi::CStr::from_ptr(self.description) };
+ panic!("API ERROR - {:?} - {}", self.kind, desc.to_str().unwrap());
+ }
+ }
+
pub fn drop(self) {
if self.description.is_null() {
return;
diff --git a/mullvad-api/src/ffi/mod.rs b/mullvad-api/src/ffi/mod.rs
index a68ea40ed6..9677488257 100644
--- a/mullvad-api/src/ffi/mod.rs
+++ b/mullvad-api/src/ffi/mod.rs
@@ -1,3 +1,4 @@
+#![cfg(not(target_os = "android"))]
use std::{
ffi::{CStr, CString},
net::SocketAddr,
@@ -6,8 +7,9 @@ use std::{
};
use crate::{
+ proxy::ApiConnectionMode,
rest::{self, MullvadRestHandle},
- AccountsProxy, DevicesProxy,
+ AccountsProxy, ApiEndpoint, DevicesProxy,
};
mod device;
@@ -48,13 +50,13 @@ impl MullvadApiClient {
struct FfiClient {
tokio_runtime: tokio::runtime::Runtime,
api_runtime: crate::Runtime,
- api_hostname: String,
}
impl FfiClient {
unsafe fn new(
api_address_ptr: *const libc::c_char,
hostname: *const libc::c_char,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> Result<Self, MullvadApiError> {
// SAFETY: addr_str must be a valid pointer to a null-terminated string.
let addr_str = unsafe { string_from_raw_ptr(api_address_ptr)? };
@@ -68,12 +70,15 @@ impl FfiClient {
)
})?;
- // The call site guarantees that
- // api_hostname and api_address will never change after the first call to new.
- std::env::set_var(crate::env::API_HOST_VAR, &api_hostname);
- std::env::set_var(crate::env::API_ADDR_VAR, &addr_str);
- std::env::set_var(crate::env::API_FORCE_DIRECT_VAR, "0");
- std::env::set_var(crate::env::DISABLE_TLS_VAR, "0");
+ let endpoint = ApiEndpoint {
+ host: Some(api_hostname.clone()),
+ address: Some(api_address),
+ #[cfg(feature = "api-override")]
+ force_direct: false,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
+ };
+
let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
runtime_builder.worker_threads(2).enable_all();
@@ -83,14 +88,12 @@ impl FfiClient {
// It is imperative that the REST runtime is created within an async context, otherwise
// ApiAvailability panics.
- let api_runtime = tokio_runtime.block_on(async {
- crate::Runtime::with_static_addr(tokio_runtime.handle().clone(), api_address)
- });
+ let api_runtime = tokio_runtime
+ .block_on(async { crate::Runtime::new(tokio_runtime.handle().clone(), &endpoint) });
let context = FfiClient {
tokio_runtime,
api_runtime,
- api_hostname,
};
Ok(context)
@@ -204,7 +207,7 @@ impl FfiClient {
fn rest_handle(&self) -> MullvadRestHandle {
self.tokio_handle().block_on(async {
self.api_runtime
- .static_mullvad_rest_handle(self.api_hostname.clone())
+ .mullvad_rest_handle(ApiConnectionMode::Direct.into_provider())
})
}
@@ -229,18 +232,31 @@ impl FfiClient {
/// struct.
///
/// * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address
-/// representation
-/// ("143.32.4.32:9090"), the port is mandatory.
+/// representation ("143.32.4.32:9090"), the port is mandatory.
///
/// * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be
/// used for TLS validation.
+/// * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift.
#[no_mangle]
pub unsafe extern "C" fn mullvad_api_client_initialize(
client_ptr: *mut MullvadApiClient,
api_address_ptr: *const libc::c_char,
hostname: *const libc::c_char,
+ disable_tls: bool,
) -> MullvadApiError {
- match unsafe { FfiClient::new(api_address_ptr, hostname) } {
+ #[cfg(not(any(feature = "api-override", test)))]
+ if disable_tls {
+ log::error!("disable_tls has no effect when mullvad-api is built without api-override");
+ }
+
+ match unsafe {
+ FfiClient::new(
+ api_address_ptr,
+ hostname,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
+ )
+ } {
Ok(client) => {
unsafe {
std::ptr::write(client_ptr, MullvadApiClient::new(client));
@@ -306,8 +322,8 @@ pub unsafe extern "C" fn mullvad_api_get_expiry(
/// * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the
/// account that will have all of it's devices removed.
///
-/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function
-/// doesn't return an error, the pointer will be initialized with a valid instance of
+/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't
+/// return an error, the pointer will be initialized with a valid instance of
/// `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices.
#[no_mangle]
pub unsafe extern "C" fn mullvad_api_list_devices(
@@ -443,3 +459,57 @@ unsafe fn string_from_raw_ptr(ptr: *const libc::c_char) -> Result<String, Mullva
})?
.to_owned())
}
+
+#[cfg(test)]
+mod test {
+ use mockito::{Server, ServerGuard};
+ use std::{mem::MaybeUninit, net::Ipv4Addr};
+
+ use super::*;
+ const STAGING_HOSTNAME: &[u8] = b"api-app.stagemole.eu\0";
+
+ #[test]
+ fn test_initialization() {
+ let _ = create_client(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 1));
+ }
+
+ fn create_client(addr: &SocketAddr) -> MullvadApiClient {
+ let mut client = MaybeUninit::<MullvadApiClient>::uninit();
+ let cstr_address = CString::new(addr.to_string()).unwrap();
+ unsafe {
+ mullvad_api_client_initialize(
+ client.as_mut_ptr(),
+ cstr_address.as_ptr().cast(),
+ STAGING_HOSTNAME.as_ptr().cast(),
+ true,
+ )
+ .unwrap();
+ };
+ unsafe { client.assume_init() }
+ }
+
+ #[test]
+ fn test_create_delete_account() {
+ let server = test_server();
+ let client = create_client(&server.socket_address());
+
+ let mut account_buf = vec![0 as libc::c_char; 100];
+ unsafe { mullvad_api_create_account(client, account_buf.as_mut_ptr().cast()).unwrap() };
+ }
+
+ fn test_server() -> ServerGuard {
+ let mut server = Server::new();
+ let expected_create_account_response = br#"{"id":"085df870-0fc2-47cb-9e8c-cb43c1bdaac0","expiry":"2024-12-11T12:56:32+00:00","max_ports":0,"can_add_ports":false,"max_devices":5,"can_add_devices":true,"number":"6705749539195318"}"#;
+ server
+ .mock(
+ "POST",
+ &*("/".to_string() + crate::ACCOUNTS_URL_PREFIX + "/accounts"),
+ )
+ .with_header("content-type", "application/json")
+ .with_status(201)
+ .with_body(expected_create_account_response)
+ .create();
+
+ server
+ }
+}
diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs
index 3dfd168a92..f86c538a67 100644
--- a/mullvad-api/src/https_client_with_sni.rs
+++ b/mullvad-api/src/https_client_with_sni.rs
@@ -41,8 +41,8 @@ use tokio::{
};
use tower::Service;
-#[cfg(feature = "api-override")]
-use crate::{proxy::ConnectionDecorator, API};
+#[cfg(any(feature = "api-override", test))]
+use crate::proxy::ConnectionDecorator;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
@@ -89,6 +89,7 @@ impl InnerConnectionMode {
hostname: &str,
addr: &SocketAddr,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> Result<ApiConnection, std::io::Error> {
match self {
// Set up a TCP-socket connection.
@@ -101,6 +102,8 @@ impl InnerConnectionMode {
make_proxy_stream,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
)
.await
}
@@ -121,6 +124,8 @@ impl InnerConnectionMode {
make_proxy_stream,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
)
.await
}
@@ -153,6 +158,8 @@ impl InnerConnectionMode {
make_proxy_stream,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
)
.await
}
@@ -168,6 +175,8 @@ impl InnerConnectionMode {
make_proxy_stream,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
)
.await
}
@@ -191,6 +200,7 @@ impl InnerConnectionMode {
hostname: &str,
make_proxy_stream: ProxyFactory,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> Result<ApiConnection, io::Error>
where
ProxyFactory: FnOnce(TcpStream) -> ProxyFuture,
@@ -206,8 +216,8 @@ impl InnerConnectionMode {
let proxy = make_proxy_stream(socket).await?;
- #[cfg(feature = "api-override")]
- if API.disable_tls {
+ #[cfg(any(feature = "api-override", test))]
+ if disable_tls {
return Ok(ApiConnection::new(Box::new(ConnectionDecorator(proxy))));
}
@@ -290,6 +300,8 @@ pub struct HttpsConnectorWithSni {
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls: bool,
}
struct HttpsConnectorWithSniInner {
@@ -304,6 +316,7 @@ impl HttpsConnectorWithSni {
pub fn new(
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> (Self, HttpsConnectorWithSniHandle) {
let (tx, mut rx) = mpsc::unbounded();
let abort_notify = Arc::new(tokio::sync::Notify::new());
@@ -352,6 +365,8 @@ impl HttpsConnectorWithSni {
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
},
HttpsConnectorWithSniHandle { tx },
)
@@ -435,6 +450,9 @@ impl Service<Uri> for HttpsConnectorWithSni {
let socket_bypass_tx = self.socket_bypass_tx.clone();
let dns_resolver = self.dns_resolver.clone();
+ #[cfg(any(feature = "api-override", test))]
+ let disable_tls = self.disable_tls;
+
let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
return Err(io::Error::new(
@@ -460,6 +478,8 @@ impl Service<Uri> for HttpsConnectorWithSni {
&addr,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
);
pin_mut!(stream_fut);
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
index 3b02e4fe98..a47c708b2e 100644
--- a/mullvad-api/src/lib.rs
+++ b/mullvad-api/src/lib.rs
@@ -10,14 +10,12 @@ use mullvad_types::{
};
use proxy::{ApiConnectionMode, ConnectionModeProvider};
use std::{
- cell::Cell,
collections::BTreeMap,
future::Future,
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
- ops::Deref,
path::Path,
- sync::{Arc, OnceLock},
+ sync::Arc,
};
use talpid_types::ErrorExt;
@@ -37,7 +35,6 @@ mod address_cache;
pub mod device;
mod relay_list;
-#[cfg(target_os = "ios")]
pub mod ffi;
pub use address_cache::AddressCache;
@@ -70,41 +67,6 @@ const APP_URL_PREFIX: &str = "app/v1";
#[cfg(target_os = "android")]
const GOOGLE_PAYMENTS_URL_PREFIX: &str = "payments/google-play/v1";
-pub static API: LazyManual<ApiEndpoint> = LazyManual::new(ApiEndpoint::from_env_vars);
-
-unsafe impl<T, F: Send> Sync for LazyManual<T, F> where OnceLock<T>: Sync {}
-
-/// A value that is either initialized on access or explicitly.
-pub struct LazyManual<T, F = fn() -> T> {
- cell: OnceLock<T>,
- lazy_fn: Cell<Option<F>>,
-}
-
-impl<T, F> LazyManual<T, F> {
- const fn new(lazy_fn: F) -> Self {
- Self {
- cell: OnceLock::new(),
- lazy_fn: Cell::new(Some(lazy_fn)),
- }
- }
-
- /// Tries to initialize the object. An error is returned if it is
- /// already initialized.
- #[cfg(feature = "api-override")]
- pub fn override_init(&self, val: T) -> Result<(), T> {
- let _ = self.lazy_fn.take();
- self.cell.set(val)
- }
-}
-
-impl<T> Deref for LazyManual<T> {
- type Target = T;
-
- fn deref(&self) -> &Self::Target {
- self.cell.get_or_init(|| (self.lazy_fn.take().unwrap())())
- }
-}
-
pub mod env {
pub const API_HOST_VAR: &str = "MULLVAD_API_HOST";
pub const API_ADDR_VAR: &str = "MULLVAD_API_ADDR";
@@ -113,7 +75,7 @@ pub mod env {
}
/// A hostname and socketaddr to reach the Mullvad REST API over.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct ApiEndpoint {
/// An overriden API hostname. Initialized with the value of the environment
/// variable `MULLVAD_API_HOST` if it has been set.
@@ -132,9 +94,7 @@ pub struct ApiEndpoint {
/// If [`Self::address`] is populated with [`Some(SocketAddr)`], it should
/// always be respected when establishing API connections.
pub address: Option<SocketAddr>,
- #[cfg(feature = "api-override")]
- pub disable_address_cache: bool,
- #[cfg(feature = "api-override")]
+ #[cfg(any(feature = "api-override", test))]
pub disable_tls: bool,
#[cfg(feature = "api-override")]
/// Whether bridges/proxies can be used to access the API or not. This is
@@ -175,7 +135,6 @@ impl ApiEndpoint {
let mut api = ApiEndpoint {
host: None,
address: None,
- disable_address_cache: host_var.is_some() || address_var.is_some(),
disable_tls: false,
force_direct: force_direct
.map(|force_direct| force_direct != "0")
@@ -244,6 +203,11 @@ impl ApiEndpoint {
api
}
+ #[cfg(feature = "api-override")]
+ pub fn should_disable_address_cache(&self) -> bool {
+ self.host.is_some() || self.address.is_some()
+ }
+
/// Returns the endpoint to connect to the API over.
///
/// # Panics
@@ -269,9 +233,31 @@ impl ApiEndpoint {
ApiEndpoint {
host: None,
address: None,
+ #[cfg(test)]
+ disable_tls: false,
}
}
+ /// Returns a new API endpoint with the given host and socket address.
+ pub fn new(
+ host: String,
+ address: SocketAddr,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
+ ) -> Self {
+ Self {
+ host: Some(host),
+ address: Some(address),
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
+ #[cfg(feature = "api-override")]
+ force_direct: false,
+ }
+ }
+
+ pub fn set_addr(&mut self, address: SocketAddr) {
+ self.address = Some(address);
+ }
+
/// Read the [`Self::host`] value, falling back to
/// [`Self::API_HOST_DEFAULT`] as default value if it does not exist.
pub fn host(&self) -> &str {
@@ -342,6 +328,7 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
+ endpoint: ApiEndpoint,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
@@ -362,40 +349,27 @@ pub enum Error {
}
impl Runtime {
- /// Create a new `Runtime`.
- pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
- Self::new_inner(
- handle,
- #[cfg(target_os = "android")]
- None,
- )
- }
-
- #[cfg(target_os = "ios")]
- pub fn with_static_addr(handle: tokio::runtime::Handle, address: SocketAddr) -> Self {
- Runtime {
- handle,
- address_cache: AddressCache::with_static_addr(address),
- api_availability: ApiAvailability::default(),
- }
- }
-
- fn new_inner(
+ /// Will create a new Runtime without a cache with the provided API endpoint.
+ pub fn new(
handle: tokio::runtime::Handle,
+ endpoint: &ApiEndpoint,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
- ) -> Result<Self, Error> {
- Ok(Runtime {
+ ) -> Self {
+ Runtime {
handle,
- address_cache: AddressCache::new(None),
+ address_cache: AddressCache::new(endpoint, None),
api_availability: ApiAvailability::default(),
+ endpoint: endpoint.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
- })
+ }
}
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
+ /// Will try to construct an API endpoint from the environment.
pub async fn with_cache(
+ endpoint: &ApiEndpoint,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
@@ -403,12 +377,13 @@ impl Runtime {
let handle = tokio::runtime::Handle::current();
#[cfg(feature = "api-override")]
- if API.disable_address_cache {
- return Self::new_inner(
+ if endpoint.should_disable_address_cache() {
+ return Ok(Self::new(
handle,
+ endpoint,
#[cfg(target_os = "android")]
socket_bypass_tx,
- );
+ ));
}
let cache_file = cache_dir.join(API_IP_CACHE_FILENAME);
@@ -418,7 +393,13 @@ impl Runtime {
None
};
- let address_cache = match AddressCache::from_file(&cache_file, write_file.clone()).await {
+ let address_cache = match AddressCache::from_file(
+ &cache_file,
+ write_file.clone(),
+ endpoint.host().to_owned(),
+ )
+ .await
+ {
Ok(cache) => cache,
Err(error) => {
if cache_file.exists() {
@@ -429,7 +410,7 @@ impl Runtime {
)
);
}
- AddressCache::new(write_file)
+ AddressCache::new(endpoint, write_file)
}
};
@@ -439,12 +420,14 @@ impl Runtime {
handle,
address_cache,
api_availability,
+ endpoint: endpoint.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}
- /// Returns a request factory initialized to create requests for the master API
+ /// Returns a request factory initialized to create requests for the master API Assumes an API
+ /// endpoint that is constructed from env vars, or uses default values.
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
connection_mode_provider: T,
@@ -454,21 +437,10 @@ impl Runtime {
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
+ #[cfg(any(feature = "api-override", test))]
+ self.endpoint.disable_tls,
);
- let token_store = access::AccessTokenStore::new(service.clone(), API.host());
- let factory = rest::RequestFactory::new(API.host().to_owned(), Some(token_store));
-
- rest::MullvadRestHandle::new(service, factory, self.availability_handle())
- }
-
- /// This is only to be used in test code
- pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
- let service = self.new_request_service(
- ApiConnectionMode::Direct.into_provider(),
- Arc::new(self.address_cache.clone()),
- #[cfg(target_os = "android")]
- self.socket_bypass_tx.clone(),
- );
+ let hostname = self.endpoint.host().to_owned();
let token_store = access::AccessTokenStore::new(service.clone(), hostname.clone());
let factory = rest::RequestFactory::new(hostname, Some(token_store));
@@ -482,6 +454,8 @@ impl Runtime {
Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
+ #[cfg(any(feature = "api-override", test))]
+ false,
)
}
@@ -491,6 +465,7 @@ impl Runtime {
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
self.api_availability.clone(),
@@ -498,6 +473,8 @@ impl Runtime {
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
)
}
@@ -582,7 +559,6 @@ impl AccountsProxy {
}
}
- #[cfg(target_os = "ios")]
pub fn delete_account(
&self,
account: AccountNumber,
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index 5b93eea311..cab3bb7e0f 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -154,11 +154,14 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ #[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
+ #[cfg(any(feature = "api-override", test))]
+ disable_tls,
);
connector_handle.set_connection_mode(connection_mode_provider.initial());
@@ -461,7 +464,6 @@ where
}
// Parse unexpected responses and errors
-
let response = response?;
if !self.expected_status.contains(&response.status()) {