diff options
| author | Emīls <emils@mullvad.net> | 2024-12-09 10:59:51 +0100 |
|---|---|---|
| committer | Emīls <emils@mullvad.net> | 2025-01-02 10:29:05 +0100 |
| commit | 3093408a057020fcd912976b892fbb6bc26e6293 (patch) | |
| tree | 56ac01255fdc7f0c29aa6a8c5a8cdd863ff4f496 /mullvad-api | |
| parent | 58efb1004f8ca762fe6aa95f541c6dc329ed790e (diff) | |
| download | mullvadvpn-3093408a057020fcd912976b892fbb6bc26e6293.tar.xz mullvadvpn-3093408a057020fcd912976b892fbb6bc26e6293.zip | |
Remove global API endpoint
Diffstat (limited to 'mullvad-api')
| -rw-r--r-- | mullvad-api/Cargo.toml | 9 | ||||
| -rw-r--r-- | mullvad-api/include/mullvad-api.h | 11 | ||||
| -rw-r--r-- | mullvad-api/src/address_cache.rs | 29 | ||||
| -rw-r--r-- | mullvad-api/src/bin/relay_list.rs | 10 | ||||
| -rw-r--r-- | mullvad-api/src/ffi/error.rs | 8 | ||||
| -rw-r--r-- | mullvad-api/src/ffi/mod.rs | 106 | ||||
| -rw-r--r-- | mullvad-api/src/https_client_with_sni.rs | 28 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 152 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 4 |
9 files changed, 215 insertions, 142 deletions
diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml index fc9d7d899b..005d1a9505 100644 --- a/mullvad-api/Cargo.toml +++ b/mullvad-api/Cargo.toml @@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["macros", "time", "rt-multi-thread", "ne tokio-rustls = { version = "0.26.0", features = ["logging", "tls12", "ring"], default-features = false} tokio-socks = "0.5.1" rustls-pemfile = "2.1.3" +uuid = { version = "1.4.1", features = ["v4"] } mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" } mullvad-fs = { path = "../mullvad-fs" } @@ -50,14 +51,6 @@ mockito = "1.6.1" [build-dependencies] cbindgen = { version = "0.24.3", default-features = false } -[target.'cfg(target_os = "ios")'.dependencies] -uuid = { version = "1.4.1", features = ["v4"] } - [lib] crate-type = [ "rlib", "staticlib" ] bench = false - -[[test]] -name = "ffi" -# required-features = [ "api-override" ] -features = [ "api-override" ] diff --git a/mullvad-api/include/mullvad-api.h b/mullvad-api/include/mullvad-api.h index e0295b20aa..4e5f78aef2 100644 --- a/mullvad-api/include/mullvad-api.h +++ b/mullvad-api/include/mullvad-api.h @@ -49,15 +49,16 @@ typedef struct MullvadApiDevice { * struct. * * * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address - * representation - * ("143.32.4.32:9090"), the port is mandatory. + * representation ("143.32.4.32:9090"), the port is mandatory. * * * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be * used for TLS validation. + * * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift. */ struct MullvadApiError mullvad_api_client_initialize(struct MullvadApiClient *client_ptr, const char *api_address_ptr, - const char *hostname); + const char *hostname, + bool disable_tls); /** * Removes all devices from a given account @@ -98,8 +99,8 @@ struct MullvadApiError mullvad_api_get_expiry(struct MullvadApiClient client_ptr * * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the * account that will have all of it's devices removed. * - * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function - * doesn't return an error, the pointer will be initialized with a valid instance of + * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't + * return an error, the pointer will be initialized with a valid instance of * `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices. */ struct MullvadApiError mullvad_api_list_devices(struct MullvadApiClient client_ptr, diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index 0898f8da1f..a6a60146b4 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -1,7 +1,6 @@ //! This module keeps track of the last known good API IP address and reads and stores it on disk. -use super::API; -use crate::DnsResolver; +use crate::{ApiEndpoint, DnsResolver}; use async_trait::async_trait; use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ @@ -38,42 +37,42 @@ impl DnsResolver for AddressCache { #[derive(Clone)] pub struct AddressCache { + hostname: String, inner: Arc<Mutex<AddressCacheInner>>, write_path: Option<Arc<Path>>, } impl AddressCache { /// Initialize cache using the hardcoded address, and write changes to `write_path`. - pub fn new(write_path: Option<Box<Path>>) -> Self { - Self::new_inner(API.address(), write_path) - } - - pub fn with_static_addr(address: SocketAddr) -> Self { - Self::new_inner(address, None) + pub fn new(endpoint: &ApiEndpoint, write_path: Option<Box<Path>>) -> Self { + Self::new_inner(endpoint.address(), endpoint.host().to_owned(), write_path) } /// Initialize cache using `read_path`, and write changes to `write_path`. - pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> { + pub async fn from_file( + read_path: &Path, + write_path: Option<Box<Path>>, + hostname: String, + ) -> Result<Self, Error> { log::debug!("Loading API addresses from {}", read_path.display()); - Ok(Self::new_inner( - read_address_file(read_path).await?, - write_path, - )) + let address = read_address_file(read_path).await?; + Ok(Self::new_inner(address, hostname, write_path)) } - fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self { + fn new_inner(address: SocketAddr, hostname: String, write_path: Option<Box<Path>>) -> Self { let cache = AddressCacheInner::from_address(address); log::debug!("Using API address: {}", cache.address); Self { inner: Arc::new(Mutex::new(cache)), write_path: write_path.map(Arc::from), + hostname, } } /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> { - if hostname.eq_ignore_ascii_case(API.host()) { + if hostname.eq_ignore_ascii_case(&self.hostname) { Some(self.get_address().await) } else { None diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index def32303ea..3ea771cc81 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,14 +2,18 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; +use mullvad_api::{ + proxy::ApiConnectionMode, rest::Error as RestError, ApiEndpoint, RelayListProxy, +}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) - .expect("Failed to load runtime"); + let runtime = mullvad_api::Runtime::new( + tokio::runtime::Handle::current(), + &ApiEndpoint::from_env_vars(), + ); let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider())) diff --git a/mullvad-api/src/ffi/error.rs b/mullvad-api/src/ffi/error.rs index 539a6c23a0..66ffc01220 100644 --- a/mullvad-api/src/ffi/error.rs +++ b/mullvad-api/src/ffi/error.rs @@ -13,6 +13,7 @@ pub enum MullvadApiErrorKind { /// MullvadApiErrorKind contains a description and an error kind. If the error kind is /// `MullvadApiErrorKind` is NoError, the pointer will be nil. +#[derive(Debug)] #[repr(C)] pub struct MullvadApiError { description: *mut libc::c_char, @@ -47,6 +48,13 @@ impl MullvadApiError { } } + pub fn unwrap(&self) { + if !matches!(self.kind, MullvadApiErrorKind::NoError) { + let desc = unsafe { std::ffi::CStr::from_ptr(self.description) }; + panic!("API ERROR - {:?} - {}", self.kind, desc.to_str().unwrap()); + } + } + pub fn drop(self) { if self.description.is_null() { return; diff --git a/mullvad-api/src/ffi/mod.rs b/mullvad-api/src/ffi/mod.rs index a68ea40ed6..9677488257 100644 --- a/mullvad-api/src/ffi/mod.rs +++ b/mullvad-api/src/ffi/mod.rs @@ -1,3 +1,4 @@ +#![cfg(not(target_os = "android"))] use std::{ ffi::{CStr, CString}, net::SocketAddr, @@ -6,8 +7,9 @@ use std::{ }; use crate::{ + proxy::ApiConnectionMode, rest::{self, MullvadRestHandle}, - AccountsProxy, DevicesProxy, + AccountsProxy, ApiEndpoint, DevicesProxy, }; mod device; @@ -48,13 +50,13 @@ impl MullvadApiClient { struct FfiClient { tokio_runtime: tokio::runtime::Runtime, api_runtime: crate::Runtime, - api_hostname: String, } impl FfiClient { unsafe fn new( api_address_ptr: *const libc::c_char, hostname: *const libc::c_char, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result<Self, MullvadApiError> { // SAFETY: addr_str must be a valid pointer to a null-terminated string. let addr_str = unsafe { string_from_raw_ptr(api_address_ptr)? }; @@ -68,12 +70,15 @@ impl FfiClient { ) })?; - // The call site guarantees that - // api_hostname and api_address will never change after the first call to new. - std::env::set_var(crate::env::API_HOST_VAR, &api_hostname); - std::env::set_var(crate::env::API_ADDR_VAR, &addr_str); - std::env::set_var(crate::env::API_FORCE_DIRECT_VAR, "0"); - std::env::set_var(crate::env::DISABLE_TLS_VAR, "0"); + let endpoint = ApiEndpoint { + host: Some(api_hostname.clone()), + address: Some(api_address), + #[cfg(feature = "api-override")] + force_direct: false, + #[cfg(any(feature = "api-override", test))] + disable_tls, + }; + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); runtime_builder.worker_threads(2).enable_all(); @@ -83,14 +88,12 @@ impl FfiClient { // It is imperative that the REST runtime is created within an async context, otherwise // ApiAvailability panics. - let api_runtime = tokio_runtime.block_on(async { - crate::Runtime::with_static_addr(tokio_runtime.handle().clone(), api_address) - }); + let api_runtime = tokio_runtime + .block_on(async { crate::Runtime::new(tokio_runtime.handle().clone(), &endpoint) }); let context = FfiClient { tokio_runtime, api_runtime, - api_hostname, }; Ok(context) @@ -204,7 +207,7 @@ impl FfiClient { fn rest_handle(&self) -> MullvadRestHandle { self.tokio_handle().block_on(async { self.api_runtime - .static_mullvad_rest_handle(self.api_hostname.clone()) + .mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()) }) } @@ -229,18 +232,31 @@ impl FfiClient { /// struct. /// /// * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address -/// representation -/// ("143.32.4.32:9090"), the port is mandatory. +/// representation ("143.32.4.32:9090"), the port is mandatory. /// /// * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be /// used for TLS validation. +/// * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift. #[no_mangle] pub unsafe extern "C" fn mullvad_api_client_initialize( client_ptr: *mut MullvadApiClient, api_address_ptr: *const libc::c_char, hostname: *const libc::c_char, + disable_tls: bool, ) -> MullvadApiError { - match unsafe { FfiClient::new(api_address_ptr, hostname) } { + #[cfg(not(any(feature = "api-override", test)))] + if disable_tls { + log::error!("disable_tls has no effect when mullvad-api is built without api-override"); + } + + match unsafe { + FfiClient::new( + api_address_ptr, + hostname, + #[cfg(any(feature = "api-override", test))] + disable_tls, + ) + } { Ok(client) => { unsafe { std::ptr::write(client_ptr, MullvadApiClient::new(client)); @@ -306,8 +322,8 @@ pub unsafe extern "C" fn mullvad_api_get_expiry( /// * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the /// account that will have all of it's devices removed. /// -/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function -/// doesn't return an error, the pointer will be initialized with a valid instance of +/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't +/// return an error, the pointer will be initialized with a valid instance of /// `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices. #[no_mangle] pub unsafe extern "C" fn mullvad_api_list_devices( @@ -443,3 +459,57 @@ unsafe fn string_from_raw_ptr(ptr: *const libc::c_char) -> Result<String, Mullva })? .to_owned()) } + +#[cfg(test)] +mod test { + use mockito::{Server, ServerGuard}; + use std::{mem::MaybeUninit, net::Ipv4Addr}; + + use super::*; + const STAGING_HOSTNAME: &[u8] = b"api-app.stagemole.eu\0"; + + #[test] + fn test_initialization() { + let _ = create_client(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 1)); + } + + fn create_client(addr: &SocketAddr) -> MullvadApiClient { + let mut client = MaybeUninit::<MullvadApiClient>::uninit(); + let cstr_address = CString::new(addr.to_string()).unwrap(); + unsafe { + mullvad_api_client_initialize( + client.as_mut_ptr(), + cstr_address.as_ptr().cast(), + STAGING_HOSTNAME.as_ptr().cast(), + true, + ) + .unwrap(); + }; + unsafe { client.assume_init() } + } + + #[test] + fn test_create_delete_account() { + let server = test_server(); + let client = create_client(&server.socket_address()); + + let mut account_buf = vec![0 as libc::c_char; 100]; + unsafe { mullvad_api_create_account(client, account_buf.as_mut_ptr().cast()).unwrap() }; + } + + fn test_server() -> ServerGuard { + let mut server = Server::new(); + let expected_create_account_response = br#"{"id":"085df870-0fc2-47cb-9e8c-cb43c1bdaac0","expiry":"2024-12-11T12:56:32+00:00","max_ports":0,"can_add_ports":false,"max_devices":5,"can_add_devices":true,"number":"6705749539195318"}"#; + server + .mock( + "POST", + &*("/".to_string() + crate::ACCOUNTS_URL_PREFIX + "/accounts"), + ) + .with_header("content-type", "application/json") + .with_status(201) + .with_body(expected_create_account_response) + .create(); + + server + } +} diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 3dfd168a92..f86c538a67 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -41,8 +41,8 @@ use tokio::{ }; use tower::Service; -#[cfg(feature = "api-override")] -use crate::{proxy::ConnectionDecorator, API}; +#[cfg(any(feature = "api-override", test))] +use crate::proxy::ConnectionDecorator; const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); @@ -89,6 +89,7 @@ impl InnerConnectionMode { hostname: &str, addr: &SocketAddr, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result<ApiConnection, std::io::Error> { match self { // Set up a TCP-socket connection. @@ -101,6 +102,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -121,6 +124,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -153,6 +158,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -168,6 +175,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -191,6 +200,7 @@ impl InnerConnectionMode { hostname: &str, make_proxy_stream: ProxyFactory, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result<ApiConnection, io::Error> where ProxyFactory: FnOnce(TcpStream) -> ProxyFuture, @@ -206,8 +216,8 @@ impl InnerConnectionMode { let proxy = make_proxy_stream(socket).await?; - #[cfg(feature = "api-override")] - if API.disable_tls { + #[cfg(any(feature = "api-override", test))] + if disable_tls { return Ok(ApiConnection::new(Box::new(ConnectionDecorator(proxy)))); } @@ -290,6 +300,8 @@ pub struct HttpsConnectorWithSni { dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] + disable_tls: bool, } struct HttpsConnectorWithSniInner { @@ -304,6 +316,7 @@ impl HttpsConnectorWithSni { pub fn new( dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> (Self, HttpsConnectorWithSniHandle) { let (tx, mut rx) = mpsc::unbounded(); let abort_notify = Arc::new(tokio::sync::Notify::new()); @@ -352,6 +365,8 @@ impl HttpsConnectorWithSni { dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, }, HttpsConnectorWithSniHandle { tx }, ) @@ -435,6 +450,9 @@ impl Service<Uri> for HttpsConnectorWithSni { let socket_bypass_tx = self.socket_bypass_tx.clone(); let dns_resolver = self.dns_resolver.clone(); + #[cfg(any(feature = "api-override", test))] + let disable_tls = self.disable_tls; + let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { return Err(io::Error::new( @@ -460,6 +478,8 @@ impl Service<Uri> for HttpsConnectorWithSni { &addr, #[cfg(target_os = "android")] socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + disable_tls, ); pin_mut!(stream_fut); diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 3b02e4fe98..a47c708b2e 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -10,14 +10,12 @@ use mullvad_types::{ }; use proxy::{ApiConnectionMode, ConnectionModeProvider}; use std::{ - cell::Cell, collections::BTreeMap, future::Future, io, net::{IpAddr, Ipv4Addr, SocketAddr}, - ops::Deref, path::Path, - sync::{Arc, OnceLock}, + sync::Arc, }; use talpid_types::ErrorExt; @@ -37,7 +35,6 @@ mod address_cache; pub mod device; mod relay_list; -#[cfg(target_os = "ios")] pub mod ffi; pub use address_cache::AddressCache; @@ -70,41 +67,6 @@ const APP_URL_PREFIX: &str = "app/v1"; #[cfg(target_os = "android")] const GOOGLE_PAYMENTS_URL_PREFIX: &str = "payments/google-play/v1"; -pub static API: LazyManual<ApiEndpoint> = LazyManual::new(ApiEndpoint::from_env_vars); - -unsafe impl<T, F: Send> Sync for LazyManual<T, F> where OnceLock<T>: Sync {} - -/// A value that is either initialized on access or explicitly. -pub struct LazyManual<T, F = fn() -> T> { - cell: OnceLock<T>, - lazy_fn: Cell<Option<F>>, -} - -impl<T, F> LazyManual<T, F> { - const fn new(lazy_fn: F) -> Self { - Self { - cell: OnceLock::new(), - lazy_fn: Cell::new(Some(lazy_fn)), - } - } - - /// Tries to initialize the object. An error is returned if it is - /// already initialized. - #[cfg(feature = "api-override")] - pub fn override_init(&self, val: T) -> Result<(), T> { - let _ = self.lazy_fn.take(); - self.cell.set(val) - } -} - -impl<T> Deref for LazyManual<T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.cell.get_or_init(|| (self.lazy_fn.take().unwrap())()) - } -} - pub mod env { pub const API_HOST_VAR: &str = "MULLVAD_API_HOST"; pub const API_ADDR_VAR: &str = "MULLVAD_API_ADDR"; @@ -113,7 +75,7 @@ pub mod env { } /// A hostname and socketaddr to reach the Mullvad REST API over. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ApiEndpoint { /// An overriden API hostname. Initialized with the value of the environment /// variable `MULLVAD_API_HOST` if it has been set. @@ -132,9 +94,7 @@ pub struct ApiEndpoint { /// If [`Self::address`] is populated with [`Some(SocketAddr)`], it should /// always be respected when establishing API connections. pub address: Option<SocketAddr>, - #[cfg(feature = "api-override")] - pub disable_address_cache: bool, - #[cfg(feature = "api-override")] + #[cfg(any(feature = "api-override", test))] pub disable_tls: bool, #[cfg(feature = "api-override")] /// Whether bridges/proxies can be used to access the API or not. This is @@ -175,7 +135,6 @@ impl ApiEndpoint { let mut api = ApiEndpoint { host: None, address: None, - disable_address_cache: host_var.is_some() || address_var.is_some(), disable_tls: false, force_direct: force_direct .map(|force_direct| force_direct != "0") @@ -244,6 +203,11 @@ impl ApiEndpoint { api } + #[cfg(feature = "api-override")] + pub fn should_disable_address_cache(&self) -> bool { + self.host.is_some() || self.address.is_some() + } + /// Returns the endpoint to connect to the API over. /// /// # Panics @@ -269,9 +233,31 @@ impl ApiEndpoint { ApiEndpoint { host: None, address: None, + #[cfg(test)] + disable_tls: false, } } + /// Returns a new API endpoint with the given host and socket address. + pub fn new( + host: String, + address: SocketAddr, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, + ) -> Self { + Self { + host: Some(host), + address: Some(address), + #[cfg(any(feature = "api-override", test))] + disable_tls, + #[cfg(feature = "api-override")] + force_direct: false, + } + } + + pub fn set_addr(&mut self, address: SocketAddr) { + self.address = Some(address); + } + /// Read the [`Self::host`] value, falling back to /// [`Self::API_HOST_DEFAULT`] as default value if it does not exist. pub fn host(&self) -> &str { @@ -342,6 +328,7 @@ pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, + endpoint: ApiEndpoint, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, } @@ -362,40 +349,27 @@ pub enum Error { } impl Runtime { - /// Create a new `Runtime`. - pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { - Self::new_inner( - handle, - #[cfg(target_os = "android")] - None, - ) - } - - #[cfg(target_os = "ios")] - pub fn with_static_addr(handle: tokio::runtime::Handle, address: SocketAddr) -> Self { - Runtime { - handle, - address_cache: AddressCache::with_static_addr(address), - api_availability: ApiAvailability::default(), - } - } - - fn new_inner( + /// Will create a new Runtime without a cache with the provided API endpoint. + pub fn new( handle: tokio::runtime::Handle, + endpoint: &ApiEndpoint, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, - ) -> Result<Self, Error> { - Ok(Runtime { + ) -> Self { + Runtime { handle, - address_cache: AddressCache::new(None), + address_cache: AddressCache::new(endpoint, None), api_availability: ApiAvailability::default(), + endpoint: endpoint.clone(), #[cfg(target_os = "android")] socket_bypass_tx, - }) + } } /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. + /// Will try to construct an API endpoint from the environment. pub async fn with_cache( + endpoint: &ApiEndpoint, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, @@ -403,12 +377,13 @@ impl Runtime { let handle = tokio::runtime::Handle::current(); #[cfg(feature = "api-override")] - if API.disable_address_cache { - return Self::new_inner( + if endpoint.should_disable_address_cache() { + return Ok(Self::new( handle, + endpoint, #[cfg(target_os = "android")] socket_bypass_tx, - ); + )); } let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); @@ -418,7 +393,13 @@ impl Runtime { None }; - let address_cache = match AddressCache::from_file(&cache_file, write_file.clone()).await { + let address_cache = match AddressCache::from_file( + &cache_file, + write_file.clone(), + endpoint.host().to_owned(), + ) + .await + { Ok(cache) => cache, Err(error) => { if cache_file.exists() { @@ -429,7 +410,7 @@ impl Runtime { ) ); } - AddressCache::new(write_file) + AddressCache::new(endpoint, write_file) } }; @@ -439,12 +420,14 @@ impl Runtime { handle, address_cache, api_availability, + endpoint: endpoint.clone(), #[cfg(target_os = "android")] socket_bypass_tx, }) } - /// Returns a request factory initialized to create requests for the master API + /// Returns a request factory initialized to create requests for the master API Assumes an API + /// endpoint that is constructed from env vars, or uses default values. pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>( &self, connection_mode_provider: T, @@ -454,21 +437,10 @@ impl Runtime { Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + self.endpoint.disable_tls, ); - let token_store = access::AccessTokenStore::new(service.clone(), API.host()); - let factory = rest::RequestFactory::new(API.host().to_owned(), Some(token_store)); - - rest::MullvadRestHandle::new(service, factory, self.availability_handle()) - } - - /// This is only to be used in test code - pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { - let service = self.new_request_service( - ApiConnectionMode::Direct.into_provider(), - Arc::new(self.address_cache.clone()), - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ); + let hostname = self.endpoint.host().to_owned(); let token_store = access::AccessTokenStore::new(service.clone(), hostname.clone()); let factory = rest::RequestFactory::new(hostname, Some(token_store)); @@ -482,6 +454,8 @@ impl Runtime { Arc::new(dns_resolver), #[cfg(target_os = "android")] None, + #[cfg(any(feature = "api-override", test))] + false, ) } @@ -491,6 +465,7 @@ impl Runtime { connection_mode_provider: T, dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> rest::RequestServiceHandle { rest::RequestService::spawn( self.api_availability.clone(), @@ -498,6 +473,8 @@ impl Runtime { dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) } @@ -582,7 +559,6 @@ impl AccountsProxy { } } - #[cfg(target_os = "ios")] pub fn delete_account( &self, account: AccountNumber, diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 5b93eea311..cab3bb7e0f 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -154,11 +154,14 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> { connection_mode_provider: T, dns_resolver: Arc<dyn DnsResolver>, #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + disable_tls, ); connector_handle.set_connection_mode(connection_mode_provider.initial()); @@ -461,7 +464,6 @@ where } // Parse unexpected responses and errors - let response = response?; if !self.expected_status.contains(&response.status()) { |
