diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-03-23 16:51:43 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-23 16:51:43 +0100 |
| commit | b1910fc14517d284c05aeded77d6f34e3949e6ef (patch) | |
| tree | f762f61e1ed40bb4be8558122f3bdc4fff59570f /mullvad-api/src | |
| parent | 74d76ff8f071d96ffbc6fbb8d92ec961b839f416 (diff) | |
| parent | 918e6588d2276122e391fb426be3245d0664e245 (diff) | |
| download | mullvadvpn-b1910fc14517d284c05aeded77d6f34e3949e6ef.tar.xz mullvadvpn-b1910fc14517d284c05aeded77d6f34e3949e6ef.zip | |
Merge branch 'rename-rpc-crate'
Diffstat (limited to 'mullvad-api/src')
| -rw-r--r-- | mullvad-api/src/abortable_stream.rs | 199 | ||||
| -rw-r--r-- | mullvad-api/src/access.rs | 110 | ||||
| -rw-r--r-- | mullvad-api/src/address_cache.rs | 118 | ||||
| -rw-r--r-- | mullvad-api/src/availability.rs | 170 | ||||
| -rw-r--r-- | mullvad-api/src/bin/relay_list.rs | 41 | ||||
| -rw-r--r-- | mullvad-api/src/device.rs | 196 | ||||
| -rw-r--r-- | mullvad-api/src/https_client_with_sni.rs | 351 | ||||
| -rw-r--r-- | mullvad-api/src/lib.rs | 530 | ||||
| -rw-r--r-- | mullvad-api/src/proxy.rs | 204 | ||||
| -rw-r--r-- | mullvad-api/src/relay_list.rs | 375 | ||||
| -rw-r--r-- | mullvad-api/src/rest.rs | 694 | ||||
| -rw-r--r-- | mullvad-api/src/tls_stream.rs | 122 |
12 files changed, 3110 insertions, 0 deletions
diff --git a/mullvad-api/src/abortable_stream.rs b/mullvad-api/src/abortable_stream.rs new file mode 100644 index 0000000000..af217c5768 --- /dev/null +++ b/mullvad-api/src/abortable_stream.rs @@ -0,0 +1,199 @@ +//! Wrapper around a stream to make it abortable. This allows in-flight requests to be cancelled +//! immediately instead of after the socket times out. + +use futures::channel::oneshot; +use hyper::client::connect::{Connected, Connection}; +use std::{ + future::Future, + io, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(err_derive::Error, Debug)] +#[error(display = "Stream is closed")] +pub struct Aborted(()); + +#[derive(Clone, Debug)] +pub struct AbortableStreamHandle { + tx: Arc<Mutex<Option<oneshot::Sender<()>>>>, +} + +impl AbortableStreamHandle { + pub fn close(self) { + if let Some(tx) = self.tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } + + /// Returns whether the stream has already stopped on its own. + pub fn is_closed(&self) -> bool { + self.tx + .lock() + .unwrap() + .as_ref() + .map(|tx| tx.is_canceled()) + .unwrap_or(true) + } +} + +pub struct AbortableStream<S: Unpin> { + stream: S, + shutdown_rx: oneshot::Receiver<()>, +} + +impl<S> AbortableStream<S> +where + S: Unpin + Send + 'static, +{ + pub fn new(stream: S) -> (Self, AbortableStreamHandle) { + let (tx, rx) = oneshot::channel(); + let stream_handle = AbortableStreamHandle { + tx: Arc::new(Mutex::new(Some(tx))), + }; + ( + Self { + stream, + shutdown_rx: rx, + }, + stream_handle, + ) + } +} + +impl<S> AsyncWrite for AbortableStream<S> +where + S: AsyncWrite + Unpin + Send + 'static, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + Aborted(()), + ))); + } + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + Aborted(()), + ))); + } + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl<S> AsyncRead for AbortableStream<S> +where + S: AsyncRead + Unpin + Send + 'static, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if let Poll::Ready(_) = Pin::new(&mut self.shutdown_rx).poll(cx) { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + Aborted(()), + ))); + } + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl<S> Connection for AbortableStream<S> +where + S: Connection + Unpin, +{ + fn connected(&self) -> Connected { + self.stream.connected() + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::time::Duration; + use tokio::io::AsyncReadExt; + + /// Test whether the abort handle stops the stream. + #[test] + fn test_abort() { + let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); + + let (client, _server) = tokio::io::duplex(64); + + runtime.block_on(async move { + let (mut stream, abort_handle) = AbortableStream::new(client); + + let stream_task = tokio::spawn(async move { + let mut buf = vec![]; + stream.read_to_end(&mut buf).await + }); + + abort_handle.close(); + let result = tokio::time::timeout(Duration::from_secs(1), stream_task) + .await + .unwrap(); + assert!( + matches!(result, Ok(Err(error)) if error.kind() == io::ErrorKind::ConnectionReset) + ); + }); + } + + /// Test the `AbortableStreamHandle::is_closed` method when explicitly closed. + #[test] + fn test_shutdown_signal() { + let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); + + let (client, _server) = tokio::io::duplex(64); + + runtime.block_on(async move { + let (_stream, abort_handle) = AbortableStream::new(client); + let abort_handle_2 = abort_handle.clone(); + assert!(!abort_handle_2.is_closed()); + abort_handle.close(); + assert!(abort_handle_2.is_closed()); + }); + } + + /// Test the `AbortableStreamHandle::is_closed` method when the stream stops on its own. + #[test] + fn test_shutdown_signal_normal() { + let runtime = tokio::runtime::Runtime::new().expect("Failed to initialize runtime"); + + let (client, server) = tokio::io::duplex(64); + + runtime.block_on(async move { + let (mut stream, abort_handle) = AbortableStream::new(client); + + assert!(!abort_handle.is_closed()); + + let stream_task = tokio::spawn(async move { + drop(server); + let mut buf = vec![]; + stream.read_to_end(&mut buf).await + }); + + assert!(tokio::time::timeout(Duration::from_secs(1), stream_task) + .await + .unwrap() + .is_ok()); + assert!(abort_handle.is_closed()); + }); + } +} diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs new file mode 100644 index 0000000000..d95a5319c2 --- /dev/null +++ b/mullvad-api/src/access.rs @@ -0,0 +1,110 @@ +use crate::{ + rest, + rest::{RequestFactory, RequestServiceHandle}, +}; +use hyper::StatusCode; +use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; +use talpid_types::ErrorExt; + +pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1"; + +#[derive(Clone)] +pub struct AccessTokenProxy { + service: RequestServiceHandle, + factory: RequestFactory, + access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>, +} + +impl AccessTokenProxy { + pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { + Self { + service, + factory, + access_from_account: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Obtain access token for an account, requesting a new one from the API if necessary. + pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> { + let existing_token = { + self.access_from_account + .lock() + .unwrap() + .get(account.as_str()) + .cloned() + }; + if let Some(access_token) = existing_token { + if access_token.is_expired() { + log::debug!("Replacing expired access token"); + return self.request_new_token(account.clone()).await; + } + log::trace!("Using stored access token"); + return Ok(access_token.access_token.clone()); + } + self.request_new_token(account.clone()).await + } + + /// Remove an access token if the API response calls for it. + pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) { + if let Err(rest::Error::ApiError(_status, code)) = response { + if code == crate::INVALID_ACCESS_TOKEN { + log::debug!("Dropping invalid access token"); + self.remove_token(account); + } + } + } + + /// Removes a stored access token. + fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> { + self.access_from_account + .lock() + .unwrap() + .remove(account) + .map(|v| v.access_token) + } + + async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> { + log::debug!("Fetching access token for an account"); + let access_token = self + .fetch_access_token(account.clone()) + .await + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain access token") + ); + error + })?; + self.access_from_account + .lock() + .unwrap() + .insert(account, access_token.clone()); + Ok(access_token.access_token) + } + + async fn fetch_access_token( + &self, + account_token: AccountToken, + ) -> Result<AccessTokenData, rest::Error> { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_number: String, + } + let request = AccessTokenRequest { + account_number: account_token, + }; + + let service = self.service.clone(); + + let rest_request = self + .factory + .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await + } +} diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs new file mode 100644 index 0000000000..3b6fcba074 --- /dev/null +++ b/mullvad-api/src/address_cache.rs @@ -0,0 +1,118 @@ +use super::API; +use std::{io, net::SocketAddr, path::Path, sync::Arc}; +use tokio::{ + fs, + io::{AsyncReadExt, AsyncWriteExt}, + sync::Mutex, +}; + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + #[error(display = "Failed to open the address cache file")] + OpenAddressCache(#[error(source)] io::Error), + + #[error(display = "Failed to read the address cache file")] + ReadAddressCache(#[error(source)] io::Error), + + #[error(display = "Failed to parse the address cache file")] + ParseAddressCache, + + #[error(display = "Failed to update the address cache file")] + WriteAddressCache(#[error(source)] io::Error), + + #[error(display = "The address cache is empty")] + EmptyAddressCache, +} + +#[derive(Clone)] +pub struct AddressCache { + inner: Arc<Mutex<AddressCacheInner>>, + write_path: Option<Arc<Path>>, +} + +impl AddressCache { + /// Initialize cache using the hardcoded address, and write changes to `write_path`. + pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> { + Self::new_inner(API.addr, write_path) + } + + /// Initialize cache using `read_path`, and write changes to `write_path`. + pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> { + log::debug!("Loading API addresses from {}", read_path.display()); + Self::new_inner(read_address_file(read_path).await?, write_path) + } + + fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> { + let cache = AddressCacheInner::from_address(address); + log::debug!("Using API address: {}", cache.address); + + let address_cache = Self { + inner: Arc::new(Mutex::new(cache)), + write_path: write_path.map(|cache| Arc::from(cache)), + }; + Ok(address_cache) + } + + /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. + pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> { + if hostname.eq_ignore_ascii_case(&API.host) { + Some(self.get_address().await) + } else { + None + } + } + + /// Returns the currently selected address. + pub async fn get_address(&self) -> SocketAddr { + self.inner.lock().await.address + } + + pub async fn set_address(&self, address: SocketAddr) -> io::Result<()> { + let mut inner = self.inner.lock().await; + if address != inner.address { + self.save_to_disk(&address).await?; + inner.address = address; + } + Ok(()) + } + + async fn save_to_disk(&self, address: &SocketAddr) -> io::Result<()> { + let write_path = match self.write_path.as_ref() { + Some(write_path) => write_path, + None => return Ok(()), + }; + + let temp_path = write_path.with_file_name("api-cache.temp"); + + let mut file = fs::File::create(&temp_path).await?; + let mut contents = address.to_string(); + contents += "\n"; + file.write_all(contents.as_bytes()).await?; + file.sync_data().await?; + + fs::rename(&temp_path, write_path).await + } +} + +#[derive(Clone, PartialEq, Eq)] +struct AddressCacheInner { + address: SocketAddr, +} + +impl AddressCacheInner { + fn from_address(address: SocketAddr) -> Self { + Self { address } + } +} + +async fn read_address_file(path: &Path) -> Result<SocketAddr, Error> { + let mut file = fs::File::open(path) + .await + .map_err(|error| Error::OpenAddressCache(error))?; + let mut address = String::new(); + file.read_to_string(&mut address) + .await + .map_err(Error::ReadAddressCache)?; + address.trim().parse().map_err(|_| Error::ParseAddressCache) +} diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs new file mode 100644 index 0000000000..2cf40cf53b --- /dev/null +++ b/mullvad-api/src/availability.rs @@ -0,0 +1,170 @@ +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; +use tokio::sync::broadcast; + +const CHANNEL_CAPACITY: usize = 100; + +#[derive(err_derive::Error, Debug)] +pub enum Error { + /// The [`ApiAvailability`] instance was dropped, or the receiver lagged behind. + #[error(display = "API availability instance was dropped")] + Interrupted(#[error(source)] broadcast::error::RecvError), +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)] +pub struct State { + suspended: bool, + pause_background: bool, + offline: bool, +} + +impl State { + pub fn is_suspended(&self) -> bool { + self.suspended + } + + pub fn is_background_paused(&self) -> bool { + self.offline || self.pause_background || self.suspended + } + + pub fn is_offline(&self) -> bool { + self.offline + } +} + +pub struct ApiAvailability { + state: Arc<Mutex<State>>, + tx: broadcast::Sender<State>, +} + +impl ApiAvailability { + pub fn new(initial_state: State) -> Self { + let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY); + let state = Arc::new(Mutex::new(initial_state)); + ApiAvailability { state, tx } + } + + pub fn get_state(&self) -> State { + *self.state.lock().unwrap() + } + + pub fn handle(&self) -> ApiAvailabilityHandle { + ApiAvailabilityHandle { + state: self.state.clone(), + tx: self.tx.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ApiAvailabilityHandle { + state: Arc<Mutex<State>>, + tx: broadcast::Sender<State>, +} + +impl ApiAvailabilityHandle { + pub fn suspend(&self) { + log::debug!("Suspending API requests"); + let mut state = self.state.lock().unwrap(); + if !state.suspended { + state.suspended = true; + let _ = self.tx.send(*state); + } + } + + pub fn unsuspend(&self) { + log::debug!("Unsuspending API requests"); + let mut state = self.state.lock().unwrap(); + if state.suspended { + state.suspended = false; + let _ = self.tx.send(*state); + } + } + + pub fn pause_background(&self) { + log::debug!("Pausing background API requests"); + let mut state = self.state.lock().unwrap(); + if !state.pause_background { + state.pause_background = true; + let _ = self.tx.send(*state); + } + } + + pub fn resume_background(&self) { + log::debug!("Resuming background API requests"); + let mut state = self.state.lock().unwrap(); + if state.pause_background { + state.pause_background = false; + let _ = self.tx.send(*state); + } + } + + pub fn set_offline(&self, offline: bool) { + if offline { + log::debug!("Pausing API requests due to being offline"); + } else { + log::debug!("Resuming API requests due to coming online"); + } + let mut state = self.state.lock().unwrap(); + if state.offline != offline { + state.offline = offline; + let _ = self.tx.send(*state); + } + } + + pub fn get_state(&self) -> State { + *self.state.lock().unwrap() + } + + pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> { + self.wait_for_state(|state| !state.is_suspended()) + } + + pub fn when_bg_resumes<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> { + let wait_task = self.wait_for_state(|state| !state.is_background_paused()); + async move { + let _ = wait_task.await; + task.await + } + } + + pub fn wait_background(&self) -> impl Future<Output = Result<(), Error>> { + self.wait_for_state(|state| !state.is_background_paused()) + } + + pub fn when_online<F: Future<Output = O>, O>(&self, task: F) -> impl Future<Output = O> { + let wait_task = self.wait_for_state(|state| !state.is_offline()); + async move { + let _ = wait_task.await; + task.await + } + } + + pub fn wait_online(&self) -> impl Future<Output = Result<(), Error>> { + self.wait_for_state(|state| !state.is_offline()) + } + + fn wait_for_state( + &self, + state_ready: impl Fn(State) -> bool, + ) -> impl Future<Output = Result<(), Error>> { + let mut rx = self.tx.subscribe(); + let state = self.state.clone(); + + async move { + let current_state = { *state.lock().unwrap() }; + if state_ready(current_state) { + return Ok(()); + } + + loop { + let new_state = rx.recv().await?; + if state_ready(new_state) { + return Ok(()); + } + } + } + } +} diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs new file mode 100644 index 0000000000..2139e51f54 --- /dev/null +++ b/mullvad-api/src/bin/relay_list.rs @@ -0,0 +1,41 @@ +//! Fetches and prints the full relay list in JSON. +//! Used by the installer artifact packer to bundle the latest available +//! relay list at the time of creating the installer. + +use mullvad_api::{self, proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; +use std::process; +use talpid_types::ErrorExt; + +#[tokio::main] +async fn main() { + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) + .expect("Failed to load runtime"); + + let relay_list_request = RelayListProxy::new( + runtime + .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat(), |_| async { true }) + .await, + ) + .relay_list(None) + .await; + + let relay_list = match relay_list_request { + Ok(relay_list) => relay_list, + Err(RestError::TimeoutError(_)) => { + eprintln!("Request timed out"); + process::exit(2); + } + Err(e @ RestError::DeserializeError(_)) => { + eprintln!( + "{}", + e.display_chain_with_msg("Failed to deserialize relay list") + ); + process::exit(3); + } + Err(e) => { + eprintln!("{}", e.display_chain_with_msg("Failed to fetch relay list")); + process::exit(1); + } + }; + println!("{}", serde_json::to_string_pretty(&relay_list).unwrap()); +} diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs new file mode 100644 index 0000000000..de572aa20d --- /dev/null +++ b/mullvad-api/src/device.rs @@ -0,0 +1,196 @@ +use http::{Method, StatusCode}; +use mullvad_types::{ + account::AccountToken, + device::{Device, DeviceId, DeviceName, DevicePort}, +}; +use std::future::Future; +use talpid_types::net::wireguard; + +use crate::rest; + +use super::ACCOUNTS_URL_PREFIX; + +#[derive(Clone)] +pub struct DevicesProxy { + handle: rest::MullvadRestHandle, +} + +#[derive(serde::Deserialize)] +struct DeviceResponse { + id: DeviceId, + name: DeviceName, + pubkey: wireguard::PublicKey, + ipv4_address: ipnetwork::Ipv4Network, + ipv6_address: ipnetwork::Ipv6Network, + ports: Vec<DevicePort>, +} + +impl DevicesProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn create( + &self, + account: AccountToken, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>> + { + #[derive(serde::Serialize)] + struct DeviceSubmission { + pubkey: wireguard::PublicKey, + } + + let submission = DeviceSubmission { pubkey }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account)), + &[StatusCode::CREATED], + ) + .await; + + let response: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + id, + name, + pubkey, + ipv4_address, + ipv6_address, + ports, + .. + } = response; + + Ok(( + Device { + id, + name, + pubkey, + ports, + }, + mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }, + )) + } + } + + pub fn get( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<Device, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn list( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn remove( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<(), rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::DELETE, + Some((access_proxy, account)), + &[StatusCode::NO_CONTENT], + ) + .await; + + response?; + Ok(()) + } + } + + pub fn replace_wg_key( + &self, + account: AccountToken, + id: DeviceId, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + { + #[derive(serde::Serialize)] + struct RotateDevicePubkey { + pubkey: wireguard::PublicKey, + } + let req_body = RotateDevicePubkey { pubkey }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id), + Method::PUT, + &req_body, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + ipv4_address, + ipv6_address, + .. + } = updated_device; + Ok(mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }) + } + } +} diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs new file mode 100644 index 0000000000..409492712e --- /dev/null +++ b/mullvad-api/src/https_client_with_sni.rs @@ -0,0 +1,351 @@ +use crate::{ + abortable_stream::{AbortableStream, AbortableStreamHandle}, + proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, + tls_stream::TlsStream, + AddressCache, +}; +use futures::{channel::mpsc, future, pin_mut, StreamExt}; +#[cfg(target_os = "android")] +use futures::{channel::oneshot, sink::SinkExt}; +use http::uri::Scheme; +use hyper::{ + client::connect::dns::{GaiResolver, Name}, + service::Service, + Uri, +}; +use shadowsocks::{ + config::ServerType, + context::{Context as SsContext, SharedContext}, + crypto::v1::CipherKind, + relay::tcprelay::ProxyClientStream, + ServerConfig, +}; +#[cfg(target_os = "android")] +use std::os::unix::io::{AsRawFd, RawFd}; +use std::{ + fmt, + future::Future, + io, + net::{IpAddr, SocketAddr}, + pin::Pin, + str::{self, FromStr}, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, +}; +use talpid_types::ErrorExt; + +use tokio::{ + net::{TcpSocket, TcpStream}, + time::timeout, +}; + +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Clone)] +pub struct HttpsConnectorWithSniHandle { + tx: mpsc::UnboundedSender<HttpsConnectorRequest>, +} + +impl HttpsConnectorWithSniHandle { + /// Stop all streams produced by this connector + pub fn reset(&self) { + let _ = self.tx.unbounded_send(HttpsConnectorRequest::Reset); + } + + /// Change the proxy settings for the connector + pub fn set_connection_mode(&self, proxy: ApiConnectionMode) { + let _ = self + .tx + .unbounded_send(HttpsConnectorRequest::SetConnectionMode(proxy)); + } +} + +enum HttpsConnectorRequest { + Reset, + SetConnectionMode(ApiConnectionMode), +} + +#[derive(Clone)] +enum InnerConnectionMode { + /// Connect directly to the target. + Direct, + /// Connect to the destination via a proxy. + Proxied(ParsedShadowsocksConfig), +} + +#[derive(Clone)] +struct ParsedShadowsocksConfig { + peer: SocketAddr, + password: String, + cipher: CipherKind, +} + +impl From<ParsedShadowsocksConfig> for ServerConfig { + fn from(config: ParsedShadowsocksConfig) -> Self { + ServerConfig::new(config.peer, config.password, config.cipher) + } +} + +#[derive(err_derive::Error, Debug)] +enum ProxyConfigError { + #[error(display = "Unrecognized cipher selected: {}", _0)] + InvalidCipher(String), +} + +impl TryFrom<ApiConnectionMode> for InnerConnectionMode { + type Error = ProxyConfigError; + + fn try_from(config: ApiConnectionMode) -> Result<Self, Self::Error> { + Ok(match config { + ApiConnectionMode::Direct => InnerConnectionMode::Direct, + ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(config)) => { + InnerConnectionMode::Proxied(ParsedShadowsocksConfig { + peer: config.peer, + password: config.password, + cipher: CipherKind::from_str(&config.cipher) + .map_err(|_| ProxyConfigError::InvalidCipher(config.cipher))?, + }) + } + }) + } +} + +/// A Connector for the `https` scheme. +#[derive(Clone)] +pub struct HttpsConnectorWithSni { + inner: Arc<Mutex<HttpsConnectorWithSniInner>>, + sni_hostname: Option<String>, + address_cache: AddressCache, + abort_notify: Arc<tokio::sync::Notify>, + proxy_context: SharedContext, + #[cfg(target_os = "android")] + socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, +} + +struct HttpsConnectorWithSniInner { + stream_handles: Vec<AbortableStreamHandle>, + proxy_config: InnerConnectionMode, +} + +#[cfg(target_os = "android")] +pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>); + +impl HttpsConnectorWithSni { + pub fn new( + sni_hostname: Option<String>, + address_cache: AddressCache, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> (Self, HttpsConnectorWithSniHandle) { + let (tx, mut rx) = mpsc::unbounded(); + let abort_notify = Arc::new(tokio::sync::Notify::new()); + let inner = Arc::new(Mutex::new(HttpsConnectorWithSniInner { + stream_handles: vec![], + proxy_config: InnerConnectionMode::Direct, + })); + + let inner_copy = inner.clone(); + let notify = abort_notify.clone(); + tokio::spawn(async move { + // Handle requests by `HttpsConnectorWithSniHandle`s + while let Some(request) = rx.next().await { + let handles = { + let mut inner = inner_copy.lock().unwrap(); + + if let HttpsConnectorRequest::SetConnectionMode(config) = request { + match InnerConnectionMode::try_from(config) { + Ok(config) => { + inner.proxy_config = config; + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to parse new API proxy config" + ) + ); + } + } + } + + std::mem::take(&mut inner.stream_handles) + }; + for handle in handles { + handle.close(); + } + notify.notify_waiters(); + } + }); + + ( + HttpsConnectorWithSni { + inner, + sni_hostname, + address_cache, + abort_notify, + proxy_context: SsContext::new_shared(ServerType::Local), + #[cfg(target_os = "android")] + socket_bypass_tx, + }, + HttpsConnectorWithSniHandle { tx }, + ) + } + + async fn open_socket( + addr: SocketAddr, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> std::io::Result<TcpStream> { + let socket = match addr { + SocketAddr::V4(_) => TcpSocket::new_v4()?, + SocketAddr::V6(_) => TcpSocket::new_v6()?, + }; + + #[cfg(target_os = "android")] + if let Some(mut tx) = socket_bypass_tx { + let (done_tx, done_rx) = oneshot::channel(); + let _ = tx.send((socket.as_raw_fd(), done_tx)).await; + if let Err(_) = done_rx.await { + log::error!("Failed to bypass socket, connection might fail"); + } + } + + timeout(CONNECT_TIMEOUT, socket.connect(addr)) + .await + .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? + } + + async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> { + let hostname = uri.host().ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid url, missing host", + ))?; + let port = uri.port_u16().unwrap_or(443); + if let Some(addr) = hostname.parse::<IpAddr>().ok() { + return Ok(SocketAddr::new(addr, port)); + } + + // Preferentially, use cached address. + // + if let Some(addr) = address_cache.resolve_hostname(hostname).await { + return Ok(SocketAddr::new(addr.ip(), port)); + } + + // Use getaddrinfo as a fallback + // + let mut addrs = GaiResolver::new() + .call( + Name::from_str(&hostname) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?, + ) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let addr = addrs + .next() + .ok_or(io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; + Ok(SocketAddr::new(addr.ip(), port)) + } +} + +impl fmt::Debug for HttpsConnectorWithSni { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpsConnectorWithSni").finish() + } +} + +impl Service<Uri> for HttpsConnectorWithSni { + type Response = AbortableStream<ApiConnection>; + type Error = io::Error; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + let mut inner = self.inner.lock().unwrap(); + inner.stream_handles.retain(|handle| !handle.is_closed()); + Poll::Ready(Ok(())) + } + + fn call(&mut self, uri: Uri) -> Self::Future { + let sni_hostname = self + .sni_hostname + .clone() + .or_else(|| uri.host().map(str::to_owned)) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") + }); + let inner = self.inner.clone(); + let abort_notify = self.abort_notify.clone(); + let proxy_context = self.proxy_context.clone(); + #[cfg(target_os = "android")] + let socket_bypass_tx = self.socket_bypass_tx.clone(); + let address_cache = self.address_cache.clone(); + + let fut = async move { + if uri.scheme() != Some(&Scheme::HTTPS) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid url, not https", + )); + } + + let hostname = sni_hostname?; + let addr = Self::resolve_address(address_cache, uri).await?; + + // Loop until we have established a connection. This starts over if a new endpoint + // is selected while connecting. + let stream = loop { + let notify = abort_notify.notified(); + let config = { inner.lock().unwrap().proxy_config.clone() }; + let stream_fut = async { + match config { + InnerConnectionMode::Direct => { + let socket = Self::open_socket( + addr, + #[cfg(target_os = "android")] + socket_bypass_tx.clone(), + ) + .await?; + let tls_stream = TlsStream::connect_https(socket, &hostname).await?; + Ok::<_, io::Error>(ApiConnection::Direct(tls_stream)) + } + InnerConnectionMode::Proxied(proxy_config) => { + let socket = Self::open_socket( + proxy_config.peer, + #[cfg(target_os = "android")] + socket_bypass_tx.clone(), + ) + .await?; + let proxy = ProxyClientStream::from_stream( + proxy_context.clone(), + socket, + &ServerConfig::from(proxy_config), + addr, + ); + let tls_stream = TlsStream::connect_https(proxy, &hostname).await?; + Ok(ApiConnection::Proxied(tls_stream)) + } + } + }; + + pin_mut!(stream_fut); + pin_mut!(notify); + + // Wait for connection. Abort and retry if we switched to a different server. + if let future::Either::Left((stream, _)) = future::select(stream_fut, notify).await + { + break stream?; + } + }; + + let (stream, socket_handle) = AbortableStream::new(stream); + + { + let mut inner = inner.lock().unwrap(); + inner.stream_handles.push(socket_handle); + } + + Ok(stream) + }; + + Box::pin(fut) + } +} diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs new file mode 100644 index 0000000000..40742fe41d --- /dev/null +++ b/mullvad-api/src/lib.rs @@ -0,0 +1,530 @@ +#![deny(rust_2018_idioms)] + +use chrono::{offset::Utc, DateTime}; +#[cfg(target_os = "android")] +use futures::channel::mpsc; +use futures::Stream; +use hyper::Method; +use mullvad_types::{ + account::{AccountToken, VoucherSubmission}, + version::AppVersion, +}; +use proxy::ApiConnectionMode; +use std::{ + collections::BTreeMap, + future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::Path, +}; +use talpid_types::ErrorExt; + +pub mod availability; +use availability::{ApiAvailability, ApiAvailabilityHandle}; +pub mod rest; + +mod abortable_stream; +mod https_client_with_sni; +pub mod proxy; +mod tls_stream; +#[cfg(target_os = "android")] +pub use crate::https_client_with_sni::SocketBypassRequest; + +mod access; +mod address_cache; +pub mod device; +mod relay_list; +pub use address_cache::AddressCache; +pub use device::DevicesProxy; +pub use hyper::StatusCode; +pub use relay_list::RelayListProxy; + +/// Error code returned by the Mullvad API if the voucher has alreaby been used. +pub const VOUCHER_USED: &str = "VOUCHER_USED"; + +/// Error code returned by the Mullvad API if the voucher code is invalid. +pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER"; + +/// Error code returned by the Mullvad API if the account token is invalid. +pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT"; + +/// Error code returned by the Mullvad API if the access token is invalid. +pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN"; + +pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED"; +pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE"; + +pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt"; + +const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1"; +const APP_URL_PREFIX: &str = "app/v1"; + +lazy_static::lazy_static! { + static ref API: ApiEndpoint = ApiEndpoint::get(); +} + +/// A hostname and socketaddr to reach the Mullvad REST API over. +struct ApiEndpoint { + host: String, + addr: SocketAddr, + disable_address_cache: bool, +} + +impl ApiEndpoint { + /// Returns the endpoint to connect to the API over. + /// + /// # Panics + /// + /// Panics if `MULLVAD_API_ADDR` has invalid contents or if only one of + /// `MULLVAD_API_ADDR` or `MULLVAD_API_HOST` has been set but not the other. + fn get() -> ApiEndpoint { + const API_HOST_DEFAULT: &str = "api.mullvad.net"; + const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(193, 138, 218, 78)); + const API_PORT_DEFAULT: u16 = 443; + + fn read_var(key: &'static str) -> Option<String> { + use std::env; + match env::var(key) { + Ok(v) => Some(v), + Err(env::VarError::NotPresent) => None, + Err(env::VarError::NotUnicode(_)) => panic!("{} does not contain valid UTF-8", key), + } + } + + let host_var = read_var("MULLVAD_API_HOST"); + let address_var = read_var("MULLVAD_API_ADDR"); + + let mut api = ApiEndpoint { + host: API_HOST_DEFAULT.to_owned(), + addr: SocketAddr::new(API_IP_DEFAULT, API_PORT_DEFAULT), + disable_address_cache: false, + }; + + if cfg!(feature = "api-override") { + match (host_var, address_var) { + (None, None) => (), + (Some(_), None) => panic!("MULLVAD_API_HOST is set, but not MULLVAD_API_ADDR"), + (None, Some(_)) => panic!("MULLVAD_API_ADDR is set, but not MULLVAD_API_HOST"), + (Some(user_host), Some(user_addr)) => { + api.host = user_host; + api.addr = user_addr + .parse() + .expect("MULLVAD_API_ADDR is not a valid socketaddr"); + api.disable_address_cache = true; + log::debug!("Overriding API. Using {} at {}", api.host, api.addr); + } + } + } else { + if host_var.is_some() || address_var.is_some() { + log::warn!( + "MULLVAD_API_HOST and MULLVAD_API_ADDR are ignored in production builds" + ); + } + } + api + } +} + +/// A type that helps with the creation of API connections. +pub struct Runtime { + handle: tokio::runtime::Handle, + pub address_cache: AddressCache, + api_availability: availability::ApiAvailability, + #[cfg(target_os = "android")] + socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, +} + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Failed to construct a rest client")] + RestError(#[error(source)] rest::Error), + + #[error(display = "Failed to load address cache")] + AddressCacheError(#[error(source)] address_cache::Error), + + #[error(display = "API availability check failed")] + ApiCheckError(#[error(source)] availability::Error), +} + +/// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`. +/// It should return a future that determines whether to reject the new endpoint or not. +pub trait ApiEndpointUpdateCallback: Fn(SocketAddr) -> Self::AcceptedNewEndpoint { + type AcceptedNewEndpoint: Future<Output = bool> + Send; +} + +impl<U, T: Future<Output = bool> + Send> ApiEndpointUpdateCallback for U +where + U: Fn(SocketAddr) -> T, +{ + type AcceptedNewEndpoint = T; +} + +impl Runtime { + /// Create a new `Runtime`. + pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> { + Self::new_inner( + handle, + #[cfg(target_os = "android")] + None, + ) + } + + fn new_inner( + handle: tokio::runtime::Handle, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> Result<Self, Error> { + Ok(Runtime { + handle, + address_cache: AddressCache::new(None)?, + api_availability: ApiAvailability::new(availability::State::default()), + #[cfg(target_os = "android")] + socket_bypass_tx, + }) + } + + /// Create a new `Runtime` using the specified directories. + /// Try to use the cache directory first, and fall back on the bundled address otherwise. + pub async fn with_cache( + cache_dir: &Path, + write_changes: bool, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> Result<Self, Error> { + let handle = tokio::runtime::Handle::current(); + if API.disable_address_cache { + return Self::new_inner( + handle, + #[cfg(target_os = "android")] + socket_bypass_tx, + ); + } + + let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); + let write_file = if write_changes { + Some(cache_file.clone().into_boxed_path()) + } else { + None + }; + + let address_cache = match AddressCache::from_file(&cache_file, write_file.clone()).await { + Ok(cache) => cache, + Err(error) => { + if cache_file.exists() { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to load cached API addresses. Falling back on bundled address" + ) + ); + } + AddressCache::new(write_file)? + } + }; + + Ok(Runtime { + handle, + address_cache, + api_availability: ApiAvailability::new(availability::State::default()), + #[cfg(target_os = "android")] + socket_bypass_tx, + }) + } + + /// Creates a new request service and returns a handle to it. + async fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>( + &self, + sni_hostname: Option<String>, + proxy_provider: T, + new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> rest::RequestServiceHandle { + let service_handle = rest::RequestService::new( + sni_hostname, + self.api_availability.handle(), + self.address_cache.clone(), + proxy_provider, + new_address_callback, + #[cfg(target_os = "android")] + socket_bypass_tx, + ) + .await; + service_handle + } + + /// Returns a request factory initialized to create requests for the master API + pub async fn mullvad_rest_handle< + T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static, + >( + &self, + proxy_provider: T, + new_address_callback: impl ApiEndpointUpdateCallback + Send + Sync + 'static, + ) -> rest::MullvadRestHandle { + let service = self + .new_request_service( + Some(API.host.clone()), + proxy_provider, + new_address_callback, + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), + ) + .await; + let factory = rest::RequestFactory::new(API.host.clone(), None); + + rest::MullvadRestHandle::new( + service, + factory, + self.address_cache.clone(), + self.availability_handle(), + ) + } + + /// Returns a new request service handle + pub async fn rest_handle(&mut self) -> rest::RequestServiceHandle { + self.new_request_service( + None, + ApiConnectionMode::Direct.into_repeat(), + |_| async { true }, + #[cfg(target_os = "android")] + None, + ) + .await + } + + pub fn handle(&mut self) -> &mut tokio::runtime::Handle { + &mut self.handle + } + + pub fn availability_handle(&self) -> ApiAvailabilityHandle { + self.api_availability.handle() + } +} + +#[derive(Clone)] +pub struct AccountsProxy { + handle: rest::MullvadRestHandle, +} + +#[derive(serde::Deserialize)] +struct AccountResponse { + number: AccountToken, + expiry: DateTime<Utc>, +} + +impl AccountsProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn get_expiry( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let account: AccountResponse = rest::deserialize_body(response?).await?; + Ok(account.expiry) + } + } + + pub fn create_account(&mut self) -> impl Future<Output = Result<AccountToken, rest::Error>> { + let service = self.handle.service.clone(); + let response = rest::send_request( + &self.handle.factory, + service, + &format!("{}/accounts", ACCOUNTS_URL_PREFIX), + Method::POST, + None, + &[StatusCode::CREATED], + ); + + async move { + let account: AccountResponse = rest::deserialize_body(response.await?).await?; + Ok(account.number) + } + } + + pub fn submit_voucher( + &mut self, + account_token: AccountToken, + voucher_code: String, + ) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> { + #[derive(serde::Serialize)] + struct VoucherSubmission { + voucher_code: String, + } + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + let submission = VoucherSubmission { voucher_code }; + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/submit-voucher", APP_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account_token)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn get_www_auth_token( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<String, rest::Error>> { + #[derive(serde::Deserialize)] + struct AuthTokenResponse { + auth_token: String, + } + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/www-auth-token", APP_URL_PREFIX), + Method::POST, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + let response: AuthTokenResponse = rest::deserialize_body(response?).await?; + Ok(response.auth_token) + } + } +} + +pub struct ProblemReportProxy { + handle: rest::MullvadRestHandle, +} + +impl ProblemReportProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn problem_report( + &self, + email: &str, + message: &str, + log: &str, + metadata: &BTreeMap<String, String>, + ) -> impl Future<Output = Result<(), rest::Error>> { + #[derive(serde::Serialize)] + struct ProblemReport { + address: String, + message: String, + log: String, + metadata: BTreeMap<String, String>, + } + + let report = ProblemReport { + address: email.to_owned(), + message: message.to_owned(), + log: log.to_owned(), + metadata: metadata.clone(), + }; + + let service = self.handle.service.clone(); + + let request = rest::send_json_request( + &self.handle.factory, + service, + &format!("{}/problem-report", APP_URL_PREFIX), + Method::POST, + &report, + None, + &[StatusCode::NO_CONTENT], + ); + + async move { + request.await?; + Ok(()) + } + } +} + +#[derive(Clone)] +pub struct AppVersionProxy { + handle: rest::MullvadRestHandle, +} + +#[derive(serde::Deserialize, Debug)] +pub struct AppVersionResponse { + pub supported: bool, + pub latest: AppVersion, + pub latest_stable: Option<AppVersion>, + pub latest_beta: AppVersion, +} + +impl AppVersionProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn version_check( + &self, + app_version: AppVersion, + platform: &str, + platform_version: String, + ) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> { + let service = self.handle.service.clone(); + + let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version); + let request = self.handle.factory.request(&path, Method::GET); + + async move { + let mut request = request?; + request.add_header("M-Platform-Version", &platform_version)?; + + let response = service.request(request).await?; + let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(parsed_response).await + } + } +} + +#[derive(Clone)] +pub struct ApiProxy { + handle: rest::MullvadRestHandle, +} + +impl ApiProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> { + let service = self.handle.service.clone(); + + let response = rest::send_request( + &self.handle.factory, + service, + &format!("{}/api-addrs", APP_URL_PREFIX), + Method::GET, + None, + &[StatusCode::OK], + ) + .await?; + + rest::deserialize_body(response).await + } +} diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs new file mode 100644 index 0000000000..009a1960dc --- /dev/null +++ b/mullvad-api/src/proxy.rs @@ -0,0 +1,204 @@ +use crate::tls_stream::TlsStream; +use futures::Stream; +use hyper::client::connect::{Connected, Connection}; +use rand::{distributions::Alphanumeric, Rng}; +use serde::{Deserialize, Serialize}; +use shadowsocks::relay::tcprelay::ProxyClientStream; +use std::{ + fmt, io, + net::SocketAddr, + path::Path, + pin::Pin, + task::{self, Poll}, +}; +use talpid_types::{net::openvpn::ShadowsocksProxySettings, ErrorExt}; +use tokio::{ + fs, + io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, + net::TcpStream, +}; + +const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json"; + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +pub enum ApiConnectionMode { + /// Connect directly to the target. + Direct, + /// Connect to the destination via a proxy. + Proxied(ProxyConfig), +} + +impl fmt::Display for ApiConnectionMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + ApiConnectionMode::Direct => write!(f, "unproxied"), + ApiConnectionMode::Proxied(settings) => settings.fmt(f), + } + } +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +pub enum ProxyConfig { + Shadowsocks(ShadowsocksProxySettings), +} + +impl fmt::Display for ProxyConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + // TODO: Do not hardcode TCP + ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer), + } + } +} + +impl ApiConnectionMode { + /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`. + /// This returns `ApiConnectionMode::Direct` if reading from disk fails for any reason. + pub async fn try_from_cache(cache_dir: &Path) -> Self { + Self::from_cache(cache_dir).await.unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to read API endpoint cache") + ); + ApiConnectionMode::Direct + }) + } + + /// Reads the proxy config from `CURRENT_CONFIG_FILENAME`. + /// If the file does not exist, this returns `Ok(ApiConnectionMode::Direct)`. + async fn from_cache(cache_dir: &Path) -> io::Result<Self> { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + match fs::read_to_string(path).await { + Ok(s) => serde_json::from_str(&s).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg(&format!( + "Failed to deserialize \"{}\"", + CURRENT_CONFIG_FILENAME + )) + ); + io::Error::new(io::ErrorKind::Other, "deserialization failed") + }), + Err(error) => { + if error.kind() == io::ErrorKind::NotFound { + Ok(ApiConnectionMode::Direct) + } else { + Err(error) + } + } + } + } + + /// Stores this config to `CURRENT_CONFIG_FILENAME`. + /// The content is saved to a temporary file first, which ensures that + /// consumers of the file never end up with partial content. + pub async fn save(&self, cache_dir: &Path) -> io::Result<()> { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + let mut temp_ext = String::from("temp"); + temp_ext.push_str( + &rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::<String>(), + ); + let temp_path = path.with_extension(temp_ext); + + { + let mut file = fs::File::create(&temp_path).await?; + let json = serde_json::to_string_pretty(self) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "serialization failed"))?; + file.write_all(json.as_bytes()).await?; + file.write_all(b"\n").await?; + file.sync_data().await?; + } + + fs::rename(&temp_path, path).await + } + + /// Attempts to remove `CURRENT_CONFIG_FILENAME`, if it exists. + pub async fn try_delete_cache(cache_dir: &Path) { + let path = cache_dir.join(CURRENT_CONFIG_FILENAME); + if let Err(err) = fs::remove_file(path).await { + if err.kind() != std::io::ErrorKind::NotFound { + log::error!( + "{}", + err.display_chain_with_msg("Failed to remove old API config") + ); + } + } + } + + /// Returns the remote address, or `None` for `ApiConnectionMode::Direct`. + pub fn get_endpoint(&self) -> Option<SocketAddr> { + match self { + ApiConnectionMode::Proxied(ProxyConfig::Shadowsocks(ss)) => Some(ss.peer), + ApiConnectionMode::Direct => None, + } + } + + pub fn is_proxy(&self) -> bool { + *self != ApiConnectionMode::Direct + } + + /// Convenience function that returns a stream that repeats + /// this config forever. + pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> { + futures::stream::repeat(self) + } +} + +/// Stream that is either a regular TLS stream or TLS via shadowsocks +pub enum ApiConnection { + Direct(TlsStream<TcpStream>), + Proxied(TlsStream<ProxyClientStream<TcpStream>>), +} + +impl AsyncRead for ApiConnection { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + match Pin::get_mut(self) { + ApiConnection::Direct(s) => Pin::new(s).poll_read(cx, buf), + ApiConnection::Proxied(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ApiConnection { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + match Pin::get_mut(self) { + ApiConnection::Direct(s) => Pin::new(s).poll_write(cx, buf), + ApiConnection::Proxied(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match Pin::get_mut(self) { + ApiConnection::Direct(s) => Pin::new(s).poll_flush(cx), + ApiConnection::Proxied(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match Pin::get_mut(self) { + ApiConnection::Direct(s) => Pin::new(s).poll_shutdown(cx), + ApiConnection::Proxied(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl Connection for ApiConnection { + fn connected(&self) -> Connected { + match self { + ApiConnection::Direct(s) => s.connected(), + ApiConnection::Proxied(s) => s.connected(), + } + } +} diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs new file mode 100644 index 0000000000..5a8a01836f --- /dev/null +++ b/mullvad-api/src/relay_list.rs @@ -0,0 +1,375 @@ +//! A module dedicated to retrieving the relay list from the Mullvad API. + +use crate::rest; + +use hyper::{header, Method, StatusCode}; +use mullvad_types::{location, relay_list}; +use talpid_types::net::{wireguard, TransportProtocol}; + +use std::{ + collections::BTreeMap, + future::Future, + net::{Ipv4Addr, Ipv6Addr}, + time::Duration, +}; + +/// Fetches relay list from https://api.mullvad.net/app/v1/relays +#[derive(Clone)] +pub struct RelayListProxy { + handle: rest::MullvadRestHandle, +} + +const RELAY_LIST_TIMEOUT: Duration = Duration::from_secs(15); + +impl RelayListProxy { + /// Construct a new relay list rest client + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + /// Fetch the relay list + pub fn relay_list( + &self, + etag: Option<String>, + ) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> { + let service = self.handle.service.clone(); + let request = self.handle.factory.request("app/v1/relays", Method::GET); + + let future = async move { + let mut request = request?; + request.set_timeout(RELAY_LIST_TIMEOUT); + + if let Some(ref tag) = etag { + request.add_header(header::IF_NONE_MATCH, tag)?; + } + + let response = service.request(request).await?; + if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED { + return Ok(None); + } + if response.status() != StatusCode::OK { + return rest::handle_error_response(response).await; + } + + let etag = response + .headers() + .get(header::ETAG) + .and_then(|tag| match tag.to_str() { + Ok(tag) => Some(tag.to_string()), + Err(_) => { + log::error!("Ignoring invalid tag from server: {:?}", tag.as_bytes()); + None + } + }); + + Ok(Some( + rest::deserialize_body::<ServerRelayList>(response) + .await? + .into_relay_list(etag), + )) + }; + future + } +} + +#[derive(Debug, serde::Deserialize)] +struct ServerRelayList { + locations: BTreeMap<String, Location>, + openvpn: OpenVpn, + wireguard: Wireguard, + bridge: Bridges, +} + +impl ServerRelayList { + fn into_relay_list(self, etag: Option<String>) -> relay_list::RelayList { + let mut countries = BTreeMap::new(); + let Self { + locations, + openvpn, + wireguard, + bridge, + } = self; + + for (code, location) in locations.into_iter() { + match split_location_code(&code) { + Some((country_code, city_code)) => { + let country_code = country_code.to_lowercase(); + let city_code = city_code.to_lowercase(); + let country = countries + .entry(country_code.clone()) + .or_insert_with(|| location_to_country(&location, country_code)); + country.cities.push(location_to_city(&location, city_code)); + } + None => { + log::error!("Bad location code:{}", code); + continue; + } + } + } + + Self::add_openvpn_relays(&mut countries, openvpn); + Self::add_wireguard_relays(&mut countries, wireguard); + Self::add_bridge_relays(&mut countries, bridge); + + relay_list::RelayList { + etag: etag.map(|mut tag| { + if tag.starts_with("\"") { + tag.insert_str(0, "W/"); + } + tag + }), + countries: countries + .into_iter() + .map(|(_key, country)| country) + .collect(), + } + } + + fn add_openvpn_relays( + countries: &mut BTreeMap<String, relay_list::RelayListCountry>, + openvpn: OpenVpn, + ) { + let openvpn_endpoint_data = openvpn.ports; + for mut openvpn_relay in openvpn.relays.into_iter() { + openvpn_relay.to_lower(); + if let Some((country_code, city_code)) = split_location_code(&openvpn_relay.location) { + if let Some(country) = countries.get_mut(country_code) { + if let Some(city) = country + .cities + .iter_mut() + .find(|city| city.code == city_code) + { + let location = location::Location { + country: country.name.clone(), + country_code: country.code.clone(), + city: city.name.clone(), + city_code: city.code.clone(), + latitude: city.latitude, + longitude: city.longitude, + }; + match city + .relays + .iter_mut() + .find(|r| r.hostname == openvpn_relay.hostname) + { + Some(relay) => relay.tunnels.openvpn = openvpn_endpoint_data.clone(), + None => { + let mut relay = relay(openvpn_relay, location); + relay.tunnels.openvpn = openvpn_endpoint_data.clone(); + city.relays.push(relay); + } + }; + } + }; + } + } + } + + fn add_wireguard_relays( + countries: &mut BTreeMap<String, relay_list::RelayListCountry>, + wireguard: Wireguard, + ) { + let Wireguard { + port_ranges, + ipv4_gateway, + ipv6_gateway, + relays, + } = wireguard; + + let wireguard_endpoint_data = + |public_key: wireguard::PublicKey| relay_list::WireguardEndpointData { + port_ranges: port_ranges.clone(), + ipv4_gateway, + ipv6_gateway, + public_key, + protocol: TransportProtocol::Udp, + }; + + for mut wireguard_relay in relays { + wireguard_relay.relay.to_lower(); + if let Some((country_code, city_code)) = + split_location_code(&wireguard_relay.relay.location) + { + if let Some(country) = countries.get_mut(country_code) { + if let Some(city) = country + .cities + .iter_mut() + .find(|city| city.code == city_code) + { + let location = location::Location { + country: country.name.clone(), + country_code: country.code.clone(), + city: city.name.clone(), + city_code: city.code.clone(), + latitude: city.latitude, + longitude: city.longitude, + }; + match city + .relays + .iter_mut() + .find(|r| r.hostname == wireguard_relay.relay.hostname) + { + Some(relay) => relay + .tunnels + .wireguard + .push(wireguard_endpoint_data(wireguard_relay.public_key)), + None => { + let mut relay = relay(wireguard_relay.relay, location); + relay.ipv6_addr_in = Some(wireguard_relay.ipv6_addr_in); + relay.tunnels.wireguard = + vec![wireguard_endpoint_data(wireguard_relay.public_key)]; + city.relays.push(relay); + } + }; + } + }; + } + } + } + + fn add_bridge_relays( + countries: &mut BTreeMap<String, relay_list::RelayListCountry>, + bridges: Bridges, + ) { + let Bridges { + relays, + shadowsocks, + } = bridges; + + for mut bridge_relay in relays { + bridge_relay.to_lower(); + if let Some((country_code, city_code)) = split_location_code(&bridge_relay.location) { + if let Some(country) = countries.get_mut(country_code) { + if let Some(city) = country + .cities + .iter_mut() + .find(|city| city.code == city_code) + { + let location = location::Location { + country: country.name.clone(), + country_code: country.code.clone(), + city: city.name.clone(), + city_code: city.code.clone(), + latitude: city.latitude, + longitude: city.longitude, + }; + + match city + .relays + .iter_mut() + .find(|r| r.hostname == bridge_relay.hostname) + { + Some(relay) => { + relay.bridges.shadowsocks = shadowsocks.clone(); + } + None => { + let mut relay = relay(bridge_relay, location); + relay.bridges.shadowsocks = shadowsocks.clone(); + city.relays.push(relay); + } + }; + } + }; + } + } + } +} + +/// Splits a location code into a country code and a city code. The input is expected to be in a +/// format like `se-mma`, with `se` being the country code, `mma` being the city code. +fn split_location_code(location: &str) -> Option<(&str, &str)> { + let mut parts = location.split('-'); + let country = parts.next()?; + let city = parts.next()?; + + Some((country, city)) +} + +fn location_to_country(location: &Location, code: String) -> relay_list::RelayListCountry { + relay_list::RelayListCountry { + cities: vec![], + name: location.country.clone(), + code, + } +} + +fn location_to_city(location: &Location, code: String) -> relay_list::RelayListCity { + relay_list::RelayListCity { + name: location.city.clone(), + code, + latitude: location.latitude, + longitude: location.longitude, + relays: vec![], + } +} + +fn relay(relay: Relay, location: location::Location) -> relay_list::Relay { + relay_list::Relay { + hostname: relay.hostname, + ipv4_addr_in: relay.ipv4_addr_in, + ipv6_addr_in: None, + include_in_country: relay.include_in_country, + active: relay.active, + owned: relay.owned, + provider: relay.provider, + weight: relay.weight, + tunnels: Default::default(), + bridges: Default::default(), + location: Some(location), + } +} + +#[derive(Debug, serde::Deserialize)] +struct Location { + city: String, + country: String, + latitude: f64, + longitude: f64, +} + +#[derive(Debug, serde::Deserialize)] +struct OpenVpn { + ports: Vec<relay_list::OpenVpnEndpointData>, + relays: Vec<Relay>, +} + +#[derive(Debug, serde::Deserialize)] +struct Relay { + hostname: String, + active: bool, + owned: bool, + location: String, + provider: String, + ipv4_addr_in: Ipv4Addr, + weight: u64, + include_in_country: bool, +} + +impl Relay { + fn to_lower(&mut self) { + self.hostname = self.hostname.to_lowercase(); + self.location = self.location.to_lowercase(); + } +} + +#[derive(Debug, serde::Deserialize)] +struct Wireguard { + port_ranges: Vec<(u16, u16)>, + ipv4_gateway: Ipv4Addr, + ipv6_gateway: Ipv6Addr, + relays: Vec<WireGuardRelay>, +} + +#[derive(Debug, serde::Deserialize)] +struct WireGuardRelay { + #[serde(flatten)] + relay: Relay, + ipv6_addr_in: Ipv6Addr, + public_key: wireguard::PublicKey, +} + +#[derive(Debug, serde::Deserialize)] +struct Bridges { + shadowsocks: Vec<relay_list::ShadowsocksEndpointData>, + relays: Vec<Relay>, +} diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs new file mode 100644 index 0000000000..292bcd0cbb --- /dev/null +++ b/mullvad-api/src/rest.rs @@ -0,0 +1,694 @@ +#[cfg(target_os = "android")] +pub use crate::https_client_with_sni::SocketBypassRequest; +use crate::{ + access::AccessTokenProxy, + address_cache::AddressCache, + availability::ApiAvailabilityHandle, + https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, + proxy::ApiConnectionMode, +}; +use futures::{ + channel::{mpsc, oneshot}, + stream::StreamExt, + Stream, TryFutureExt, +}; +use hyper::{ + client::Client, + header::{self, HeaderValue}, + Method, Uri, +}; +use mullvad_types::account::AccountToken; +use std::{ + future::Future, + str::FromStr, + sync::{Arc, Weak}, + time::{Duration, Instant}, +}; +use talpid_types::ErrorExt; + +pub use hyper::StatusCode; + +pub type Request = hyper::Request<hyper::Body>; +pub type Response = hyper::Response<hyper::Body>; + +const USER_AGENT: &str = "mullvad-app"; + +const TIMER_CHECK_INTERVAL: Duration = Duration::from_secs(60); +const API_IP_CHECK_DELAY: Duration = Duration::from_secs(15 * 60); +const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); +const API_IP_CHECK_ERROR_INTERVAL: Duration = Duration::from_secs(15 * 60); + +pub type Result<T> = std::result::Result<T, Error>; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +/// Describes all the ways a REST request can fail +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "Request cancelled")] + Aborted, + + #[error(display = "Hyper error")] + HyperError(#[error(source)] hyper::Error), + + #[error(display = "Invalid header value")] + InvalidHeaderError(#[error(source)] http::header::InvalidHeaderValue), + + #[error(display = "HTTP error")] + HttpError(#[error(source)] http::Error), + + #[error(display = "Request timed out")] + TimeoutError(#[error(source)] tokio::time::error::Elapsed), + + #[error(display = "Failed to deserialize data")] + DeserializeError(#[error(source)] serde_json::Error), + + #[error(display = "Failed to send request to rest client")] + SendError, + + #[error(display = "Failed to receive response from rest client")] + ReceiveError, + + /// Unexpected response code + #[error(display = "Unexpected response status code {} - {}", _0, _1)] + ApiError(StatusCode, String), + + /// The string given was not a valid URI. + #[error(display = "Not a valid URI")] + UriError(#[error(source)] http::uri::InvalidUri), +} + +impl Error { + pub fn is_network_error(&self) -> bool { + match self { + Error::HyperError(_) | Error::TimeoutError(_) => true, + _ => false, + } + } + + /// Returns a new instance for which `abortable_stream::Aborted` is mapped to `Self::Aborted`. + fn map_aborted(self) -> Self { + if let Error::HyperError(error) = &self { + use std::error::Error; + let mut source = error.source(); + while let Some(error) = source { + let io_error: Option<&std::io::Error> = error.downcast_ref(); + if let Some(io_error) = io_error { + let abort_error: Option<&crate::abortable_stream::Aborted> = + io_error.get_ref().and_then(|inner| inner.downcast_ref()); + if abort_error.is_some() { + return Self::Aborted; + } + } + source = error.source(); + } + } + self + } +} + +use super::ApiEndpointUpdateCallback; + +/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight +/// requests +pub(crate) struct RequestService< + T: Stream<Item = ApiConnectionMode>, + F: ApiEndpointUpdateCallback + Send, +> { + command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>, + command_rx: mpsc::UnboundedReceiver<RequestCommand>, + connector_handle: HttpsConnectorWithSniHandle, + client: hyper::Client<HttpsConnectorWithSni, hyper::Body>, + proxy_config_provider: T, + new_address_callback: F, + address_cache: AddressCache, + api_availability: ApiAvailabilityHandle, +} + +impl< + T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static, + F: ApiEndpointUpdateCallback + Send + Sync + 'static, + > RequestService<T, F> +{ + /// Constructs a new request service. + pub async fn new( + sni_hostname: Option<String>, + api_availability: ApiAvailabilityHandle, + address_cache: AddressCache, + mut proxy_config_provider: T, + new_address_callback: F, + #[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>, + ) -> RequestServiceHandle { + let (connector, connector_handle) = HttpsConnectorWithSni::new( + sni_hostname, + address_cache.clone(), + #[cfg(target_os = "android")] + socket_bypass_tx.clone(), + ); + + proxy_config_provider + .next() + .await + .map(|config| connector_handle.set_connection_mode(config)); + + let (command_tx, command_rx) = mpsc::unbounded(); + let client = Client::builder().build(connector); + + let command_tx = Arc::new(command_tx); + + let service = Self { + command_tx: Arc::downgrade(&command_tx), + command_rx, + connector_handle, + client, + proxy_config_provider, + new_address_callback, + address_cache, + api_availability, + }; + let handle = RequestServiceHandle { tx: command_tx }; + tokio::spawn(service.into_future()); + handle + } + + async fn process_command(&mut self, command: RequestCommand) { + match command { + RequestCommand::NewRequest(request, completion_tx) => { + let tx = self.command_tx.upgrade(); + let timeout = request.timeout(); + + let hyper_request = request.into_request(); + + let api_availability = self.api_availability.clone(); + let suspend_fut = api_availability.wait_for_unsuspend(); + let request_fut = self.client.request(hyper_request).map_err(Error::from); + + let request_future = async move { + let _ = suspend_fut.await; + request_fut.await + }; + + let future = async move { + let response = tokio::time::timeout(timeout, request_future) + .await + .map_err(Error::TimeoutError); + + let response = flatten_result(response).map_err(|error| error.map_aborted()); + + if let Err(err) = &response { + if err.is_network_error() && !api_availability.get_state().is_offline() { + log::error!("{}", err.display_chain_with_msg("HTTP request failed")); + if let Some(tx) = tx { + let _ = tx.unbounded_send(RequestCommand::NextApiConfig); + } + } + } + + if completion_tx.send(response).is_err() { + log::trace!( + "Failed to send response to caller, caller channel is shut down" + ); + } + }; + tokio::spawn(future); + } + RequestCommand::Reset => { + self.connector_handle.reset(); + } + RequestCommand::NextApiConfig => { + if let Some(new_config) = self.proxy_config_provider.next().await { + let endpoint = match new_config.get_endpoint() { + Some(endpoint) => endpoint, + None => self.address_cache.get_address().await, + }; + // Switch to new connection mode unless rejected by address change callback + if (self.new_address_callback)(endpoint).await { + self.connector_handle.set_connection_mode(new_config); + } + } + } + } + } + + async fn into_future(mut self) { + while let Some(command) = self.command_rx.next().await { + self.process_command(command).await; + } + self.connector_handle.reset(); + } +} + +#[derive(Clone)] +/// A handle to interact with a spawned `RequestService`. +pub struct RequestServiceHandle { + tx: Arc<mpsc::UnboundedSender<RequestCommand>>, +} + +impl RequestServiceHandle { + /// Resets the corresponding RequestService, dropping all in-flight requests. + pub async fn reset(&self) { + let _ = self.tx.unbounded_send(RequestCommand::Reset); + } + + /// Submits a `RestRequest` for exectuion to the request service. + pub async fn request(&self, request: RestRequest) -> Result<Response> { + let (completion_tx, completion_rx) = oneshot::channel(); + self.tx + .unbounded_send(RequestCommand::NewRequest(request, completion_tx)) + .map_err(|_| Error::SendError)?; + completion_rx.await.map_err(|_| Error::ReceiveError)? + } +} + +#[derive(Debug)] +pub(crate) enum RequestCommand { + NewRequest( + RestRequest, + oneshot::Sender<std::result::Result<Response, Error>>, + ), + Reset, + NextApiConfig, +} + +/// A REST request that is sent to the RequestService to be executed. +#[derive(Debug)] +pub struct RestRequest { + request: Request, + timeout: Duration, + auth: Option<HeaderValue>, +} + +impl RestRequest { + /// Constructs a GET request with the given URI. Returns an error if the URI is not valid. + pub fn get(uri: &str) -> Result<Self> { + let uri = hyper::Uri::from_str(&uri).map_err(Error::UriError)?; + + let mut builder = http::request::Builder::new() + .method(Method::GET) + .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) + .header(header::ACCEPT, HeaderValue::from_static("application/json")); + if let Some(host) = uri.host() { + builder = builder.header(header::HOST, HeaderValue::from_str(&host)?); + }; + + let request = builder + .uri(uri) + .body(hyper::Body::empty()) + .map_err(Error::HttpError)?; + + Ok(RestRequest { + timeout: DEFAULT_TIMEOUT, + auth: None, + request, + }) + } + + /// Set the auth header with the following format: `Bearer $auth`. + pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> { + let header = match auth { + Some(auth) => Some( + HeaderValue::from_str(&format!("Bearer {}", auth)) + .map_err(Error::InvalidHeaderError)?, + ), + None => None, + }; + + self.auth = header; + Ok(()) + } + + /// Sets timeout for the request. + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout; + } + + /// Retrieves timeout + pub fn timeout(&self) -> Duration { + self.timeout + } + + pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> { + let header_value = http::HeaderValue::from_str(value).map_err(Error::InvalidHeaderError)?; + self.request.headers_mut().insert(key, header_value); + Ok(()) + } + + /// Converts into a `hyper::Request<hyper::Body>` + fn into_request(self) -> Request { + let Self { + mut request, auth, .. + } = self; + if let Some(auth) = auth { + request.headers_mut().insert(header::AUTHORIZATION, auth); + } + request + } + + /// Returns the URI of the request + pub fn uri(&self) -> &Uri { + self.request.uri() + } +} + +impl From<Request> for RestRequest { + fn from(request: Request) -> Self { + Self { + request, + timeout: DEFAULT_TIMEOUT, + auth: None, + } + } +} + +#[derive(serde::Deserialize)] +pub struct ErrorResponse { + pub code: String, +} + +#[derive(Clone)] +pub struct RequestFactory { + hostname: String, + path_prefix: Option<String>, + pub timeout: Duration, +} + +impl RequestFactory { + pub fn new(hostname: String, path_prefix: Option<String>) -> Self { + Self { + hostname, + path_prefix, + timeout: DEFAULT_TIMEOUT, + } + } + + pub fn request(&self, path: &str, method: Method) -> Result<RestRequest> { + self.hyper_request(path, method) + .map(RestRequest::from) + .map(|req| self.set_request_timeout(req)) + } + + pub fn get(&self, path: &str) -> Result<RestRequest> { + self.hyper_request(path, Method::GET) + .map(RestRequest::from) + .map(|req| self.set_request_timeout(req)) + } + + pub fn post(&self, path: &str) -> Result<RestRequest> { + self.hyper_request(path, Method::POST) + .map(RestRequest::from) + .map(|req| self.set_request_timeout(req)) + } + + pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> { + self.json_request(Method::POST, path, body) + } + + fn json_request<S: serde::Serialize>( + &self, + method: Method, + path: &str, + body: &S, + ) -> Result<RestRequest> { + let mut request = self.hyper_request(path, method)?; + + let json_body = serde_json::to_string(&body)?; + let body_length = json_body.as_bytes().len() as u64; + *request.body_mut() = json_body.into_bytes().into(); + + let headers = request.headers_mut(); + headers.insert( + header::CONTENT_LENGTH, + HeaderValue::from_str(&body_length.to_string()).map_err(Error::InvalidHeaderError)?, + ); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + Ok(self.set_request_timeout(RestRequest::from(request))) + } + + pub fn delete(&self, path: &str) -> Result<RestRequest> { + self.hyper_request(path, Method::DELETE) + .map(RestRequest::from) + .map(|req| self.set_request_timeout(req)) + } + + fn hyper_request(&self, path: &str, method: Method) -> Result<Request> { + let uri = self.get_uri(path)?; + let request = http::request::Builder::new() + .method(method) + .uri(uri) + .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) + .header(header::ACCEPT, HeaderValue::from_static("application/json")) + .header(header::HOST, self.hostname.clone()); + + request.body(hyper::Body::empty()).map_err(Error::HttpError) + } + + fn get_uri(&self, path: &str) -> Result<Uri> { + let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or(""); + let uri = format!("https://{}/{}{}", self.hostname, prefix, path); + hyper::Uri::from_str(&uri).map_err(Error::UriError) + } + + fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest { + request.timeout = self.timeout; + request + } +} + +pub fn get_request<T: serde::de::DeserializeOwned>( + factory: &RequestFactory, + service: RequestServiceHandle, + uri: &str, + auth: Option<String>, + expected_statuses: &'static [hyper::StatusCode], +) -> impl Future<Output = Result<Response>> + 'static { + let request = factory.get(uri); + async move { + let mut request = request?; + request.set_auth(auth)?; + let response = service.request(request).await?; + parse_rest_response(response, expected_statuses).await + } +} + +pub fn send_request( + factory: &RequestFactory, + service: RequestServiceHandle, + uri: &str, + method: Method, + auth: Option<(AccessTokenProxy, AccountToken)>, + expected_statuses: &'static [hyper::StatusCode], +) -> impl Future<Output = Result<Response>> { + let request = factory.request(uri, method); + + async move { + let mut request = request?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } + let response = service.request(request).await?; + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result + } +} + +pub fn send_json_request<B: serde::Serialize>( + factory: &RequestFactory, + service: RequestServiceHandle, + uri: &str, + method: Method, + body: &B, + auth: Option<(AccessTokenProxy, AccountToken)>, + expected_statuses: &'static [hyper::StatusCode], +) -> impl Future<Output = Result<Response>> { + let request = factory.json_request(method, uri, body); + async move { + let mut request = request?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } + let response = service.request(request).await?; + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result + } +} + +pub async fn deserialize_body<T: serde::de::DeserializeOwned>(response: Response) -> Result<T> { + let body_length = get_body_length(&response); + deserialize_body_inner(response, body_length).await +} + +async fn deserialize_body_inner<T: serde::de::DeserializeOwned>( + mut response: Response, + body_length: usize, +) -> Result<T> { + let mut body: Vec<u8> = Vec::with_capacity(body_length); + while let Some(chunk) = response.body_mut().next().await { + body.extend(&chunk?); + } + + serde_json::from_slice(&body).map_err(Error::DeserializeError) +} + +fn get_body_length(response: &Response) -> usize { + response + .headers() + .get(header::CONTENT_LENGTH) + .and_then(|header_value| header_value.to_str().ok()) + .and_then(|length| length.parse::<usize>().ok()) + .unwrap_or(0) +} + +pub async fn parse_rest_response( + response: Response, + expected_statuses: &'static [hyper::StatusCode], +) -> Result<Response> { + if !expected_statuses.contains(&response.status()) { + log::error!( + "Unexpected HTTP status code {}, expected codes [{}]", + response.status(), + expected_statuses + .iter() + .map(ToString::to_string) + .collect::<Vec<_>>() + .join(",") + ); + if !response.status().is_success() { + return handle_error_response(response).await; + } + } + + Ok(response) +} + +pub async fn handle_error_response<T>(response: Response) -> Result<T> { + let status = response.status(); + let error_message = match status { + hyper::StatusCode::NOT_FOUND => "Not found", + hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed", + status => match get_body_length(&response) { + 0 => status.canonical_reason().unwrap_or("Unexpected error"), + body_length => { + let err: ErrorResponse = deserialize_body_inner(response, body_length).await?; + return Err(Error::ApiError(status, err.code)); + } + }, + }; + Err(Error::ApiError(status, error_message.to_owned())) +} + +#[derive(Clone)] +pub struct MullvadRestHandle { + pub(crate) service: RequestServiceHandle, + pub factory: RequestFactory, + pub availability: ApiAvailabilityHandle, + pub token_store: AccessTokenProxy, +} + +impl MullvadRestHandle { + pub(crate) fn new( + service: RequestServiceHandle, + factory: RequestFactory, + address_cache: AddressCache, + availability: ApiAvailabilityHandle, + ) -> Self { + let token_store = AccessTokenProxy::new(service.clone(), factory.clone()); + + let handle = Self { + service, + factory, + availability, + token_store, + }; + if !super::API.disable_address_cache { + handle.spawn_api_address_fetcher(address_cache); + } + handle + } + + fn spawn_api_address_fetcher(&self, address_cache: AddressCache) { + let handle = self.clone(); + let availability = self.availability.clone(); + + tokio::spawn(async move { + // always start the fetch after 15 minutes + let api_proxy = crate::ApiProxy::new(handle); + let mut next_check = Instant::now() + API_IP_CHECK_DELAY; + + let next_error_check = || Instant::now() + API_IP_CHECK_ERROR_INTERVAL; + let next_regular_check = || Instant::now() + API_IP_CHECK_INTERVAL; + + let mut interval = tokio::time::interval_at(next_check.into(), TIMER_CHECK_INTERVAL); + + loop { + interval.tick().await; + if next_check < Instant::now() { + if let Err(error) = availability.wait_background().await { + log::error!("Failed while waiting for API: {}", error); + next_check = next_error_check(); + continue; + } + match api_proxy.clone().get_api_addrs().await { + Ok(new_addrs) => { + if let Some(addr) = new_addrs.get(0) { + log::debug!( + "Fetched new API address {:?}. Fetching again in {} hours", + addr, + API_IP_CHECK_INTERVAL.as_secs() / (60 * 60) + ); + if let Err(err) = address_cache.set_address(*addr).await { + log::error!( + "Failed to save newly updated API address: {}", + err + ); + } + } else { + log::error!("API returned no API addresses"); + } + next_check = next_regular_check(); + } + Err(err) => { + log::error!( + "Failed to fetch new API addresses: {}. Retrying in {} seconds", + err, + API_IP_CHECK_ERROR_INTERVAL.as_secs() + ); + next_check = next_error_check(); + } + } + } + } + }); + } + + pub fn service(&self) -> RequestServiceHandle { + self.service.clone() + } + + pub fn factory(&self) -> &RequestFactory { + &self.factory + } +} + +fn flatten_result<T, E>( + result: std::result::Result<std::result::Result<T, E>, E>, +) -> std::result::Result<T, E> { + match result { + Ok(value) => value, + Err(err) => Err(err), + } +} diff --git a/mullvad-api/src/tls_stream.rs b/mullvad-api/src/tls_stream.rs new file mode 100644 index 0000000000..cad0268ac3 --- /dev/null +++ b/mullvad-api/src/tls_stream.rs @@ -0,0 +1,122 @@ +//! Provides a TLS 1.3 stream with SNI and LE root cert only. +use std::{ + io::{self, ErrorKind}, + pin::Pin, + sync::Arc, + task::{self, Poll}, +}; + +use hyper::client::connect::{Connected, Connection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{ + rustls::{self, ClientConfig, ServerName}, + TlsConnector, +}; + +const LE_ROOT_CERT: &[u8] = include_bytes!("../le_root_cert.pem"); + +pub struct TlsStream<S: AsyncRead + AsyncWrite + Unpin> { + stream: tokio_rustls::client::TlsStream<S>, +} + +impl<S> TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn connect_https(stream: S, domain: &str) -> io::Result<TlsStream<S>> { + lazy_static::lazy_static! { + static ref TLS_CONFIG: Arc<ClientConfig> = { + let config = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(read_cert_store()) + .with_no_client_auth(); + Arc::new(config) + }; + } + + let connector = TlsConnector::from(TLS_CONFIG.clone()); + + let host = match ServerName::try_from(domain) { + Ok(n) => n, + Err(_) => { + return Err(io::Error::new( + ErrorKind::InvalidInput, + format!("invalid hostname \"{}\"", domain), + )); + } + }; + + let stream = connector.connect(host, stream).await?; + + Ok(TlsStream { stream }) + } +} + +fn read_cert_store() -> rustls::RootCertStore { + let mut cert_store = rustls::RootCertStore::empty(); + + let certs = rustls_pemfile::certs(&mut std::io::BufReader::new(LE_ROOT_CERT)) + .expect("Failed to parse pem file"); + let (num_certs_added, num_failures) = cert_store.add_parsable_certificates(&certs); + if num_failures > 0 || num_certs_added != 1 { + panic!("Failed to add root cert"); + } + + cert_store +} + +impl<S> AsyncRead for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl<S> AsyncWrite for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl<S> Connection for TlsStream<S> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cert_loading() { + let _certs = read_cert_store(); + } +} |
