diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-11-28 12:56:18 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-11-28 12:56:18 +0100 |
| commit | 760d987a9422725c71b9154289b768c2ec87e60c (patch) | |
| tree | 394e4d55cd514db1ebfd8668cac471e6df1bd450 | |
| parent | 5699d3f30333a7cc90eefb987b6c7e79ac14f423 (diff) | |
| parent | 4766d2857570999838f698dd38b75130399cb08e (diff) | |
| download | mullvadvpn-760d987a9422725c71b9154289b768c2ec87e60c.tar.xz mullvadvpn-760d987a9422725c71b9154289b768c2ec87e60c.zip | |
Merge branch 'android-api-override'
| -rw-r--r-- | Cargo.lock | 2 | ||||
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/ApiEndpoint.kt | 13 | ||||
| -rw-r--r-- | android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt | 8 | ||||
| -rwxr-xr-x | build-apk.sh | 2 | ||||
| -rw-r--r-- | mullvad-api/Cargo.toml | 2 | ||||
| -rw-r--r-- | mullvad-api/src/https_client_with_sni.rs | 33 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 114 | ||||
| -rw-r--r-- | mullvad-api/src/proxy.rs | 99 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 24 | ||||
| -rw-r--r-- | mullvad-api/src/tls_stream.rs | 23 | ||||
| -rw-r--r-- | mullvad-jni/Cargo.toml | 4 | ||||
| -rw-r--r-- | mullvad-jni/src/lib.rs | 109 |
13 files changed, 354 insertions, 81 deletions
diff --git a/Cargo.lock b/Cargo.lock index 700c6cd464..79083db6d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,9 +1570,9 @@ dependencies = [ "http", "hyper", "ipnetwork", - "lazy_static", "log", "mullvad-types", + "once_cell", "regex", "rustls-pemfile", "serde", @@ -154,6 +154,8 @@ See [this](Release.md) for instructions on how to make a new release. * `MULLVAD_API_ADDR` - Set the IP address and port to use in API requests. E.g. `10.10.1.2:443`. +* `MULLVAD_API_DISABLE_TLS` - Use plain HTTP for API requests. + ### Setting environment variables #### Windows diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/ApiEndpoint.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/ApiEndpoint.kt new file mode 100644 index 0000000000..df40bfac4d --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/model/ApiEndpoint.kt @@ -0,0 +1,13 @@ +package net.mullvad.mullvadvpn.model + +import android.os.Parcelable +import java.net.InetSocketAddress +import kotlinx.parcelize.Parcelize + +@Parcelize +data class ApiEndpoint( + val address: InetSocketAddress, + val disableAddressCache: Boolean, + val disableTls: Boolean, + val forceDirectConnection: Boolean +) : Parcelable diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt index 8d983ad883..aac23cee25 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadDaemon.kt @@ -2,6 +2,7 @@ package net.mullvad.mullvadvpn.service import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.asSharedFlow +import net.mullvad.mullvadvpn.model.ApiEndpoint import net.mullvad.mullvadvpn.model.AppVersionInfo import net.mullvad.mullvadvpn.model.Device import net.mullvad.mullvadvpn.model.DeviceEvent @@ -38,7 +39,9 @@ class MullvadDaemon(vpnService: MullvadVpnService) { init { System.loadLibrary("mullvad_jni") - initialize(vpnService, vpnService.cacheDir.absolutePath, vpnService.filesDir.absolutePath) + initialize( + vpnService, vpnService.cacheDir.absolutePath, vpnService.filesDir.absolutePath, null + ) onSettingsChange.notify(getSettings()) @@ -176,7 +179,8 @@ class MullvadDaemon(vpnService: MullvadVpnService) { private external fun initialize( vpnService: MullvadVpnService, cacheDirectory: String, - resourceDirectory: String + resourceDirectory: String, + apiEndpoint: ApiEndpoint? ) private external fun deinitialize() diff --git a/build-apk.sh b/build-apk.sh index 027060dfb9..75b14c7d03 100755 --- a/build-apk.sh +++ b/build-apk.sh @@ -30,7 +30,7 @@ while [ ! -z "${1:-""}" ]; do BUNDLE_TASK="bundleDebug" BUILT_APK_SUFFIX="-debug" FILE_SUFFIX="-debug" - CARGO_ARGS="" + CARGO_ARGS="--features api-override" elif [[ "${1:-""}" == "--fdroid" ]]; then GRADLE_BUILD_TYPE="fdroid" GRADLE_TASK="assembleFdroid" diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml index e02ac22e11..6699ab432c 100644 --- a/mullvad-api/Cargo.toml +++ b/mullvad-api/Cargo.toml @@ -25,7 +25,7 @@ serde_json = "1.0" tokio = { version = "1.8", features = ["macros", "time", "rt-multi-thread", "net", "io-std", "io-util", "fs"] } tokio-rustls = "0.23" rustls-pemfile = "0.2" -lazy_static = "1.1.0" +once_cell = "1.13" mullvad-types = { path = "../mullvad-types" } talpid-types = { path = "../talpid-types" } diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 0d85ddd790..7202257686 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -40,6 +40,9 @@ use tokio::{ time::timeout, }; +#[cfg(feature = "api-override")] +use crate::{proxy::ConnectionDecorator, API}; + const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); #[derive(Clone)] @@ -215,18 +218,23 @@ impl HttpsConnectorWithSni { } async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> { + const DEFAULT_PORT: u16 = 443; + let hostname = uri.host().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") })?; - let port = uri.port_u16().unwrap_or(443); + let port = uri.port_u16(); if let Ok(addr) = hostname.parse::<IpAddr>() { - return Ok(SocketAddr::new(addr, port)); + return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT))); } // Preferentially, use cached address. // if let Some(addr) = address_cache.resolve_hostname(hostname).await { - return Ok(SocketAddr::new(addr.ip(), port)); + return Ok(SocketAddr::new( + addr.ip(), + port.unwrap_or_else(|| addr.port()), + )); } // Use getaddrinfo as a fallback @@ -241,7 +249,7 @@ impl HttpsConnectorWithSni { let addr = addrs .next() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; - Ok(SocketAddr::new(addr.ip(), port)) + Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT))) } } @@ -303,8 +311,13 @@ impl Service<Uri> for HttpsConnectorWithSni { socket_bypass_tx.clone(), ) .await?; + #[cfg(feature = "api-override")] + if API.disable_tls { + return Ok::<_, io::Error>(ApiConnection::new(Box::new(socket))); + } + let tls_stream = TlsStream::connect_https(socket, &hostname).await?; - Ok::<_, io::Error>(ApiConnection::Direct(Box::new(tls_stream))) + Ok::<_, io::Error>(ApiConnection::new(Box::new(tls_stream))) } InnerConnectionMode::Proxied(proxy_config) => { let socket = Self::open_socket( @@ -319,8 +332,16 @@ impl Service<Uri> for HttpsConnectorWithSni { &ServerConfig::from(proxy_config), addr, ); + + #[cfg(feature = "api-override")] + if API.disable_tls { + return Ok(ApiConnection::new(Box::new(ConnectionDecorator( + proxy, + )))); + } + let tls_stream = TlsStream::connect_https(proxy, &hostname).await?; - Ok(ApiConnection::Proxied(Box::new(tls_stream))) + Ok(ApiConnection::new(Box::new(tls_stream))) } } }; diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 25a53d016c..5872f3af71 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -9,11 +9,14 @@ use mullvad_types::{ account::{AccountToken, VoucherSubmission}, version::AppVersion, }; +use once_cell::sync::OnceCell; use proxy::ApiConnectionMode; use std::{ + cell::Cell, collections::BTreeMap, future::Future, net::{IpAddr, Ipv4Addr, SocketAddr}, + ops::Deref, path::Path, }; use talpid_types::ErrorExt; @@ -62,15 +65,52 @@ pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt"; const ACCOUNTS_URL_PREFIX: &str = "accounts/v1"; const APP_URL_PREFIX: &str = "app/v1"; -lazy_static::lazy_static! { - static ref API: ApiEndpoint = ApiEndpoint::get(); +pub static API: LazyManual<ApiEndpoint> = LazyManual::new(ApiEndpoint::from_env_vars); + +unsafe impl<T, F: Send> Sync for LazyManual<T, F> where OnceCell<T>: Sync {} + +/// A value that is either initialized on access or explicitly. +pub struct LazyManual<T, F = fn() -> T> { + cell: OnceCell<T>, + lazy_fn: Cell<Option<F>>, +} + +impl<T, F> LazyManual<T, F> { + const fn new(lazy_fn: F) -> Self { + Self { + cell: OnceCell::new(), + lazy_fn: Cell::new(Some(lazy_fn)), + } + } + + /// Tries to initialize the object. An error is returned if it is + /// already initialized. + #[cfg(feature = "api-override")] + pub fn override_init(&self, val: T) -> Result<(), T> { + let _ = self.lazy_fn.take(); + self.cell.set(val) + } +} + +impl<T> Deref for LazyManual<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.cell.get_or_init(|| (self.lazy_fn.take().unwrap())()) + } } /// A hostname and socketaddr to reach the Mullvad REST API over. -struct ApiEndpoint { - host: String, - addr: SocketAddr, - disable_address_cache: bool, +#[derive(Debug)] +pub struct ApiEndpoint { + pub host: String, + pub addr: SocketAddr, + #[cfg(feature = "api-override")] + pub disable_address_cache: bool, + #[cfg(feature = "api-override")] + pub disable_tls: bool, + #[cfg(feature = "api-override")] + pub force_direct_connection: bool, } impl ApiEndpoint { @@ -80,7 +120,7 @@ impl ApiEndpoint { /// /// Panics if `MULLVAD_API_ADDR` has invalid contents or if only one of /// `MULLVAD_API_ADDR` or `MULLVAD_API_HOST` has been set but not the other. - fn get() -> ApiEndpoint { + pub fn from_env_vars() -> ApiEndpoint { const API_HOST_DEFAULT: &str = "api.mullvad.net"; const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(45, 83, 223, 196)); const API_PORT_DEFAULT: u16 = 443; @@ -96,29 +136,60 @@ impl ApiEndpoint { let host_var = read_var("MULLVAD_API_HOST"); let address_var = read_var("MULLVAD_API_ADDR"); + let disable_tls_var = read_var("MULLVAD_API_DISABLE_TLS"); + #[cfg_attr(not(feature = "api-override"), allow(unused_mut))] let mut api = ApiEndpoint { host: API_HOST_DEFAULT.to_owned(), addr: SocketAddr::new(API_IP_DEFAULT, API_PORT_DEFAULT), + #[cfg(feature = "api-override")] disable_address_cache: false, + #[cfg(feature = "api-override")] + disable_tls: false, + #[cfg(feature = "api-override")] + force_direct_connection: false, }; - if cfg!(feature = "api-override") { - match (host_var, address_var) { - (None, None) => (), - (Some(_), None) => panic!("MULLVAD_API_HOST is set, but not MULLVAD_API_ADDR"), - (None, Some(_)) => panic!("MULLVAD_API_ADDR is set, but not MULLVAD_API_HOST"), - (Some(user_host), Some(user_addr)) => { - api.host = user_host; - api.addr = user_addr - .parse() - .expect("MULLVAD_API_ADDR is not a valid socketaddr"); - api.disable_address_cache = true; - log::debug!("Overriding API. Using {} at {}", api.host, api.addr); + #[cfg(feature = "api-override")] + { + use std::net::ToSocketAddrs; + + if host_var.is_none() && address_var.is_none() { + if disable_tls_var.is_some() { + log::warn!("MULLVAD_API_DISABLE_TLS is ignored since MULLVAD_API_HOST and MULLVAD_API_ADDR are not set"); } + return api; } - } else if host_var.is_some() || address_var.is_some() { - log::warn!("MULLVAD_API_HOST and MULLVAD_API_ADDR are ignored in production builds"); + + let scheme = if let Some(disable_tls_var) = disable_tls_var { + api.disable_tls = disable_tls_var != "0"; + "http://" + } else { + "https://" + }; + + if let Some(user_host) = host_var { + api.host = user_host; + } + if let Some(user_addr) = address_var { + api.addr = user_addr + .parse() + .expect("MULLVAD_API_ADDR is not a valid socketaddr"); + } else { + log::warn!("Resolving API IP from MULLVAD_API_HOST"); + api.addr = format!("{}:{}", api.host, API_PORT_DEFAULT) + .to_socket_addrs() + .expect("failed to resolve API host") + .next() + .expect("API host yielded 0 addresses"); + } + api.disable_address_cache = true; + api.force_direct_connection = true; + log::debug!("Overriding API. Using {} at {scheme}{}", api.host, api.addr); + } + #[cfg(not(feature = "api-override"))] + if host_var.is_some() || address_var.is_some() || disable_tls_var.is_some() { + log::warn!("These variables are ignored in production builds: MULLVAD_API_HOST, MULLVAD_API_ADDR, MULLVAD_API_DISABLE_TLS"); } api } @@ -189,6 +260,7 @@ impl Runtime { #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, ) -> Result<Self, Error> { let handle = tokio::runtime::Handle::current(); + #[cfg(feature = "api-override")] if API.disable_address_cache { return Self::new_inner( handle, diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs index 2f3764e7e6..fa1da913ba 100644 --- a/mullvad-api/src/proxy.rs +++ b/mullvad-api/src/proxy.rs @@ -1,8 +1,6 @@ -use crate::tls_stream::TlsStream; use futures::Stream; -use hyper::client::connect::{Connected, Connection}; +use hyper::client::connect::Connected; use serde::{Deserialize, Serialize}; -use shadowsocks::relay::tcprelay::ProxyClientStream; use std::{ fmt, io, net::SocketAddr, @@ -14,7 +12,6 @@ use talpid_types::{net::openvpn::ShadowsocksProxySettings, ErrorExt}; use tokio::{ fs, io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, - net::TcpStream, }; const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json"; @@ -130,57 +127,93 @@ impl ApiConnectionMode { } } -/// Stream that is either a regular TLS stream or TLS via shadowsocks -pub enum ApiConnection { - Direct(Box<TlsStream<TcpStream>>), - Proxied(Box<TlsStream<ProxyClientStream<TcpStream>>>), +/// Implements `hyper::client::connect::Connection` by wrapping a type. +pub struct ConnectionDecorator<T: AsyncRead + AsyncWrite>(pub T); + +impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for ConnectionDecorator<T> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ConnectionDecorator<T> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl<T: AsyncRead + AsyncWrite> hyper::client::connect::Connection for ConnectionDecorator<T> { + fn connected(&self) -> Connected { + Connected::new() + } +} + +trait Connection: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send {} + +impl<T: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send> Connection + for T +{ +} + +/// Stream that represents a Mullvad API connection +pub struct ApiConnection(Box<dyn Connection>); + +impl ApiConnection { + pub fn new< + T: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send + 'static, + >( + conn: Box<T>, + ) -> Self { + Self(conn) + } } impl AsyncRead for ApiConnection { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { - match Pin::get_mut(self) { - ApiConnection::Direct(s) => Pin::new(s).poll_read(cx, buf), - ApiConnection::Proxied(s) => Pin::new(s).poll_read(cx, buf), - } + Pin::new(&mut self.0).poll_read(cx, buf) } } impl AsyncWrite for ApiConnection { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - match Pin::get_mut(self) { - ApiConnection::Direct(s) => Pin::new(s).poll_write(cx, buf), - ApiConnection::Proxied(s) => Pin::new(s).poll_write(cx, buf), - } + Pin::new(&mut self.0).poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { - match Pin::get_mut(self) { - ApiConnection::Direct(s) => Pin::new(s).poll_flush(cx), - ApiConnection::Proxied(s) => Pin::new(s).poll_flush(cx), - } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { - match Pin::get_mut(self) { - ApiConnection::Direct(s) => Pin::new(s).poll_shutdown(cx), - ApiConnection::Proxied(s) => Pin::new(s).poll_shutdown(cx), - } + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_shutdown(cx) } } -impl Connection for ApiConnection { +impl hyper::client::connect::Connection for ApiConnection { fn connected(&self) -> Connected { - match self { - ApiConnection::Direct(s) => s.connected(), - ApiConnection::Proxied(s) => s.connected(), - } + self.0.connected() } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index c80f01049a..1aaba487a7 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -26,6 +26,9 @@ use std::{ }; use talpid_types::ErrorExt; +#[cfg(feature = "api-override")] +use crate::API; + pub use hyper::StatusCode; pub type Request = hyper::Request<hyper::Body>; @@ -145,7 +148,14 @@ impl< socket_bypass_tx.clone(), ); - if let Some(config) = proxy_config_provider.next().await { + #[cfg(feature = "api-override")] + let force_direct_connection = API.force_direct_connection; + #[cfg(not(feature = "api-override"))] + let force_direct_connection = false; + + if force_direct_connection { + log::debug!("API proxies are disabled"); + } else if let Some(config) = proxy_config_provider.next().await { connector_handle.set_connection_mode(config); } @@ -214,6 +224,12 @@ impl< self.connector_handle.reset(); } RequestCommand::NextApiConfig => { + #[cfg(feature = "api-override")] + if API.force_direct_connection { + log::debug!("Ignoring API connection mode"); + return; + } + if let Some(new_config) = self.proxy_config_provider.next().await { let endpoint = match new_config.get_endpoint() { Some(endpoint) => endpoint, @@ -619,9 +635,11 @@ impl MullvadRestHandle { availability, token_store, }; - if !super::API.disable_address_cache { - handle.spawn_api_address_fetcher(address_cache); + #[cfg(feature = "api-override")] + if API.disable_address_cache { + return handle; } + handle.spawn_api_address_fetcher(address_cache); handle } diff --git a/mullvad-api/src/tls_stream.rs b/mullvad-api/src/tls_stream.rs index cad0268ac3..ac3b4c2e24 100644 --- a/mullvad-api/src/tls_stream.rs +++ b/mullvad-api/src/tls_stream.rs @@ -7,6 +7,7 @@ use std::{ }; use hyper::client::connect::{Connected, Connection}; +use once_cell::sync::Lazy; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::{ rustls::{self, ClientConfig, ServerName}, @@ -24,18 +25,16 @@ where S: AsyncRead + AsyncWrite + Unpin, { pub async fn connect_https(stream: S, domain: &str) -> io::Result<TlsStream<S>> { - lazy_static::lazy_static! { - static ref TLS_CONFIG: Arc<ClientConfig> = { - let config = ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_root_certificates(read_cert_store()) - .with_no_client_auth(); - Arc::new(config) - }; - } + static TLS_CONFIG: Lazy<Arc<ClientConfig>> = Lazy::new(|| { + let config = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(read_cert_store()) + .with_no_client_auth(); + Arc::new(config) + }); let connector = TlsConnector::from(TLS_CONFIG.clone()); diff --git a/mullvad-jni/Cargo.toml b/mullvad-jni/Cargo.toml index dcf3546a95..b4974e7493 100644 --- a/mullvad-jni/Cargo.toml +++ b/mullvad-jni/Cargo.toml @@ -7,6 +7,10 @@ license = "GPL-3.0" edition = "2021" publish = false +[features] +# Allow the API server to use to be configured +api-override = ["mullvad-api/api-override"] + [lib] crate_type = ["cdylib"] diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 938d73e033..b7e0815bfa 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -39,6 +39,9 @@ use std::{ }; use talpid_types::{android::AndroidContext, ErrorExt}; +#[cfg(feature = "api-override")] +use std::net::{IpAddr, SocketAddr}; + const LOG_FILENAME: &str = "daemon.log"; static DAEMON_INSTANCE_COUNT: AtomicUsize = AtomicUsize::new(0); @@ -199,18 +202,40 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_initial vpnService: JObject<'_>, cacheDirectory: JObject<'_>, resourceDirectory: JObject<'_>, + apiEndpoint: JObject<'_>, ) { let env = JnixEnv::from(env); let cache_dir = PathBuf::from(String::from_java(&env, cacheDirectory)); let resource_dir = PathBuf::from(String::from_java(&env, resourceDirectory)); + let api_endpoint = if !apiEndpoint.is_null() { + #[cfg(feature = "api-override")] + { + Some(api_endpoint_from_java(&env, apiEndpoint)) + } + #[cfg(not(feature = "api-override"))] + { + log::warn!("apiEndpoint will be ignored since 'api-override' is not enabled"); + None + } + } else { + None + }; + match start_logging(&resource_dir) { Ok(()) => { version::log_version(); LOAD_CLASSES.call_once(|| env.preload_classes(classes::CLASSES.iter().cloned())); - if let Err(error) = initialize(&env, &this, &vpnService, cache_dir, resource_dir) { + if let Err(error) = initialize( + &env, + &this, + &vpnService, + cache_dir, + resource_dir, + api_endpoint, + ) { log::error!("{}", error.display_chain()); } } @@ -220,6 +245,75 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_initial } } +#[cfg(feature = "api-override")] +fn api_endpoint_from_java(env: &JnixEnv<'_>, object: JObject<'_>) -> mullvad_api::ApiEndpoint { + let mut endpoint = mullvad_api::ApiEndpoint::from_env_vars(); + + let address = env + .call_method(object, "component1", "()Ljava/net/InetSocketAddress;", &[]) + .expect("missing ApiEndpoint.address") + .l() + .expect("ApiEndpoint.address is not an InetSocketAddress"); + + endpoint.addr = + try_socketaddr_from_java(env, address).expect("received unresolved InetSocketAddress"); + endpoint.disable_address_cache = env + .call_method(object, "component2", "()Z", &[]) + .expect("missing ApiEndpoint.disableAddressCache") + .z() + .expect("ApiEndpoint.disableAddressCache is not a bool"); + endpoint.disable_tls = env + .call_method(object, "component3", "()Z", &[]) + .expect("missing ApiEndpoint.disableTls") + .z() + .expect("ApiEndpoint.disableTls is not a bool"); + endpoint.force_direct_connection = env + .call_method(object, "component4", "()Z", &[]) + .expect("missing ApiEndpoint.forceDirectConnection") + .z() + .expect("ApiEndpoint.forceDirectConnection is not a bool"); + + endpoint +} + +/// Converts InetSocketAddress to a SocketAddr. Return `None` if the +/// hostname is unresolved. +#[cfg(feature = "api-override")] +fn try_socketaddr_from_java(env: &JnixEnv<'_>, address: JObject<'_>) -> Option<SocketAddr> { + let class = env.get_class("java/net/InetSocketAddress"); + + let method_id = env + .get_method_id(&class, "getAddress", "()Ljava/net/InetAddress;") + .expect("Failed to get method ID for InetSocketAddress.getAddress()"); + let return_type = JavaType::Object("java/net/InetAddress".to_owned()); + + let ip_addr = env + .call_method_unchecked(address, method_id, return_type, &[]) + .expect("Failed to call InetSocketAddress.getAddress()") + .l() + .expect("Call to InetSocketAddress.getAddress() did not return an object"); + + if ip_addr.is_null() { + return None; + } + + let method_id = env + .get_method_id(&class, "getPort", "()I") + .expect("Failed to get method ID for InetSocketAddress.getPort()"); + let return_type = JavaType::Primitive(Primitive::Int); + + let port = env + .call_method_unchecked(address, method_id, return_type, &[]) + .expect("Failed to call InetSocketAddress.getPort()") + .i() + .expect("Call to InetSocketAddress.getPort() did not return an int"); + + Some(SocketAddr::new( + IpAddr::from_java(env, ip_addr), + u16::try_from(port).expect("invalid port"), + )) +} + fn start_logging(log_dir: &Path) -> Result<(), String> { unsafe { LOG_START.call_once(|| LOG_INIT_RESULT = Some(initialize_logging(log_dir))); @@ -246,6 +340,7 @@ fn initialize( vpn_service: &JObject<'_>, cache_dir: PathBuf, resource_dir: PathBuf, + api_endpoint: Option<mullvad_api::ApiEndpoint>, ) -> Result<(), Error> { let android_context = create_android_context(env, *vpn_service)?; let daemon_command_channel = DaemonCommandChannel::new(); @@ -256,6 +351,7 @@ fn initialize( this, cache_dir, resource_dir, + api_endpoint, daemon_command_channel, android_context, )?; @@ -282,6 +378,9 @@ fn spawn_daemon( this: &JObject<'_>, cache_dir: PathBuf, resource_dir: PathBuf, + #[cfg_attr(not(feature = "api-override"), allow(unused_variables))] api_endpoint: Option< + mullvad_api::ApiEndpoint, + >, command_channel: DaemonCommandChannel, android_context: AndroidContext, ) -> Result<(), Error> { @@ -306,6 +405,14 @@ fn spawn_daemon( ); } + #[cfg(feature = "api-override")] + if let Some(api_endpoint) = api_endpoint { + log::debug!("Overriding API endpoint: {api_endpoint:?}"); + if mullvad_api::API.override_init(api_endpoint).is_err() { + log::warn!("Ignoring API settings (already initialized)"); + } + } + let daemon = runtime.block_on(Daemon::start( Some(resource_dir.clone()), resource_dir.clone(), |
