diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-10-11 15:53:26 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-14 12:08:37 +0100 |
| commit | b98c366f8647b17c21bfffd903582a1dc09158fb (patch) | |
| tree | fa8b4ba6265767f4a82afb8f9a50e4f3a6f57ecc | |
| parent | 78dc4644a82d7b3fb904ef3cbac8a1f705f0a213 (diff) | |
| download | mullvadvpn-b98c366f8647b17c21bfffd903582a1dc09158fb.tar.xz mullvadvpn-b98c366f8647b17c21bfffd903582a1dc09158fb.zip | |
Implement device concept
| -rw-r--r-- | mullvad-cli/src/cmds/account.rs | 66 | ||||
| -rw-r--r-- | mullvad-cli/src/cmds/status.rs | 5 | ||||
| -rw-r--r-- | mullvad-cli/src/cmds/tunnel.rs | 11 | ||||
| -rw-r--r-- | mullvad-daemon/src/account.rs | 173 | ||||
| -rw-r--r-- | mullvad-daemon/src/device.rs | 748 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 645 | ||||
| -rw-r--r-- | mullvad-daemon/src/management_interface.rs | 89 | ||||
| -rw-r--r-- | mullvad-daemon/src/relays/mod.rs | 114 | ||||
| -rw-r--r-- | mullvad-daemon/src/settings.rs | 17 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 499 | ||||
| -rw-r--r-- | mullvad-management-interface/proto/management_interface.proto | 55 | ||||
| -rw-r--r-- | mullvad-management-interface/src/types.rs | 38 | ||||
| -rw-r--r-- | mullvad-rpc/src/access.rs | 108 | ||||
| -rw-r--r-- | mullvad-rpc/src/lib.rs | 293 | ||||
| -rw-r--r-- | mullvad-rpc/src/relay_list.rs | 4 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 56 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 72 | ||||
| -rw-r--r-- | mullvad-types/src/account.rs | 21 | ||||
| -rw-r--r-- | mullvad-types/src/device.rs | 37 | ||||
| -rw-r--r-- | mullvad-types/src/lib.rs | 1 | ||||
| -rw-r--r-- | mullvad-types/src/settings/mod.rs | 44 |
21 files changed, 1692 insertions, 1404 deletions
diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs index 0bbbc28024..fae3b39396 100644 --- a/mullvad-cli/src/cmds/account.rs +++ b/mullvad-cli/src/cmds/account.rs @@ -16,23 +16,17 @@ impl Command for Account { clap::App::new(self.name()) .about("Control and display information about your Mullvad account") .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::App::new("create").about("Create and log in to a new account")) .subcommand( - clap::App::new("set").about("Change account").arg( + clap::App::new("login").about("Log in to an account").arg( clap::Arg::new("token") .help("The Mullvad account token to configure the client with") .required(false), ), ) + .subcommand(clap::App::new("logout").about("Log out of the current account")) .subcommand( - clap::App::new("get") - .about("Display information about the currently configured account"), - ) - .subcommand( - clap::App::new("unset").about("Removes the account number from the settings"), - ) - .subcommand( - clap::App::new("create") - .about("Creates a new account and sets it as the active one"), + clap::App::new("get").about("Display information about the current account"), ) .subcommand( clap::App::new("redeem").about("Redeems a voucher").arg( @@ -44,7 +38,9 @@ impl Command for Account { } async fn run(&self, matches: &clap::ArgMatches) -> Result<()> { - if let Some(set_matches) = matches.subcommand_matches("set") { + if let Some(_matches) = matches.subcommand_matches("create") { + self.create().await + } else if let Some(set_matches) = matches.subcommand_matches("login") { let mut token = match set_matches.value_of("token") { Some(token) => token.to_string(), None => { @@ -60,13 +56,11 @@ impl Command for Account { } }; token = token.split_whitespace().join("").to_string(); - self.set(Some(token)).await + self.login(token).await + } else if let Some(_matches) = matches.subcommand_matches("logout") { + self.logout().await } else if let Some(_matches) = matches.subcommand_matches("get") { self.get().await - } else if let Some(_matches) = matches.subcommand_matches("unset") { - self.set(None).await - } else if let Some(_matches) = matches.subcommand_matches("create") { - self.create().await } else if let Some(matches) = matches.subcommand_matches("redeem") { let voucher = matches.value_of_t_or_exit("voucher"); self.redeem_voucher(voucher).await @@ -77,24 +71,35 @@ impl Command for Account { } impl Account { - async fn set(&self, token: Option<AccountToken>) -> Result<()> { + async fn create(&self) -> Result<()> { let mut rpc = new_rpc_client().await?; - rpc.set_account(token.clone().unwrap_or_default()).await?; - if let Some(token) = token { - println!("Mullvad account \"{}\" set", token); - } else { - println!("Mullvad account removed"); - } + rpc.create_new_account(()).await?; + println!("New account created!"); + self.get().await + } + + async fn login(&self, token: AccountToken) -> Result<()> { + let mut rpc = new_rpc_client().await?; + rpc.login_account(token.clone()).await?; + println!("Mullvad account \"{}\" set", token); + Ok(()) + } + + async fn logout(&self) -> Result<()> { + let mut rpc = new_rpc_client().await?; + rpc.logout_account(()).await?; + println!("Removed device from Mullvad account"); Ok(()) } async fn get(&self) -> Result<()> { let mut rpc = new_rpc_client().await?; - let settings = rpc.get_settings(()).await?.into_inner(); - if settings.account_token != "" { - println!("Mullvad account: {}", settings.account_token); + let device = rpc.get_device(()).await?.into_inner(); + if !device.account_token.is_empty() { + println!("Mullvad account: {}", device.account_token); + println!("Device name : {}", device.device.unwrap().name); let expiry = rpc - .get_account_data(settings.account_token) + .get_account_data(device.account_token) .await .map_err(|error| Error::RpcFailedExt("Failed to fetch account data", error))? .into_inner(); @@ -108,13 +113,6 @@ impl Account { Ok(()) } - async fn create(&self) -> Result<()> { - let mut rpc = new_rpc_client().await?; - rpc.create_new_account(()).await?; - println!("New account created!"); - self.get().await - } - async fn redeem_voucher(&self, mut voucher: String) -> Result<()> { let mut rpc = new_rpc_client().await?; voucher.retain(|c| c.is_alphanumeric()); diff --git a/mullvad-cli/src/cmds/status.rs b/mullvad-cli/src/cmds/status.rs index 8c4a929c30..f5a681e36c 100644 --- a/mullvad-cli/src/cmds/status.rs +++ b/mullvad-cli/src/cmds/status.rs @@ -74,10 +74,9 @@ impl Command for Status { println!("New app version info: {:#?}", app_version_info); } } - EventType::KeyEvent(key_event) => { + EventType::Device(device) => { if verbose { - print!("Key event: "); - print_keygen_event(&key_event); + println!("Device event: {:#?}", device); } } } diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs index f3b218648e..f27e29d147 100644 --- a/mullvad-cli/src/cmds/tunnel.rs +++ b/mullvad-cli/src/cmds/tunnel.rs @@ -246,20 +246,13 @@ impl Tunnel { println!("No key is set"); return Ok(()); } - - let is_valid = rpc - .verify_wireguard_key(()) - .await - .map_err(|error| Error::RpcFailedExt("Failed to verify key", error))? - .into_inner(); - println!("Key is valid for use with current account: {}", is_valid); Ok(()) } async fn process_wireguard_key_generate() -> Result<()> { let mut rpc = new_rpc_client().await?; - let keygen_event = rpc.generate_wireguard_key(()).await?; - print_keygen_event(&keygen_event.into_inner()); + let keygen_event = rpc.rotate_wireguard_key(()).await?; + println!("Rotated WireGuard key"); Ok(()) } diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs deleted file mode 100644 index f5655c9d1f..0000000000 --- a/mullvad-daemon/src/account.rs +++ /dev/null @@ -1,173 +0,0 @@ -use chrono::{DateTime, Utc}; -use futures::future::{abortable, AbortHandle}; -use mullvad_rpc::{ - availability::ApiAvailabilityHandle, - rest::{self, Error as RestError, MullvadRestHandle}, - AccountsProxy, -}; -use mullvad_types::account::{AccountToken, VoucherSubmission}; -use std::{future::Future, time::Duration}; -use talpid_core::future_retry::{ - constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered, -}; - -const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; -const RETRY_ACTION_MAX_RETRIES: usize = 2; - -const RETRY_EXPIRY_CHECK_INTERVAL_INITIAL: Duration = Duration::from_secs(4); -const RETRY_EXPIRY_CHECK_INTERVAL_FACTOR: u32 = 5; -const RETRY_EXPIRY_CHECK_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); - -pub struct Account(()); - -#[derive(Clone)] -pub struct AccountHandle { - api_availability: ApiAvailabilityHandle, - initial_check_abort_handle: AbortHandle, - proxy: AccountsProxy, -} - -impl AccountHandle { - pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> { - let mut proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - retry_future_n( - move || proxy.create_account(), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - } - - pub fn get_www_auth_token( - &self, - account: AccountToken, - ) -> impl Future<Output = Result<String, rest::Error>> { - let proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - retry_future_n( - move || proxy.get_www_auth_token(account.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - } - - pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { - let proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - let result = retry_future_n( - move || proxy.get_expiry(token.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - .await; - if handle_expiry_result_inner(&result, &self.api_availability) { - self.initial_check_abort_handle.abort(); - } - result - } - - pub async fn submit_voucher( - &mut self, - account_token: AccountToken, - voucher: String, - ) -> Result<VoucherSubmission, rest::Error> { - let mut proxy = self.proxy.clone(); - let api_handle = self.api_availability.clone(); - let result = retry_future_n( - move || proxy.submit_voucher(account_token.clone(), voucher.clone()), - move |result| Self::should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, - ) - .await; - if result.is_ok() { - self.initial_check_abort_handle.abort(); - self.api_availability.resume_background(); - } - result - } - - fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool { - match result { - Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), - _ => false, - } - } -} - -impl Account { - pub fn new( - runtime: tokio::runtime::Handle, - rpc_handle: MullvadRestHandle, - token: Option<String>, - api_availability: ApiAvailabilityHandle, - ) -> AccountHandle { - let accounts_proxy = AccountsProxy::new(rpc_handle); - api_availability.pause_background(); - - let api_availability_copy = api_availability.clone(); - let accounts_proxy_copy = accounts_proxy.clone(); - - let (future, initial_check_abort_handle) = abortable(async move { - let token = if let Some(token) = token { - token - } else { - api_availability.pause_background(); - return; - }; - - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new( - RETRY_EXPIRY_CHECK_INTERVAL_INITIAL, - RETRY_EXPIRY_CHECK_INTERVAL_FACTOR, - ) - .max_delay(RETRY_EXPIRY_CHECK_INTERVAL_MAX), - ); - let future_generator = move || { - let wait_online = api_availability.wait_online(); - let expiry_fut = accounts_proxy.get_expiry(token.clone()); - let api_availability_copy = api_availability.clone(); - async move { - let _ = wait_online.await; - handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) - } - }; - let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; - retry_future(future_generator, should_retry, retry_strategy).await; - }); - runtime.spawn(future); - - AccountHandle { - api_availability: api_availability_copy, - initial_check_abort_handle, - proxy: accounts_proxy_copy, - } - } -} - -fn handle_expiry_result_inner( - result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>, - api_availability: &ApiAvailabilityHandle, -) -> bool { - match result { - Ok(_expiry) if *_expiry >= chrono::Utc::now() => { - api_availability.resume_background(); - true - } - Ok(_expiry) => { - api_availability.pause_background(); - true - } - Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => { - if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH { - api_availability.pause_background(); - return true; - } - false - } - Err(_) => false, - } -} diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs new file mode 100644 index 0000000000..42c8ee23bf --- /dev/null +++ b/mullvad-daemon/src/device.rs @@ -0,0 +1,748 @@ +use crate::DaemonEventSender; +use chrono::{DateTime, Utc}; +use futures::{ + channel::mpsc, + future::{abortable, AbortHandle}, + stream::StreamExt, +}; +use mullvad_rpc::{ + availability::{self, ApiAvailabilityHandle}, + rest::{self, Error as RestError, MullvadRestHandle}, + AccountsProxy, DevicesProxy, +}; +use mullvad_types::{ + account::{AccountToken, VoucherSubmission}, + device::{Device, DeviceData, DeviceId}, + wireguard::{RotationInterval, WireguardData}, +}; +use std::{ + future::Future, + path::Path, + sync::{Arc, Mutex}, + time::Duration, +}; +use talpid_core::{ + future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered}, + mpsc::Sender, +}; +use talpid_types::{net::wireguard::PrivateKey, ErrorExt}; +use tokio::{ + fs, + io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; + +/// How often to check whether the key has expired. +/// A short interval is used in case the computer is ever suspended. +const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60); + +/// File that used to store account and device data. +const DEVICE_CACHE_FILENAME: &str = "device.json"; + +const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; +const RETRY_ACTION_MAX_RETRIES: usize = 2; + +const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4); +const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5; +const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); + +pub struct DeviceKeyEvent(pub DeviceData); + +#[derive(err_derive::Error, Debug)] +pub enum Error { + #[error(display = "The account has reached the maximum number of devices")] + TooManyDevices, + #[error(display = "No device is set")] + NoDevice, + #[error(display = "The login attempt was aborted")] + LoginAborted, + #[error(display = "Unexpected HTTP request error")] + RestError(#[error(source)] rest::Error), + #[error(display = "API availability check was interrupted")] + ApiCheckError(#[error(source)] availability::Error), + #[error(display = "Failed to read or write device cache")] + DeviceIoError(#[error(source)] io::Error), + #[error(display = "Failed parse device cache")] + ParseDeviceCache(#[error(source)] serde_json::Error), +} + +pub(crate) struct AccountManager { + runtime: tokio::runtime::Handle, + account_service: AccountService, + device_service: DeviceService, + inner: Arc<Mutex<AccountManagerInner>>, + cache_update_tx: mpsc::UnboundedSender<Option<DeviceData>>, + cache_task_join_handle: Option<tokio::task::JoinHandle<()>>, + key_update_tx: DaemonEventSender<DeviceKeyEvent>, + rotation_abort_handle: Option<AbortHandle>, +} + +struct AccountManagerInner { + data: Option<DeviceData>, + rotation_interval: RotationInterval, +} + +impl AccountManager { + pub async fn new( + runtime: tokio::runtime::Handle, + rest_handle: rest::MullvadRestHandle, + api_availability: ApiAvailabilityHandle, + settings_dir: &Path, + key_update_tx: DaemonEventSender<DeviceKeyEvent>, + ) -> Result<AccountManager, Error> { + let (mut cacher, device_data) = DeviceCacher::new(settings_dir).await?; + let token = device_data.as_ref().map(|state| state.token.clone()); + let account_service = Account::new( + runtime.clone(), + rest_handle.clone(), + token, + api_availability.clone(), + ); + let should_start_rotation = device_data.is_some(); + let inner = Arc::new(Mutex::new(AccountManagerInner { + data: device_data, + rotation_interval: RotationInterval::default(), + })); + + let (cache_update_tx, mut cache_update_rx) = mpsc::unbounded(); + let cache_task_join_handle = runtime.spawn(async move { + while let Some(new_device) = cache_update_rx.next().await { + if let Err(error) = cacher.write(new_device).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update device cache") + ); + } + } + }); + + let mut manager = AccountManager { + runtime, + account_service, + device_service: DeviceService::new(rest_handle, api_availability), + inner, + cache_update_tx, + cache_task_join_handle: Some(cache_task_join_handle), + key_update_tx, + rotation_abort_handle: None, + }; + + if should_start_rotation { + manager.start_key_rotation(); + } + + Ok(manager) + } + + pub fn account_service(&self) -> AccountService { + self.account_service.clone() + } + + pub fn device_service(&self) -> DeviceService { + self.device_service.clone() + } + + pub async fn login(&mut self, token: AccountToken) -> Result<DeviceData, Error> { + let data = self.device_service.generate_for_account(token).await?; + self.logout(); + { + let mut inner = self.inner.lock().unwrap(); + inner.data.replace(data.clone()); + let _ = self.cache_update_tx.unbounded_send(Some(data.clone())); + } + self.start_key_rotation(); + + Ok(data) + } + + pub fn set(&mut self, data: DeviceData) { + self.logout(); + { + let mut inner = self.inner.lock().unwrap(); + inner.data.replace(data.clone()); + let _ = self.cache_update_tx.unbounded_send(Some(data)); + } + self.start_key_rotation(); + } + + /// Log out without waiting for the result. + pub fn logout(&mut self) { + let fut = self.logout_inner(true); + self.runtime.spawn(fut); + } + + /// Log out, and wait until the API has removed the device. + #[cfg(not(target_os = "android"))] + pub fn logout_wait(&mut self) -> impl Future<Output = Result<(), Error>> { + self.logout_inner(false) + } + + fn logout_inner(&mut self, use_backoff: bool) -> impl Future<Output = Result<(), Error>> { + self.stop_key_rotation(); + let data = { + let mut inner = self.inner.lock().unwrap(); + let _ = self.cache_update_tx.unbounded_send(None); + inner.data.take() + }; + let service = self.device_service.clone(); + async move { + if let Some(data) = data { + if use_backoff { + return service + .remove_device_with_backoff(data.token, data.device.id) + .await; + } else { + return service.remove_device(data.token, data.device.id).await; + } + } + Ok(()) + } + } + + pub async fn rotate_key(&mut self) -> Result<WireguardData, Error> { + let mut data = { + let inner = self.inner.lock().unwrap(); + inner.data.as_ref().ok_or(Error::NoDevice)?.clone() + }; + self.stop_key_rotation(); + let result = self + .device_service + .rotate_key(data.token.clone(), data.device.id.clone()) + .await; + if let Ok(ref wg_data) = result { + data.wg_data = wg_data.clone(); + let mut inner = self.inner.lock().unwrap(); + inner.data.replace(data.clone()); + let _ = self.cache_update_tx.unbounded_send(Some(data)); + } + self.start_key_rotation(); + result + } + + pub fn get(&self) -> Option<DeviceData> { + self.inner.lock().unwrap().data.clone() + } + + pub fn is_some(&self) -> bool { + self.inner.lock().unwrap().data.is_some() + } + + pub async fn set_rotation_interval(&mut self, interval: RotationInterval) { + self.stop_key_rotation(); + let restart_rotation = { + let mut inner = self.inner.lock().unwrap(); + inner.rotation_interval = interval; + inner.data.is_some() + }; + if restart_rotation { + self.start_key_rotation(); + } + } + + fn start_key_rotation(&mut self) { + self.stop_key_rotation(); + + let service = self.device_service.clone(); + let inner = self.inner.clone(); + let cache_update_tx = self.cache_update_tx.clone(); + let key_update_tx = self.key_update_tx.clone(); + + let (task, abort_handle) = abortable(async move { + loop { + tokio::time::sleep(KEY_CHECK_INTERVAL).await; + + let rotation_interval = { inner.lock().unwrap().rotation_interval.clone() }; + + let mut state = { + match inner.lock().unwrap().data.as_ref() { + Some(device_config) => device_config.clone(), + None => continue, + } + }; + + if (chrono::Utc::now() + .signed_duration_since(state.wg_data.created) + .num_seconds() as u64) + < rotation_interval.as_duration().as_secs() + { + continue; + } + + match service + .rotate_key_with_backoff(state.token.clone(), state.device.id.clone()) + .await + { + Ok(wg_data) => { + state.wg_data = wg_data; + { + let mut inner = inner.lock().unwrap(); + inner.data.replace(state.clone()); + let _ = cache_update_tx.unbounded_send(Some(state.clone())); + } + let _ = key_update_tx.send(DeviceKeyEvent(state)); + } + Err(error) => { + log::debug!("{}", error.display_chain_with_msg("Stopping key rotation")); + } + } + } + }); + self.runtime.spawn(task); + self.rotation_abort_handle = Some(abort_handle); + } + + fn stop_key_rotation(&mut self) { + if let Some(abort_handle) = self.rotation_abort_handle.take() { + abort_handle.abort(); + } + } +} + +impl Drop for AccountManager { + fn drop(&mut self) { + self.stop_key_rotation(); + if let Some(cache_task_join_handle) = self.cache_task_join_handle.take() { + let _ = self.runtime.block_on(cache_task_join_handle); + } + } +} + +#[derive(Clone)] +pub struct DeviceService { + api_availability: ApiAvailabilityHandle, + proxy: DevicesProxy, +} + +impl DeviceService { + fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self { + Self { + proxy: DevicesProxy::new(handle), + api_availability, + } + } + + /// Generate a new device for a given token + pub async fn generate_for_account(&self, token: AccountToken) -> Result<DeviceData, Error> { + let private_key = PrivateKey::new_from_random(); + let pubkey = private_key.public_key(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let token_copy = token.clone(); + let (device, addresses) = retry_future_n( + move || proxy.create(token_copy.clone(), pubkey.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await?; + + Ok(DeviceData { + token, + device, + wg_data: WireguardData { + private_key, + addresses, + created: Utc::now(), + }, + }) + } + + pub async fn generate_for_account_with_backoff( + &self, + token: AccountToken, + ) -> Result<DeviceData, Error> { + let private_key = PrivateKey::new_from_random(); + let pubkey = private_key.public_key(); + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ) + .max_delay(RETRY_BACKOFF_INTERVAL_MAX), + ); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let token_copy = token.clone(); + let (device, addresses) = retry_future( + move || { + let wait_online = api_handle.wait_online(); + let fut = proxy.create(token_copy.clone(), pubkey.clone()); + async move { + let _ = wait_online.await; + fut.await + } + }, + should_retry_backoff, + retry_strategy, + ) + .await?; + + Ok(DeviceData { + token, + device, + wg_data: WireguardData { + private_key, + addresses, + created: Utc::now(), + }, + }) + } + + pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.remove(token.clone(), device.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await?; + Ok(()) + } + + pub async fn remove_device_with_backoff( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<(), Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ), // Not setting a maximum interval + ); + + retry_future( + move || { + let wait_online = api_handle.wait_online(); + let fut = proxy.remove(token.clone(), device.clone()); + async move { + let _ = wait_online.await; + fut.await + } + }, + should_retry_backoff, + retry_strategy, + ) + .await?; + + Ok(()) + } + + pub async fn rotate_key( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<WireguardData, Error> { + let private_key = PrivateKey::new_from_random(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let pubkey = private_key.public_key(); + let addresses = retry_future_n( + move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await?; + + Ok(WireguardData { + private_key, + addresses, + created: Utc::now(), + }) + } + + pub async fn rotate_key_with_backoff( + &self, + token: AccountToken, + device: DeviceId, + ) -> Result<WireguardData, Error> { + let private_key = PrivateKey::new_from_random(); + + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let pubkey = private_key.public_key(); + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ) + .max_delay(RETRY_BACKOFF_INTERVAL_MAX), + ); + let addresses = retry_future( + move || { + let wait_online = api_handle.wait_online(); + let fut = proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()); + async move { + let _ = wait_online.await; + fut.await + } + }, + should_retry_backoff, + retry_strategy, + ) + .await?; + + Ok(WireguardData { + private_key, + addresses, + created: Utc::now(), + }) + } + + pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.list(token.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await + .map_err(Error::RestError) + } +} + +pub struct DeviceCacher { + file: io::BufWriter<fs::File>, +} + +impl DeviceCacher { + pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> { + let mut options = std::fs::OpenOptions::new(); + #[cfg(unix)] + { + use std::os::unix::fs::OpenOptionsExt; + options.mode(0o600); + } + #[cfg(windows)] + { + use std::os::windows::fs::OpenOptionsExt; + // exclusive access + options.share_mode(0); + } + + let path = settings_dir.join(DEVICE_CACHE_FILENAME); + let cache_exists = path.is_file(); + + let mut file = fs::OpenOptions::from(options) + .write(true) + .read(true) + .create(true) + .open(path) + .await?; + + let device: Option<DeviceData> = if cache_exists { + let mut reader = io::BufReader::new(&mut file); + let mut buffer = String::new(); + reader.read_to_string(&mut buffer).await?; + if !buffer.is_empty() { + serde_json::from_str(&buffer)? + } else { + None + } + } else { + None + }; + + Ok(( + DeviceCacher { + file: io::BufWriter::new(file), + }, + device, + )) + } + + pub async fn write(&mut self, device: Option<DeviceData>) -> Result<(), Error> { + let data = serde_json::to_vec_pretty(&device).unwrap(); + + self.file.get_mut().set_len(0).await?; + self.file.seek(io::SeekFrom::Start(0)).await?; + self.file.write_all(&data).await?; + self.file.flush().await?; + self.file.get_mut().sync_data().await?; + + Ok(()) + } +} + +#[derive(Clone)] +pub struct AccountService { + api_availability: ApiAvailabilityHandle, + initial_check_abort_handle: AbortHandle, + proxy: AccountsProxy, +} + +impl AccountService { + pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> { + let mut proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.create_account(), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + } + + pub fn get_www_auth_token( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<String, rest::Error>> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + retry_future_n( + move || proxy.get_www_auth_token(account.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + } + + pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> { + let proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let result = retry_future_n( + move || proxy.get_expiry(token.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await; + if handle_expiry_result_inner(&result, &self.api_availability) { + self.initial_check_abort_handle.abort(); + } + result + } + + pub async fn submit_voucher( + &mut self, + account_token: AccountToken, + voucher: String, + ) -> Result<VoucherSubmission, rest::Error> { + let mut proxy = self.proxy.clone(); + let api_handle = self.api_availability.clone(); + let result = retry_future_n( + move || proxy.submit_voucher(account_token.clone(), voucher.clone()), + move |result| should_retry(result, &api_handle), + constant_interval(RETRY_ACTION_INTERVAL), + RETRY_ACTION_MAX_RETRIES, + ) + .await; + if result.is_ok() { + self.initial_check_abort_handle.abort(); + self.api_availability.resume_background(); + } + result + } +} + +struct Account(()); + +impl Account { + pub fn new( + runtime: tokio::runtime::Handle, + rpc_handle: MullvadRestHandle, + token: Option<String>, + api_availability: ApiAvailabilityHandle, + ) -> AccountService { + let accounts_proxy = AccountsProxy::new(rpc_handle); + api_availability.pause_background(); + + let api_availability_copy = api_availability.clone(); + let accounts_proxy_copy = accounts_proxy.clone(); + + let (future, initial_check_abort_handle) = abortable(async move { + let token = if let Some(token) = token { + token + } else { + api_availability.pause_background(); + return; + }; + + let retry_strategy = Jittered::jitter( + ExponentialBackoff::new( + RETRY_BACKOFF_INTERVAL_INITIAL, + RETRY_BACKOFF_INTERVAL_FACTOR, + ) + .max_delay(RETRY_BACKOFF_INTERVAL_MAX), + ); + let future_generator = move || { + let wait_online = api_availability.wait_online(); + let expiry_fut = accounts_proxy.get_expiry(token.clone()); + let api_availability_copy = api_availability.clone(); + async move { + let _ = wait_online.await; + handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) + } + }; + let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; + retry_future(future_generator, should_retry, retry_strategy).await; + }); + runtime.spawn(future); + + AccountService { + api_availability: api_availability_copy, + initial_check_abort_handle, + proxy: accounts_proxy_copy, + } + } +} + +fn handle_expiry_result_inner( + result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>, + api_availability: &ApiAvailabilityHandle, +) -> bool { + match result { + Ok(_expiry) if *_expiry >= chrono::Utc::now() => { + api_availability.resume_background(); + true + } + Ok(_expiry) => { + api_availability.pause_background(); + true + } + Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => { + if code == mullvad_rpc::INVALID_ACCOUNT { + api_availability.pause_background(); + return true; + } + false + } + Err(_) => false, + } +} + +fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool { + match result { + Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), + _ => false, + } +} + +fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool { + match result { + Ok(_) => false, + Err(error) => { + if let RestError::ApiError(status, code) = error { + *status != rest::StatusCode::NOT_FOUND + && code != mullvad_rpc::INVALID_ACCOUNT + && code != mullvad_rpc::KEY_LIMIT_REACHED + && code != mullvad_rpc::MAX_DEVICES_REACHED + && code != mullvad_rpc::PUBKEY_IN_USE + } else { + true + } + } + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 0d9ec96c87..8d38eb8f54 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -4,9 +4,9 @@ #[macro_use] extern crate serde; -mod account; pub mod account_history; mod api; +pub mod device; pub mod exception_logging; #[cfg(target_os = "macos")] pub mod exclusion_gid; @@ -36,6 +36,7 @@ use mullvad_rpc::{ }; use mullvad_types::{ account::{AccountData, AccountToken, VoucherSubmission}, + device::{Device, DeviceData, DeviceEvent, DeviceId}, endpoint::MullvadEndpoint, location::{Coordinates, GeoIpLocation}, relay_constraints::{ @@ -46,7 +47,7 @@ use mullvad_types::{ settings::{DnsOptions, DnsState, Settings}, states::{TargetState, TunnelState}, version::{AppVersion, AppVersionInfo}, - wireguard::{KeygenEvent, RotationInterval}, + wireguard::{KeygenEvent, PublicKey, RotationInterval}, }; use settings::SettingsPersister; #[cfg(target_os = "android")] @@ -75,7 +76,8 @@ use talpid_types::android::AndroidContext; use talpid_types::{ net::{ openvpn::{self, ProxySettings}, - TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType, + wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, + TunnelType, }, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, @@ -84,12 +86,6 @@ use talpid_types::{ use tokio::fs; use tokio::io; -#[path = "wireguard.rs"] -mod wireguard; - -/// Timeout for first WireGuard key pushing -const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5); - /// Delay between generating a new WireGuard key and reconnecting const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60); @@ -124,13 +120,28 @@ pub enum Error { #[error(display = "Unable to load account history")] LoadAccountHistory(#[error(source)] account_history::Error), + #[error(display = "Failed to start account manager")] + LoadAccountManager(#[error(source)] device::Error), + + #[error(display = "Failed to log in to account")] + LoginError(#[error(source)] device::Error), + + #[error(display = "Failed to log out of account")] + LogoutError(#[error(source)] device::Error), + + #[error(display = "Failed to rotate WireGuard key")] + KeyRotationError(#[error(source)] device::Error), + + #[error(display = "Failed to list devices")] + ListDevicesError(#[error(source)] device::Error), + + #[error(display = "Failed to remove device")] + RemoveDeviceError(#[error(source)] device::Error), + #[cfg(target_os = "linux")] #[error(display = "Unable to initialize split tunneling")] InitSplitTunneling(#[error(source)] split_tunnel::Error), - #[error(display = "The account has too many wireguard keys")] - TooManyKeys, - #[cfg(windows)] #[error(display = "Split tunneling error")] SplitTunnelError(#[error(source)] split_tunnel::Error), @@ -226,8 +237,16 @@ pub enum DaemonCommand { /// Trigger an asynchronous relay list update. This returns before the relay list is actually /// updated. UpdateRelayLocations, - /// Set which account token to use for subsequent connection attempts. - SetAccount(ResponseTx<(), settings::Error>, Option<AccountToken>), + /// Log in with a given account and create a new device. + LoginAccount(ResponseTx<(), Error>, AccountToken), + /// Log out of the current account and remove the device, if they exist. + LogoutAccount(ResponseTx<(), Error>), + /// Return the current device configuration, if there is one. + GetDevice(ResponseTx<Option<DeviceData>, Error>), + /// Return all the devices for a given account token. + ListDevices(ResponseTx<Vec<Device>, Error>, AccountToken), + /// Remove device from a given account. + RemoveDevice(ResponseTx<(), Error>, AccountToken, DeviceId), /// Place constraints on the type of tunnel and relay UpdateRelaySettings(ResponseTx<(), settings::Error>, RelaySettingsUpdate), /// Set the allow LAN setting. @@ -256,11 +275,9 @@ pub enum DaemonCommand { /// Get the daemon settings GetSettings(oneshot::Sender<Settings>), /// Generate new wireguard key - GenerateWireguardKey(ResponseTx<wireguard::KeygenEvent, Error>), + RotateWireguardKey(ResponseTx<(), Error>), /// Return a public key of the currently set wireguard private key, if there is one - GetWireguardKey(ResponseTx<Option<wireguard::PublicKey>, Error>), - /// Verify if the currently set wireguard key is valid. - VerifyWireguardKey(ResponseTx<bool, Error>), + GetWireguardKey(ResponseTx<Option<PublicKey>, Error>), /// Get information about the currently running and latest app versions GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>), /// Get current version of the app @@ -320,19 +337,14 @@ pub(crate) enum InternalDaemonEvent { Command(DaemonCommand), /// Daemon shutdown triggered by a signal, ctrl-c or similar. TriggerShutdown, - /// Wireguard key generation event - WgKeyEvent( - ( - AccountToken, - Result<mullvad_types::wireguard::WireguardData, wireguard::Error>, - ), - ), /// New Account created NewAccountEvent(AccountToken, oneshot::Sender<Result<String, Error>>), /// The background job fetching new `AppVersionInfo`s got a new info object. NewAppVersionInfo(AppVersionInfo), /// Request from REST client to use a different API endpoint. GenerateApiConnectionMode(api::ApiConnectionModeRequest), + /// Sent when a device key is rotated. + DeviceKeyEvent(device::DeviceKeyEvent), /// The split tunnel paths or state were updated. #[cfg(target_os = "windows")] ExcludedPathsEvent(ExcludedPathsUpdate, oneshot::Sender<Result<(), Error>>), @@ -368,6 +380,12 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent { } } +impl From<device::DeviceKeyEvent> for InternalDaemonEvent { + fn from(event: device::DeviceKeyEvent) -> Self { + InternalDaemonEvent::DeviceKeyEvent(event) + } +} + #[derive(Clone, Debug, Eq, PartialEq)] enum DaemonExecutionState { Running, @@ -529,8 +547,8 @@ pub trait EventListener { /// Or some flag about the currently running version is changed. fn notify_app_version(&self, app_version_info: AppVersionInfo); - /// Notify clients of a key generation event. - fn notify_key_event(&self, key_event: KeygenEvent); + /// Notify that device changed (login, logout, or key rotation). + fn notify_device_event(&self, event: DeviceEvent); } pub struct Daemon<L: EventListener> { @@ -546,10 +564,9 @@ pub struct Daemon<L: EventListener> { event_listener: L, settings: SettingsPersister, account_history: account_history::AccountHistory, - account: account::AccountHandle, + account_manager: device::AccountManager, rpc_runtime: mullvad_rpc::MullvadRpcRuntime, rpc_handle: mullvad_rpc::rest::MullvadRestHandle, - wireguard_key_manager: wireguard::KeyManager, version_updater_handle: version_check::VersionUpdaterHandle, relay_selector: relays::RelaySelector, last_generated_relay: Option<Relay>, @@ -584,8 +601,6 @@ where mullvad_rpc::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; - let runtime = tokio::runtime::Handle::current(); - let (internal_event_tx, internal_event_rx) = command_channel.destructure(); if let Err(error) = migrations::migrate_all(&cache_dir, &settings_dir).await { @@ -596,7 +611,55 @@ where } let settings = SettingsPersister::load(&settings_dir).await; - let target_state = if settings.get_account_token().is_none() { + let tunnel_parameters_generator = MullvadTunnelParametersGenerator { + tx: internal_event_tx.clone(), + }; + + let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( + &cache_dir, + true, + #[cfg(target_os = "android")] + Self::create_bypass_tx(&internal_event_tx), + ) + .await + .map_err(Error::InitRpcFactory)?; + + let api_availability = rpc_runtime.availability_handle(); + api_availability.suspend(); + + let endpoint_updater = api::ApiEndpointUpdaterHandle::new(); + + let proxy_provider = api::create_api_config_provider( + internal_event_tx.to_specialized_sender(), + ApiConnectionMode::Direct, + ); + let rpc_handle = rpc_runtime + .mullvad_rest_handle(proxy_provider, endpoint_updater.callback()) + .await; + + let mut account_manager = device::AccountManager::new( + runtime.clone(), + rpc_handle.clone(), + api_availability.clone(), + &settings_dir, + internal_event_tx.to_specialized_sender(), + ) + .await + .map_err(Error::LoadAccountManager)?; + if let Some(rotation_interval) = settings.tunnel_options.wireguard.rotation_interval { + account_manager + .set_rotation_interval(rotation_interval) + .await; + } + + let account_history = account_history::AccountHistory::new( + &settings_dir, + account_manager.get().map(|device| device.token), + ) + .await + .map_err(Error::LoadAccountHistory)?; + + let target_state = if !account_manager.is_some() { PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await } else if settings.auto_connect { log::info!("Automatically connecting since auto-connect is turned on"); @@ -605,10 +668,6 @@ where PersistentTargetState::new(&cache_dir).await }; - let tunnel_parameters_generator = MullvadTunnelParametersGenerator { - tx: internal_event_tx.clone(), - }; - #[cfg(windows)] let exclude_paths = if settings.split_tunnel.enable_exclusions { settings @@ -621,18 +680,6 @@ where vec![] }; - let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache( - &cache_dir, - true, - #[cfg(target_os = "android")] - Self::create_bypass_tx(&internal_event_tx), - ) - .await - .map_err(Error::InitRpcFactory)?; - - let api_availability = rpc_runtime.availability_handle(); - api_availability.suspend(); - let initial_api_endpoint = api::get_allowed_endpoint(rpc_runtime.address_cache.get_address().await); @@ -664,17 +711,8 @@ where .await .map_err(Error::TunnelError)?; - let endpoint_updater = api::ApiEndpointUpdaterHandle::new(); endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx)); - let proxy_provider = api::create_api_config_provider( - internal_event_tx.to_specialized_sender(), - ApiConnectionMode::Direct, - ); - let rpc_handle = rpc_runtime - .mullvad_rest_handle(proxy_provider, endpoint_updater.callback()) - .await; - Self::forward_offline_state(api_availability.clone(), offline_state_rx).await; let relay_list_listener = event_listener.clone(); @@ -700,28 +738,11 @@ where settings.show_beta_releases, ); tokio::spawn(version_updater.run()); - let account_history = - account_history::AccountHistory::new(&settings_dir, settings.get_account_token()) - .await - .map_err(Error::LoadAccountHistory)?; - - let wireguard_key_manager = wireguard::KeyManager::new( - internal_event_tx.clone(), - api_availability.clone(), - rpc_handle.clone(), - ); - - let account = account::Account::new( - runtime, - rpc_handle.clone(), - settings.get_account_token(), - api_availability.clone(), - ); // Attempt to download a fresh relay list relay_selector.update().await; - let mut daemon = Daemon { + let daemon = Daemon { tunnel_command_tx, tunnel_state: TunnelState::Disconnected, target_state, @@ -734,10 +755,9 @@ where event_listener, settings, account_history, - account, + account_manager, rpc_runtime, rpc_handle, - wireguard_key_manager, version_updater_handle, relay_selector, last_generated_relay: None, @@ -751,8 +771,6 @@ where volume_update_tx, }; - daemon.ensure_wireguard_keys_for_current_account().await; - api_availability.unsuspend(); Ok(daemon) @@ -881,7 +899,6 @@ where } Command(command) => self.handle_command(command).await, TriggerShutdown => self.trigger_shutdown_event(), - WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await, NewAccountEvent(account_token, tx) => { self.handle_new_account_event(account_token, tx).await } @@ -891,6 +908,7 @@ where GenerateApiConnectionMode(request) => { self.handle_generate_api_connection_mode(request).await } + DeviceKeyEvent(event) => self.handle_device_key_event(event).await, #[cfg(windows)] ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await, } @@ -967,7 +985,7 @@ where >, retry_attempt: u32, ) { - if let Some(account_token) = self.settings.get_account_token() { + if let Some(device) = self.account_manager.get() { let result = match self.settings.get_relay_settings() { RelaySettings::CustomTunnelEndpoint(custom_relay) => { self.last_generated_relay = None; @@ -987,7 +1005,6 @@ where &constraints, self.settings.get_bridge_state(), retry_attempt, - self.settings.get_wireguard().is_some(), ) .ok(); if let Some(relays::RelaySelectorResult { @@ -1000,7 +1017,7 @@ where .create_tunnel_parameters( &exit_relay, endpoint, - account_token, + device.token, retry_attempt, ) .await; @@ -1111,7 +1128,11 @@ where .into()) } MullvadEndpoint::Wireguard(endpoint) => { - let wg_data = self.settings.get_wireguard().ok_or(Error::NoKeyAvailable)?; + let wg_data = self + .account_manager + .get() + .map(|device| device.wg_data) + .ok_or(Error::NoKeyAvailable)?; let tunnel = wireguard::TunnelConfig { private_key: wg_data.private_key, addresses: vec![ @@ -1175,7 +1196,13 @@ where SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await, GetRelayLocations(tx) => self.on_get_relay_locations(tx), UpdateRelayLocations => self.on_update_relay_locations().await, - SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await, + LoginAccount(tx, account_token) => self.on_login_account(tx, account_token).await, + LogoutAccount(tx) => self.on_logout_account(tx).await, + GetDevice(tx) => self.on_get_device(tx).await, + ListDevices(tx, account_token) => self.on_list_devices(tx, account_token).await, + RemoveDevice(tx, account_token, device_id) => { + self.on_remove_device(tx, account_token, device_id).await + } GetAccountHistory(tx) => self.on_get_account_history(tx), ClearAccountHistory(tx) => self.on_clear_account_history(tx).await, UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update).await, @@ -1198,9 +1225,8 @@ where self.on_set_wireguard_rotation_interval(tx, interval).await } GetSettings(tx) => self.on_get_settings(tx), - GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await, + RotateWireguardKey(tx) => self.on_rotate_wireguard_key(tx).await, GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await, - VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await, GetVersionInfo(tx) => self.on_get_version_info(tx).await, GetCurrentVersion(tx) => self.on_get_current_version(tx), #[cfg(not(target_os = "android"))] @@ -1232,86 +1258,6 @@ where } } - async fn handle_wireguard_key_event( - &mut self, - event: ( - AccountToken, - Result<mullvad_types::wireguard::WireguardData, wireguard::Error>, - ), - ) { - let (account, result) = event; - // If the account has been reset whilst a key was being generated, the event should be - // dropped even if a new key was generated. - if self - .settings - .get_account_token() - .map(|current_account| current_account != account) - .unwrap_or(true) - { - log::info!("Dropping wireguard key event since account has been changed"); - return; - } - - match result { - Ok(data) => { - let public_key = data.get_public_key(); - let is_first_key = self.settings.get_wireguard().is_none(); - match self.settings.set_wireguard(Some(data)).await { - Ok(_) => { - if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY).await; - } - self.event_listener - .notify_key_event(KeygenEvent::NewKey(public_key)); - if is_first_key { - self.ensure_key_rotation().await; - } - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg( - "Failed to add new wireguard key to account data" - ) - ); - self.event_listener - .notify_key_event(KeygenEvent::GenerationFailure) - } - } - } - Err(wireguard::Error::TooManyKeys) => { - self.event_listener - .notify_key_event(KeygenEvent::TooManyKeys); - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg("Failed to generate wireguard key") - ); - self.event_listener - .notify_key_event(KeygenEvent::GenerationFailure); - } - } - } - - async fn ensure_key_rotation(&mut self) { - let token = match self.settings.get_account_token() { - Some(token) => token, - None => return, - }; - let public_key = match self.settings.get_wireguard() { - Some(data) => data.get_public_key(), - None => return, - }; - self.wireguard_key_manager - .set_rotation_interval( - public_key, - token, - self.settings.tunnel_options.wireguard.rotation_interval, - ) - .await; - } - async fn handle_new_account_event( &mut self, new_token: AccountToken, @@ -1322,12 +1268,12 @@ where self.set_target_state(TargetState::Unsecured).await; let _ = tx.send(Ok(new_token)); } - Err(err) => { + Err(error) => { log::error!( "{}", - err.display_chain_with_msg("Failed to save new account") + error.display_chain_with_msg("Handling new account failed") ); - let _ = tx.send(Err(Error::SettingsError(err))); + let _ = tx.send(Err(error)); } }; } @@ -1409,6 +1355,25 @@ where let _ = request.response_tx.send(config); } + async fn handle_device_key_event(&mut self, event: device::DeviceKeyEvent) { + let device_id = &event.0.device.id; + if Some(device_id) + != self + .account_manager + .get() + .map(|device| device.device.id) + .as_ref() + { + // Stale config + return; + } + if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { + self.schedule_reconnect(WG_RECONNECT_DELAY).await; + } + self.event_listener + .notify_device_event(DeviceEvent(Some(Device::from(event.0)))); + } + #[cfg(windows)] async fn handle_new_excluded_paths( &mut self, @@ -1541,7 +1506,7 @@ where async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) { let daemon_tx = self.tx.clone(); - let future = self.account.create_account(); + let future = self.account_manager.account_service().create_account(); tokio::spawn(async move { match future.await { Ok(account_token) => { @@ -1559,7 +1524,7 @@ where tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>, account_token: AccountToken, ) { - let account = self.account.clone(); + let account = self.account_manager.account_service(); tokio::spawn(async move { let result = account.check_expiry(account_token).await; Self::oneshot_send( @@ -1571,8 +1536,11 @@ where } async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) { - if let Some(account_token) = self.settings.get_account_token() { - let future = self.account.get_www_auth_token(account_token); + if let Some(device) = self.account_manager.get() { + let future = self + .account_manager + .account_service() + .get_www_auth_token(device.token); let rpc_call = async { Self::oneshot_send( tx, @@ -1595,13 +1563,13 @@ where tx: ResponseTx<VoucherSubmission, Error>, voucher: String, ) { - if let Some(account_token) = self.settings.get_account_token() { - let mut account = self.account.clone(); + if let Some(device) = self.account_manager.get() { + let mut account = self.account_manager.account_service(); tokio::spawn(async move { Self::oneshot_send( tx, account - .submit_voucher(account_token, voucher) + .submit_voucher(device.token, voucher) .await .map_err(Error::RestError), "submit_voucher response", @@ -1620,90 +1588,103 @@ where self.relay_selector.update().await; } - async fn on_set_account( - &mut self, - tx: ResponseTx<(), settings::Error>, - account_token: Option<String>, - ) { - match self.set_account(account_token.clone()).await { + async fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) { + match self.set_account(Some(account_token)).await { Ok(account_changed) => { if account_changed { - match account_token { - Some(_) => { - log::info!( - "Initiating tunnel restart because the account token changed" - ); - self.reconnect_tunnel(); - } - None => { - log::info!("Disconnecting because account token was cleared"); - self.set_target_state(TargetState::Unsecured).await; - } - }; + log::info!("Initiating tunnel restart because the account token changed"); + self.reconnect_tunnel(); } - Self::oneshot_send(tx, Ok(()), "set_account response"); + Self::oneshot_send(tx, Ok(()), "login_account response"); } Err(error) => { - log::error!("{}", error.display_chain_with_msg("Failed to set account")); - Self::oneshot_send(tx, Err(error), "set_account response"); + log::error!("{}", error.display_chain_with_msg("Login failed")); + Self::oneshot_send(tx, Err(error), "login_account response"); } } } - async fn set_account( - &mut self, - account_token: Option<String>, - ) -> Result<bool, settings::Error> { - let previous_token = self.settings.get_account_token(); - let account_changed = self - .settings - .set_account_token(account_token.clone()) - .await?; - if account_changed { - self.event_listener - .notify_settings(self.settings.to_settings()); - - let history_token = match account_token { - Some(token) => token, - None => previous_token.clone().unwrap_or("".to_string()), - }; - if let Err(error) = self.account_history.set(history_token).await { - log::error!( - "{}", - error.display_chain_with_msg("Failed to update account history") - ); + async fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) { + match self.set_account(None).await { + Ok(account_changed) => { + if account_changed { + log::info!("Disconnecting because account token was cleared"); + self.set_target_state(TargetState::Unsecured).await; + } + Self::oneshot_send(tx, Ok(()), "logout_account response"); + } + Err(error) => { + log::error!("{}", error.display_chain_with_msg("Logout failed")); + Self::oneshot_send(tx, Err(error), "logout_account response"); } + } + } - if let Some(previous_token) = previous_token { - if let Some(previous_key) = self - .settings - .get_wireguard() - .map(|data| data.private_key.public_key()) - { - let remove_key = self - .wireguard_key_manager - .remove_key_with_backoff(previous_token, previous_key); - tokio::spawn(async move { - if let Err(error) = remove_key.await { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to remove WireGuard key for previous account" - ) - ); - } - }); - } + async fn set_account(&mut self, account_token: Option<String>) -> Result<bool, Error> { + let previous_token = self.account_manager.get().map(|device| device.token); + if previous_token == account_token { + return Ok(false); + } + + match account_token.clone() { + Some(token) => { + let device_data = self + .account_manager + .login(token) + .await + .map_err(Error::LoginError)?; + self.event_listener + .notify_device_event(DeviceEvent(Some(Device::from(device_data)))); + } + None => { + self.account_manager.logout(); + self.event_listener.notify_device_event(DeviceEvent(None)); } - if let Err(error) = self.settings.set_wireguard(None).await { + } + + if let Some(token) = account_token.or(previous_token) { + if let Err(error) = self.account_history.set(token).await { log::error!( "{}", - error.display_chain_with_msg("Error resetting WireGuard key") + error.display_chain_with_msg("Failed to update account history") ); } - self.ensure_wireguard_keys_for_current_account().await; } - Ok(account_changed) + + Ok(true) + } + + async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceData>, Error>) { + Self::oneshot_send(tx, Ok(self.account_manager.get()), "get_device response"); + } + + async fn on_list_devices(&mut self, tx: ResponseTx<Vec<Device>, Error>, token: AccountToken) { + Self::oneshot_send( + tx, + self.account_manager + .device_service() + .list_devices(token) + .await + .map_err(Error::ListDevicesError), + "list_devices response", + ); + } + + async fn on_remove_device( + &mut self, + tx: ResponseTx<(), Error>, + token: AccountToken, + device_id: DeviceId, + ) { + Self::oneshot_send( + tx, + self.account_manager + .device_service() + .remove_device(token, device_id) + .await + .map_err(Error::RemoveDeviceError), + "remove_device response", + ); } fn on_get_account_history(&mut self, tx: oneshot::Sender<Option<AccountToken>>) { @@ -1723,37 +1704,6 @@ where Self::oneshot_send(tx, result, "clear_account_history response"); } - // Remove the key associated with the current account, if there is one. - // This does not modify settings or account history. - #[cfg(not(target_os = "android"))] - fn remove_current_key_rpc(&self) -> impl std::future::Future<Output = Result<(), Error>> { - let remove_key = if let Some(token) = self.settings.get_account_token() { - if let Some(wg_data) = self.settings.get_wireguard() { - Some( - self.wireguard_key_manager - .remove_key(token, wg_data.private_key.public_key()), - ) - } else { - None - } - } else { - None - }; - - async move { - if let Some(task) = remove_key { - match task.await { - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - // This result should never occur - Err(wireguard::Error::TooManyKeys) => Err(Error::TooManyKeys), - _ => Ok(()), - } - } else { - Ok(()) - } - } - } - async fn on_get_version_info(&mut self, tx: oneshot::Sender<Option<AppVersionInfo>>) { if self.app_version_info.is_none() { log::debug!("No version cache found. Fetching new info"); @@ -1795,17 +1745,13 @@ where async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) { let mut last_error = Ok(()); - let remove_key = self.remove_current_key_rpc(); - tokio::spawn(async move { - if let Err(error) = remove_key.await { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to remove WireGuard key for previous account" - ) - ); - } - }); + if let Err(error) = self.account_manager.logout_wait().await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to clear device cache") + ); + last_error = Err(Error::LogoutError(error)); + } if let Err(error) = self.account_history.clear().await { log::error!( @@ -2315,7 +2261,9 @@ where Ok(settings_changed) => { Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response"); if settings_changed { - self.ensure_key_rotation().await; + self.account_manager + .set_rotation_interval(interval.unwrap_or_default()) + .await; self.event_listener .notify_settings(self.settings.to_settings()); } @@ -2327,128 +2275,25 @@ where } } - async fn ensure_wireguard_keys_for_current_account(&mut self) { - if let Some(account) = self.settings.get_account_token() { - if self.settings.get_wireguard().is_none() { - log::info!("Generating new WireGuard key for account"); - self.wireguard_key_manager - .spawn_key_generation_task(account, Some(FIRST_KEY_PUSH_TIMEOUT)) - .await; - } else { - log::info!("Account already has WireGuard key"); - self.ensure_key_rotation().await; - } - } - } - - async fn on_generate_wireguard_key(&mut self, tx: ResponseTx<KeygenEvent, Error>) { - match self.on_generate_wireguard_key_inner().await { - Ok(key_event) => { - Self::oneshot_send(tx, Ok(key_event), "generate_wireguard_key"); - } - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg("Failed to generate new wireguard key") - ); - Self::oneshot_send(tx, Err(e), "generate_wireguard_key"); - } - } - } - - async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, Error> { - let account_token = self - .settings - .get_account_token() - .ok_or(Error::NoAccountToken)?; - let wireguard_data = self.settings.get_wireguard(); - - let gen_result = match &wireguard_data { - Some(wireguard_data) => { - self.wireguard_key_manager - .replace_key(account_token.clone(), wireguard_data.get_public_key()) - .await - } - None => { - self.wireguard_key_manager - .generate_key_sync(account_token.clone()) - .await - } - }; - - match gen_result { - Ok(new_data) => { - let public_key = new_data.get_public_key(); - self.settings - .set_wireguard(Some(new_data)) - .await - .map_err(Error::SettingsError)?; - if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY).await; - } - let keygen_event = KeygenEvent::NewKey(public_key.clone()); - self.event_listener.notify_key_event(keygen_event.clone()); - - // update automatic rotation - self.wireguard_key_manager - .set_rotation_interval( - public_key, - account_token, - self.settings.tunnel_options.wireguard.rotation_interval, - ) - .await; - - Ok(keygen_event) - } - Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys), - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), + async fn on_rotate_wireguard_key(&mut self, tx: ResponseTx<(), Error>) { + let result = self.account_manager.rotate_key().await; + if let Ok(ref _wg_data) = result { + let device = self.account_manager.get().map(Device::from); + self.event_listener + .notify_device_event(DeviceEvent(device.clone())); } + let _ = tx.send(result.map(|_| ()).map_err(Error::KeyRotationError)); } - async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<wireguard::PublicKey>, Error>) { - let result = if self.settings.get_account_token().is_some() { - Ok(self - .settings - .get_wireguard() - .map(|data| data.get_public_key())) + async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<PublicKey>, Error>) { + let result = if let Some(device) = self.account_manager.get() { + Ok(Some(device.wg_data.get_public_key())) } else { Err(Error::NoAccountToken) }; Self::oneshot_send(tx, result, "get_wireguard_key response"); } - async fn on_verify_wireguard_key(&mut self, tx: ResponseTx<bool, Error>) { - let account = match self.settings.get_account_token() { - Some(account) => account, - None => { - Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response"); - return; - } - }; - let public_key = match self.settings.get_wireguard() { - Some(wg_data) => wg_data.private_key.public_key(), - None => { - Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response"); - return; - } - }; - - let verification_rpc = self - .wireguard_key_manager - .verify_wireguard_key(account, public_key); - - tokio::spawn(async move { - let result = match verification_rpc.await { - Ok(is_valid) => Ok(is_valid), - Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), - Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)), - Err(wireguard::Error::TooManyKeys) => return, - }; - Self::oneshot_send(tx, result, "verify_wireguard_key response"); - }); - } - fn on_get_settings(&self, tx: oneshot::Sender<Settings>) { Self::oneshot_send(tx, self.settings.to_settings(), "get_settings response"); } diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index ba828ed903..e6f84660c1 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -378,20 +378,25 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_daemon_error) } - async fn set_account(&self, request: Request<AccountToken>) -> ServiceResult<()> { - log::debug!("set_account"); + async fn login_account(&self, request: Request<AccountToken>) -> ServiceResult<()> { + log::debug!("login_account"); let account_token = request.into_inner(); - let account_token = if account_token == "" { - None - } else { - Some(account_token) - }; let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::SetAccount(tx, account_token))?; + self.send_command_to_daemon(DaemonCommand::LoginAccount(tx, account_token))?; self.wait_for_result(rx) .await? .map(Response::new) - .map_err(map_settings_error) + .map_err(map_daemon_error) + } + + async fn logout_account(&self, _: Request<()>) -> ServiceResult<()> { + log::debug!("logout_account"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::LogoutAccount(tx))?; + self.wait_for_result(rx) + .await? + .map(Response::new) + .map_err(map_daemon_error) } async fn get_account_data( @@ -479,6 +484,44 @@ impl ManagementService for ManagementServiceImpl { }) } + // Device management + async fn get_device(&self, _: Request<()>) -> ServiceResult<types::DeviceConfig> { + log::debug!("get_device"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::GetDevice(tx))?; + let device = self + .wait_for_result(rx) + .await? + .map_err(map_daemon_error)? + .ok_or(Status::new(Code::NotFound, "no device is set"))?; + Ok(Response::new(types::DeviceConfig::from(device))) + } + + async fn list_devices( + &self, + request: Request<AccountToken>, + ) -> ServiceResult<types::DeviceList> { + log::debug!("list_devices"); + let (tx, rx) = oneshot::channel(); + let token = request.into_inner(); + self.send_command_to_daemon(DaemonCommand::ListDevices(tx, token))?; + let device = self.wait_for_result(rx).await?.map_err(map_daemon_error)?; + Ok(Response::new(types::DeviceList::from(device))) + } + + async fn remove_device(&self, request: Request<types::DeviceRemoval>) -> ServiceResult<()> { + log::debug!("remove_device"); + let (tx, rx) = oneshot::channel(); + let removal = request.into_inner(); + self.send_command_to_daemon(DaemonCommand::RemoveDevice( + tx, + removal.account_token, + removal.device_id, + ))?; + self.wait_for_result(rx).await?.map_err(map_daemon_error)?; + Ok(Response::new(())) + } + // WireGuard key management // @@ -515,15 +558,13 @@ impl ManagementService for ManagementServiceImpl { .map_err(map_settings_error) } - async fn generate_wireguard_key(&self, _: Request<()>) -> ServiceResult<types::KeygenEvent> { - // TODO: return error for TooManyKeys, GenerationFailure - // on success, simply return the new key or nil - log::debug!("generate_wireguard_key"); + async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> { + log::debug!("rotate_wireguard_key"); let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::GenerateWireguardKey(tx))?; + self.send_command_to_daemon(DaemonCommand::RotateWireguardKey(tx))?; self.wait_for_result(rx) .await? - .map(|event| Response::new(types::KeygenEvent::from(event))) + .map(Response::new) .map_err(map_daemon_error) } @@ -538,16 +579,6 @@ impl ManagementService for ManagementServiceImpl { } } - async fn verify_wireguard_key(&self, _: Request<()>) -> ServiceResult<bool> { - log::debug!("verify_wireguard_key"); - let (tx, rx) = oneshot::channel(); - self.send_command_to_daemon(DaemonCommand::VerifyWireguardKey(tx))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_daemon_error) - } - // Split tunneling // @@ -832,11 +863,11 @@ impl EventListener for ManagementInterfaceEventBroadcaster { }) } - fn notify_key_event(&self, key_event: mullvad_types::wireguard::KeygenEvent) { - log::debug!("Broadcasting new wireguard key event"); + fn notify_device_event(&self, device: mullvad_types::device::DeviceEvent) { + log::debug!("Broadcasting device event"); self.notify(types::DaemonEvent { - event: Some(daemon_event::Event::KeyEvent(types::KeygenEvent::from( - key_event, + event: Some(daemon_event::Event::Device(types::DeviceEvent::from( + device, ))), }) } diff --git a/mullvad-daemon/src/relays/mod.rs b/mullvad-daemon/src/relays/mod.rs index 332ca5fea3..c4a136369c 100644 --- a/mullvad-daemon/src/relays/mod.rs +++ b/mullvad-daemon/src/relays/mod.rs @@ -276,7 +276,6 @@ impl RelaySelector { relay_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> Result<RelaySelectorResult, Error> { match relay_constraints.tunnel_protocol { Constraint::Only(TunnelType::OpenVpn) => self.get_openvpn_endpoint( @@ -293,12 +292,9 @@ impl RelaySelector { &relay_constraints.wireguard_constraints, retry_attempt, ), - Constraint::Any => self.get_any_tunnel_endpoint( - relay_constraints, - bridge_state, - retry_attempt, - wg_key_exists, - ), + Constraint::Any => { + self.get_any_tunnel_endpoint(relay_constraints, bridge_state, retry_attempt) + } } } @@ -479,14 +475,9 @@ impl RelaySelector { relay_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> Result<RelaySelectorResult, Error> { - let preferred_constraints = self.preferred_constraints( - &relay_constraints, - bridge_state, - retry_attempt, - wg_key_exists, - ); + let preferred_constraints = + self.preferred_constraints(&relay_constraints, bridge_state, retry_attempt); let original_matcher: RelayMatcher<_> = relay_constraints.clone().into(); let preferred_tunnel_protocol = preferred_constraints.tunnel_protocol; @@ -543,14 +534,12 @@ impl RelaySelector { original_constraints: &RelayConstraints, bridge_state: BridgeState, retry_attempt: u32, - wg_key_exists: bool, ) -> RelayConstraints { let (preferred_port, preferred_protocol, preferred_tunnel) = self .preferred_tunnel_constraints( retry_attempt, &original_constraints.location, &original_constraints.providers, - wg_key_exists, ); let mut relay_constraints = original_constraints.clone(); @@ -731,7 +720,6 @@ impl RelaySelector { retry_attempt: u32, location_constraint: &Constraint<LocationConstraint>, providers_constraint: &Constraint<Providers>, - wg_key_exists: bool, ) -> (Constraint<u16>, TransportProtocol, TunnelType) { #[cfg(target_os = "windows")] { @@ -757,7 +745,7 @@ impl RelaySelector { }); // If location does not support WireGuard, defer to preferred OpenVPN tunnel // constraints - if !location_supports_wireguard || !wg_key_exists { + if !location_supports_wireguard { let (preferred_port, preferred_protocol) = Self::preferred_openvpn_constraints(retry_attempt); return (preferred_port, preferred_protocol, TunnelType::OpenVpn); @@ -1159,7 +1147,7 @@ mod test { }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) @@ -1167,7 +1155,7 @@ mod test { for attempt in 0..10 { assert!(relay_selector - .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .is_ok()); } @@ -1184,7 +1172,7 @@ mod test { }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1192,7 +1180,7 @@ mod test { for attempt in 0..10 { assert!(relay_selector - .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .is_ok()); } @@ -1205,7 +1193,6 @@ mod test { &relay_constraints, BridgeState::Off, attempt, - true, ); assert_eq!( preferred.tunnel_protocol, @@ -1215,7 +1202,6 @@ mod test { &relay_constraints, BridgeState::Off, attempt, - true, ) { Ok(result) if matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)) => (), _ => panic!("OpenVPN endpoint was not selected"), @@ -1250,14 +1236,14 @@ mod test { // The same host cannot be used for entry and exit assert!(relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .is_err()); relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location2); // If the entry and exit differ, this should succeed assert!(relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .is_ok()); } @@ -1286,7 +1272,7 @@ mod test { // The exit must not equal the entry let exit_relay = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .map_err(|error| error.to_string())? .exit_relay; @@ -1301,7 +1287,7 @@ mod test { endpoint, .. } = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .map_err(|error| error.to_string())?; assert_eq!(exit_relay.hostname, specific_hostname); @@ -1336,7 +1322,7 @@ mod test { }); let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1362,7 +1348,7 @@ mod test { ..RelayConstraints::default() }; let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) @@ -1381,14 +1367,14 @@ mod test { #[cfg(all(unix, not(target_os = "android")))] { let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::Wireguard) ); } let preferred = - relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2, true); + relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2); assert_eq!( preferred.tunnel_protocol, Constraint::Only(TunnelType::OpenVpn) @@ -1405,54 +1391,6 @@ mod test { } #[test] - fn test_wg_relay_with_no_key() { - let mut relay_constraints = RelayConstraints { - tunnel_protocol: Constraint::Only(TunnelType::Wireguard), - ..RelayConstraints::default() - }; - - let relay_selector = new_relay_selector(); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect("Failed to get WireGuard relay when WireGuard relay was specified as the only tunnel protocol"); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - - relay_constraints.tunnel_protocol = Constraint::Any; - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect("Failed to get OpenVPN relay with tunnel protocol constraint set to Any and without a WireGuard key"); - - assert!(matches!(result.endpoint, MullvadEndpoint::OpenVpn(_))); - - let wireguard_specific_location = LocationConstraint::Hostname( - "se".to_string(), - "got".to_string(), - "se9-wireguard".to_string(), - ); - relay_constraints.location = Constraint::Only(wireguard_specific_location); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false) - .expect( - "Failed to get a valid WireGuard relay when tunnel constraints are set to any - tunnel protocol and with a wireguard specific location without a wireguard key", - ); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - - let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) - .expect( - "Failed to get a valid WireGuard relay when tunnel constraints are set to any - tunnel protocol and with a wireguard specific location with a wireguard key", - ); - - assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_))); - } - - #[test] fn test_selecting_any_relay_will_consider_multihop() { let relay_constraints = RelayConstraints { wireguard_constraints: WireguardConstraints { @@ -1467,7 +1405,7 @@ mod test { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection"); // Windows will ignore WireGuard until WireGuard is supported well enough // TODO: Remove this caveat once Windows defaults to using WireGuard @@ -1502,7 +1440,7 @@ mod test { fn test_selecting_wireguard_location_will_consider_multihop() { let relay_selector = new_relay_selector(); - let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, true) + let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0) .expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection"); @@ -1526,7 +1464,7 @@ mod test { let relay_selector = new_relay_selector(); let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get WireGuard TCP multihop relay"); assert!(result.entry_relay.is_some()); @@ -1555,7 +1493,7 @@ mod test { let relay_selector = new_relay_selector(); let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0) .expect("Failed to get WireGuard TCP relay"); let endpoint = result.endpoint.unwrap_wireguard(); assert!(matches!(endpoint.peer.protocol, TransportProtocol::Tcp)); @@ -1570,7 +1508,7 @@ mod test { const INVALID_UDP_PORTS: [u16; 2] = [80, 443]; for attempt in 0..1000 { let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .expect("Failed to get WireGuard TCP multihop relay"); assert!(!INVALID_UDP_PORTS.contains(&result.endpoint.to_endpoint().address.port())); assert_eq!( @@ -1587,7 +1525,7 @@ mod test { const VALID_TCP_PORTS: [u16; 3] = [80, 443, 5001]; for attempt in 0..1000 { let result = relay_selector - .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true) + .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt) .expect("Failed to get WireGuard TCP multihop relay"); assert!(VALID_TCP_PORTS.contains(&result.endpoint.to_endpoint().address.port())); assert_eq!( @@ -1609,7 +1547,7 @@ mod test { ..RelayConstraints::default() }; relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&constraints, BridgeState::Off, 0) .expect_err("Successfully selected a relay that should be filtered"); constraints.location = Constraint::Only(LocationConstraint::Hostname( @@ -1619,7 +1557,7 @@ mod test { )); relay_selector - .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true) + .get_tunnel_endpoint(&constraints, BridgeState::Off, 0) .expect_err("Successfully selected a relay that should be filtered"); } } diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index ec610f63d4..bf3fe710c8 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -3,7 +3,7 @@ use futures::TryFutureExt; use mullvad_types::{ relay_constraints::{BridgeSettings, BridgeState, RelaySettingsUpdate}, settings::{DnsOptions, Settings}, - wireguard::{RotationInterval, WireguardData}, + wireguard::RotationInterval, }; #[cfg(target_os = "windows")] use std::collections::HashSet; @@ -191,21 +191,6 @@ impl SettingsPersister { settings } - /// Changes account number to the one given. Also saves the new settings to disk. - /// The boolean in the Result indicates if the account token changed or not - pub async fn set_account_token( - &mut self, - account_token: Option<String>, - ) -> Result<bool, Error> { - let should_save = self.settings.set_account_token(account_token); - self.update(should_save).await - } - - pub async fn set_wireguard(&mut self, wireguard: Option<WireguardData>) -> Result<bool, Error> { - let should_save = self.settings.set_wireguard(wireguard); - self.update(should_save).await - } - pub async fn update_relay_settings( &mut self, update: RelaySettingsUpdate, diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs deleted file mode 100644 index eb198b858b..0000000000 --- a/mullvad-daemon/src/wireguard.rs +++ /dev/null @@ -1,499 +0,0 @@ -use crate::{DaemonEventSender, InternalDaemonEvent}; -use chrono::offset::Utc; -use mullvad_rpc::{ - availability::ApiAvailabilityHandle, - rest::{Error as RestError, MullvadRestHandle}, -}; -use mullvad_types::account::AccountToken; -pub use mullvad_types::wireguard::*; -use std::{future::Future, pin::Pin, time::Duration}; - -use futures::future::{abortable, AbortHandle}; -#[cfg(not(target_os = "android"))] -use talpid_core::future_retry::constant_interval; -use talpid_core::{ - future_retry::{retry_future, retry_future_n, ExponentialBackoff, Jittered}, - mpsc::Sender, -}; - -pub use talpid_types::net::wireguard::{ - ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters, -}; -use talpid_types::ErrorExt; - -/// How long to wait before starting key rotation -const ROTATION_START_DELAY: Duration = Duration::from_secs(60 * 3); - -/// How often to check whether the key has expired. -/// A short interval is used in case the computer is ever suspended. -const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60); - -const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4); -const RETRY_INTERVAL_FACTOR: u32 = 5; -const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); - -#[cfg(not(target_os = "android"))] -const SHORT_RETRY_INTERVAL: Duration = Duration::ZERO; - -const MAX_KEY_REMOVAL_RETRIES: usize = 2; - -#[derive(err_derive::Error, Debug)] -pub enum Error { - #[error(display = "Unexpected HTTP request error")] - RestError(#[error(source)] mullvad_rpc::rest::Error), - #[error(display = "API availability check was interrupted")] - ApiCheckError(#[error(source)] mullvad_rpc::availability::Error), - #[error(display = "Account already has maximum number of keys")] - TooManyKeys, -} - -pub type Result<T> = std::result::Result<T, Error>; - -pub struct KeyManager { - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - current_job: Option<AbortHandle>, - - abort_scheduler_tx: Option<AbortHandle>, - auto_rotation_interval: RotationInterval, -} - -impl KeyManager { - pub(crate) fn new( - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - ) -> Self { - Self { - daemon_tx, - availability_handle, - http_handle, - current_job: None, - abort_scheduler_tx: None, - auto_rotation_interval: RotationInterval::default(), - } - } - - /// Reset key rotation, cancelling the current one and starting a new one for the specified - /// account - pub async fn reset_rotation(&mut self, current_key: PublicKey, account_token: AccountToken) { - self.run_automatic_rotation(account_token, current_key) - .await - } - - /// Update automatic key rotation interval - /// Passing `None` for the interval will cause the default value to be used. - pub async fn set_rotation_interval( - &mut self, - current_key: PublicKey, - account_token: AccountToken, - auto_rotation_interval: Option<RotationInterval>, - ) { - self.auto_rotation_interval = auto_rotation_interval.unwrap_or_default(); - self.reset_rotation(current_key, account_token).await; - } - - /// Stop current key generation - pub fn reset(&mut self) { - if let Some(job) = self.current_job.take() { - job.abort() - } - } - - /// Generate a new private key - pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> { - self.reset(); - let private_key = PrivateKey::new_from_random(); - - self.push_future_generator(account, private_key, None)() - .await - .map_err(Self::map_rpc_error) - } - - /// Replace a key for an account synchronously - pub async fn replace_key( - &mut self, - account: AccountToken, - old_key: PublicKey, - ) -> Result<WireguardData> { - self.reset(); - - let new_key = PrivateKey::new_from_random(); - Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await - } - - /// Verifies whether a key is valid or not. - pub fn verify_wireguard_key( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<bool>> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - async move { - match rpc.get_wireguard_key(account, &key).await { - Ok(_) => Ok(true), - Err(mullvad_rpc::rest::Error::ApiError(status, _code)) - if status == mullvad_rpc::StatusCode::NOT_FOUND => - { - Ok(false) - } - Err(err) => Err(Self::map_rpc_error(err)), - } - } - } - - /// Removes a key from an account - #[cfg(not(target_os = "android"))] - pub fn remove_key( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<()>> { - self.remove_key_inner(account, key, constant_interval(SHORT_RETRY_INTERVAL), false) - } - - /// Removes a key from an account - pub fn remove_key_with_backoff( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - ) -> impl Future<Output = Result<()>> { - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - self.remove_key_inner(account, key, retry_strategy, true) - } - - fn remove_key_inner<D: Iterator<Item = Duration> + 'static>( - &self, - account: AccountToken, - key: talpid_types::net::wireguard::PublicKey, - retry_strategy: D, - offline_check: bool, - ) -> impl Future<Output = Result<()>> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - let api_handle = self.availability_handle.clone(); - let api_handle_2 = api_handle.clone(); - let future = retry_future_n( - move || { - let remove_key = rpc.remove_wireguard_key(account.clone(), key.clone()); - let wait_future = api_handle.wait_online(); - async move { - if offline_check { - let _ = wait_future.await; - } - remove_key.await - } - }, - move |result| match result { - Ok(_) => false, - Err(error) => Self::should_retry_removal(error, &api_handle_2), - }, - retry_strategy, - MAX_KEY_REMOVAL_RETRIES, - ); - async move { future.await.map_err(Self::map_rpc_error) } - } - - fn should_retry_removal(error: &RestError, api_handle: &ApiAvailabilityHandle) -> bool { - error.is_network_error() && !api_handle.get_state().is_offline() - } - - fn should_retry(error: &RestError) -> bool { - if let RestError::ApiError(_status, code) = &error { - code != mullvad_rpc::INVALID_ACCOUNT && code != mullvad_rpc::KEY_LIMIT_REACHED - } else { - true - } - } - - /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel. - pub async fn spawn_key_generation_task( - &mut self, - account: AccountToken, - timeout: Option<Duration>, - ) { - self.reset(); - let private_key = PrivateKey::new_from_random(); - - let error_tx = self.daemon_tx.clone(); - let error_account = account.clone(); - - let mut inner_future_generator = - self.push_future_generator(account.clone(), private_key, timeout); - - let availability_handle = self.availability_handle.clone(); - - let future_generator = move || { - let wait_available = availability_handle.wait_background(); - let fut = inner_future_generator(); - let error_tx = error_tx.clone(); - let error_account = error_account.clone(); - async move { - let error_account_copy = error_account.clone(); - wait_available.await.map_err(|error| { - let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( - error_account_copy, - Err(Error::ApiCheckError(error)), - ))); - false - })?; - let response = fut.await; - match response { - Ok(addresses) => Ok(addresses), - Err(err) => { - let should_retry = Self::should_retry(&err); - let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent(( - error_account, - Err(Self::map_rpc_error(err)), - ))); - Err(should_retry) - } - } - } - }; - - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - - let should_retry = move |result: &std::result::Result<_, bool>| -> bool { - match result { - Ok(_) => false, - Err(should_retry) => *should_retry, - } - }; - - let upload_future = retry_future(future_generator, should_retry, retry_strategy); - - let (cancellable_upload, abort_handle) = abortable(Box::pin(upload_future)); - let daemon_tx = self.daemon_tx.clone(); - let future = async move { - match cancellable_upload.await { - Ok(Ok(wireguard_data)) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account, - Ok(wireguard_data), - ))); - } - Ok(Err(_)) => {} - Err(_) => { - log::error!("Key generation cancelled"); - } - } - }; - - tokio::spawn(Box::pin(future)); - self.current_job = Some(abort_handle); - } - - fn push_future_generator( - &self, - account: AccountToken, - private_key: PrivateKey, - timeout: Option<Duration>, - ) -> Box< - dyn FnMut() -> Pin< - Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>, - > + Send, - > { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - let public_key = private_key.public_key(); - - let push_future = - move || -> std::pin::Pin<Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send >> { - let key = private_key.clone(); - let address_future = rpc - .push_wg_key(account.clone(), public_key.clone(), timeout); - Box::pin(async move { - let addresses = address_future.await?; - Ok(WireguardData { - private_key: key, - addresses, - created: Utc::now(), - }) - }) - }; - Box::new(push_future) - } - - async fn replace_key_rpc( - http_handle: MullvadRestHandle, - account: AccountToken, - old_key: PublicKey, - new_key: PrivateKey, - ) -> Result<WireguardData> { - let mut rpc = mullvad_rpc::WireguardKeyProxy::new(http_handle); - let new_public_key = new_key.public_key(); - let addresses = rpc - .replace_wg_key(account, old_key.key, new_public_key) - .await - .map_err(Self::map_rpc_error)?; - Ok(WireguardData { - private_key: new_key, - addresses, - created: Utc::now(), - }) - } - - fn map_rpc_error(err: mullvad_rpc::rest::Error) -> Error { - match &err { - // TODO: Consider handling the invalid account case too. - mullvad_rpc::rest::Error::ApiError(status, message) - if *status == mullvad_rpc::StatusCode::BAD_REQUEST - && message == mullvad_rpc::KEY_LIMIT_REACHED => - { - Error::TooManyKeys - } - _ => Error::RestError(err), - } - } - - async fn wait_for_key_expiry(key: &PublicKey, rotation_interval_secs: u64) { - let mut interval = tokio::time::interval(KEY_CHECK_INTERVAL); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - interval.tick().await; - if (Utc::now().signed_duration_since(key.created)).num_seconds() as u64 - >= rotation_interval_secs - { - return; - } - } - } - - async fn create_automatic_rotation( - daemon_tx: DaemonEventSender, - availability_handle: ApiAvailabilityHandle, - http_handle: MullvadRestHandle, - mut public_key: PublicKey, - rotation_interval_secs: u64, - account_token: AccountToken, - ) { - tokio::time::sleep(ROTATION_START_DELAY).await; - - let rotate_key_for_account = - move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> { - let wait_available = availability_handle.wait_background(); - let rotate = Self::rotate_key( - daemon_tx.clone(), - http_handle.clone(), - account_token.clone(), - old_key.clone(), - ); - Box::pin(async move { - wait_available.await?; - rotate.await - }) - }; - - loop { - Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await; - - let rotate_key_for_account_copy = rotate_key_for_account.clone(); - match Self::rotate_key_with_retries(public_key.clone(), rotate_key_for_account_copy) - .await - { - Ok(new_key) => public_key = new_key, - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg( - "Stopping automatic key rotation due to an error" - ) - ); - return; - } - } - } - } - - fn rotate_key( - daemon_tx: DaemonEventSender, - http_handle: MullvadRestHandle, - account_token: AccountToken, - old_key: PublicKey, - ) -> impl Future<Output = Result<PublicKey>> { - let new_key = PrivateKey::new_from_random(); - let rpc_result = - Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key); - - async move { - match rpc_result.await { - Ok(data) => { - // Update account data - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token, - Ok(data.clone()), - ))); - Ok(data.get_public_key()) - } - Err(Error::TooManyKeys) => { - let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent(( - account_token, - Err(Error::TooManyKeys), - ))); - Err(Error::TooManyKeys) - } - Err(unknown) => Err(unknown), - } - } - } - - async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey> - where - F: FnMut(&PublicKey) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> - + Clone - + 'static, - { - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR) - .max_delay(RETRY_INTERVAL_MAX), - ); - let should_retry = move |result: &Result<PublicKey>| -> bool { - match result { - Ok(_) => false, - Err(error) => match error { - Error::RestError(error) => Self::should_retry(error), - _ => false, - }, - } - }; - - retry_future( - move || rotate_key.clone()(&old_key), - should_retry, - retry_strategy, - ) - .await - } - - async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) { - self.stop_automatic_rotation(); - - log::debug!("Starting automatic key rotation job"); - // Schedule cancellable series of repeating rotation tasks - let fut = Self::create_automatic_rotation( - self.daemon_tx.clone(), - self.availability_handle.clone(), - self.http_handle.clone(), - public_key, - self.auto_rotation_interval.as_duration().as_secs(), - account_token, - ); - let (request, abort_handle) = abortable(Box::pin(fut)); - - tokio::spawn(request); - self.abort_scheduler_tx = Some(abort_handle); - } - - fn stop_automatic_rotation(&mut self) { - if let Some(abort_handle) = self.abort_scheduler_tx.take() { - log::info!("Stopping automatic key rotation"); - abort_handle.abort(); - } - } -} diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index e690557aae..701a2e267e 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -44,19 +44,24 @@ service ManagementService { // Account management rpc CreateNewAccount(google.protobuf.Empty) returns (google.protobuf.StringValue) {} - rpc SetAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc LoginAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc LogoutAccount(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetAccountData(google.protobuf.StringValue) returns (AccountData) {} rpc GetAccountHistory(google.protobuf.Empty) returns (AccountHistory) {} rpc ClearAccountHistory(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetWwwAuthToken(google.protobuf.Empty) returns (google.protobuf.StringValue) {} rpc SubmitVoucher(google.protobuf.StringValue) returns (VoucherSubmission) {} + // Device management + rpc GetDevice(google.protobuf.Empty) returns (DeviceConfig) {} + rpc ListDevices(google.protobuf.StringValue) returns (DeviceList) {} + rpc RemoveDevice(DeviceRemoval) returns (google.protobuf.Empty) {} + // WireGuard key management rpc SetWireguardRotationInterval(google.protobuf.Duration) returns (google.protobuf.Empty) {} rpc ResetWireguardRotationInterval(google.protobuf.Empty) returns (google.protobuf.Empty) {} - rpc GenerateWireguardKey(google.protobuf.Empty) returns (KeygenEvent) {} + rpc RotateWireguardKey(google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {} - rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {} // Split tunneling (Linux) rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {} @@ -265,16 +270,15 @@ message BridgeState { } message Settings { - string account_token = 1; - RelaySettings relay_settings = 2; - BridgeSettings bridge_settings = 3; - BridgeState bridge_state = 4; - bool allow_lan = 5; - bool block_when_disconnected = 6; - bool auto_connect = 7; - TunnelOptions tunnel_options = 8; - bool show_beta_releases = 9; - SplitTunnelSettings split_tunnel = 10; + RelaySettings relay_settings = 1; + BridgeSettings bridge_settings = 2; + BridgeState bridge_state = 3; + bool allow_lan = 4; + bool block_when_disconnected = 5; + bool auto_connect = 6; + TunnelOptions tunnel_options = 7; + bool show_beta_releases = 8; + SplitTunnelSettings split_tunnel = 9; } message SplitTunnelSettings { @@ -521,10 +525,33 @@ message DaemonEvent { Settings settings = 2; RelayList relay_list = 3; AppVersionInfo version_info = 4; - KeygenEvent key_event = 5; + DeviceEvent device = 5; } } message RelayList { repeated RelayListCountry countries = 1; } + +message DeviceConfig { + string account_token = 1; + Device device = 2; +} + +message Device { + string id = 1; + string name = 2; +} + +message DeviceList { + repeated Device devices = 1; +} + +message DeviceRemoval { + string account_token = 1; + string device_id = 2; +} + +message DeviceEvent { + Device device = 1; +} diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 5398927569..eb6dbf6a31 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -197,6 +197,43 @@ impl From<mullvad_types::states::TunnelState> for TunnelState { } } +impl From<mullvad_types::device::Device> for Device { + fn from(device: mullvad_types::device::Device) -> Self { + Device { + id: device.id, + name: device.name, + } + } +} + +impl From<mullvad_types::device::DeviceEvent> for DeviceEvent { + fn from(event: mullvad_types::device::DeviceEvent) -> Self { + DeviceEvent { + device: event.0.map(|device| Device::from(device)), + } + } +} + +impl From<mullvad_types::device::DeviceData> for DeviceConfig { + fn from(device: mullvad_types::device::DeviceData) -> Self { + DeviceConfig { + account_token: device.token, + device: Some(Device::from(device.device)), + } + } +} + +impl From<Vec<mullvad_types::device::Device>> for DeviceList { + fn from(devices: Vec<mullvad_types::device::Device>) -> Self { + DeviceList { + devices: devices + .into_iter() + .map(|device| Device::from(device)) + .collect(), + } + } +} + impl From<mullvad_types::wireguard::KeygenEvent> for KeygenEvent { fn from(event: mullvad_types::wireguard::KeygenEvent) -> Self { use keygen_event::KeygenEvent as Event; @@ -387,7 +424,6 @@ impl From<&mullvad_types::settings::Settings> for Settings { let split_tunnel = None; Self { - account_token: settings.get_account_token().unwrap_or_default(), relay_settings: Some(RelaySettings::from(settings.get_relay_settings())), bridge_settings: Some(BridgeSettings::from(settings.bridge_settings.clone())), bridge_state: Some(BridgeState::from(settings.get_bridge_state())), diff --git a/mullvad-rpc/src/access.rs b/mullvad-rpc/src/access.rs new file mode 100644 index 0000000000..b58ceee809 --- /dev/null +++ b/mullvad-rpc/src/access.rs @@ -0,0 +1,108 @@ +use crate::{ + rest, + rest::{RequestFactory, RequestServiceHandle}, +}; +use hyper::StatusCode; +use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; +use talpid_types::ErrorExt; + +pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1"; + +#[derive(Clone)] +pub struct AccessTokenProxy { + service: RequestServiceHandle, + factory: RequestFactory, + access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>, +} + +impl AccessTokenProxy { + pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { + Self { + service, + factory, + access_from_account: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Obtain access token for an account, requesting a new one from the API if necessary. + pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> { + let existing_token = { + self.access_from_account + .lock() + .unwrap() + .get(account.as_str()) + .cloned() + }; + if let Some(access_token) = existing_token { + if access_token.is_expired() { + log::debug!("Replacing expired access token"); + return self.request_new_token(account.clone()).await; + } + log::trace!("Using stored access token"); + return Ok(access_token.access_token.clone()); + } + self.request_new_token(account.clone()).await + } + + /// Remove an access token if the API response calls for it. + pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) { + if let Err(rest::Error::ApiError(_status, code)) = response { + if code == crate::INVALID_ACCESS_TOKEN { + log::debug!("Dropping invalid access token"); + self.remove_token(account); + } + } + } + + /// Removes a stored access token. + fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> { + self.access_from_account + .lock() + .unwrap() + .remove(account) + .map(|v| v.access_token) + } + + async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> { + log::debug!("Fetching access token for an account"); + let access_token = self + .fetch_access_token(account.clone()) + .await + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to obtain access token") + ); + error + })?; + self.access_from_account + .lock() + .unwrap() + .insert(account, access_token.clone()); + Ok(access_token.access_token) + } + + async fn fetch_access_token( + &self, + account_token: AccountToken, + ) -> Result<AccessTokenData, rest::Error> { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_token: String, + } + let request = AccessTokenRequest { account_token }; + + let service = self.service.clone(); + + let rest_request = self + .factory + .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await + } +} diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs index 614aa3bdb6..a49e392320 100644 --- a/mullvad-rpc/src/lib.rs +++ b/mullvad-rpc/src/lib.rs @@ -7,6 +7,7 @@ use futures::Stream; use hyper::Method; use mullvad_types::{ account::{AccountToken, VoucherSubmission}, + device::{Device, DeviceId, DeviceName}, version::AppVersion, }; use proxy::ApiConnectionMode; @@ -29,6 +30,7 @@ mod tls_stream; #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; +mod access; mod address_cache; mod relay_list; pub use address_cache::AddressCache; @@ -44,11 +46,17 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER"; /// Error code returned by the Mullvad API if the account token is invalid. pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT"; -/// Error code returned by the Mullvad API if the account token is missing or invalid. -pub const INVALID_AUTH: &str = "INVALID_AUTH"; +/// Error code returned by the Mullvad API if the access token is invalid. +pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN"; + +pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED"; +pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE"; pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt"; +const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1"; +const APP_URL_PREFIX: &str = "app/v1"; + lazy_static::lazy_static! { static ref API: ApiEndpoint = ApiEndpoint::get(); } @@ -257,7 +265,7 @@ impl MullvadRpcRuntime { self.socket_bypass_tx.clone(), ) .await; - let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned())); + let factory = rest::RequestFactory::new(API.host.clone(), None); rest::MullvadRestHandle::new( service, @@ -296,7 +304,7 @@ pub struct AccountsProxy { #[derive(serde::Deserialize)] struct AccountResponse { token: AccountToken, - expires: DateTime<Utc>, + expiry: DateTime<Utc>, } impl AccountsProxy { @@ -309,18 +317,21 @@ impl AccountsProxy { account: AccountToken, ) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> { let service = self.handle.service.clone(); - - let response = rest::send_request( - &self.handle.factory, - service, - "/v1/me", - Method::GET, - Some(account), - &[StatusCode::OK], - ); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); async move { - let account: AccountResponse = rest::deserialize_body(response.await?).await?; - Ok(account.expires) + let response = rest::send_request( + &factory, + service, + &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let account: AccountResponse = rest::deserialize_body(response?).await?; + Ok(account.expiry) } } @@ -329,7 +340,7 @@ impl AccountsProxy { let response = rest::send_request( &self.handle.factory, service, - "/v1/accounts", + &format!("{}/accounts", ACCOUNTS_URL_PREFIX), Method::POST, None, &[StatusCode::CREATED], @@ -352,18 +363,23 @@ impl AccountsProxy { } let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); let submission = VoucherSubmission { voucher_code }; - let response = rest::post_request_with_json( - &self.handle.factory, - service, - "/v1/submit-voucher", - &submission, - Some(account_token), - &[StatusCode::OK], - ); - - async move { rest::deserialize_body(response.await?).await } + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/submit-voucher", APP_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account_token)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } } pub fn get_www_auth_token( @@ -376,22 +392,206 @@ impl AccountsProxy { } let service = self.handle.service.clone(); - let response = rest::send_request( - &self.handle.factory, - service, - "/v1/www-auth-token", - Method::POST, - Some(account), - &[StatusCode::OK], - ); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); async move { - let response: AuthTokenResponse = rest::deserialize_body(response.await?).await?; + let response = rest::send_request( + &factory, + service, + &format!("{}/www-auth-token", APP_URL_PREFIX), + Method::POST, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + let response: AuthTokenResponse = rest::deserialize_body(response?).await?; Ok(response.auth_token) } } } +#[derive(Clone)] +pub struct DevicesProxy { + handle: rest::MullvadRestHandle, +} + +#[derive(serde::Deserialize)] +struct DeviceResponse { + id: DeviceId, + name: DeviceName, + ipv4_address: ipnetwork::Ipv4Network, + ipv6_address: ipnetwork::Ipv6Network, +} + +impl DevicesProxy { + pub fn new(handle: rest::MullvadRestHandle) -> Self { + Self { handle } + } + + pub fn create( + &self, + account: AccountToken, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>> + { + #[derive(serde::Serialize)] + struct DeviceSubmission { + pubkey: wireguard::PublicKey, + kind: String, + } + + let submission = DeviceSubmission { + pubkey, + // TODO: constant + kind: "App".to_string(), + }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::POST, + &submission, + Some((access_proxy, account)), + &[StatusCode::CREATED], + ) + .await; + + let response: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + id, + name, + ipv4_address, + ipv6_address, + .. + } = response; + + Ok(( + Device { id, name }, + mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }, + )) + } + } + + pub fn get( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<Device, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn list( + &self, + account: AccountToken, + ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices", ACCOUNTS_URL_PREFIX), + Method::GET, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + rest::deserialize_body(response?).await + } + } + + pub fn remove( + &self, + account: AccountToken, + id: DeviceId, + ) -> impl Future<Output = Result<(), rest::Error>> { + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + async move { + let response = rest::send_request( + &factory, + service, + &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id), + Method::DELETE, + Some((access_proxy, account)), + &[StatusCode::NO_CONTENT], + ) + .await; + + response?; + Ok(()) + } + } + + pub fn replace_wg_key( + &self, + account: AccountToken, + id: DeviceId, + pubkey: wireguard::PublicKey, + ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>> + { + #[derive(serde::Serialize)] + struct RotateDevicePubkey { + pubkey: wireguard::PublicKey, + } + let req_body = RotateDevicePubkey { pubkey }; + + let service = self.handle.service.clone(); + let factory = self.handle.factory.clone(); + let access_proxy = self.handle.token_store.clone(); + + async move { + let response = rest::send_json_request( + &factory, + service, + &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id), + Method::PUT, + &req_body, + Some((access_proxy, account)), + &[StatusCode::OK], + ) + .await; + + let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; + let DeviceResponse { + ipv4_address, + ipv6_address, + .. + } = updated_device; + Ok(mullvad_types::wireguard::AssociatedAddresses { + ipv4_address, + ipv6_address, + }) + } + } +} + pub struct ProblemReportProxy { handle: rest::MullvadRestHandle, } @@ -425,10 +625,11 @@ impl ProblemReportProxy { let service = self.handle.service.clone(); - let request = rest::post_request_with_json( + let request = rest::send_json_request( &self.handle.factory, service, - "/v1/problem-report", + &format!("{}/problem-report", APP_URL_PREFIX), + Method::POST, &report, None, &[StatusCode::NO_CONTENT], @@ -467,7 +668,7 @@ impl AppVersionProxy { ) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> { let service = self.handle.service.clone(); - let path = format!("/v1/releases/{}/{}", platform, app_version); + let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version); let request = self.handle.factory.request(&path, Method::GET); async move { @@ -508,7 +709,10 @@ impl WireguardKeyProxy { let service = self.handle.service.clone(); let body = PublishRequest { pubkey: public_key }; - let request = self.handle.factory.post_json(&"/v1/wireguard-keys", &body); + let request = self + .handle + .factory + .post_json(&"app/v1/wireguard-keys", &body); async move { let mut request = request?; if let Some(timeout) = timeout { @@ -538,10 +742,11 @@ impl WireguardKeyProxy { let service = self.handle.service.clone(); let body = ReplacementRequest { old, new }; - let response = rest::post_request_with_json( + let response = rest::send_json_request( &self.handle.factory, service, - &"/v1/replace-wireguard-key", + &"app/v1/replace-wireguard-key", + Method::POST, &body, Some(account_token), [StatusCode::CREATED, StatusCode::OK].as_slice(), @@ -562,7 +767,7 @@ impl WireguardKeyProxy { &self.handle.factory, service, &format!( - "/v1/wireguard-keys/{}", + "app/v1/wireguard-keys/{}", urlencoding::encode(&key.to_base64()) ), Method::GET, @@ -584,7 +789,7 @@ impl WireguardKeyProxy { &self.handle.factory, service, &format!( - "/v1/wireguard-keys/{}", + "app/v1/wireguard-keys/{}", urlencoding::encode(&key.to_base64()) ), Method::DELETE, @@ -614,7 +819,7 @@ impl ApiProxy { let response = rest::send_request( &self.handle.factory, service, - "/v1/api-addrs", + &format!("{}/api-addrs", APP_URL_PREFIX), Method::GET, None, &[StatusCode::OK], diff --git a/mullvad-rpc/src/relay_list.rs b/mullvad-rpc/src/relay_list.rs index f1ed2217fd..5a8a01836f 100644 --- a/mullvad-rpc/src/relay_list.rs +++ b/mullvad-rpc/src/relay_list.rs @@ -13,7 +13,7 @@ use std::{ time::Duration, }; -/// Fetches relay list from <https://api.mullvad.net/v1/relays> +/// Fetches relay list from https://api.mullvad.net/app/v1/relays #[derive(Clone)] pub struct RelayListProxy { handle: rest::MullvadRestHandle, @@ -33,7 +33,7 @@ impl RelayListProxy { etag: Option<String>, ) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> { let service = self.handle.service.clone(); - let request = self.handle.factory.request("/v1/relays", Method::GET); + let request = self.handle.factory.request("app/v1/relays", Method::GET); let future = async move { let mut request = request?; diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 17362cce05..c7e5d02fb1 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -1,6 +1,7 @@ #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ + access::AccessTokenProxy, address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, @@ -17,6 +18,7 @@ use hyper::{ header::{self, HeaderValue}, Method, Uri, }; +use mullvad_types::account::AccountToken; use std::{ future::Future, str::FromStr, @@ -302,11 +304,11 @@ impl RestRequest { }) } - /// Set the auth header with the following format: `Token $auth`. + /// Set the auth header with the following format: `Bearer $auth`. pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> { let header = match auth { Some(auth) => Some( - HeaderValue::from_str(&format!("Token {}", auth)) + HeaderValue::from_str(&format!("Bearer {}", auth)) .map_err(Error::InvalidHeaderError)?, ), None => None, @@ -399,7 +401,16 @@ impl RequestFactory { } pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> { - let mut request = self.hyper_request(path, Method::POST)?; + self.json_request(Method::POST, path, body) + } + + fn json_request<S: serde::Serialize>( + &self, + method: Method, + path: &str, + body: &S, + ) -> Result<RestRequest> { + let mut request = self.hyper_request(path, method)?; let json_body = serde_json::to_string(&body)?; let body_length = json_body.as_bytes().len() as u64; @@ -468,33 +479,52 @@ pub fn send_request( service: RequestServiceHandle, uri: &str, method: Method, - auth: Option<String>, + auth: Option<(AccessTokenProxy, AccountToken)>, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future<Output = Result<Response>> { let request = factory.request(uri, method); async move { let mut request = request?; - request.set_auth(auth)?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result } } -pub fn post_request_with_json<B: serde::Serialize>( +pub fn send_json_request<B: serde::Serialize>( factory: &RequestFactory, service: RequestServiceHandle, uri: &str, + method: Method, body: &B, - auth: Option<String>, + auth: Option<(AccessTokenProxy, AccountToken)>, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future<Output = Result<Response>> { - let request = factory.post_json(uri, body); + let request = factory.json_request(method, uri, body); async move { let mut request = request?; - request.set_auth(auth)?; + if let Some((store, account)) = &auth { + let access_token = store.get_token(&account).await?; + request.set_auth(Some(access_token))?; + } let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await + let result = parse_rest_response(response, expected_statuses).await; + + if let Some((store, account)) = &auth { + store.check_response(&account, &result); + } + + result } } @@ -554,6 +584,7 @@ pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, availability: ApiAvailabilityHandle, + pub token_store: AccessTokenProxy, } impl MullvadRestHandle { @@ -563,10 +594,13 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { + let token_store = AccessTokenProxy::new(service.clone(), factory.clone()); + let handle = Self { service, factory, availability, + token_store, }; if !super::API.disable_address_cache { handle.spawn_api_address_fetcher(address_cache); diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index e65b1278f8..37061c8854 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -72,8 +72,11 @@ pub enum Error { #[error(display = "Failed to obtain cache directory path")] CachePathError(#[error(source)] mullvad_paths::Error), - #[error(display = "Failed to update the settings")] - SettingsError(#[error(source)] mullvad_daemon::settings::Error), + #[error(display = "Failed to read the device cache")] + ReadDeviceCacheError(#[error(source)] mullvad_daemon::device::Error), + + #[error(display = "Failed to write the device cache")] + WriteDeviceCacheError(#[error(source)] mullvad_daemon::device::Error), #[error(display = "Cannot parse the version string")] ParseVersionStringError, @@ -161,41 +164,40 @@ async fn reset_firewall() -> Result<(), Error> { async fn remove_wireguard_key() -> Result<(), Error> { let (cache_path, settings_path) = get_paths()?; - let mut settings = mullvad_daemon::settings::SettingsPersister::load(&settings_path).await; + let (mut cacher, data) = mullvad_daemon::device::DeviceCacher::new(&settings_path) + .await + .map_err(Error::ReadDeviceCacheError)?; + if let Some(device) = data { + let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false) + .await + .map_err(Error::RpcInitializationError)?; - if let Some(token) = settings.get_account_token() { - if let Some(wg_data) = settings.get_wireguard() { - let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false) - .await - .map_err(Error::RpcInitializationError)?; - let mut key_proxy = mullvad_rpc::WireguardKeyProxy::new( - rpc_runtime - .mullvad_rest_handle( - ApiConnectionMode::try_from_cache(&cache_path) - .await - .into_repeat(), - |_| async { true }, - ) - .await, - ); - retry_future_n( - move || { - key_proxy.remove_wireguard_key(token.clone(), wg_data.private_key.public_key()) - }, - move |result| match result { - Err(error) => error.is_network_error(), - _ => false, - }, - constant_interval(KEY_RETRY_INTERVAL), - KEY_RETRY_MAX_RETRIES, - ) + let proxy = mullvad_rpc::DevicesProxy::new( + rpc_runtime + .mullvad_rest_handle( + ApiConnectionMode::try_from_cache(&cache_path) + .await + .into_repeat(), + |_| async { true }, + ) + .await, + ); + retry_future_n( + move || proxy.remove(device.token.clone(), device.device.id.clone()), + move |result| match result { + Err(error) => error.is_network_error(), + _ => false, + }, + constant_interval(KEY_RETRY_INTERVAL), + KEY_RETRY_MAX_RETRIES, + ) + .await + .map_err(Error::RemoveKeyError)?; + + cacher + .write(None) .await - .map_err(Error::RemoveKeyError)?; - settings - .set_wireguard(None) - .await - .map_err(Error::SettingsError)?; - } + .map_err(Error::WriteDeviceCacheError)?; } Ok(()) diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs index b5479640e6..16f6a963f2 100644 --- a/mullvad-types/src/account.rs +++ b/mullvad-types/src/account.rs @@ -3,9 +3,12 @@ use chrono::{offset::Utc, DateTime}; use jnix::IntoJava; use serde::{Deserialize, Serialize}; -/// Identifier used to authenticate or identify a Mullvad account. +/// Identifier used to identify a Mullvad account. pub type AccountToken = String; +/// Identifier used to authenticate a Mullvad account. +pub type AccessToken = String; + /// Account expiration info returned by the API via `/v1/me`. #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[cfg_attr(target_os = "android", derive(IntoJava))] @@ -18,7 +21,7 @@ pub struct AccountData { impl AccountData { /// Return true if the account has no time left. pub fn is_expired(&self) -> bool { - self.expiry >= Utc::now() + Utc::now() >= self.expiry } } @@ -35,3 +38,17 @@ pub struct VoucherSubmission { #[cfg_attr(target_os = "android", jnix(map = "|expiry| expiry.to_string()"))] pub new_expiry: DateTime<Utc>, } + +/// Token used for authentication in the API. +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct AccessTokenData { + pub access_token: AccessToken, + pub expiry: DateTime<Utc>, +} + +impl AccessTokenData { + /// Return true if the token is no longer valid. + pub fn is_expired(&self) -> bool { + Utc::now() >= self.expiry + } +} diff --git a/mullvad-types/src/device.rs b/mullvad-types/src/device.rs new file mode 100644 index 0000000000..e40a3d7080 --- /dev/null +++ b/mullvad-types/src/device.rs @@ -0,0 +1,37 @@ +use crate::{account::AccountToken, wireguard}; +use serde::{Deserialize, Serialize}; +use talpid_types::net::wireguard::PublicKey; + +/// UUID for a device. +pub type DeviceId = String; + +/// Human-readable device identifier. +pub type DeviceName = String; + +/// Contains data for a device returned by the API. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct Device { + pub id: DeviceId, + pub name: DeviceName, + pub pubkey: PublicKey, +} + +impl Eq for Device {} + +/// A complete device configuration. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct DeviceData { + pub token: AccountToken, + pub device: Device, + pub wg_data: wireguard::WireguardData, +} + +impl From<DeviceData> for Device { + fn from(data: DeviceData) -> Device { + data.device + } +} + +/// Emitted when logging in or out of an account, or when the device changes. +#[derive(Clone, Debug)] +pub struct DeviceEvent(pub Option<Device>); diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs index e93ab2f606..6d636aceb5 100644 --- a/mullvad-types/src/lib.rs +++ b/mullvad-types/src/lib.rs @@ -2,6 +2,7 @@ pub mod account; pub mod auth_failed; +pub mod device; pub mod endpoint; pub mod location; pub mod relay_constraints; diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs index 26a24202a5..63ccb480a2 100644 --- a/mullvad-types/src/settings/mod.rs +++ b/mullvad-types/src/settings/mod.rs @@ -61,9 +61,6 @@ impl Serialize for SettingsVersion { #[cfg_attr(target_os = "android", derive(IntoJava))] #[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] pub struct Settings { - account_token: Option<String>, - #[cfg_attr(target_os = "android", jnix(skip))] - wireguard: Option<wireguard::WireguardData>, relay_settings: RelaySettings, #[cfg_attr(target_os = "android", jnix(skip))] pub bridge_settings: BridgeSettings, @@ -102,8 +99,6 @@ pub struct SplitTunnelSettings { impl Default for Settings { fn default() -> Self { Settings { - account_token: None, - wireguard: None, relay_settings: RelaySettings::Normal(RelayConstraints { location: Constraint::Only(LocationConstraint::Country("se".to_owned())), ..Default::default() @@ -123,45 +118,6 @@ impl Default for Settings { } impl Settings { - pub fn get_account_token(&self) -> Option<String> { - self.account_token.clone() - } - - /// Changes account number to the one given. Also saves the new settings to disk. - /// The boolean in the Result indicates if the account token changed or not - pub fn set_account_token(&mut self, mut account_token: Option<String>) -> bool { - if account_token.as_ref().map(String::len) == Some(0) { - log::debug!("Setting empty account token is treated as unsetting it"); - account_token = None; - } - if account_token != self.account_token { - if account_token.is_none() { - log::info!("Unsetting account token"); - } else if self.account_token.is_none() { - log::info!("Setting account token"); - } else { - log::info!("Changing account token") - } - self.account_token = account_token; - true - } else { - false - } - } - - pub fn get_wireguard(&self) -> Option<wireguard::WireguardData> { - self.wireguard.clone() - } - - pub fn set_wireguard(&mut self, wireguard: Option<wireguard::WireguardData>) -> bool { - if wireguard != self.wireguard { - self.wireguard = wireguard; - true - } else { - false - } - } - pub fn get_relay_settings(&self) -> RelaySettings { self.relay_settings.clone() } |
