summaryrefslogtreecommitdiffhomepage
path: root/mullvad-api/src
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-03-23 16:51:43 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-23 16:51:43 +0100
commitb1910fc14517d284c05aeded77d6f34e3949e6ef (patch)
treef762f61e1ed40bb4be8558122f3bdc4fff59570f /mullvad-api/src
parent74d76ff8f071d96ffbc6fbb8d92ec961b839f416 (diff)
parent918e6588d2276122e391fb426be3245d0664e245 (diff)
downloadmullvadvpn-b1910fc14517d284c05aeded77d6f34e3949e6ef.tar.xz
mullvadvpn-b1910fc14517d284c05aeded77d6f34e3949e6ef.zip
Merge branch 'rename-rpc-crate'
Diffstat (limited to 'mullvad-api/src')
-rw-r--r--mullvad-api/src/abortable_stream.rs199
-rw-r--r--mullvad-api/src/access.rs110
-rw-r--r--mullvad-api/src/address_cache.rs118
-rw-r--r--mullvad-api/src/availability.rs170
-rw-r--r--mullvad-api/src/bin/relay_list.rs41
-rw-r--r--mullvad-api/src/device.rs196
-rw-r--r--mullvad-api/src/https_client_with_sni.rs351
-rw-r--r--mullvad-api/src/lib.rs530
-rw-r--r--mullvad-api/src/proxy.rs204
-rw-r--r--mullvad-api/src/relay_list.rs375
-rw-r--r--mullvad-api/src/rest.rs694
-rw-r--r--mullvad-api/src/tls_stream.rs122
12 files changed, 3110 insertions, 0 deletions
diff --git a/mullvad-api/src/abortable_stream.rs b/mullvad-api/src/abortable_stream.rs
new file mode 100644
index 0000000000..af217c5768
--- /dev/null
+++ b/mullvad-api/src/abortable_stream.rs
@@ -0,0 +1,199 @@
+//! Wrapper around a stream to make it abortable. This allows in-flight requests to be cancelled
+//! immediately instead of after the socket times out.
+
+use futures::channel::oneshot;
+use hyper::client::connect::{Connected, Connection};
+use std::{
+ future::Future,
+ io,
+ pin::Pin,
+ sync::{Arc, Mutex},
+ task::{Context, Poll},
+};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+#[derive(err_derive::Error, Debug)]
+#[error(display = "Stream is closed")]
+pub struct Aborted(());
+
+#[derive(Clone, Debug)]
+pub struct AbortableStreamHandle {
+ tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
+}
+
+impl AbortableStreamHandle {
+ pub fn close(self) {
+ if let Some(tx) = self.tx.lock().unwrap().take() {
+ let _ = tx.send(());
+ }
+ }
+
+ /// Returns whether the stream has already stopped on its own.
+ pub fn is_closed(&self) -> bool {
+ self.tx
+ .lock()
+ .unwrap()
+ .as_ref()
+ .map(|tx| tx.is_canceled())
+ .unwrap_or(true)
+ }
+}
+
+pub struct AbortableStream<S: Unpin> {
+ stream: S,
+ shutdown_rx: oneshot::Receiver<()>,
+}
+
+impl<S> AbortableStream<S>
+where
+ S: Unpin + Send + 'static,
+{
+ pub fn new(stream: S) -> (Self, AbortableStreamHandle) {
+ let (tx, rx) = oneshot::channel();
+ let stream_handle = AbortableStreamHandle {
+ tx: Arc::new(Mutex::new(Some(tx))),
+ };
+ (
+ Self {
+ stream,
+ shutdown_rx: rx,
+ },
+ stream_handle,
+ )
+ }
+}
+
+impl<S> AsyncWrite for AbortableStream<S>
+where
+ S: AsyncWrite + Unpin + Send + 'static,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ Aborted(()),
+ )));
+ }
+ Pin::new(&mut self.stream).poll_write(cx, buf)
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ Aborted(()),
+ )));
+ }
+ Pin::new(&mut self.stream).poll_flush(cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.stream).poll_shutdown(cx)
+ }
+}
+
+impl<S> AsyncRead for AbortableStream<S>
+where
+ S: AsyncRead + Unpin + Send + 'static,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ Aborted(()),
+ )));
+ }
+ Pin::new(&mut self.stream).poll_read(cx, buf)
+ }
+}
+
+impl<S> Connection for AbortableStream<S>
+where
+ S: Connection + Unpin,
+{
+ fn connected(&self) -> Connected {
+ self.stream.connected()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use std::time::Duration;
+ use tokio::io::AsyncReadExt;
+
+ /// Test whether the abort handle stops the stream.
+ #[test]
+ fn test_abort() {
+ let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
+
+ let (client, _server) = tokio::io::duplex(64);
+
+ runtime.block_on(async move {
+ let (mut stream, abort_handle) = AbortableStream::new(client);
+
+ let stream_task = tokio::spawn(async move {
+ let mut buf = vec![];
+ stream.read_to_end(&mut buf).await
+ });
+
+ abort_handle.close();
+ let result = tokio::time::timeout(Duration::from_secs(1), stream_task)
+ .await
+ .unwrap();
+ assert!(
+ matches!(result, Ok(Err(error)) if error.kind() == io::ErrorKind::ConnectionReset)
+ );
+ });
+ }
+
+ /// Test the `AbortableStreamHandle::is_closed` method when explicitly closed.
+ #[test]
+ fn test_shutdown_signal() {
+ let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
+
+ let (client, _server) = tokio::io::duplex(64);
+
+ runtime.block_on(async move {
+ let (_stream, abort_handle) = AbortableStream::new(client);
+ let abort_handle_2 = abort_handle.clone();
+ assert!(!abort_handle_2.is_closed());
+ abort_handle.close();
+ assert!(abort_handle_2.is_closed());
+ });
+ }
+
+ /// Test the `AbortableStreamHandle::is_closed` method when the stream stops on its own.
+ #[test]
+ fn test_shutdown_signal_normal() {
+ let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime");
+
+ let (client, server) = tokio::io::duplex(64);
+
+ runtime.block_on(async move {
+ let (mut stream, abort_handle) = AbortableStream::new(client);
+
+ assert!(!abort_handle.is_closed());
+
+ let stream_task = tokio::spawn(async move {
+ drop(server);
+ let mut buf = vec![];
+ stream.read_to_end(&mut buf).await
+ });
+
+ assert!(tokio::time::timeout(Duration::from_secs(1), stream_task)
+ .await
+ .unwrap()
+ .is_ok());
+ assert!(abort_handle.is_closed());
+ });
+ }
+}
diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs
new file mode 100644
index 0000000000..d95a5319c2
--- /dev/null
+++ b/mullvad-api/src/access.rs
@@ -0,0 +1,110 @@
+use crate::{
+ rest,
+ rest::{RequestFactory, RequestServiceHandle},
+};
+use hyper::StatusCode;
+use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
+use std::{
+ collections::HashMap,
+ sync::{Arc, Mutex},
+};
+use talpid_types::ErrorExt;
+
+pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1";
+
+#[derive(Clone)]
+pub struct AccessTokenProxy {
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
+}
+
+impl AccessTokenProxy {
+ pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
+ Self {
+ service,
+ factory,
+ access_from_account: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ /// Obtain access token for an account, requesting a new one from the API if necessary.
+ pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
+ let existing_token = {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .get(account.as_str())
+ .cloned()
+ };
+ if let Some(access_token) = existing_token {
+ if access_token.is_expired() {
+ log::debug!("Replacing expired access token");
+ return self.request_new_token(account.clone()).await;
+ }
+ log::trace!("Using stored access token");
+ return Ok(access_token.access_token.clone());
+ }
+ self.request_new_token(account.clone()).await
+ }
+
+ /// Remove an access token if the API response calls for it.
+ pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
+ if let Err(rest::Error::ApiError(_status, code)) = response {
+ if code == crate::INVALID_ACCESS_TOKEN {
+ log::debug!("Dropping invalid access token");
+ self.remove_token(account);
+ }
+ }
+ }
+
+ /// Removes a stored access token.
+ fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .remove(account)
+ .map(|v| v.access_token)
+ }
+
+ async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
+ log::debug!("Fetching access token for an account");
+ let access_token = self
+ .fetch_access_token(account.clone())
+ .await
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain access token")
+ );
+ error
+ })?;
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .insert(account, access_token.clone());
+ Ok(access_token.access_token)
+ }
+
+ async fn fetch_access_token(
+ &self,
+ account_token: AccountToken,
+ ) -> Result<AccessTokenData, rest::Error> {
+ #[derive(serde::Serialize)]
+ struct AccessTokenRequest {
+ account_number: String,
+ }
+ let request = AccessTokenRequest {
+ account_number: account_token,
+ };
+
+ let service = self.service.clone();
+
+ let rest_request = self
+ .factory
+ .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?;
+ let response = service.request(rest_request).await?;
+ let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(response).await
+ }
+}
diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs
new file mode 100644
index 0000000000..3b6fcba074
--- /dev/null
+++ b/mullvad-api/src/address_cache.rs
@@ -0,0 +1,118 @@
+use super::API;
+use std::{io, net::SocketAddr, path::Path, sync::Arc};
+use tokio::{
+ fs,
+ io::{AsyncReadExt, AsyncWriteExt},
+ sync::Mutex,
+};
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ #[error(display = "Failed to open the address cache file")]
+ OpenAddressCache(#[error(source)] io::Error),
+
+ #[error(display = "Failed to read the address cache file")]
+ ReadAddressCache(#[error(source)] io::Error),
+
+ #[error(display = "Failed to parse the address cache file")]
+ ParseAddressCache,
+
+ #[error(display = "Failed to update the address cache file")]
+ WriteAddressCache(#[error(source)] io::Error),
+
+ #[error(display = "The address cache is empty")]
+ EmptyAddressCache,
+}
+
+#[derive(Clone)]
+pub struct AddressCache {
+ inner: Arc<Mutex<AddressCacheInner>>,
+ write_path: Option<Arc<Path>>,
+}
+
+impl AddressCache {
+ /// Initialize cache using the hardcoded address, and write changes to `write_path`.
+ pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ Self::new_inner(API.addr, write_path)
+ }
+
+ /// Initialize cache using `read_path`, and write changes to `write_path`.
+ pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ log::debug!("Loading API addresses from {}", read_path.display());
+ Self::new_inner(read_address_file(read_path).await?, write_path)
+ }
+
+ fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> {
+ let cache = AddressCacheInner::from_address(address);
+ log::debug!("Using API address: {}", cache.address);
+
+ let address_cache = Self {
+ inner: Arc::new(Mutex::new(cache)),
+ write_path: write_path.map(|cache| Arc::from(cache)),
+ };
+ Ok(address_cache)
+ }
+
+ /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
+ pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
+ if hostname.eq_ignore_ascii_case(&API.host) {
+ Some(self.get_address().await)
+ } else {
+ None
+ }
+ }
+
+ /// Returns the currently selected address.
+ pub async fn get_address(&self) -> SocketAddr {
+ self.inner.lock().await.address
+ }
+
+ pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> {
+ let mut inner = self.inner.lock().await;
+ if address != inner.address {
+ self.save_to_disk(&address).await?;
+ inner.address = address;
+ }
+ Ok(())
+ }
+
+ async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> {
+ let write_path = match self.write_path.as_ref() {
+ Some(write_path) => write_path,
+ None => return Ok(()),
+ };
+
+ let temp_path = write_path.with_file_name("api-cache.temp");
+
+ let mut file = fs::File::create(&temp_path).await?;
+ let mut contents = address.to_string();
+ contents += "\n";
+ file.write_all(contents.as_bytes()).await?;
+ file.sync_data().await?;
+
+ fs::rename(&temp_path, write_path).await
+ }
+}
+
+#[derive(Clone, PartialEq, Eq)]
+struct AddressCacheInner {
+ address: SocketAddr,
+}
+
+impl AddressCacheInner {
+ fn from_address(address: SocketAddr) -> Self {
+ Self { address }
+ }
+}
+
+async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> {
+ let mut file = fs::File::open(path)
+ .await
+ .map_err(|error| Error::OpenAddressCache(error))?;
+ let mut address = String::new();
+ file.read_to_string(&mut address)
+ .await
+ .map_err(Error::ReadAddressCache)?;
+ address.trim().parse().map_err(|_| Error::ParseAddressCache)
+}
diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs
new file mode 100644
index 0000000000..2cf40cf53b
--- /dev/null
+++ b/mullvad-api/src/availability.rs
@@ -0,0 +1,170 @@
+use std::{
+ future::Future,
+ sync::{Arc, Mutex},
+};
+use tokio::sync::broadcast;
+
+const CHANNEL_CAPACITY: usize = 100;
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ /// The [`ApiAvailability`] instance was dropped, or the receiver lagged behind.
+ #[error(display = "API availability instance was dropped")]
+ Interrupted(#[error(source)] broadcast::error::RecvError),
+}
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)]
+pub struct State {
+ suspended: bool,
+ pause_background: bool,
+ offline: bool,
+}
+
+impl State {
+ pub fn is_suspended(&self) -> bool {
+ self.suspended
+ }
+
+ pub fn is_background_paused(&self) -> bool {
+ self.offline || self.pause_background || self.suspended
+ }
+
+ pub fn is_offline(&self) -> bool {
+ self.offline
+ }
+}
+
+pub struct ApiAvailability {
+ state: Arc<Mutex<State>>,
+ tx: broadcast::Sender<State>,
+}
+
+impl ApiAvailability {
+ pub fn new(initial_state: State) -> Self {
+ let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY);
+ let state = Arc::new(Mutex::new(initial_state));
+ ApiAvailability { state, tx }
+ }
+
+ pub fn get_state(&self) -> State {
+ *self.state.lock().unwrap()
+ }
+
+ pub fn handle(&self) -> ApiAvailabilityHandle {
+ ApiAvailabilityHandle {
+ state: self.state.clone(),
+ tx: self.tx.clone(),
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ApiAvailabilityHandle {
+ state: Arc<Mutex<State>>,
+ tx: broadcast::Sender<State>,
+}
+
+impl ApiAvailabilityHandle {
+ pub fn suspend(&self) {
+ log::debug!("Suspending API requests");
+ let mut state = self.state.lock().unwrap();
+ if !state.suspended {
+ state.suspended = true;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn unsuspend(&self) {
+ log::debug!("Unsuspending API requests");
+ let mut state = self.state.lock().unwrap();
+ if state.suspended {
+ state.suspended = false;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn pause_background(&self) {
+ log::debug!("Pausing background API requests");
+ let mut state = self.state.lock().unwrap();
+ if !state.pause_background {
+ state.pause_background = true;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn resume_background(&self) {
+ log::debug!("Resuming background API requests");
+ let mut state = self.state.lock().unwrap();
+ if state.pause_background {
+ state.pause_background = false;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn set_offline(&self, offline: bool) {
+ if offline {
+ log::debug!("Pausing API requests due to being offline");
+ } else {
+ log::debug!("Resuming API requests due to coming online");
+ }
+ let mut state = self.state.lock().unwrap();
+ if state.offline != offline {
+ state.offline = offline;
+ let _ = self.tx.send(*state);
+ }
+ }
+
+ pub fn get_state(&self) -> State {
+ *self.state.lock().unwrap()
+ }
+
+ pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> {
+ self.wait_for_state(|state| !state.is_suspended())
+ }
+
+ pub fn when_bg_resumes<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> {
+ let wait_task = self.wait_for_state(|state| !state.is_background_paused());
+ async move {
+ let _ = wait_task.await;
+ task.await
+ }
+ }
+
+ pub fn wait_background(&self) -> impl Future<Output = Result<(), Error>> {
+ self.wait_for_state(|state| !state.is_background_paused())
+ }
+
+ pub fn when_online<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> {
+ let wait_task = self.wait_for_state(|state| !state.is_offline());
+ async move {
+ let _ = wait_task.await;
+ task.await
+ }
+ }
+
+ pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> {
+ self.wait_for_state(|state| !state.is_offline())
+ }
+
+ fn wait_for_state(
+ &self,
+ state_ready: impl Fn(State) -> bool,
+ ) -> impl Future<Output = Result<(), Error>> {
+ let mut rx = self.tx.subscribe();
+ let state = self.state.clone();
+
+ async move {
+ let current_state = { *state.lock().unwrap() };
+ if state_ready(current_state) {
+ return Ok(());
+ }
+
+ loop {
+ let new_state = rx.recv().await?;
+ if state_ready(new_state) {
+ return Ok(());
+ }
+ }
+ }
+ }
+}
diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs
new file mode 100644
index 0000000000..2139e51f54
--- /dev/null
+++ b/mullvad-api/src/bin/relay_list.rs
@@ -0,0 +1,41 @@
+//! Fetches and prints the full relay list in JSON.
+//! Used by the installer artifact packer to bundle the latest available
+//! relay list at the time of creating the installer.
+
+use mullvad_api::{self, proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
+use std::process;
+use talpid_types::ErrorExt;
+
+#[tokio::main]
+async fn main() {
+ let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
+ .expect("Failed to load runtime");
+
+ let relay_list_request = RelayListProxy::new(
+ runtime
+ .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat(), |_| async { true })
+ .await,
+ )
+ .relay_list(None)
+ .await;
+
+ let relay_list = match relay_list_request {
+ Ok(relay_list) => relay_list,
+ Err(RestError::TimeoutError(_)) => {
+ eprintln!("Request timed out");
+ process::exit(2);
+ }
+ Err(e @ RestError::DeserializeError(_)) => {
+ eprintln!(
+ "{}",
+ e.display_chain_with_msg("Failed to deserialize relay list")
+ );
+ process::exit(3);
+ }
+ Err(e) => {
+ eprintln!("{}", e.display_chain_with_msg("Failed to fetch relay list"));
+ process::exit(1);
+ }
+ };
+ println!("{}", serde_json::to_string_pretty(&relay_list).unwrap());
+}
diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs
new file mode 100644
index 0000000000..de572aa20d
--- /dev/null
+++ b/mullvad-api/src/device.rs
@@ -0,0 +1,196 @@
+use http::{Method, StatusCode};
+use mullvad_types::{
+ account::AccountToken,
+ device::{Device, DeviceId, DeviceName, DevicePort},
+};
+use std::future::Future;
+use talpid_types::net::wireguard;
+
+use crate::rest;
+
+use super::ACCOUNTS_URL_PREFIX;
+
+#[derive(Clone)]
+pub struct DevicesProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+#[derive(serde::Deserialize)]
+struct DeviceResponse {
+ id: DeviceId,
+ name: DeviceName,
+ pubkey: wireguard::PublicKey,
+ ipv4_address: ipnetwork::Ipv4Network,
+ ipv6_address: ipnetwork::Ipv6Network,
+ ports: Vec<DevicePort>,
+}
+
+impl DevicesProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn create(
+ &self,
+ account: AccountToken,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct DeviceSubmission {
+ pubkey: wireguard::PublicKey,
+ }
+
+ let submission = DeviceSubmission { pubkey };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account)),
+ &[StatusCode::CREATED],
+ )
+ .await;
+
+ let response: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ id,
+ name,
+ pubkey,
+ ipv4_address,
+ ipv6_address,
+ ports,
+ ..
+ } = response;
+
+ Ok((
+ Device {
+ id,
+ name,
+ pubkey,
+ ports,
+ },
+ mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ },
+ ))
+ }
+ }
+
+ pub fn get(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<Device, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn list(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn remove(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<(), rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::DELETE,
+ Some((access_proxy, account)),
+ &[StatusCode::NO_CONTENT],
+ )
+ .await;
+
+ response?;
+ Ok(())
+ }
+ }
+
+ pub fn replace_wg_key(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct RotateDevicePubkey {
+ pubkey: wireguard::PublicKey,
+ }
+ let req_body = RotateDevicePubkey { pubkey };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id),
+ Method::PUT,
+ &req_body,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ ipv4_address,
+ ipv6_address,
+ ..
+ } = updated_device;
+ Ok(mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ })
+ }
+ }
+}
diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs
new file mode 100644
index 0000000000..409492712e
--- /dev/null
+++ b/mullvad-api/src/https_client_with_sni.rs
@@ -0,0 +1,351 @@
+use crate::{
+ abortable_stream::{AbortableStream, AbortableStreamHandle},
+ proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
+ tls_stream::TlsStream,
+ AddressCache,
+};
+use futures::{channel::mpsc, future, pin_mut, StreamExt};
+#[cfg(target_os = "android")]
+use futures::{channel::oneshot, sink::SinkExt};
+use http::uri::Scheme;
+use hyper::{
+ client::connect::dns::{GaiResolver, Name},
+ service::Service,
+ Uri,
+};
+use shadowsocks::{
+ config::ServerType,
+ context::{Context as SsContext, SharedContext},
+ crypto::v1::CipherKind,
+ relay::tcprelay::ProxyClientStream,
+ ServerConfig,
+};
+#[cfg(target_os = "android")]
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::{
+ fmt,
+ future::Future,
+ io,
+ net::{IpAddr, SocketAddr},
+ pin::Pin,
+ str::{self, FromStr},
+ sync::{Arc, Mutex},
+ task::{Context, Poll},
+ time::Duration,
+};
+use talpid_types::ErrorExt;
+
+use tokio::{
+ net::{TcpSocket, TcpStream},
+ time::timeout,
+};
+
+const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
+
+#[derive(Clone)]
+pub struct HttpsConnectorWithSniHandle {
+ tx: mpsc::UnboundedSender<HttpsConnectorRequest>,
+}
+
+impl HttpsConnectorWithSniHandle {
+ /// Stop all streams produced by this connector
+ pub fn reset(&self) {
+ let _ = self.tx.unbounded_send(HttpsConnectorRequest::Reset);
+ }
+
+ /// Change the proxy settings for the connector
+ pub fn set_connection_mode(&self, proxy: ApiConnectionMode) {
+ let _ = self
+ .tx
+ .unbounded_send(HttpsConnectorRequest::SetConnectionMode(proxy));
+ }
+}
+
+enum HttpsConnectorRequest {
+ Reset,
+ SetConnectionMode(ApiConnectionMode),
+}
+
+#[derive(Clone)]
+enum InnerConnectionMode {
+ /// Connect directly to the target.
+ Direct,
+ /// Connect to the destination via a proxy.
+ Proxied(ParsedShadowsocksConfig),
+}
+
+#[derive(Clone)]
+struct ParsedShadowsocksConfig {
+ peer: SocketAddr,
+ password: String,
+ cipher: CipherKind,
+}
+
+impl From<ParsedShadowsocksConfig> for ServerConfig {
+ fn from(config: ParsedShadowsocksConfig) -> Self {
+ ServerConfig::new(config.peer, config.password, config.cipher)
+ }
+}
+
+#[derive(err_derive::Error, Debug)]
+enum ProxyConfigError {
+ #[error(display = "Unrecognized cipher selected: {}", _0)]
+ InvalidCipher(String),
+}
+
+impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
+ type Error = ProxyConfigError;
+
+ fn try_from(config: ApiConnectionMode) -> Result<Self, Self::Error> {
+ Ok(match config {
+ ApiConnectionMode::Direct => InnerConnectionMode::Direct,
+ ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(config)) => {
+ InnerConnectionMode::Proxied(ParsedShadowsocksConfig {
+ peer: config.peer,
+ password: config.password,
+ cipher: CipherKind::from_str(&config.cipher)
+ .map_err(|_| ProxyConfigError::InvalidCipher(config.cipher))?,
+ })
+ }
+ })
+ }
+}
+
+/// A Connector for the `https` scheme.
+#[derive(Clone)]
+pub struct HttpsConnectorWithSni {
+ inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
+ sni_hostname: Option<String>,
+ address_cache: AddressCache,
+ abort_notify: Arc<tokio::sync::Notify>,
+ proxy_context: SharedContext,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+}
+
+struct HttpsConnectorWithSniInner {
+ stream_handles: Vec<AbortableStreamHandle>,
+ proxy_config: InnerConnectionMode,
+}
+
+#[cfg(target_os = "android")]
+pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
+
+impl HttpsConnectorWithSni {
+ pub fn new(
+ sni_hostname: Option<String>,
+ address_cache: AddressCache,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> (Self, HttpsConnectorWithSniHandle) {
+ let (tx, mut rx) = mpsc::unbounded();
+ let abort_notify = Arc::new(tokio::sync::Notify::new());
+ let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner {
+ stream_handles: vec![],
+ proxy_config: InnerConnectionMode::Direct,
+ }));
+
+ let inner_copy = inner.clone();
+ let notify = abort_notify.clone();
+ tokio::spawn(async move {
+ // Handle requests by `HttpsConnectorWithSniHandle`s
+ while let Some(request) = rx.next().await {
+ let handles = {
+ let mut inner = inner_copy.lock().unwrap();
+
+ if let HttpsConnectorRequest::SetConnectionMode(config) = request {
+ match InnerConnectionMode::try_from(config) {
+ Ok(config) => {
+ inner.proxy_config = config;
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to parse new API proxy config"
+ )
+ );
+ }
+ }
+ }
+
+ std::mem::take(&mut inner.stream_handles)
+ };
+ for handle in handles {
+ handle.close();
+ }
+ notify.notify_waiters();
+ }
+ });
+
+ (
+ HttpsConnectorWithSni {
+ inner,
+ sni_hostname,
+ address_cache,
+ abort_notify,
+ proxy_context: SsContext::new_shared(ServerType::Local),
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ },
+ HttpsConnectorWithSniHandle { tx },
+ )
+ }
+
+ async fn open_socket(
+ addr: SocketAddr,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> std::io::Result<TcpStream> {
+ let socket = match addr {
+ SocketAddr::V4(_) => TcpSocket::new_v4()?,
+ SocketAddr::V6(_) => TcpSocket::new_v6()?,
+ };
+
+ #[cfg(target_os = "android")]
+ if let Some(mut tx) = socket_bypass_tx {
+ let (done_tx, done_rx) = oneshot::channel();
+ let _ = tx.send((socket.as_raw_fd(), done_tx)).await;
+ if let Err(_) = done_rx.await {
+ log::error!("Failed to bypass socket, connection might fail");
+ }
+ }
+
+ timeout(CONNECT_TIMEOUT, socket.connect(addr))
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
+ }
+
+ async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> {
+ let hostname = uri.host().ok_or(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "invalid url, missing host",
+ ))?;
+ let port = uri.port_u16().unwrap_or(443);
+ if let Some(addr) = hostname.parse::<IpAddr>().ok() {
+ return Ok(SocketAddr::new(addr, port));
+ }
+
+ // Preferentially, use cached address.
+ //
+ if let Some(addr) = address_cache.resolve_hostname(hostname).await {
+ return Ok(SocketAddr::new(addr.ip(), port));
+ }
+
+ // Use getaddrinfo as a fallback
+ //
+ let mut addrs = GaiResolver::new()
+ .call(
+ Name::from_str(&hostname)
+ .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
+ )
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+ let addr = addrs
+ .next()
+ .ok_or(io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
+ Ok(SocketAddr::new(addr.ip(), port))
+ }
+}
+
+impl fmt::Debug for HttpsConnectorWithSni {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("HttpsConnectorWithSni").finish()
+ }
+}
+
+impl Service<Uri> for HttpsConnectorWithSni {
+ type Response = AbortableStream<ApiConnection>;
+ type Error = io::Error;
+ type Future =
+ Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
+
+ fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ let mut inner = self.inner.lock().unwrap();
+ inner.stream_handles.retain(|handle| !handle.is_closed());
+ Poll::Ready(Ok(()))
+ }
+
+ fn call(&mut self, uri: Uri) -> Self::Future {
+ let sni_hostname = self
+ .sni_hostname
+ .clone()
+ .or_else(|| uri.host().map(str::to_owned))
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
+ });
+ let inner = self.inner.clone();
+ let abort_notify = self.abort_notify.clone();
+ let proxy_context = self.proxy_context.clone();
+ #[cfg(target_os = "android")]
+ let socket_bypass_tx = self.socket_bypass_tx.clone();
+ let address_cache = self.address_cache.clone();
+
+ let fut = async move {
+ if uri.scheme() != Some(&Scheme::HTTPS) {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "invalid url, not https",
+ ));
+ }
+
+ let hostname = sni_hostname?;
+ let addr = Self::resolve_address(address_cache, uri).await?;
+
+ // Loop until we have established a connection. This starts over if a new endpoint
+ // is selected while connecting.
+ let stream = loop {
+ let notify = abort_notify.notified();
+ let config = { inner.lock().unwrap().proxy_config.clone() };
+ let stream_fut = async {
+ match config {
+ InnerConnectionMode::Direct => {
+ let socket = Self::open_socket(
+ addr,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx.clone(),
+ )
+ .await?;
+ let tls_stream = TlsStream::connect_https(socket, &hostname).await?;
+ Ok::<_, io::Error>(ApiConnection::Direct(tls_stream))
+ }
+ InnerConnectionMode::Proxied(proxy_config) => {
+ let socket = Self::open_socket(
+ proxy_config.peer,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx.clone(),
+ )
+ .await?;
+ let proxy = ProxyClientStream::from_stream(
+ proxy_context.clone(),
+ socket,
+ &ServerConfig::from(proxy_config),
+ addr,
+ );
+ let tls_stream = TlsStream::connect_https(proxy, &hostname).await?;
+ Ok(ApiConnection::Proxied(tls_stream))
+ }
+ }
+ };
+
+ pin_mut!(stream_fut);
+ pin_mut!(notify);
+
+ // Wait for connection. Abort and retry if we switched to a different server.
+ if let future::Either::Left((stream, _)) = future::select(stream_fut, notify).await
+ {
+ break stream?;
+ }
+ };
+
+ let (stream, socket_handle) = AbortableStream::new(stream);
+
+ {
+ let mut inner = inner.lock().unwrap();
+ inner.stream_handles.push(socket_handle);
+ }
+
+ Ok(stream)
+ };
+
+ Box::pin(fut)
+ }
+}
diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs
new file mode 100644
index 0000000000..40742fe41d
--- /dev/null
+++ b/mullvad-api/src/lib.rs
@@ -0,0 +1,530 @@
+#![deny(rust_2018_idioms)]
+
+use chrono::{offset::Utc, DateTime};
+#[cfg(target_os = "android")]
+use futures::channel::mpsc;
+use futures::Stream;
+use hyper::Method;
+use mullvad_types::{
+ account::{AccountToken, VoucherSubmission},
+ version::AppVersion,
+};
+use proxy::ApiConnectionMode;
+use std::{
+ collections::BTreeMap,
+ future::Future,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::Path,
+};
+use talpid_types::ErrorExt;
+
+pub mod availability;
+use availability::{ApiAvailability, ApiAvailabilityHandle};
+pub mod rest;
+
+mod abortable_stream;
+mod https_client_with_sni;
+pub mod proxy;
+mod tls_stream;
+#[cfg(target_os = "android")]
+pub use crate::https_client_with_sni::SocketBypassRequest;
+
+mod access;
+mod address_cache;
+pub mod device;
+mod relay_list;
+pub use address_cache::AddressCache;
+pub use device::DevicesProxy;
+pub use hyper::StatusCode;
+pub use relay_list::RelayListProxy;
+
+/// Error code returned by the Mullvad API if the voucher has alreaby been used.
+pub const VOUCHER_USED: &str = "VOUCHER_USED";
+
+/// Error code returned by the Mullvad API if the voucher code is invalid.
+pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER";
+
+/// Error code returned by the Mullvad API if the account token is invalid.
+pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT";
+
+/// Error code returned by the Mullvad API if the access token is invalid.
+pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN";
+
+pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED";
+pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE";
+
+pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt";
+
+const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1";
+const APP_URL_PREFIX: &str = "app/v1";
+
+lazy_static::lazy_static! {
+ static ref API: ApiEndpoint = ApiEndpoint::get();
+}
+
+/// A hostname and socketaddr to reach the Mullvad REST API over.
+struct ApiEndpoint {
+ host: String,
+ addr: SocketAddr,
+ disable_address_cache: bool,
+}
+
+impl ApiEndpoint {
+ /// Returns the endpoint to connect to the API over.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `MULLVAD_API_ADDR` has invalid contents or if only one of
+ /// `MULLVAD_API_ADDR` or `MULLVAD_API_HOST` has been set but not the other.
+ fn get() -> ApiEndpoint {
+ const API_HOST_DEFAULT: &str = "api.mullvad.net";
+ const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78));
+ const API_PORT_DEFAULT: u16 = 443;
+
+ fn read_var(key: &'static str) -> Option<String> {
+ use std::env;
+ match env::var(key) {
+ Ok(v) => Some(v),
+ Err(env::VarError::NotPresent) => None,
+ Err(env::VarError::NotUnicode(_)) => panic!("{} does not contain valid UTF-8", key),
+ }
+ }
+
+ let host_var = read_var("MULLVAD_API_HOST");
+ let address_var = read_var("MULLVAD_API_ADDR");
+
+ let mut api = ApiEndpoint {
+ host: API_HOST_DEFAULT.to_owned(),
+ addr: SocketAddr::new(API_IP_DEFAULT, API_PORT_DEFAULT),
+ disable_address_cache: false,
+ };
+
+ if cfg!(feature = "api-override") {
+ match (host_var, address_var) {
+ (None, None) => (),
+ (Some(_), None) => panic!("MULLVAD_API_HOST is set, but not MULLVAD_API_ADDR"),
+ (None, Some(_)) => panic!("MULLVAD_API_ADDR is set, but not MULLVAD_API_HOST"),
+ (Some(user_host), Some(user_addr)) => {
+ api.host = user_host;
+ api.addr = user_addr
+ .parse()
+ .expect("MULLVAD_API_ADDR is not a valid socketaddr");
+ api.disable_address_cache = true;
+ log::debug!("Overriding API. Using {} at {}", api.host, api.addr);
+ }
+ }
+ } else {
+ if host_var.is_some() || address_var.is_some() {
+ log::warn!(
+ "MULLVAD_API_HOST and MULLVAD_API_ADDR are ignored in production builds"
+ );
+ }
+ }
+ api
+ }
+}
+
+/// A type that helps with the creation of API connections.
+pub struct Runtime {
+ handle: tokio::runtime::Handle,
+ pub address_cache: AddressCache,
+ api_availability: availability::ApiAvailability,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+}
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "Failed to construct a rest client")]
+ RestError(#[error(source)] rest::Error),
+
+ #[error(display = "Failed to load address cache")]
+ AddressCacheError(#[error(source)] address_cache::Error),
+
+ #[error(display = "API availability check failed")]
+ ApiCheckError(#[error(source)] availability::Error),
+}
+
+/// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`.
+/// It should return a future that determines whether to reject the new endpoint or not.
+pub trait ApiEndpointUpdateCallback: Fn(SocketAddr) -> Self::AcceptedNewEndpoint {
+ type AcceptedNewEndpoint: Future<Output = bool> + Send;
+}
+
+impl<U, T: Future<Output = bool> + Send> ApiEndpointUpdateCallback for U
+where
+ U: Fn(SocketAddr) -> T,
+{
+ type AcceptedNewEndpoint = T;
+}
+
+impl Runtime {
+ /// Create a new `Runtime`.
+ pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
+ Self::new_inner(
+ handle,
+ #[cfg(target_os = "android")]
+ None,
+ )
+ }
+
+ fn new_inner(
+ handle: tokio::runtime::Handle,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> Result<Self, Error> {
+ Ok(Runtime {
+ handle,
+ address_cache: AddressCache::new(None)?,
+ api_availability: ApiAvailability::new(availability::State::default()),
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ })
+ }
+
+ /// Create a new `Runtime` using the specified directories.
+ /// Try to use the cache directory first, and fall back on the bundled address otherwise.
+ pub async fn with_cache(
+ cache_dir: &Path,
+ write_changes: bool,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> Result<Self, Error> {
+ let handle = tokio::runtime::Handle::current();
+ if API.disable_address_cache {
+ return Self::new_inner(
+ handle,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ );
+ }
+
+ let cache_file = cache_dir.join(API_IP_CACHE_FILENAME);
+ let write_file = if write_changes {
+ Some(cache_file.clone().into_boxed_path())
+ } else {
+ None
+ };
+
+ let address_cache = match AddressCache::from_file(&cache_file, write_file.clone()).await {
+ Ok(cache) => cache,
+ Err(error) => {
+ if cache_file.exists() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to load cached API addresses. Falling back on bundled address"
+ )
+ );
+ }
+ AddressCache::new(write_file)?
+ }
+ };
+
+ Ok(Runtime {
+ handle,
+ address_cache,
+ api_availability: ApiAvailability::new(availability::State::default()),
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ })
+ }
+
+ /// Creates a new request service and returns a handle to it.
+ async fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
+ &self,
+ sni_hostname: Option<String>,
+ proxy_provider: T,
+ new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> rest::RequestServiceHandle {
+ let service_handle = rest::RequestService::new(
+ sni_hostname,
+ self.api_availability.handle(),
+ self.address_cache.clone(),
+ proxy_provider,
+ new_address_callback,
+ #[cfg(target_os = "android")]
+ socket_bypass_tx,
+ )
+ .await;
+ service_handle
+ }
+
+ /// Returns a request factory initialized to create requests for the master API
+ pub async fn mullvad_rest_handle<
+ T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
+ >(
+ &self,
+ proxy_provider: T,
+ new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static,
+ ) -> rest::MullvadRestHandle {
+ let service = self
+ .new_request_service(
+ Some(API.host.clone()),
+ proxy_provider,
+ new_address_callback,
+ #[cfg(target_os = "android")]
+ self.socket_bypass_tx.clone(),
+ )
+ .await;
+ let factory = rest::RequestFactory::new(API.host.clone(), None);
+
+ rest::MullvadRestHandle::new(
+ service,
+ factory,
+ self.address_cache.clone(),
+ self.availability_handle(),
+ )
+ }
+
+ /// Returns a new request service handle
+ pub async fn rest_handle(&mut self) -> rest::RequestServiceHandle {
+ self.new_request_service(
+ None,
+ ApiConnectionMode::Direct.into_repeat(),
+ |_| async { true },
+ #[cfg(target_os = "android")]
+ None,
+ )
+ .await
+ }
+
+ pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
+ &mut self.handle
+ }
+
+ pub fn availability_handle(&self) -> ApiAvailabilityHandle {
+ self.api_availability.handle()
+ }
+}
+
+#[derive(Clone)]
+pub struct AccountsProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+#[derive(serde::Deserialize)]
+struct AccountResponse {
+ number: AccountToken,
+ expiry: DateTime<Utc>,
+}
+
+impl AccountsProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn get_expiry(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let account: AccountResponse = rest::deserialize_body(response?).await?;
+ Ok(account.expiry)
+ }
+ }
+
+ pub fn create_account(&mut self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
+ let service = self.handle.service.clone();
+ let response = rest::send_request(
+ &self.handle.factory,
+ service,
+ &format!("{}/accounts", ACCOUNTS_URL_PREFIX),
+ Method::POST,
+ None,
+ &[StatusCode::CREATED],
+ );
+
+ async move {
+ let account: AccountResponse = rest::deserialize_body(response.await?).await?;
+ Ok(account.number)
+ }
+ }
+
+ pub fn submit_voucher(
+ &mut self,
+ account_token: AccountToken,
+ voucher_code: String,
+ ) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> {
+ #[derive(serde::Serialize)]
+ struct VoucherSubmission {
+ voucher_code: String,
+ }
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ let submission = VoucherSubmission { voucher_code };
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/submit-voucher", APP_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account_token)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn get_www_auth_token(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<String, rest::Error>> {
+ #[derive(serde::Deserialize)]
+ struct AuthTokenResponse {
+ auth_token: String,
+ }
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/www-auth-token", APP_URL_PREFIX),
+ Method::POST,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
+ Ok(response.auth_token)
+ }
+ }
+}
+
+pub struct ProblemReportProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+impl ProblemReportProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn problem_report(
+ &self,
+ email: &str,
+ message: &str,
+ log: &str,
+ metadata: &BTreeMap<String, String>,
+ ) -> impl Future<Output = Result<(), rest::Error>> {
+ #[derive(serde::Serialize)]
+ struct ProblemReport {
+ address: String,
+ message: String,
+ log: String,
+ metadata: BTreeMap<String, String>,
+ }
+
+ let report = ProblemReport {
+ address: email.to_owned(),
+ message: message.to_owned(),
+ log: log.to_owned(),
+ metadata: metadata.clone(),
+ };
+
+ let service = self.handle.service.clone();
+
+ let request = rest::send_json_request(
+ &self.handle.factory,
+ service,
+ &format!("{}/problem-report", APP_URL_PREFIX),
+ Method::POST,
+ &report,
+ None,
+ &[StatusCode::NO_CONTENT],
+ );
+
+ async move {
+ request.await?;
+ Ok(())
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct AppVersionProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+#[derive(serde::Deserialize, Debug)]
+pub struct AppVersionResponse {
+ pub supported: bool,
+ pub latest: AppVersion,
+ pub latest_stable: Option<AppVersion>,
+ pub latest_beta: AppVersion,
+}
+
+impl AppVersionProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn version_check(
+ &self,
+ app_version: AppVersion,
+ platform: &str,
+ platform_version: String,
+ ) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> {
+ let service = self.handle.service.clone();
+
+ let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version);
+ let request = self.handle.factory.request(&path, Method::GET);
+
+ async move {
+ let mut request = request?;
+ request.add_header("M-Platform-Version", &platform_version)?;
+
+ let response = service.request(request).await?;
+ let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(parsed_response).await
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct ApiProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+impl ApiProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> {
+ let service = self.handle.service.clone();
+
+ let response = rest::send_request(
+ &self.handle.factory,
+ service,
+ &format!("{}/api-addrs", APP_URL_PREFIX),
+ Method::GET,
+ None,
+ &[StatusCode::OK],
+ )
+ .await?;
+
+ rest::deserialize_body(response).await
+ }
+}
diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs
new file mode 100644
index 0000000000..009a1960dc
--- /dev/null
+++ b/mullvad-api/src/proxy.rs
@@ -0,0 +1,204 @@
+use crate::tls_stream::TlsStream;
+use futures::Stream;
+use hyper::client::connect::{Connected, Connection};
+use rand::{distributions::Alphanumeric, Rng};
+use serde::{Deserialize, Serialize};
+use shadowsocks::relay::tcprelay::ProxyClientStream;
+use std::{
+ fmt, io,
+ net::SocketAddr,
+ path::Path,
+ pin::Pin,
+ task::{self, Poll},
+};
+use talpid_types::{net::openvpn::ShadowsocksProxySettings, ErrorExt};
+use tokio::{
+ fs,
+ io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
+ net::TcpStream,
+};
+
+const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json";
+
+#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
+pub enum ApiConnectionMode {
+ /// Connect directly to the target.
+ Direct,
+ /// Connect to the destination via a proxy.
+ Proxied(ProxyConfig),
+}
+
+impl fmt::Display for ApiConnectionMode {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ ApiConnectionMode::Direct => write!(f, "unproxied"),
+ ApiConnectionMode::Proxied(settings) => settings.fmt(f),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
+pub enum ProxyConfig {
+ Shadowsocks(ShadowsocksProxySettings),
+}
+
+impl fmt::Display for ProxyConfig {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ // TODO: Do not hardcode TCP
+ ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer),
+ }
+ }
+}
+
+impl ApiConnectionMode {
+ /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`.
+ /// This returns `ApiConnectionMode::Direct` if reading from disk fails for any reason.
+ pub async fn try_from_cache(cache_dir: &Path) -> Self {
+ Self::from_cache(cache_dir).await.unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to read API endpoint cache")
+ );
+ ApiConnectionMode::Direct
+ })
+ }
+
+ /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`.
+ /// If the file does not exist, this returns `Ok(ApiConnectionMode::Direct)`.
+ async fn from_cache(cache_dir: &Path) -> io::Result<Self> {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ match fs::read_to_string(path).await {
+ Ok(s) => serde_json::from_str(&s).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(&format!(
+ "Failed to deserialize \"{}\"",
+ CURRENT_CONFIG_FILENAME
+ ))
+ );
+ io::Error::new(io::ErrorKind::Other, "deserialization failed")
+ }),
+ Err(error) => {
+ if error.kind() == io::ErrorKind::NotFound {
+ Ok(ApiConnectionMode::Direct)
+ } else {
+ Err(error)
+ }
+ }
+ }
+ }
+
+ /// Stores this config to `CURRENT_CONFIG_FILENAME`.
+ /// The content is saved to a temporary file first, which ensures that
+ /// consumers of the file never end up with partial content.
+ pub async fn save(&self, cache_dir: &Path) -> io::Result<()> {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ let mut temp_ext = String::from("temp");
+ temp_ext.push_str(
+ &rand::thread_rng()
+ .sample_iter(&Alphanumeric)
+ .take(5)
+ .map(char::from)
+ .collect::<String>(),
+ );
+ let temp_path = path.with_extension(temp_ext);
+
+ {
+ let mut file = fs::File::create(&temp_path).await?;
+ let json = serde_json::to_string_pretty(self)
+ .map_err(|_| io::Error::new(io::ErrorKind::Other, "serialization failed"))?;
+ file.write_all(json.as_bytes()).await?;
+ file.write_all(b"\n").await?;
+ file.sync_data().await?;
+ }
+
+ fs::rename(&temp_path, path).await
+ }
+
+ /// Attempts to remove `CURRENT_CONFIG_FILENAME`, if it exists.
+ pub async fn try_delete_cache(cache_dir: &Path) {
+ let path = cache_dir.join(CURRENT_CONFIG_FILENAME);
+ if let Err(err) = fs::remove_file(path).await {
+ if err.kind() != std::io::ErrorKind::NotFound {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("Failed to remove old API config")
+ );
+ }
+ }
+ }
+
+ /// Returns the remote address, or `None` for `ApiConnectionMode::Direct`.
+ pub fn get_endpoint(&self) -> Option<SocketAddr> {
+ match self {
+ ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(ss)) => Some(ss.peer),
+ ApiConnectionMode::Direct => None,
+ }
+ }
+
+ pub fn is_proxy(&self) -> bool {
+ *self != ApiConnectionMode::Direct
+ }
+
+ /// Convenience function that returns a stream that repeats
+ /// this config forever.
+ pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> {
+ futures::stream::repeat(self)
+ }
+}
+
+/// Stream that is either a regular TLS stream or TLS via shadowsocks
+pub enum ApiConnection {
+ Direct(TlsStream<TcpStream>),
+ Proxied(TlsStream<ProxyClientStream<TcpStream>>),
+}
+
+impl AsyncRead for ApiConnection {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ match Pin::get_mut(self) {
+ ApiConnection::Direct(s) => Pin::new(s).poll_read(cx, buf),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_read(cx, buf),
+ }
+ }
+}
+
+impl AsyncWrite for ApiConnection {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ match Pin::get_mut(self) {
+ ApiConnection::Direct(s) => Pin::new(s).poll_write(cx, buf),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_write(cx, buf),
+ }
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ match Pin::get_mut(self) {
+ ApiConnection::Direct(s) => Pin::new(s).poll_flush(cx),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_flush(cx),
+ }
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ match Pin::get_mut(self) {
+ ApiConnection::Direct(s) => Pin::new(s).poll_shutdown(cx),
+ ApiConnection::Proxied(s) => Pin::new(s).poll_shutdown(cx),
+ }
+ }
+}
+
+impl Connection for ApiConnection {
+ fn connected(&self) -> Connected {
+ match self {
+ ApiConnection::Direct(s) => s.connected(),
+ ApiConnection::Proxied(s) => s.connected(),
+ }
+ }
+}
diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs
new file mode 100644
index 0000000000..5a8a01836f
--- /dev/null
+++ b/mullvad-api/src/relay_list.rs
@@ -0,0 +1,375 @@
+//! A module dedicated to retrieving the relay list from the Mullvad API.
+
+use crate::rest;
+
+use hyper::{header, Method, StatusCode};
+use mullvad_types::{location, relay_list};
+use talpid_types::net::{wireguard, TransportProtocol};
+
+use std::{
+ collections::BTreeMap,
+ future::Future,
+ net::{Ipv4Addr, Ipv6Addr},
+ time::Duration,
+};
+
+/// Fetches relay list from https://api.mullvad.net/app/v1/relays
+#[derive(Clone)]
+pub struct RelayListProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+const RELAY_LIST_TIMEOUT: Duration = Duration::from_secs(15);
+
+impl RelayListProxy {
+ /// Construct a new relay list rest client
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ /// Fetch the relay list
+ pub fn relay_list(
+ &self,
+ etag: Option<String>,
+ ) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> {
+ let service = self.handle.service.clone();
+ let request = self.handle.factory.request("app/v1/relays", Method::GET);
+
+ let future = async move {
+ let mut request = request?;
+ request.set_timeout(RELAY_LIST_TIMEOUT);
+
+ if let Some(ref tag) = etag {
+ request.add_header(header::IF_NONE_MATCH, tag)?;
+ }
+
+ let response = service.request(request).await?;
+ if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED {
+ return Ok(None);
+ }
+ if response.status() != StatusCode::OK {
+ return rest::handle_error_response(response).await;
+ }
+
+ let etag = response
+ .headers()
+ .get(header::ETAG)
+ .and_then(|tag| match tag.to_str() {
+ Ok(tag) => Some(tag.to_string()),
+ Err(_) => {
+ log::error!("Ignoring invalid tag from server: {:?}", tag.as_bytes());
+ None
+ }
+ });
+
+ Ok(Some(
+ rest::deserialize_body::<ServerRelayList>(response)
+ .await?
+ .into_relay_list(etag),
+ ))
+ };
+ future
+ }
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct ServerRelayList {
+ locations: BTreeMap<String, Location>,
+ openvpn: OpenVpn,
+ wireguard: Wireguard,
+ bridge: Bridges,
+}
+
+impl ServerRelayList {
+ fn into_relay_list(self, etag: Option<String>) -> relay_list::RelayList {
+ let mut countries = BTreeMap::new();
+ let Self {
+ locations,
+ openvpn,
+ wireguard,
+ bridge,
+ } = self;
+
+ for (code, location) in locations.into_iter() {
+ match split_location_code(&code) {
+ Some((country_code, city_code)) => {
+ let country_code = country_code.to_lowercase();
+ let city_code = city_code.to_lowercase();
+ let country = countries
+ .entry(country_code.clone())
+ .or_insert_with(|| location_to_country(&location, country_code));
+ country.cities.push(location_to_city(&location, city_code));
+ }
+ None => {
+ log::error!("Bad location code:{}", code);
+ continue;
+ }
+ }
+ }
+
+ Self::add_openvpn_relays(&mut countries, openvpn);
+ Self::add_wireguard_relays(&mut countries, wireguard);
+ Self::add_bridge_relays(&mut countries, bridge);
+
+ relay_list::RelayList {
+ etag: etag.map(|mut tag| {
+ if tag.starts_with("\"") {
+ tag.insert_str(0, "W/");
+ }
+ tag
+ }),
+ countries: countries
+ .into_iter()
+ .map(|(_key, country)| country)
+ .collect(),
+ }
+ }
+
+ fn add_openvpn_relays(
+ countries: &mut BTreeMap<String, relay_list::RelayListCountry>,
+ openvpn: OpenVpn,
+ ) {
+ let openvpn_endpoint_data = openvpn.ports;
+ for mut openvpn_relay in openvpn.relays.into_iter() {
+ openvpn_relay.to_lower();
+ if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) {
+ if let Some(country) = countries.get_mut(country_code) {
+ if let Some(city) = country
+ .cities
+ .iter_mut()
+ .find(|city| city.code == city_code)
+ {
+ let location = location::Location {
+ country: country.name.clone(),
+ country_code: country.code.clone(),
+ city: city.name.clone(),
+ city_code: city.code.clone(),
+ latitude: city.latitude,
+ longitude: city.longitude,
+ };
+ match city
+ .relays
+ .iter_mut()
+ .find(|r| r.hostname == openvpn_relay.hostname)
+ {
+ Some(relay) => relay.tunnels.openvpn = openvpn_endpoint_data.clone(),
+ None => {
+ let mut relay = relay(openvpn_relay, location);
+ relay.tunnels.openvpn = openvpn_endpoint_data.clone();
+ city.relays.push(relay);
+ }
+ };
+ }
+ };
+ }
+ }
+ }
+
+ fn add_wireguard_relays(
+ countries: &mut BTreeMap<String, relay_list::RelayListCountry>,
+ wireguard: Wireguard,
+ ) {
+ let Wireguard {
+ port_ranges,
+ ipv4_gateway,
+ ipv6_gateway,
+ relays,
+ } = wireguard;
+
+ let wireguard_endpoint_data =
+ |public_key: wireguard::PublicKey| relay_list::WireguardEndpointData {
+ port_ranges: port_ranges.clone(),
+ ipv4_gateway,
+ ipv6_gateway,
+ public_key,
+ protocol: TransportProtocol::Udp,
+ };
+
+ for mut wireguard_relay in relays {
+ wireguard_relay.relay.to_lower();
+ if let Some((country_code, city_code)) =
+ split_location_code(&wireguard_relay.relay.location)
+ {
+ if let Some(country) = countries.get_mut(country_code) {
+ if let Some(city) = country
+ .cities
+ .iter_mut()
+ .find(|city| city.code == city_code)
+ {
+ let location = location::Location {
+ country: country.name.clone(),
+ country_code: country.code.clone(),
+ city: city.name.clone(),
+ city_code: city.code.clone(),
+ latitude: city.latitude,
+ longitude: city.longitude,
+ };
+ match city
+ .relays
+ .iter_mut()
+ .find(|r| r.hostname == wireguard_relay.relay.hostname)
+ {
+ Some(relay) => relay
+ .tunnels
+ .wireguard
+ .push(wireguard_endpoint_data(wireguard_relay.public_key)),
+ None => {
+ let mut relay = relay(wireguard_relay.relay, location);
+ relay.ipv6_addr_in = Some(wireguard_relay.ipv6_addr_in);
+ relay.tunnels.wireguard =
+ vec![wireguard_endpoint_data(wireguard_relay.public_key)];
+ city.relays.push(relay);
+ }
+ };
+ }
+ };
+ }
+ }
+ }
+
+ fn add_bridge_relays(
+ countries: &mut BTreeMap<String, relay_list::RelayListCountry>,
+ bridges: Bridges,
+ ) {
+ let Bridges {
+ relays,
+ shadowsocks,
+ } = bridges;
+
+ for mut bridge_relay in relays {
+ bridge_relay.to_lower();
+ if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) {
+ if let Some(country) = countries.get_mut(country_code) {
+ if let Some(city) = country
+ .cities
+ .iter_mut()
+ .find(|city| city.code == city_code)
+ {
+ let location = location::Location {
+ country: country.name.clone(),
+ country_code: country.code.clone(),
+ city: city.name.clone(),
+ city_code: city.code.clone(),
+ latitude: city.latitude,
+ longitude: city.longitude,
+ };
+
+ match city
+ .relays
+ .iter_mut()
+ .find(|r| r.hostname == bridge_relay.hostname)
+ {
+ Some(relay) => {
+ relay.bridges.shadowsocks = shadowsocks.clone();
+ }
+ None => {
+ let mut relay = relay(bridge_relay, location);
+ relay.bridges.shadowsocks = shadowsocks.clone();
+ city.relays.push(relay);
+ }
+ };
+ }
+ };
+ }
+ }
+ }
+}
+
+/// Splits a location code into a country code and a city code. The input is expected to be in a
+/// format like `se-mma`, with `se` being the country code, `mma` being the city code.
+fn split_location_code(location: &str) -> Option<(&str, &str)> {
+ let mut parts = location.split('-');
+ let country = parts.next()?;
+ let city = parts.next()?;
+
+ Some((country, city))
+}
+
+fn location_to_country(location: &Location, code: String) -> relay_list::RelayListCountry {
+ relay_list::RelayListCountry {
+ cities: vec![],
+ name: location.country.clone(),
+ code,
+ }
+}
+
+fn location_to_city(location: &Location, code: String) -> relay_list::RelayListCity {
+ relay_list::RelayListCity {
+ name: location.city.clone(),
+ code,
+ latitude: location.latitude,
+ longitude: location.longitude,
+ relays: vec![],
+ }
+}
+
+fn relay(relay: Relay, location: location::Location) -> relay_list::Relay {
+ relay_list::Relay {
+ hostname: relay.hostname,
+ ipv4_addr_in: relay.ipv4_addr_in,
+ ipv6_addr_in: None,
+ include_in_country: relay.include_in_country,
+ active: relay.active,
+ owned: relay.owned,
+ provider: relay.provider,
+ weight: relay.weight,
+ tunnels: Default::default(),
+ bridges: Default::default(),
+ location: Some(location),
+ }
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct Location {
+ city: String,
+ country: String,
+ latitude: f64,
+ longitude: f64,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct OpenVpn {
+ ports: Vec<relay_list::OpenVpnEndpointData>,
+ relays: Vec<Relay>,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct Relay {
+ hostname: String,
+ active: bool,
+ owned: bool,
+ location: String,
+ provider: String,
+ ipv4_addr_in: Ipv4Addr,
+ weight: u64,
+ include_in_country: bool,
+}
+
+impl Relay {
+ fn to_lower(&mut self) {
+ self.hostname = self.hostname.to_lowercase();
+ self.location = self.location.to_lowercase();
+ }
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct Wireguard {
+ port_ranges: Vec<(u16, u16)>,
+ ipv4_gateway: Ipv4Addr,
+ ipv6_gateway: Ipv6Addr,
+ relays: Vec<WireGuardRelay>,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct WireGuardRelay {
+ #[serde(flatten)]
+ relay: Relay,
+ ipv6_addr_in: Ipv6Addr,
+ public_key: wireguard::PublicKey,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct Bridges {
+ shadowsocks: Vec<relay_list::ShadowsocksEndpointData>,
+ relays: Vec<Relay>,
+}
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
new file mode 100644
index 0000000000..292bcd0cbb
--- /dev/null
+++ b/mullvad-api/src/rest.rs
@@ -0,0 +1,694 @@
+#[cfg(target_os = "android")]
+pub use crate::https_client_with_sni::SocketBypassRequest;
+use crate::{
+ access::AccessTokenProxy,
+ address_cache::AddressCache,
+ availability::ApiAvailabilityHandle,
+ https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
+ proxy::ApiConnectionMode,
+};
+use futures::{
+ channel::{mpsc, oneshot},
+ stream::StreamExt,
+ Stream, TryFutureExt,
+};
+use hyper::{
+ client::Client,
+ header::{self, HeaderValue},
+ Method, Uri,
+};
+use mullvad_types::account::AccountToken;
+use std::{
+ future::Future,
+ str::FromStr,
+ sync::{Arc, Weak},
+ time::{Duration, Instant},
+};
+use talpid_types::ErrorExt;
+
+pub use hyper::StatusCode;
+
+pub type Request = hyper::Request<hyper::Body>;
+pub type Response = hyper::Response<hyper::Body>;
+
+const USER_AGENT: &str = "mullvad-app";
+
+const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60);
+const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60);
+const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
+const API_IP_CHECK_ERROR_INTERVAL: Duration = Duration::from_secs(15 * 60);
+
+pub type Result<T> = std::result::Result<T, Error>;
+const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
+
+/// Describes all the ways a REST request can fail
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "Request cancelled")]
+ Aborted,
+
+ #[error(display = "Hyper error")]
+ HyperError(#[error(source)] hyper::Error),
+
+ #[error(display = "Invalid header value")]
+ InvalidHeaderError(#[error(source)] http::header::InvalidHeaderValue),
+
+ #[error(display = "HTTP error")]
+ HttpError(#[error(source)] http::Error),
+
+ #[error(display = "Request timed out")]
+ TimeoutError(#[error(source)] tokio::time::error::Elapsed),
+
+ #[error(display = "Failed to deserialize data")]
+ DeserializeError(#[error(source)] serde_json::Error),
+
+ #[error(display = "Failed to send request to rest client")]
+ SendError,
+
+ #[error(display = "Failed to receive response from rest client")]
+ ReceiveError,
+
+ /// Unexpected response code
+ #[error(display = "Unexpected response status code {} - {}", _0, _1)]
+ ApiError(StatusCode, String),
+
+ /// The string given was not a valid URI.
+ #[error(display = "Not a valid URI")]
+ UriError(#[error(source)] http::uri::InvalidUri),
+}
+
+impl Error {
+ pub fn is_network_error(&self) -> bool {
+ match self {
+ Error::HyperError(_) | Error::TimeoutError(_) => true,
+ _ => false,
+ }
+ }
+
+ /// Returns a new instance for which `abortable_stream::Aborted` is mapped to `Self::Aborted`.
+ fn map_aborted(self) -> Self {
+ if let Error::HyperError(error) = &self {
+ use std::error::Error;
+ let mut source = error.source();
+ while let Some(error) = source {
+ let io_error: Option<&std::io::Error> = error.downcast_ref();
+ if let Some(io_error) = io_error {
+ let abort_error: Option<&crate::abortable_stream::Aborted> =
+ io_error.get_ref().and_then(|inner| inner.downcast_ref());
+ if abort_error.is_some() {
+ return Self::Aborted;
+ }
+ }
+ source = error.source();
+ }
+ }
+ self
+ }
+}
+
+use super::ApiEndpointUpdateCallback;
+
+/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
+/// requests
+pub(crate) struct RequestService<
+ T: Stream<Item = ApiConnectionMode>,
+ F: ApiEndpointUpdateCallback + Send,
+> {
+ command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
+ command_rx: mpsc::UnboundedReceiver<RequestCommand>,
+ connector_handle: HttpsConnectorWithSniHandle,
+ client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
+ proxy_config_provider: T,
+ new_address_callback: F,
+ address_cache: AddressCache,
+ api_availability: ApiAvailabilityHandle,
+}
+
+impl<
+ T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
+ F: ApiEndpointUpdateCallback + Send + Sync + 'static,
+ > RequestService<T, F>
+{
+ /// Constructs a new request service.
+ pub async fn new(
+ sni_hostname: Option<String>,
+ api_availability: ApiAvailabilityHandle,
+ address_cache: AddressCache,
+ mut proxy_config_provider: T,
+ new_address_callback: F,
+ #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
+ ) -> RequestServiceHandle {
+ let (connector, connector_handle) = HttpsConnectorWithSni::new(
+ sni_hostname,
+ address_cache.clone(),
+ #[cfg(target_os = "android")]
+ socket_bypass_tx.clone(),
+ );
+
+ proxy_config_provider
+ .next()
+ .await
+ .map(|config| connector_handle.set_connection_mode(config));
+
+ let (command_tx, command_rx) = mpsc::unbounded();
+ let client = Client::builder().build(connector);
+
+ let command_tx = Arc::new(command_tx);
+
+ let service = Self {
+ command_tx: Arc::downgrade(&command_tx),
+ command_rx,
+ connector_handle,
+ client,
+ proxy_config_provider,
+ new_address_callback,
+ address_cache,
+ api_availability,
+ };
+ let handle = RequestServiceHandle { tx: command_tx };
+ tokio::spawn(service.into_future());
+ handle
+ }
+
+ async fn process_command(&mut self, command: RequestCommand) {
+ match command {
+ RequestCommand::NewRequest(request, completion_tx) => {
+ let tx = self.command_tx.upgrade();
+ let timeout = request.timeout();
+
+ let hyper_request = request.into_request();
+
+ let api_availability = self.api_availability.clone();
+ let suspend_fut = api_availability.wait_for_unsuspend();
+ let request_fut = self.client.request(hyper_request).map_err(Error::from);
+
+ let request_future = async move {
+ let _ = suspend_fut.await;
+ request_fut.await
+ };
+
+ let future = async move {
+ let response = tokio::time::timeout(timeout, request_future)
+ .await
+ .map_err(Error::TimeoutError);
+
+ let response = flatten_result(response).map_err(|error| error.map_aborted());
+
+ if let Err(err) = &response {
+ if err.is_network_error() && !api_availability.get_state().is_offline() {
+ log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
+ if let Some(tx) = tx {
+ let _ = tx.unbounded_send(RequestCommand::NextApiConfig);
+ }
+ }
+ }
+
+ if completion_tx.send(response).is_err() {
+ log::trace!(
+ "Failed to send response to caller, caller channel is shut down"
+ );
+ }
+ };
+ tokio::spawn(future);
+ }
+ RequestCommand::Reset => {
+ self.connector_handle.reset();
+ }
+ RequestCommand::NextApiConfig => {
+ if let Some(new_config) = self.proxy_config_provider.next().await {
+ let endpoint = match new_config.get_endpoint() {
+ Some(endpoint) => endpoint,
+ None => self.address_cache.get_address().await,
+ };
+ // Switch to new connection mode unless rejected by address change callback
+ if (self.new_address_callback)(endpoint).await {
+ self.connector_handle.set_connection_mode(new_config);
+ }
+ }
+ }
+ }
+ }
+
+ async fn into_future(mut self) {
+ while let Some(command) = self.command_rx.next().await {
+ self.process_command(command).await;
+ }
+ self.connector_handle.reset();
+ }
+}
+
+#[derive(Clone)]
+/// A handle to interact with a spawned `RequestService`.
+pub struct RequestServiceHandle {
+ tx: Arc<mpsc::UnboundedSender<RequestCommand>>,
+}
+
+impl RequestServiceHandle {
+ /// Resets the corresponding RequestService, dropping all in-flight requests.
+ pub async fn reset(&self) {
+ let _ = self.tx.unbounded_send(RequestCommand::Reset);
+ }
+
+ /// Submits a `RestRequest` for exectuion to the request service.
+ pub async fn request(&self, request: RestRequest) -> Result<Response> {
+ let (completion_tx, completion_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RequestCommand::NewRequest(request, completion_tx))
+ .map_err(|_| Error::SendError)?;
+ completion_rx.await.map_err(|_| Error::ReceiveError)?
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum RequestCommand {
+ NewRequest(
+ RestRequest,
+ oneshot::Sender<std::result::Result<Response, Error>>,
+ ),
+ Reset,
+ NextApiConfig,
+}
+
+/// A REST request that is sent to the RequestService to be executed.
+#[derive(Debug)]
+pub struct RestRequest {
+ request: Request,
+ timeout: Duration,
+ auth: Option<HeaderValue>,
+}
+
+impl RestRequest {
+ /// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
+ pub fn get(uri: &str) -> Result<Self> {
+ let uri = hyper::Uri::from_str(&uri).map_err(Error::UriError)?;
+
+ let mut builder = http::request::Builder::new()
+ .method(Method::GET)
+ .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
+ .header(header::ACCEPT, HeaderValue::from_static("application/json"));
+ if let Some(host) = uri.host() {
+ builder = builder.header(header::HOST, HeaderValue::from_str(&host)?);
+ };
+
+ let request = builder
+ .uri(uri)
+ .body(hyper::Body::empty())
+ .map_err(Error::HttpError)?;
+
+ Ok(RestRequest {
+ timeout: DEFAULT_TIMEOUT,
+ auth: None,
+ request,
+ })
+ }
+
+ /// Set the auth header with the following format: `Bearer $auth`.
+ pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> {
+ let header = match auth {
+ Some(auth) => Some(
+ HeaderValue::from_str(&format!("Bearer {}", auth))
+ .map_err(Error::InvalidHeaderError)?,
+ ),
+ None => None,
+ };
+
+ self.auth = header;
+ Ok(())
+ }
+
+ /// Sets timeout for the request.
+ pub fn set_timeout(&mut self, timeout: Duration) {
+ self.timeout = timeout;
+ }
+
+ /// Retrieves timeout
+ pub fn timeout(&self) -> Duration {
+ self.timeout
+ }
+
+ pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> {
+ let header_value = http::HeaderValue::from_str(value).map_err(Error::InvalidHeaderError)?;
+ self.request.headers_mut().insert(key, header_value);
+ Ok(())
+ }
+
+ /// Converts into a `hyper::Request<hyper::Body>`
+ fn into_request(self) -> Request {
+ let Self {
+ mut request, auth, ..
+ } = self;
+ if let Some(auth) = auth {
+ request.headers_mut().insert(header::AUTHORIZATION, auth);
+ }
+ request
+ }
+
+ /// Returns the URI of the request
+ pub fn uri(&self) -> &Uri {
+ self.request.uri()
+ }
+}
+
+impl From<Request> for RestRequest {
+ fn from(request: Request) -> Self {
+ Self {
+ request,
+ timeout: DEFAULT_TIMEOUT,
+ auth: None,
+ }
+ }
+}
+
+#[derive(serde::Deserialize)]
+pub struct ErrorResponse {
+ pub code: String,
+}
+
+#[derive(Clone)]
+pub struct RequestFactory {
+ hostname: String,
+ path_prefix: Option<String>,
+ pub timeout: Duration,
+}
+
+impl RequestFactory {
+ pub fn new(hostname: String, path_prefix: Option<String>) -> Self {
+ Self {
+ hostname,
+ path_prefix,
+ timeout: DEFAULT_TIMEOUT,
+ }
+ }
+
+ pub fn request(&self, path: &str, method: Method) -> Result<RestRequest> {
+ self.hyper_request(path, method)
+ .map(RestRequest::from)
+ .map(|req| self.set_request_timeout(req))
+ }
+
+ pub fn get(&self, path: &str) -> Result<RestRequest> {
+ self.hyper_request(path, Method::GET)
+ .map(RestRequest::from)
+ .map(|req| self.set_request_timeout(req))
+ }
+
+ pub fn post(&self, path: &str) -> Result<RestRequest> {
+ self.hyper_request(path, Method::POST)
+ .map(RestRequest::from)
+ .map(|req| self.set_request_timeout(req))
+ }
+
+ pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
+ self.json_request(Method::POST, path, body)
+ }
+
+ fn json_request<S: serde::Serialize>(
+ &self,
+ method: Method,
+ path: &str,
+ body: &S,
+ ) -> Result<RestRequest> {
+ let mut request = self.hyper_request(path, method)?;
+
+ let json_body = serde_json::to_string(&body)?;
+ let body_length = json_body.as_bytes().len() as u64;
+ *request.body_mut() = json_body.into_bytes().into();
+
+ let headers = request.headers_mut();
+ headers.insert(
+ header::CONTENT_LENGTH,
+ HeaderValue::from_str(&body_length.to_string()).map_err(Error::InvalidHeaderError)?,
+ );
+ headers.insert(
+ header::CONTENT_TYPE,
+ HeaderValue::from_static("application/json"),
+ );
+
+ Ok(self.set_request_timeout(RestRequest::from(request)))
+ }
+
+ pub fn delete(&self, path: &str) -> Result<RestRequest> {
+ self.hyper_request(path, Method::DELETE)
+ .map(RestRequest::from)
+ .map(|req| self.set_request_timeout(req))
+ }
+
+ fn hyper_request(&self, path: &str, method: Method) -> Result<Request> {
+ let uri = self.get_uri(path)?;
+ let request = http::request::Builder::new()
+ .method(method)
+ .uri(uri)
+ .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
+ .header(header::ACCEPT, HeaderValue::from_static("application/json"))
+ .header(header::HOST, self.hostname.clone());
+
+ request.body(hyper::Body::empty()).map_err(Error::HttpError)
+ }
+
+ fn get_uri(&self, path: &str) -> Result<Uri> {
+ let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
+ let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
+ hyper::Uri::from_str(&uri).map_err(Error::UriError)
+ }
+
+ fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest {
+ request.timeout = self.timeout;
+ request
+ }
+}
+
+pub fn get_request<T: serde::de::DeserializeOwned>(
+ factory: &RequestFactory,
+ service: RequestServiceHandle,
+ uri: &str,
+ auth: Option<String>,
+ expected_statuses: &'static [hyper::StatusCode],
+) -> impl Future<Output = Result<Response>> + 'static {
+ let request = factory.get(uri);
+ async move {
+ let mut request = request?;
+ request.set_auth(auth)?;
+ let response = service.request(request).await?;
+ parse_rest_response(response, expected_statuses).await
+ }
+}
+
+pub fn send_request(
+ factory: &RequestFactory,
+ service: RequestServiceHandle,
+ uri: &str,
+ method: Method,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
+ expected_statuses: &'static [hyper::StatusCode],
+) -> impl Future<Output = Result<Response>> {
+ let request = factory.request(uri, method);
+
+ async move {
+ let mut request = request?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
+ let response = service.request(request).await?;
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
+ }
+}
+
+pub fn send_json_request<B: serde::Serialize>(
+ factory: &RequestFactory,
+ service: RequestServiceHandle,
+ uri: &str,
+ method: Method,
+ body: &B,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
+ expected_statuses: &'static [hyper::StatusCode],
+) -> impl Future<Output = Result<Response>> {
+ let request = factory.json_request(method, uri, body);
+ async move {
+ let mut request = request?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
+ let response = service.request(request).await?;
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
+ }
+}
+
+pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> {
+ let body_length = get_body_length(&response);
+ deserialize_body_inner(response, body_length).await
+}
+
+async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
+ mut response: Response,
+ body_length: usize,
+) -> Result<T> {
+ let mut body: Vec<u8> = Vec::with_capacity(body_length);
+ while let Some(chunk) = response.body_mut().next().await {
+ body.extend(&chunk?);
+ }
+
+ serde_json::from_slice(&body).map_err(Error::DeserializeError)
+}
+
+fn get_body_length(response: &Response) -> usize {
+ response
+ .headers()
+ .get(header::CONTENT_LENGTH)
+ .and_then(|header_value| header_value.to_str().ok())
+ .and_then(|length| length.parse::<usize>().ok())
+ .unwrap_or(0)
+}
+
+pub async fn parse_rest_response(
+ response: Response,
+ expected_statuses: &'static [hyper::StatusCode],
+) -> Result<Response> {
+ if !expected_statuses.contains(&response.status()) {
+ log::error!(
+ "Unexpected HTTP status code {}, expected codes [{}]",
+ response.status(),
+ expected_statuses
+ .iter()
+ .map(ToString::to_string)
+ .collect::<Vec<_>>()
+ .join(",")
+ );
+ if !response.status().is_success() {
+ return handle_error_response(response).await;
+ }
+ }
+
+ Ok(response)
+}
+
+pub async fn handle_error_response<T>(response: Response) -> Result<T> {
+ let status = response.status();
+ let error_message = match status {
+ hyper::StatusCode::NOT_FOUND => "Not found",
+ hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed",
+ status => match get_body_length(&response) {
+ 0 => status.canonical_reason().unwrap_or("Unexpected error"),
+ body_length => {
+ let err: ErrorResponse = deserialize_body_inner(response, body_length).await?;
+ return Err(Error::ApiError(status, err.code));
+ }
+ },
+ };
+ Err(Error::ApiError(status, error_message.to_owned()))
+}
+
+#[derive(Clone)]
+pub struct MullvadRestHandle {
+ pub(crate) service: RequestServiceHandle,
+ pub factory: RequestFactory,
+ pub availability: ApiAvailabilityHandle,
+ pub token_store: AccessTokenProxy,
+}
+
+impl MullvadRestHandle {
+ pub(crate) fn new(
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ address_cache: AddressCache,
+ availability: ApiAvailabilityHandle,
+ ) -> Self {
+ let token_store = AccessTokenProxy::new(service.clone(), factory.clone());
+
+ let handle = Self {
+ service,
+ factory,
+ availability,
+ token_store,
+ };
+ if !super::API.disable_address_cache {
+ handle.spawn_api_address_fetcher(address_cache);
+ }
+ handle
+ }
+
+ fn spawn_api_address_fetcher(&self, address_cache: AddressCache) {
+ let handle = self.clone();
+ let availability = self.availability.clone();
+
+ tokio::spawn(async move {
+ // always start the fetch after 15 minutes
+ let api_proxy = crate::ApiProxy::new(handle);
+ let mut next_check = Instant::now() + API_IP_CHECK_DELAY;
+
+ let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL;
+ let next_regular_check = || Instant::now() + API_IP_CHECK_INTERVAL;
+
+ let mut interval = tokio::time::interval_at(next_check.into(), TIMER_CHECK_INTERVAL);
+
+ loop {
+ interval.tick().await;
+ if next_check < Instant::now() {
+ if let Err(error) = availability.wait_background().await {
+ log::error!("Failed while waiting for API: {}", error);
+ next_check = next_error_check();
+ continue;
+ }
+ match api_proxy.clone().get_api_addrs().await {
+ Ok(new_addrs) => {
+ if let Some(addr) = new_addrs.get(0) {
+ log::debug!(
+ "Fetched new API address {:?}. Fetching again in {} hours",
+ addr,
+ API_IP_CHECK_INTERVAL.as_secs() / (60 * 60)
+ );
+ if let Err(err) = address_cache.set_address(*addr).await {
+ log::error!(
+ "Failed to save newly updated API address: {}",
+ err
+ );
+ }
+ } else {
+ log::error!("API returned no API addresses");
+ }
+ next_check = next_regular_check();
+ }
+ Err(err) => {
+ log::error!(
+ "Failed to fetch new API addresses: {}. Retrying in {} seconds",
+ err,
+ API_IP_CHECK_ERROR_INTERVAL.as_secs()
+ );
+ next_check = next_error_check();
+ }
+ }
+ }
+ }
+ });
+ }
+
+ pub fn service(&self) -> RequestServiceHandle {
+ self.service.clone()
+ }
+
+ pub fn factory(&self) -> &RequestFactory {
+ &self.factory
+ }
+}
+
+fn flatten_result<T, E>(
+ result: std::result::Result<std::result::Result<T, E>, E>,
+) -> std::result::Result<T, E> {
+ match result {
+ Ok(value) => value,
+ Err(err) => Err(err),
+ }
+}
diff --git a/mullvad-api/src/tls_stream.rs b/mullvad-api/src/tls_stream.rs
new file mode 100644
index 0000000000..cad0268ac3
--- /dev/null
+++ b/mullvad-api/src/tls_stream.rs
@@ -0,0 +1,122 @@
+//! Provides a TLS 1.3 stream with SNI and LE root cert only.
+use std::{
+ io::{self, ErrorKind},
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll},
+};
+
+use hyper::client::connect::{Connected, Connection};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio_rustls::{
+ rustls::{self, ClientConfig, ServerName},
+ TlsConnector,
+};
+
+const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem");
+
+pub struct TlsStream<S: AsyncRead + AsyncWrite + Unpin> {
+ stream: tokio_rustls::client::TlsStream<S>,
+}
+
+impl<S> TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ pub async fn connect_https(stream: S, domain: &str) -> io::Result<TlsStream<S>> {
+ lazy_static::lazy_static! {
+ static ref TLS_CONFIG: Arc<ClientConfig> = {
+ let config = ClientConfig::builder()
+ .with_safe_default_cipher_suites()
+ .with_safe_default_kx_groups()
+ .with_protocol_versions(&[&rustls::version::TLS13])
+ .unwrap()
+ .with_root_certificates(read_cert_store())
+ .with_no_client_auth();
+ Arc::new(config)
+ };
+ }
+
+ let connector = TlsConnector::from(TLS_CONFIG.clone());
+
+ let host = match ServerName::try_from(domain) {
+ Ok(n) => n,
+ Err(_) => {
+ return Err(io::Error::new(
+ ErrorKind::InvalidInput,
+ format!("invalid hostname \"{}\"", domain),
+ ));
+ }
+ };
+
+ let stream = connector.connect(host, stream).await?;
+
+ Ok(TlsStream { stream })
+ }
+}
+
+fn read_cert_store() -> rustls::RootCertStore {
+ let mut cert_store = rustls::RootCertStore::empty();
+
+ let certs = rustls_pemfile::certs(&mut std::io::BufReader::new(LE_ROOT_CERT))
+ .expect("Failed to parse pem file");
+ let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs);
+ if num_failures > 0 || num_certs_added != 1 {
+ panic!("Failed to add root cert");
+ }
+
+ cert_store
+}
+
+impl<S> AsyncRead for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.stream).poll_read(cx, buf)
+ }
+}
+
+impl<S> AsyncWrite for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ Pin::new(&mut self.stream).poll_write(cx, buf)
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.stream).poll_flush(cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.stream).poll_shutdown(cx)
+ }
+}
+
+impl<S> Connection for TlsStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin,
+{
+ fn connected(&self) -> Connected {
+ Connected::new()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_cert_loading() {
+ let _certs = read_cert_store();
+ }
+}