summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-10-11 15:53:26 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-03-14 12:08:37 +0100
commitb98c366f8647b17c21bfffd903582a1dc09158fb (patch)
treefa8b4ba6265767f4a82afb8f9a50e4f3a6f57ecc
parent78dc4644a82d7b3fb904ef3cbac8a1f705f0a213 (diff)
downloadmullvadvpn-b98c366f8647b17c21bfffd903582a1dc09158fb.tar.xz
mullvadvpn-b98c366f8647b17c21bfffd903582a1dc09158fb.zip
Implement device concept
-rw-r--r--mullvad-cli/src/cmds/account.rs66
-rw-r--r--mullvad-cli/src/cmds/status.rs5
-rw-r--r--mullvad-cli/src/cmds/tunnel.rs11
-rw-r--r--mullvad-daemon/src/account.rs173
-rw-r--r--mullvad-daemon/src/device.rs748
-rw-r--r--mullvad-daemon/src/lib.rs645
-rw-r--r--mullvad-daemon/src/management_interface.rs89
-rw-r--r--mullvad-daemon/src/relays/mod.rs114
-rw-r--r--mullvad-daemon/src/settings.rs17
-rw-r--r--mullvad-daemon/src/wireguard.rs499
-rw-r--r--mullvad-management-interface/proto/management_interface.proto55
-rw-r--r--mullvad-management-interface/src/types.rs38
-rw-r--r--mullvad-rpc/src/access.rs108
-rw-r--r--mullvad-rpc/src/lib.rs293
-rw-r--r--mullvad-rpc/src/relay_list.rs4
-rw-r--r--mullvad-rpc/src/rest.rs56
-rw-r--r--mullvad-setup/src/main.rs72
-rw-r--r--mullvad-types/src/account.rs21
-rw-r--r--mullvad-types/src/device.rs37
-rw-r--r--mullvad-types/src/lib.rs1
-rw-r--r--mullvad-types/src/settings/mod.rs44
21 files changed, 1692 insertions, 1404 deletions
diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs
index 0bbbc28024..fae3b39396 100644
--- a/mullvad-cli/src/cmds/account.rs
+++ b/mullvad-cli/src/cmds/account.rs
@@ -16,23 +16,17 @@ impl Command for Account {
clap::App::new(self.name())
.about("Control and display information about your Mullvad account")
.setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::App::new("create").about("Create and log in to a new account"))
.subcommand(
- clap::App::new("set").about("Change account").arg(
+ clap::App::new("login").about("Log in to an account").arg(
clap::Arg::new("token")
.help("The Mullvad account token to configure the client with")
.required(false),
),
)
+ .subcommand(clap::App::new("logout").about("Log out of the current account"))
.subcommand(
- clap::App::new("get")
- .about("Display information about the currently configured account"),
- )
- .subcommand(
- clap::App::new("unset").about("Removes the account number from the settings"),
- )
- .subcommand(
- clap::App::new("create")
- .about("Creates a new account and sets it as the active one"),
+ clap::App::new("get").about("Display information about the current account"),
)
.subcommand(
clap::App::new("redeem").about("Redeems a voucher").arg(
@@ -44,7 +38,9 @@ impl Command for Account {
}
async fn run(&self, matches: &clap::ArgMatches) -> Result<()> {
- if let Some(set_matches) = matches.subcommand_matches("set") {
+ if let Some(_matches) = matches.subcommand_matches("create") {
+ self.create().await
+ } else if let Some(set_matches) = matches.subcommand_matches("login") {
let mut token = match set_matches.value_of("token") {
Some(token) => token.to_string(),
None => {
@@ -60,13 +56,11 @@ impl Command for Account {
}
};
token = token.split_whitespace().join("").to_string();
- self.set(Some(token)).await
+ self.login(token).await
+ } else if let Some(_matches) = matches.subcommand_matches("logout") {
+ self.logout().await
} else if let Some(_matches) = matches.subcommand_matches("get") {
self.get().await
- } else if let Some(_matches) = matches.subcommand_matches("unset") {
- self.set(None).await
- } else if let Some(_matches) = matches.subcommand_matches("create") {
- self.create().await
} else if let Some(matches) = matches.subcommand_matches("redeem") {
let voucher = matches.value_of_t_or_exit("voucher");
self.redeem_voucher(voucher).await
@@ -77,24 +71,35 @@ impl Command for Account {
}
impl Account {
- async fn set(&self, token: Option<AccountToken>) -> Result<()> {
+ async fn create(&self) -> Result<()> {
let mut rpc = new_rpc_client().await?;
- rpc.set_account(token.clone().unwrap_or_default()).await?;
- if let Some(token) = token {
- println!("Mullvad account \"{}\" set", token);
- } else {
- println!("Mullvad account removed");
- }
+ rpc.create_new_account(()).await?;
+ println!("New account created!");
+ self.get().await
+ }
+
+ async fn login(&self, token: AccountToken) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ rpc.login_account(token.clone()).await?;
+ println!("Mullvad account \"{}\" set", token);
+ Ok(())
+ }
+
+ async fn logout(&self) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ rpc.logout_account(()).await?;
+ println!("Removed device from Mullvad account");
Ok(())
}
async fn get(&self) -> Result<()> {
let mut rpc = new_rpc_client().await?;
- let settings = rpc.get_settings(()).await?.into_inner();
- if settings.account_token != "" {
- println!("Mullvad account: {}", settings.account_token);
+ let device = rpc.get_device(()).await?.into_inner();
+ if !device.account_token.is_empty() {
+ println!("Mullvad account: {}", device.account_token);
+ println!("Device name : {}", device.device.unwrap().name);
let expiry = rpc
- .get_account_data(settings.account_token)
+ .get_account_data(device.account_token)
.await
.map_err(|error| Error::RpcFailedExt("Failed to fetch account data", error))?
.into_inner();
@@ -108,13 +113,6 @@ impl Account {
Ok(())
}
- async fn create(&self) -> Result<()> {
- let mut rpc = new_rpc_client().await?;
- rpc.create_new_account(()).await?;
- println!("New account created!");
- self.get().await
- }
-
async fn redeem_voucher(&self, mut voucher: String) -> Result<()> {
let mut rpc = new_rpc_client().await?;
voucher.retain(|c| c.is_alphanumeric());
diff --git a/mullvad-cli/src/cmds/status.rs b/mullvad-cli/src/cmds/status.rs
index 8c4a929c30..f5a681e36c 100644
--- a/mullvad-cli/src/cmds/status.rs
+++ b/mullvad-cli/src/cmds/status.rs
@@ -74,10 +74,9 @@ impl Command for Status {
println!("New app version info: {:#?}", app_version_info);
}
}
- EventType::KeyEvent(key_event) => {
+ EventType::Device(device) => {
if verbose {
- print!("Key event: ");
- print_keygen_event(&key_event);
+ println!("Device event: {:#?}", device);
}
}
}
diff --git a/mullvad-cli/src/cmds/tunnel.rs b/mullvad-cli/src/cmds/tunnel.rs
index f3b218648e..f27e29d147 100644
--- a/mullvad-cli/src/cmds/tunnel.rs
+++ b/mullvad-cli/src/cmds/tunnel.rs
@@ -246,20 +246,13 @@ impl Tunnel {
println!("No key is set");
return Ok(());
}
-
- let is_valid = rpc
- .verify_wireguard_key(())
- .await
- .map_err(|error| Error::RpcFailedExt("Failed to verify key", error))?
- .into_inner();
- println!("Key is valid for use with current account: {}", is_valid);
Ok(())
}
async fn process_wireguard_key_generate() -> Result<()> {
let mut rpc = new_rpc_client().await?;
- let keygen_event = rpc.generate_wireguard_key(()).await?;
- print_keygen_event(&keygen_event.into_inner());
+ let keygen_event = rpc.rotate_wireguard_key(()).await?;
+ println!("Rotated WireGuard key");
Ok(())
}
diff --git a/mullvad-daemon/src/account.rs b/mullvad-daemon/src/account.rs
deleted file mode 100644
index f5655c9d1f..0000000000
--- a/mullvad-daemon/src/account.rs
+++ /dev/null
@@ -1,173 +0,0 @@
-use chrono::{DateTime, Utc};
-use futures::future::{abortable, AbortHandle};
-use mullvad_rpc::{
- availability::ApiAvailabilityHandle,
- rest::{self, Error as RestError, MullvadRestHandle},
- AccountsProxy,
-};
-use mullvad_types::account::{AccountToken, VoucherSubmission};
-use std::{future::Future, time::Duration};
-use talpid_core::future_retry::{
- constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered,
-};
-
-const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
-const RETRY_ACTION_MAX_RETRIES: usize = 2;
-
-const RETRY_EXPIRY_CHECK_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
-const RETRY_EXPIRY_CHECK_INTERVAL_FACTOR: u32 = 5;
-const RETRY_EXPIRY_CHECK_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-
-pub struct Account(());
-
-#[derive(Clone)]
-pub struct AccountHandle {
- api_availability: ApiAvailabilityHandle,
- initial_check_abort_handle: AbortHandle,
- proxy: AccountsProxy,
-}
-
-impl AccountHandle {
- pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.create_account(),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub fn get_www_auth_token(
- &self,
- account: AccountToken,
- ) -> impl Future<Output = Result<String, rest::Error>> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.get_www_auth_token(account.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.get_expiry(token.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if handle_expiry_result_inner(&result, &self.api_availability) {
- self.initial_check_abort_handle.abort();
- }
- result
- }
-
- pub async fn submit_voucher(
- &mut self,
- account_token: AccountToken,
- voucher: String,
- ) -> Result<VoucherSubmission, rest::Error> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
- move |result| Self::should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if result.is_ok() {
- self.initial_check_abort_handle.abort();
- self.api_availability.resume_background();
- }
- result
- }
-
- fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
- match result {
- Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
- _ => false,
- }
- }
-}
-
-impl Account {
- pub fn new(
- runtime: tokio::runtime::Handle,
- rpc_handle: MullvadRestHandle,
- token: Option<String>,
- api_availability: ApiAvailabilityHandle,
- ) -> AccountHandle {
- let accounts_proxy = AccountsProxy::new(rpc_handle);
- api_availability.pause_background();
-
- let api_availability_copy = api_availability.clone();
- let accounts_proxy_copy = accounts_proxy.clone();
-
- let (future, initial_check_abort_handle) = abortable(async move {
- let token = if let Some(token) = token {
- token
- } else {
- api_availability.pause_background();
- return;
- };
-
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(
- RETRY_EXPIRY_CHECK_INTERVAL_INITIAL,
- RETRY_EXPIRY_CHECK_INTERVAL_FACTOR,
- )
- .max_delay(RETRY_EXPIRY_CHECK_INTERVAL_MAX),
- );
- let future_generator = move || {
- let wait_online = api_availability.wait_online();
- let expiry_fut = accounts_proxy.get_expiry(token.clone());
- let api_availability_copy = api_availability.clone();
- async move {
- let _ = wait_online.await;
- handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy)
- }
- };
- let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
- retry_future(future_generator, should_retry, retry_strategy).await;
- });
- runtime.spawn(future);
-
- AccountHandle {
- api_availability: api_availability_copy,
- initial_check_abort_handle,
- proxy: accounts_proxy_copy,
- }
- }
-}
-
-fn handle_expiry_result_inner(
- result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>,
- api_availability: &ApiAvailabilityHandle,
-) -> bool {
- match result {
- Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
- api_availability.resume_background();
- true
- }
- Ok(_expiry) => {
- api_availability.pause_background();
- true
- }
- Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => {
- if code == mullvad_rpc::INVALID_ACCOUNT || code == mullvad_rpc::INVALID_AUTH {
- api_availability.pause_background();
- return true;
- }
- false
- }
- Err(_) => false,
- }
-}
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs
new file mode 100644
index 0000000000..42c8ee23bf
--- /dev/null
+++ b/mullvad-daemon/src/device.rs
@@ -0,0 +1,748 @@
+use crate::DaemonEventSender;
+use chrono::{DateTime, Utc};
+use futures::{
+ channel::mpsc,
+ future::{abortable, AbortHandle},
+ stream::StreamExt,
+};
+use mullvad_rpc::{
+ availability::{self, ApiAvailabilityHandle},
+ rest::{self, Error as RestError, MullvadRestHandle},
+ AccountsProxy, DevicesProxy,
+};
+use mullvad_types::{
+ account::{AccountToken, VoucherSubmission},
+ device::{Device, DeviceData, DeviceId},
+ wireguard::{RotationInterval, WireguardData},
+};
+use std::{
+ future::Future,
+ path::Path,
+ sync::{Arc, Mutex},
+ time::Duration,
+};
+use talpid_core::{
+ future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered},
+ mpsc::Sender,
+};
+use talpid_types::{net::wireguard::PrivateKey, ErrorExt};
+use tokio::{
+ fs,
+ io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
+};
+
+/// How often to check whether the key has expired.
+/// A short interval is used in case the computer is ever suspended.
+const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60);
+
+/// File that used to store account and device data.
+const DEVICE_CACHE_FILENAME: &str = "device.json";
+
+const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
+const RETRY_ACTION_MAX_RETRIES: usize = 2;
+
+const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
+const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5;
+const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
+
+pub struct DeviceKeyEvent(pub DeviceData);
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "The account has reached the maximum number of devices")]
+ TooManyDevices,
+ #[error(display = "No device is set")]
+ NoDevice,
+ #[error(display = "The login attempt was aborted")]
+ LoginAborted,
+ #[error(display = "Unexpected HTTP request error")]
+ RestError(#[error(source)] rest::Error),
+ #[error(display = "API availability check was interrupted")]
+ ApiCheckError(#[error(source)] availability::Error),
+ #[error(display = "Failed to read or write device cache")]
+ DeviceIoError(#[error(source)] io::Error),
+ #[error(display = "Failed parse device cache")]
+ ParseDeviceCache(#[error(source)] serde_json::Error),
+}
+
+pub(crate) struct AccountManager {
+ runtime: tokio::runtime::Handle,
+ account_service: AccountService,
+ device_service: DeviceService,
+ inner: Arc<Mutex<AccountManagerInner>>,
+ cache_update_tx: mpsc::UnboundedSender<Option<DeviceData>>,
+ cache_task_join_handle: Option<tokio::task::JoinHandle<()>>,
+ key_update_tx: DaemonEventSender<DeviceKeyEvent>,
+ rotation_abort_handle: Option<AbortHandle>,
+}
+
+struct AccountManagerInner {
+ data: Option<DeviceData>,
+ rotation_interval: RotationInterval,
+}
+
+impl AccountManager {
+ pub async fn new(
+ runtime: tokio::runtime::Handle,
+ rest_handle: rest::MullvadRestHandle,
+ api_availability: ApiAvailabilityHandle,
+ settings_dir: &Path,
+ key_update_tx: DaemonEventSender<DeviceKeyEvent>,
+ ) -> Result<AccountManager, Error> {
+ let (mut cacher, device_data) = DeviceCacher::new(settings_dir).await?;
+ let token = device_data.as_ref().map(|state| state.token.clone());
+ let account_service = Account::new(
+ runtime.clone(),
+ rest_handle.clone(),
+ token,
+ api_availability.clone(),
+ );
+ let should_start_rotation = device_data.is_some();
+ let inner = Arc::new(Mutex::new(AccountManagerInner {
+ data: device_data,
+ rotation_interval: RotationInterval::default(),
+ }));
+
+ let (cache_update_tx, mut cache_update_rx) = mpsc::unbounded();
+ let cache_task_join_handle = runtime.spawn(async move {
+ while let Some(new_device) = cache_update_rx.next().await {
+ if let Err(error) = cacher.write(new_device).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to update device cache")
+ );
+ }
+ }
+ });
+
+ let mut manager = AccountManager {
+ runtime,
+ account_service,
+ device_service: DeviceService::new(rest_handle, api_availability),
+ inner,
+ cache_update_tx,
+ cache_task_join_handle: Some(cache_task_join_handle),
+ key_update_tx,
+ rotation_abort_handle: None,
+ };
+
+ if should_start_rotation {
+ manager.start_key_rotation();
+ }
+
+ Ok(manager)
+ }
+
+ pub fn account_service(&self) -> AccountService {
+ self.account_service.clone()
+ }
+
+ pub fn device_service(&self) -> DeviceService {
+ self.device_service.clone()
+ }
+
+ pub async fn login(&mut self, token: AccountToken) -> Result<DeviceData, Error> {
+ let data = self.device_service.generate_for_account(token).await?;
+ self.logout();
+ {
+ let mut inner = self.inner.lock().unwrap();
+ inner.data.replace(data.clone());
+ let _ = self.cache_update_tx.unbounded_send(Some(data.clone()));
+ }
+ self.start_key_rotation();
+
+ Ok(data)
+ }
+
+ pub fn set(&mut self, data: DeviceData) {
+ self.logout();
+ {
+ let mut inner = self.inner.lock().unwrap();
+ inner.data.replace(data.clone());
+ let _ = self.cache_update_tx.unbounded_send(Some(data));
+ }
+ self.start_key_rotation();
+ }
+
+ /// Log out without waiting for the result.
+ pub fn logout(&mut self) {
+ let fut = self.logout_inner(true);
+ self.runtime.spawn(fut);
+ }
+
+ /// Log out, and wait until the API has removed the device.
+ #[cfg(not(target_os = "android"))]
+ pub fn logout_wait(&mut self) -> impl Future<Output = Result<(), Error>> {
+ self.logout_inner(false)
+ }
+
+ fn logout_inner(&mut self, use_backoff: bool) -> impl Future<Output = Result<(), Error>> {
+ self.stop_key_rotation();
+ let data = {
+ let mut inner = self.inner.lock().unwrap();
+ let _ = self.cache_update_tx.unbounded_send(None);
+ inner.data.take()
+ };
+ let service = self.device_service.clone();
+ async move {
+ if let Some(data) = data {
+ if use_backoff {
+ return service
+ .remove_device_with_backoff(data.token, data.device.id)
+ .await;
+ } else {
+ return service.remove_device(data.token, data.device.id).await;
+ }
+ }
+ Ok(())
+ }
+ }
+
+ pub async fn rotate_key(&mut self) -> Result<WireguardData, Error> {
+ let mut data = {
+ let inner = self.inner.lock().unwrap();
+ inner.data.as_ref().ok_or(Error::NoDevice)?.clone()
+ };
+ self.stop_key_rotation();
+ let result = self
+ .device_service
+ .rotate_key(data.token.clone(), data.device.id.clone())
+ .await;
+ if let Ok(ref wg_data) = result {
+ data.wg_data = wg_data.clone();
+ let mut inner = self.inner.lock().unwrap();
+ inner.data.replace(data.clone());
+ let _ = self.cache_update_tx.unbounded_send(Some(data));
+ }
+ self.start_key_rotation();
+ result
+ }
+
+ pub fn get(&self) -> Option<DeviceData> {
+ self.inner.lock().unwrap().data.clone()
+ }
+
+ pub fn is_some(&self) -> bool {
+ self.inner.lock().unwrap().data.is_some()
+ }
+
+ pub async fn set_rotation_interval(&mut self, interval: RotationInterval) {
+ self.stop_key_rotation();
+ let restart_rotation = {
+ let mut inner = self.inner.lock().unwrap();
+ inner.rotation_interval = interval;
+ inner.data.is_some()
+ };
+ if restart_rotation {
+ self.start_key_rotation();
+ }
+ }
+
+ fn start_key_rotation(&mut self) {
+ self.stop_key_rotation();
+
+ let service = self.device_service.clone();
+ let inner = self.inner.clone();
+ let cache_update_tx = self.cache_update_tx.clone();
+ let key_update_tx = self.key_update_tx.clone();
+
+ let (task, abort_handle) = abortable(async move {
+ loop {
+ tokio::time::sleep(KEY_CHECK_INTERVAL).await;
+
+ let rotation_interval = { inner.lock().unwrap().rotation_interval.clone() };
+
+ let mut state = {
+ match inner.lock().unwrap().data.as_ref() {
+ Some(device_config) => device_config.clone(),
+ None => continue,
+ }
+ };
+
+ if (chrono::Utc::now()
+ .signed_duration_since(state.wg_data.created)
+ .num_seconds() as u64)
+ < rotation_interval.as_duration().as_secs()
+ {
+ continue;
+ }
+
+ match service
+ .rotate_key_with_backoff(state.token.clone(), state.device.id.clone())
+ .await
+ {
+ Ok(wg_data) => {
+ state.wg_data = wg_data;
+ {
+ let mut inner = inner.lock().unwrap();
+ inner.data.replace(state.clone());
+ let _ = cache_update_tx.unbounded_send(Some(state.clone()));
+ }
+ let _ = key_update_tx.send(DeviceKeyEvent(state));
+ }
+ Err(error) => {
+ log::debug!("{}", error.display_chain_with_msg("Stopping key rotation"));
+ }
+ }
+ }
+ });
+ self.runtime.spawn(task);
+ self.rotation_abort_handle = Some(abort_handle);
+ }
+
+ fn stop_key_rotation(&mut self) {
+ if let Some(abort_handle) = self.rotation_abort_handle.take() {
+ abort_handle.abort();
+ }
+ }
+}
+
+impl Drop for AccountManager {
+ fn drop(&mut self) {
+ self.stop_key_rotation();
+ if let Some(cache_task_join_handle) = self.cache_task_join_handle.take() {
+ let _ = self.runtime.block_on(cache_task_join_handle);
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct DeviceService {
+ api_availability: ApiAvailabilityHandle,
+ proxy: DevicesProxy,
+}
+
+impl DeviceService {
+ fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self {
+ Self {
+ proxy: DevicesProxy::new(handle),
+ api_availability,
+ }
+ }
+
+ /// Generate a new device for a given token
+ pub async fn generate_for_account(&self, token: AccountToken) -> Result<DeviceData, Error> {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ let (device, addresses) = retry_future_n(
+ move || proxy.create(token_copy.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+
+ pub async fn generate_for_account_with_backoff(
+ &self,
+ token: AccountToken,
+ ) -> Result<DeviceData, Error> {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ )
+ .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
+ );
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ let (device, addresses) = retry_future(
+ move || {
+ let wait_online = api_handle.wait_online();
+ let fut = proxy.create(token_copy.clone(), pubkey.clone());
+ async move {
+ let _ = wait_online.await;
+ fut.await
+ }
+ },
+ should_retry_backoff,
+ retry_strategy,
+ )
+ .await?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+
+ pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.remove(token.clone(), device.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await?;
+ Ok(())
+ }
+
+ pub async fn remove_device_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ ), // Not setting a maximum interval
+ );
+
+ retry_future(
+ move || {
+ let wait_online = api_handle.wait_online();
+ let fut = proxy.remove(token.clone(), device.clone());
+ async move {
+ let _ = wait_online.await;
+ fut.await
+ }
+ },
+ should_retry_backoff,
+ retry_strategy,
+ )
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn rotate_key(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+ let addresses = retry_future_n(
+ move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn rotate_key_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ )
+ .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
+ );
+ let addresses = retry_future(
+ move || {
+ let wait_online = api_handle.wait_online();
+ let fut = proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone());
+ async move {
+ let _ = wait_online.await;
+ fut.await
+ }
+ },
+ should_retry_backoff,
+ retry_strategy,
+ )
+ .await?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.list(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(Error::RestError)
+ }
+}
+
+pub struct DeviceCacher {
+ file: io::BufWriter<fs::File>,
+}
+
+impl DeviceCacher {
+ pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> {
+ let mut options = std::fs::OpenOptions::new();
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::OpenOptionsExt;
+ options.mode(0o600);
+ }
+ #[cfg(windows)]
+ {
+ use std::os::windows::fs::OpenOptionsExt;
+ // exclusive access
+ options.share_mode(0);
+ }
+
+ let path = settings_dir.join(DEVICE_CACHE_FILENAME);
+ let cache_exists = path.is_file();
+
+ let mut file = fs::OpenOptions::from(options)
+ .write(true)
+ .read(true)
+ .create(true)
+ .open(path)
+ .await?;
+
+ let device: Option<DeviceData> = if cache_exists {
+ let mut reader = io::BufReader::new(&mut file);
+ let mut buffer = String::new();
+ reader.read_to_string(&mut buffer).await?;
+ if !buffer.is_empty() {
+ serde_json::from_str(&buffer)?
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ Ok((
+ DeviceCacher {
+ file: io::BufWriter::new(file),
+ },
+ device,
+ ))
+ }
+
+ pub async fn write(&mut self, device: Option<DeviceData>) -> Result<(), Error> {
+ let data = serde_json::to_vec_pretty(&device).unwrap();
+
+ self.file.get_mut().set_len(0).await?;
+ self.file.seek(io::SeekFrom::Start(0)).await?;
+ self.file.write_all(&data).await?;
+ self.file.flush().await?;
+ self.file.get_mut().sync_data().await?;
+
+ Ok(())
+ }
+}
+
+#[derive(Clone)]
+pub struct AccountService {
+ api_availability: ApiAvailabilityHandle,
+ initial_check_abort_handle: AbortHandle,
+ proxy: AccountsProxy,
+}
+
+impl AccountService {
+ pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.create_account(),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub fn get_www_auth_token(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<String, rest::Error>> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.get_www_auth_token(account.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.get_expiry(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if handle_expiry_result_inner(&result, &self.api_availability) {
+ self.initial_check_abort_handle.abort();
+ }
+ result
+ }
+
+ pub async fn submit_voucher(
+ &mut self,
+ account_token: AccountToken,
+ voucher: String,
+ ) -> Result<VoucherSubmission, rest::Error> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if result.is_ok() {
+ self.initial_check_abort_handle.abort();
+ self.api_availability.resume_background();
+ }
+ result
+ }
+}
+
+struct Account(());
+
+impl Account {
+ pub fn new(
+ runtime: tokio::runtime::Handle,
+ rpc_handle: MullvadRestHandle,
+ token: Option<String>,
+ api_availability: ApiAvailabilityHandle,
+ ) -> AccountService {
+ let accounts_proxy = AccountsProxy::new(rpc_handle);
+ api_availability.pause_background();
+
+ let api_availability_copy = api_availability.clone();
+ let accounts_proxy_copy = accounts_proxy.clone();
+
+ let (future, initial_check_abort_handle) = abortable(async move {
+ let token = if let Some(token) = token {
+ token
+ } else {
+ api_availability.pause_background();
+ return;
+ };
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ )
+ .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
+ );
+ let future_generator = move || {
+ let wait_online = api_availability.wait_online();
+ let expiry_fut = accounts_proxy.get_expiry(token.clone());
+ let api_availability_copy = api_availability.clone();
+ async move {
+ let _ = wait_online.await;
+ handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy)
+ }
+ };
+ let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
+ retry_future(future_generator, should_retry, retry_strategy).await;
+ });
+ runtime.spawn(future);
+
+ AccountService {
+ api_availability: api_availability_copy,
+ initial_check_abort_handle,
+ proxy: accounts_proxy_copy,
+ }
+ }
+}
+
+fn handle_expiry_result_inner(
+ result: &Result<chrono::DateTime<chrono::Utc>, mullvad_rpc::rest::Error>,
+ api_availability: &ApiAvailabilityHandle,
+) -> bool {
+ match result {
+ Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
+ api_availability.resume_background();
+ true
+ }
+ Ok(_expiry) => {
+ api_availability.pause_background();
+ true
+ }
+ Err(mullvad_rpc::rest::Error::ApiError(_status, code)) => {
+ if code == mullvad_rpc::INVALID_ACCOUNT {
+ api_availability.pause_background();
+ return true;
+ }
+ false
+ }
+ Err(_) => false,
+ }
+}
+
+fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
+ match result {
+ Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
+ _ => false,
+ }
+}
+
+fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
+ match result {
+ Ok(_) => false,
+ Err(error) => {
+ if let RestError::ApiError(status, code) = error {
+ *status != rest::StatusCode::NOT_FOUND
+ && code != mullvad_rpc::INVALID_ACCOUNT
+ && code != mullvad_rpc::KEY_LIMIT_REACHED
+ && code != mullvad_rpc::MAX_DEVICES_REACHED
+ && code != mullvad_rpc::PUBKEY_IN_USE
+ } else {
+ true
+ }
+ }
+ }
+}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 0d9ec96c87..8d38eb8f54 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -4,9 +4,9 @@
#[macro_use]
extern crate serde;
-mod account;
pub mod account_history;
mod api;
+pub mod device;
pub mod exception_logging;
#[cfg(target_os = "macos")]
pub mod exclusion_gid;
@@ -36,6 +36,7 @@ use mullvad_rpc::{
};
use mullvad_types::{
account::{AccountData, AccountToken, VoucherSubmission},
+ device::{Device, DeviceData, DeviceEvent, DeviceId},
endpoint::MullvadEndpoint,
location::{Coordinates, GeoIpLocation},
relay_constraints::{
@@ -46,7 +47,7 @@ use mullvad_types::{
settings::{DnsOptions, DnsState, Settings},
states::{TargetState, TunnelState},
version::{AppVersion, AppVersionInfo},
- wireguard::{KeygenEvent, RotationInterval},
+ wireguard::{KeygenEvent, PublicKey, RotationInterval},
};
use settings::SettingsPersister;
#[cfg(target_os = "android")]
@@ -75,7 +76,8 @@ use talpid_types::android::AndroidContext;
use talpid_types::{
net::{
openvpn::{self, ProxySettings},
- TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType,
+ wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters,
+ TunnelType,
},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
@@ -84,12 +86,6 @@ use talpid_types::{
use tokio::fs;
use tokio::io;
-#[path = "wireguard.rs"]
-mod wireguard;
-
-/// Timeout for first WireGuard key pushing
-const FIRST_KEY_PUSH_TIMEOUT: Duration = Duration::from_secs(5);
-
/// Delay between generating a new WireGuard key and reconnecting
const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60);
@@ -124,13 +120,28 @@ pub enum Error {
#[error(display = "Unable to load account history")]
LoadAccountHistory(#[error(source)] account_history::Error),
+ #[error(display = "Failed to start account manager")]
+ LoadAccountManager(#[error(source)] device::Error),
+
+ #[error(display = "Failed to log in to account")]
+ LoginError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to log out of account")]
+ LogoutError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to rotate WireGuard key")]
+ KeyRotationError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to list devices")]
+ ListDevicesError(#[error(source)] device::Error),
+
+ #[error(display = "Failed to remove device")]
+ RemoveDeviceError(#[error(source)] device::Error),
+
#[cfg(target_os = "linux")]
#[error(display = "Unable to initialize split tunneling")]
InitSplitTunneling(#[error(source)] split_tunnel::Error),
- #[error(display = "The account has too many wireguard keys")]
- TooManyKeys,
-
#[cfg(windows)]
#[error(display = "Split tunneling error")]
SplitTunnelError(#[error(source)] split_tunnel::Error),
@@ -226,8 +237,16 @@ pub enum DaemonCommand {
/// Trigger an asynchronous relay list update. This returns before the relay list is actually
/// updated.
UpdateRelayLocations,
- /// Set which account token to use for subsequent connection attempts.
- SetAccount(ResponseTx<(), settings::Error>, Option<AccountToken>),
+ /// Log in with a given account and create a new device.
+ LoginAccount(ResponseTx<(), Error>, AccountToken),
+ /// Log out of the current account and remove the device, if they exist.
+ LogoutAccount(ResponseTx<(), Error>),
+ /// Return the current device configuration, if there is one.
+ GetDevice(ResponseTx<Option<DeviceData>, Error>),
+ /// Return all the devices for a given account token.
+ ListDevices(ResponseTx<Vec<Device>, Error>, AccountToken),
+ /// Remove device from a given account.
+ RemoveDevice(ResponseTx<(), Error>, AccountToken, DeviceId),
/// Place constraints on the type of tunnel and relay
UpdateRelaySettings(ResponseTx<(), settings::Error>, RelaySettingsUpdate),
/// Set the allow LAN setting.
@@ -256,11 +275,9 @@ pub enum DaemonCommand {
/// Get the daemon settings
GetSettings(oneshot::Sender<Settings>),
/// Generate new wireguard key
- GenerateWireguardKey(ResponseTx<wireguard::KeygenEvent, Error>),
+ RotateWireguardKey(ResponseTx<(), Error>),
/// Return a public key of the currently set wireguard private key, if there is one
- GetWireguardKey(ResponseTx<Option<wireguard::PublicKey>, Error>),
- /// Verify if the currently set wireguard key is valid.
- VerifyWireguardKey(ResponseTx<bool, Error>),
+ GetWireguardKey(ResponseTx<Option<PublicKey>, Error>),
/// Get information about the currently running and latest app versions
GetVersionInfo(oneshot::Sender<Option<AppVersionInfo>>),
/// Get current version of the app
@@ -320,19 +337,14 @@ pub(crate) enum InternalDaemonEvent {
Command(DaemonCommand),
/// Daemon shutdown triggered by a signal, ctrl-c or similar.
TriggerShutdown,
- /// Wireguard key generation event
- WgKeyEvent(
- (
- AccountToken,
- Result<mullvad_types::wireguard::WireguardData, wireguard::Error>,
- ),
- ),
/// New Account created
NewAccountEvent(AccountToken, oneshot::Sender<Result<String, Error>>),
/// The background job fetching new `AppVersionInfo`s got a new info object.
NewAppVersionInfo(AppVersionInfo),
/// Request from REST client to use a different API endpoint.
GenerateApiConnectionMode(api::ApiConnectionModeRequest),
+ /// Sent when a device key is rotated.
+ DeviceKeyEvent(device::DeviceKeyEvent),
/// The split tunnel paths or state were updated.
#[cfg(target_os = "windows")]
ExcludedPathsEvent(ExcludedPathsUpdate, oneshot::Sender<Result<(), Error>>),
@@ -368,6 +380,12 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent {
}
}
+impl From<device::DeviceKeyEvent> for InternalDaemonEvent {
+ fn from(event: device::DeviceKeyEvent) -> Self {
+ InternalDaemonEvent::DeviceKeyEvent(event)
+ }
+}
+
#[derive(Clone, Debug, Eq, PartialEq)]
enum DaemonExecutionState {
Running,
@@ -529,8 +547,8 @@ pub trait EventListener {
/// Or some flag about the currently running version is changed.
fn notify_app_version(&self, app_version_info: AppVersionInfo);
- /// Notify clients of a key generation event.
- fn notify_key_event(&self, key_event: KeygenEvent);
+ /// Notify that device changed (login, logout, or key rotation).
+ fn notify_device_event(&self, event: DeviceEvent);
}
pub struct Daemon<L: EventListener> {
@@ -546,10 +564,9 @@ pub struct Daemon<L: EventListener> {
event_listener: L,
settings: SettingsPersister,
account_history: account_history::AccountHistory,
- account: account::AccountHandle,
+ account_manager: device::AccountManager,
rpc_runtime: mullvad_rpc::MullvadRpcRuntime,
rpc_handle: mullvad_rpc::rest::MullvadRestHandle,
- wireguard_key_manager: wireguard::KeyManager,
version_updater_handle: version_check::VersionUpdaterHandle,
relay_selector: relays::RelaySelector,
last_generated_relay: Option<Relay>,
@@ -584,8 +601,6 @@ where
mullvad_rpc::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await;
- let runtime = tokio::runtime::Handle::current();
-
let (internal_event_tx, internal_event_rx) = command_channel.destructure();
if let Err(error) = migrations::migrate_all(&cache_dir, &settings_dir).await {
@@ -596,7 +611,55 @@ where
}
let settings = SettingsPersister::load(&settings_dir).await;
- let target_state = if settings.get_account_token().is_none() {
+ let tunnel_parameters_generator = MullvadTunnelParametersGenerator {
+ tx: internal_event_tx.clone(),
+ };
+
+ let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
+ &cache_dir,
+ true,
+ #[cfg(target_os = "android")]
+ Self::create_bypass_tx(&internal_event_tx),
+ )
+ .await
+ .map_err(Error::InitRpcFactory)?;
+
+ let api_availability = rpc_runtime.availability_handle();
+ api_availability.suspend();
+
+ let endpoint_updater = api::ApiEndpointUpdaterHandle::new();
+
+ let proxy_provider = api::create_api_config_provider(
+ internal_event_tx.to_specialized_sender(),
+ ApiConnectionMode::Direct,
+ );
+ let rpc_handle = rpc_runtime
+ .mullvad_rest_handle(proxy_provider, endpoint_updater.callback())
+ .await;
+
+ let mut account_manager = device::AccountManager::new(
+ runtime.clone(),
+ rpc_handle.clone(),
+ api_availability.clone(),
+ &settings_dir,
+ internal_event_tx.to_specialized_sender(),
+ )
+ .await
+ .map_err(Error::LoadAccountManager)?;
+ if let Some(rotation_interval) = settings.tunnel_options.wireguard.rotation_interval {
+ account_manager
+ .set_rotation_interval(rotation_interval)
+ .await;
+ }
+
+ let account_history = account_history::AccountHistory::new(
+ &settings_dir,
+ account_manager.get().map(|device| device.token),
+ )
+ .await
+ .map_err(Error::LoadAccountHistory)?;
+
+ let target_state = if !account_manager.is_some() {
PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await
} else if settings.auto_connect {
log::info!("Automatically connecting since auto-connect is turned on");
@@ -605,10 +668,6 @@ where
PersistentTargetState::new(&cache_dir).await
};
- let tunnel_parameters_generator = MullvadTunnelParametersGenerator {
- tx: internal_event_tx.clone(),
- };
-
#[cfg(windows)]
let exclude_paths = if settings.split_tunnel.enable_exclusions {
settings
@@ -621,18 +680,6 @@ where
vec![]
};
- let rpc_runtime = mullvad_rpc::MullvadRpcRuntime::with_cache(
- &cache_dir,
- true,
- #[cfg(target_os = "android")]
- Self::create_bypass_tx(&internal_event_tx),
- )
- .await
- .map_err(Error::InitRpcFactory)?;
-
- let api_availability = rpc_runtime.availability_handle();
- api_availability.suspend();
-
let initial_api_endpoint =
api::get_allowed_endpoint(rpc_runtime.address_cache.get_address().await);
@@ -664,17 +711,8 @@ where
.await
.map_err(Error::TunnelError)?;
- let endpoint_updater = api::ApiEndpointUpdaterHandle::new();
endpoint_updater.set_tunnel_command_tx(Arc::downgrade(&tunnel_command_tx));
- let proxy_provider = api::create_api_config_provider(
- internal_event_tx.to_specialized_sender(),
- ApiConnectionMode::Direct,
- );
- let rpc_handle = rpc_runtime
- .mullvad_rest_handle(proxy_provider, endpoint_updater.callback())
- .await;
-
Self::forward_offline_state(api_availability.clone(), offline_state_rx).await;
let relay_list_listener = event_listener.clone();
@@ -700,28 +738,11 @@ where
settings.show_beta_releases,
);
tokio::spawn(version_updater.run());
- let account_history =
- account_history::AccountHistory::new(&settings_dir, settings.get_account_token())
- .await
- .map_err(Error::LoadAccountHistory)?;
-
- let wireguard_key_manager = wireguard::KeyManager::new(
- internal_event_tx.clone(),
- api_availability.clone(),
- rpc_handle.clone(),
- );
-
- let account = account::Account::new(
- runtime,
- rpc_handle.clone(),
- settings.get_account_token(),
- api_availability.clone(),
- );
// Attempt to download a fresh relay list
relay_selector.update().await;
- let mut daemon = Daemon {
+ let daemon = Daemon {
tunnel_command_tx,
tunnel_state: TunnelState::Disconnected,
target_state,
@@ -734,10 +755,9 @@ where
event_listener,
settings,
account_history,
- account,
+ account_manager,
rpc_runtime,
rpc_handle,
- wireguard_key_manager,
version_updater_handle,
relay_selector,
last_generated_relay: None,
@@ -751,8 +771,6 @@ where
volume_update_tx,
};
- daemon.ensure_wireguard_keys_for_current_account().await;
-
api_availability.unsuspend();
Ok(daemon)
@@ -881,7 +899,6 @@ where
}
Command(command) => self.handle_command(command).await,
TriggerShutdown => self.trigger_shutdown_event(),
- WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await,
NewAccountEvent(account_token, tx) => {
self.handle_new_account_event(account_token, tx).await
}
@@ -891,6 +908,7 @@ where
GenerateApiConnectionMode(request) => {
self.handle_generate_api_connection_mode(request).await
}
+ DeviceKeyEvent(event) => self.handle_device_key_event(event).await,
#[cfg(windows)]
ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await,
}
@@ -967,7 +985,7 @@ where
>,
retry_attempt: u32,
) {
- if let Some(account_token) = self.settings.get_account_token() {
+ if let Some(device) = self.account_manager.get() {
let result = match self.settings.get_relay_settings() {
RelaySettings::CustomTunnelEndpoint(custom_relay) => {
self.last_generated_relay = None;
@@ -987,7 +1005,6 @@ where
&constraints,
self.settings.get_bridge_state(),
retry_attempt,
- self.settings.get_wireguard().is_some(),
)
.ok();
if let Some(relays::RelaySelectorResult {
@@ -1000,7 +1017,7 @@ where
.create_tunnel_parameters(
&exit_relay,
endpoint,
- account_token,
+ device.token,
retry_attempt,
)
.await;
@@ -1111,7 +1128,11 @@ where
.into())
}
MullvadEndpoint::Wireguard(endpoint) => {
- let wg_data = self.settings.get_wireguard().ok_or(Error::NoKeyAvailable)?;
+ let wg_data = self
+ .account_manager
+ .get()
+ .map(|device| device.wg_data)
+ .ok_or(Error::NoKeyAvailable)?;
let tunnel = wireguard::TunnelConfig {
private_key: wg_data.private_key,
addresses: vec![
@@ -1175,7 +1196,13 @@ where
SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await,
GetRelayLocations(tx) => self.on_get_relay_locations(tx),
UpdateRelayLocations => self.on_update_relay_locations().await,
- SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await,
+ LoginAccount(tx, account_token) => self.on_login_account(tx, account_token).await,
+ LogoutAccount(tx) => self.on_logout_account(tx).await,
+ GetDevice(tx) => self.on_get_device(tx).await,
+ ListDevices(tx, account_token) => self.on_list_devices(tx, account_token).await,
+ RemoveDevice(tx, account_token, device_id) => {
+ self.on_remove_device(tx, account_token, device_id).await
+ }
GetAccountHistory(tx) => self.on_get_account_history(tx),
ClearAccountHistory(tx) => self.on_clear_account_history(tx).await,
UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update).await,
@@ -1198,9 +1225,8 @@ where
self.on_set_wireguard_rotation_interval(tx, interval).await
}
GetSettings(tx) => self.on_get_settings(tx),
- GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await,
+ RotateWireguardKey(tx) => self.on_rotate_wireguard_key(tx).await,
GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await,
- VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await,
GetVersionInfo(tx) => self.on_get_version_info(tx).await,
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
@@ -1232,86 +1258,6 @@ where
}
}
- async fn handle_wireguard_key_event(
- &mut self,
- event: (
- AccountToken,
- Result<mullvad_types::wireguard::WireguardData, wireguard::Error>,
- ),
- ) {
- let (account, result) = event;
- // If the account has been reset whilst a key was being generated, the event should be
- // dropped even if a new key was generated.
- if self
- .settings
- .get_account_token()
- .map(|current_account| current_account != account)
- .unwrap_or(true)
- {
- log::info!("Dropping wireguard key event since account has been changed");
- return;
- }
-
- match result {
- Ok(data) => {
- let public_key = data.get_public_key();
- let is_first_key = self.settings.get_wireguard().is_none();
- match self.settings.set_wireguard(Some(data)).await {
- Ok(_) => {
- if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY).await;
- }
- self.event_listener
- .notify_key_event(KeygenEvent::NewKey(public_key));
- if is_first_key {
- self.ensure_key_rotation().await;
- }
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg(
- "Failed to add new wireguard key to account data"
- )
- );
- self.event_listener
- .notify_key_event(KeygenEvent::GenerationFailure)
- }
- }
- }
- Err(wireguard::Error::TooManyKeys) => {
- self.event_listener
- .notify_key_event(KeygenEvent::TooManyKeys);
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg("Failed to generate wireguard key")
- );
- self.event_listener
- .notify_key_event(KeygenEvent::GenerationFailure);
- }
- }
- }
-
- async fn ensure_key_rotation(&mut self) {
- let token = match self.settings.get_account_token() {
- Some(token) => token,
- None => return,
- };
- let public_key = match self.settings.get_wireguard() {
- Some(data) => data.get_public_key(),
- None => return,
- };
- self.wireguard_key_manager
- .set_rotation_interval(
- public_key,
- token,
- self.settings.tunnel_options.wireguard.rotation_interval,
- )
- .await;
- }
-
async fn handle_new_account_event(
&mut self,
new_token: AccountToken,
@@ -1322,12 +1268,12 @@ where
self.set_target_state(TargetState::Unsecured).await;
let _ = tx.send(Ok(new_token));
}
- Err(err) => {
+ Err(error) => {
log::error!(
"{}",
- err.display_chain_with_msg("Failed to save new account")
+ error.display_chain_with_msg("Handling new account failed")
);
- let _ = tx.send(Err(Error::SettingsError(err)));
+ let _ = tx.send(Err(error));
}
};
}
@@ -1409,6 +1355,25 @@ where
let _ = request.response_tx.send(config);
}
+ async fn handle_device_key_event(&mut self, event: device::DeviceKeyEvent) {
+ let device_id = &event.0.device.id;
+ if Some(device_id)
+ != self
+ .account_manager
+ .get()
+ .map(|device| device.device.id)
+ .as_ref()
+ {
+ // Stale config
+ return;
+ }
+ if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
+ self.schedule_reconnect(WG_RECONNECT_DELAY).await;
+ }
+ self.event_listener
+ .notify_device_event(DeviceEvent(Some(Device::from(event.0))));
+ }
+
#[cfg(windows)]
async fn handle_new_excluded_paths(
&mut self,
@@ -1541,7 +1506,7 @@ where
async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) {
let daemon_tx = self.tx.clone();
- let future = self.account.create_account();
+ let future = self.account_manager.account_service().create_account();
tokio::spawn(async move {
match future.await {
Ok(account_token) => {
@@ -1559,7 +1524,7 @@ where
tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>,
account_token: AccountToken,
) {
- let account = self.account.clone();
+ let account = self.account_manager.account_service();
tokio::spawn(async move {
let result = account.check_expiry(account_token).await;
Self::oneshot_send(
@@ -1571,8 +1536,11 @@ where
}
async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) {
- if let Some(account_token) = self.settings.get_account_token() {
- let future = self.account.get_www_auth_token(account_token);
+ if let Some(device) = self.account_manager.get() {
+ let future = self
+ .account_manager
+ .account_service()
+ .get_www_auth_token(device.token);
let rpc_call = async {
Self::oneshot_send(
tx,
@@ -1595,13 +1563,13 @@ where
tx: ResponseTx<VoucherSubmission, Error>,
voucher: String,
) {
- if let Some(account_token) = self.settings.get_account_token() {
- let mut account = self.account.clone();
+ if let Some(device) = self.account_manager.get() {
+ let mut account = self.account_manager.account_service();
tokio::spawn(async move {
Self::oneshot_send(
tx,
account
- .submit_voucher(account_token, voucher)
+ .submit_voucher(device.token, voucher)
.await
.map_err(Error::RestError),
"submit_voucher response",
@@ -1620,90 +1588,103 @@ where
self.relay_selector.update().await;
}
- async fn on_set_account(
- &mut self,
- tx: ResponseTx<(), settings::Error>,
- account_token: Option<String>,
- ) {
- match self.set_account(account_token.clone()).await {
+ async fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) {
+ match self.set_account(Some(account_token)).await {
Ok(account_changed) => {
if account_changed {
- match account_token {
- Some(_) => {
- log::info!(
- "Initiating tunnel restart because the account token changed"
- );
- self.reconnect_tunnel();
- }
- None => {
- log::info!("Disconnecting because account token was cleared");
- self.set_target_state(TargetState::Unsecured).await;
- }
- };
+ log::info!("Initiating tunnel restart because the account token changed");
+ self.reconnect_tunnel();
}
- Self::oneshot_send(tx, Ok(()), "set_account response");
+ Self::oneshot_send(tx, Ok(()), "login_account response");
}
Err(error) => {
- log::error!("{}", error.display_chain_with_msg("Failed to set account"));
- Self::oneshot_send(tx, Err(error), "set_account response");
+ log::error!("{}", error.display_chain_with_msg("Login failed"));
+ Self::oneshot_send(tx, Err(error), "login_account response");
}
}
}
- async fn set_account(
- &mut self,
- account_token: Option<String>,
- ) -> Result<bool, settings::Error> {
- let previous_token = self.settings.get_account_token();
- let account_changed = self
- .settings
- .set_account_token(account_token.clone())
- .await?;
- if account_changed {
- self.event_listener
- .notify_settings(self.settings.to_settings());
-
- let history_token = match account_token {
- Some(token) => token,
- None => previous_token.clone().unwrap_or("".to_string()),
- };
- if let Err(error) = self.account_history.set(history_token).await {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to update account history")
- );
+ async fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) {
+ match self.set_account(None).await {
+ Ok(account_changed) => {
+ if account_changed {
+ log::info!("Disconnecting because account token was cleared");
+ self.set_target_state(TargetState::Unsecured).await;
+ }
+ Self::oneshot_send(tx, Ok(()), "logout_account response");
+ }
+ Err(error) => {
+ log::error!("{}", error.display_chain_with_msg("Logout failed"));
+ Self::oneshot_send(tx, Err(error), "logout_account response");
}
+ }
+ }
- if let Some(previous_token) = previous_token {
- if let Some(previous_key) = self
- .settings
- .get_wireguard()
- .map(|data| data.private_key.public_key())
- {
- let remove_key = self
- .wireguard_key_manager
- .remove_key_with_backoff(previous_token, previous_key);
- tokio::spawn(async move {
- if let Err(error) = remove_key.await {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to remove WireGuard key for previous account"
- )
- );
- }
- });
- }
+ async fn set_account(&mut self, account_token: Option<String>) -> Result<bool, Error> {
+ let previous_token = self.account_manager.get().map(|device| device.token);
+ if previous_token == account_token {
+ return Ok(false);
+ }
+
+ match account_token.clone() {
+ Some(token) => {
+ let device_data = self
+ .account_manager
+ .login(token)
+ .await
+ .map_err(Error::LoginError)?;
+ self.event_listener
+ .notify_device_event(DeviceEvent(Some(Device::from(device_data))));
+ }
+ None => {
+ self.account_manager.logout();
+ self.event_listener.notify_device_event(DeviceEvent(None));
}
- if let Err(error) = self.settings.set_wireguard(None).await {
+ }
+
+ if let Some(token) = account_token.or(previous_token) {
+ if let Err(error) = self.account_history.set(token).await {
log::error!(
"{}",
- error.display_chain_with_msg("Error resetting WireGuard key")
+ error.display_chain_with_msg("Failed to update account history")
);
}
- self.ensure_wireguard_keys_for_current_account().await;
}
- Ok(account_changed)
+
+ Ok(true)
+ }
+
+ async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceData>, Error>) {
+ Self::oneshot_send(tx, Ok(self.account_manager.get()), "get_device response");
+ }
+
+ async fn on_list_devices(&mut self, tx: ResponseTx<Vec<Device>, Error>, token: AccountToken) {
+ Self::oneshot_send(
+ tx,
+ self.account_manager
+ .device_service()
+ .list_devices(token)
+ .await
+ .map_err(Error::ListDevicesError),
+ "list_devices response",
+ );
+ }
+
+ async fn on_remove_device(
+ &mut self,
+ tx: ResponseTx<(), Error>,
+ token: AccountToken,
+ device_id: DeviceId,
+ ) {
+ Self::oneshot_send(
+ tx,
+ self.account_manager
+ .device_service()
+ .remove_device(token, device_id)
+ .await
+ .map_err(Error::RemoveDeviceError),
+ "remove_device response",
+ );
}
fn on_get_account_history(&mut self, tx: oneshot::Sender<Option<AccountToken>>) {
@@ -1723,37 +1704,6 @@ where
Self::oneshot_send(tx, result, "clear_account_history response");
}
- // Remove the key associated with the current account, if there is one.
- // This does not modify settings or account history.
- #[cfg(not(target_os = "android"))]
- fn remove_current_key_rpc(&self) -> impl std::future::Future<Output = Result<(), Error>> {
- let remove_key = if let Some(token) = self.settings.get_account_token() {
- if let Some(wg_data) = self.settings.get_wireguard() {
- Some(
- self.wireguard_key_manager
- .remove_key(token, wg_data.private_key.public_key()),
- )
- } else {
- None
- }
- } else {
- None
- };
-
- async move {
- if let Some(task) = remove_key {
- match task.await {
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- // This result should never occur
- Err(wireguard::Error::TooManyKeys) => Err(Error::TooManyKeys),
- _ => Ok(()),
- }
- } else {
- Ok(())
- }
- }
- }
-
async fn on_get_version_info(&mut self, tx: oneshot::Sender<Option<AppVersionInfo>>) {
if self.app_version_info.is_none() {
log::debug!("No version cache found. Fetching new info");
@@ -1795,17 +1745,13 @@ where
async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) {
let mut last_error = Ok(());
- let remove_key = self.remove_current_key_rpc();
- tokio::spawn(async move {
- if let Err(error) = remove_key.await {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Failed to remove WireGuard key for previous account"
- )
- );
- }
- });
+ if let Err(error) = self.account_manager.logout_wait().await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to clear device cache")
+ );
+ last_error = Err(Error::LogoutError(error));
+ }
if let Err(error) = self.account_history.clear().await {
log::error!(
@@ -2315,7 +2261,9 @@ where
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response");
if settings_changed {
- self.ensure_key_rotation().await;
+ self.account_manager
+ .set_rotation_interval(interval.unwrap_or_default())
+ .await;
self.event_listener
.notify_settings(self.settings.to_settings());
}
@@ -2327,128 +2275,25 @@ where
}
}
- async fn ensure_wireguard_keys_for_current_account(&mut self) {
- if let Some(account) = self.settings.get_account_token() {
- if self.settings.get_wireguard().is_none() {
- log::info!("Generating new WireGuard key for account");
- self.wireguard_key_manager
- .spawn_key_generation_task(account, Some(FIRST_KEY_PUSH_TIMEOUT))
- .await;
- } else {
- log::info!("Account already has WireGuard key");
- self.ensure_key_rotation().await;
- }
- }
- }
-
- async fn on_generate_wireguard_key(&mut self, tx: ResponseTx<KeygenEvent, Error>) {
- match self.on_generate_wireguard_key_inner().await {
- Ok(key_event) => {
- Self::oneshot_send(tx, Ok(key_event), "generate_wireguard_key");
- }
- Err(e) => {
- log::error!(
- "{}",
- e.display_chain_with_msg("Failed to generate new wireguard key")
- );
- Self::oneshot_send(tx, Err(e), "generate_wireguard_key");
- }
- }
- }
-
- async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, Error> {
- let account_token = self
- .settings
- .get_account_token()
- .ok_or(Error::NoAccountToken)?;
- let wireguard_data = self.settings.get_wireguard();
-
- let gen_result = match &wireguard_data {
- Some(wireguard_data) => {
- self.wireguard_key_manager
- .replace_key(account_token.clone(), wireguard_data.get_public_key())
- .await
- }
- None => {
- self.wireguard_key_manager
- .generate_key_sync(account_token.clone())
- .await
- }
- };
-
- match gen_result {
- Ok(new_data) => {
- let public_key = new_data.get_public_key();
- self.settings
- .set_wireguard(Some(new_data))
- .await
- .map_err(Error::SettingsError)?;
- if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY).await;
- }
- let keygen_event = KeygenEvent::NewKey(public_key.clone());
- self.event_listener.notify_key_event(keygen_event.clone());
-
- // update automatic rotation
- self.wireguard_key_manager
- .set_rotation_interval(
- public_key,
- account_token,
- self.settings.tunnel_options.wireguard.rotation_interval,
- )
- .await;
-
- Ok(keygen_event)
- }
- Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys),
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
+ async fn on_rotate_wireguard_key(&mut self, tx: ResponseTx<(), Error>) {
+ let result = self.account_manager.rotate_key().await;
+ if let Ok(ref _wg_data) = result {
+ let device = self.account_manager.get().map(Device::from);
+ self.event_listener
+ .notify_device_event(DeviceEvent(device.clone()));
}
+ let _ = tx.send(result.map(|_| ()).map_err(Error::KeyRotationError));
}
- async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<wireguard::PublicKey>, Error>) {
- let result = if self.settings.get_account_token().is_some() {
- Ok(self
- .settings
- .get_wireguard()
- .map(|data| data.get_public_key()))
+ async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<PublicKey>, Error>) {
+ let result = if let Some(device) = self.account_manager.get() {
+ Ok(Some(device.wg_data.get_public_key()))
} else {
Err(Error::NoAccountToken)
};
Self::oneshot_send(tx, result, "get_wireguard_key response");
}
- async fn on_verify_wireguard_key(&mut self, tx: ResponseTx<bool, Error>) {
- let account = match self.settings.get_account_token() {
- Some(account) => account,
- None => {
- Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response");
- return;
- }
- };
- let public_key = match self.settings.get_wireguard() {
- Some(wg_data) => wg_data.private_key.public_key(),
- None => {
- Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response");
- return;
- }
- };
-
- let verification_rpc = self
- .wireguard_key_manager
- .verify_wireguard_key(account, public_key);
-
- tokio::spawn(async move {
- let result = match verification_rpc.await {
- Ok(is_valid) => Ok(is_valid),
- Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)),
- Err(wireguard::Error::ApiCheckError(error)) => Err(Error::ApiCheckError(error)),
- Err(wireguard::Error::TooManyKeys) => return,
- };
- Self::oneshot_send(tx, result, "verify_wireguard_key response");
- });
- }
-
fn on_get_settings(&self, tx: oneshot::Sender<Settings>) {
Self::oneshot_send(tx, self.settings.to_settings(), "get_settings response");
}
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index ba828ed903..e6f84660c1 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -378,20 +378,25 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_daemon_error)
}
- async fn set_account(&self, request: Request<AccountToken>) -> ServiceResult<()> {
- log::debug!("set_account");
+ async fn login_account(&self, request: Request<AccountToken>) -> ServiceResult<()> {
+ log::debug!("login_account");
let account_token = request.into_inner();
- let account_token = if account_token == "" {
- None
- } else {
- Some(account_token)
- };
let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::SetAccount(tx, account_token))?;
+ self.send_command_to_daemon(DaemonCommand::LoginAccount(tx, account_token))?;
self.wait_for_result(rx)
.await?
.map(Response::new)
- .map_err(map_settings_error)
+ .map_err(map_daemon_error)
+ }
+
+ async fn logout_account(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("logout_account");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::LogoutAccount(tx))?;
+ self.wait_for_result(rx)
+ .await?
+ .map(Response::new)
+ .map_err(map_daemon_error)
}
async fn get_account_data(
@@ -479,6 +484,44 @@ impl ManagementService for ManagementServiceImpl {
})
}
+ // Device management
+ async fn get_device(&self, _: Request<()>) -> ServiceResult<types::DeviceConfig> {
+ log::debug!("get_device");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::GetDevice(tx))?;
+ let device = self
+ .wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)?
+ .ok_or(Status::new(Code::NotFound, "no device is set"))?;
+ Ok(Response::new(types::DeviceConfig::from(device)))
+ }
+
+ async fn list_devices(
+ &self,
+ request: Request<AccountToken>,
+ ) -> ServiceResult<types::DeviceList> {
+ log::debug!("list_devices");
+ let (tx, rx) = oneshot::channel();
+ let token = request.into_inner();
+ self.send_command_to_daemon(DaemonCommand::ListDevices(tx, token))?;
+ let device = self.wait_for_result(rx).await?.map_err(map_daemon_error)?;
+ Ok(Response::new(types::DeviceList::from(device)))
+ }
+
+ async fn remove_device(&self, request: Request<types::DeviceRemoval>) -> ServiceResult<()> {
+ log::debug!("remove_device");
+ let (tx, rx) = oneshot::channel();
+ let removal = request.into_inner();
+ self.send_command_to_daemon(DaemonCommand::RemoveDevice(
+ tx,
+ removal.account_token,
+ removal.device_id,
+ ))?;
+ self.wait_for_result(rx).await?.map_err(map_daemon_error)?;
+ Ok(Response::new(()))
+ }
+
// WireGuard key management
//
@@ -515,15 +558,13 @@ impl ManagementService for ManagementServiceImpl {
.map_err(map_settings_error)
}
- async fn generate_wireguard_key(&self, _: Request<()>) -> ServiceResult<types::KeygenEvent> {
- // TODO: return error for TooManyKeys, GenerationFailure
- // on success, simply return the new key or nil
- log::debug!("generate_wireguard_key");
+ async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("rotate_wireguard_key");
let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::GenerateWireguardKey(tx))?;
+ self.send_command_to_daemon(DaemonCommand::RotateWireguardKey(tx))?;
self.wait_for_result(rx)
.await?
- .map(|event| Response::new(types::KeygenEvent::from(event)))
+ .map(Response::new)
.map_err(map_daemon_error)
}
@@ -538,16 +579,6 @@ impl ManagementService for ManagementServiceImpl {
}
}
- async fn verify_wireguard_key(&self, _: Request<()>) -> ServiceResult<bool> {
- log::debug!("verify_wireguard_key");
- let (tx, rx) = oneshot::channel();
- self.send_command_to_daemon(DaemonCommand::VerifyWireguardKey(tx))?;
- self.wait_for_result(rx)
- .await?
- .map(Response::new)
- .map_err(map_daemon_error)
- }
-
// Split tunneling
//
@@ -832,11 +863,11 @@ impl EventListener for ManagementInterfaceEventBroadcaster {
})
}
- fn notify_key_event(&self, key_event: mullvad_types::wireguard::KeygenEvent) {
- log::debug!("Broadcasting new wireguard key event");
+ fn notify_device_event(&self, device: mullvad_types::device::DeviceEvent) {
+ log::debug!("Broadcasting device event");
self.notify(types::DaemonEvent {
- event: Some(daemon_event::Event::KeyEvent(types::KeygenEvent::from(
- key_event,
+ event: Some(daemon_event::Event::Device(types::DeviceEvent::from(
+ device,
))),
})
}
diff --git a/mullvad-daemon/src/relays/mod.rs b/mullvad-daemon/src/relays/mod.rs
index 332ca5fea3..c4a136369c 100644
--- a/mullvad-daemon/src/relays/mod.rs
+++ b/mullvad-daemon/src/relays/mod.rs
@@ -276,7 +276,6 @@ impl RelaySelector {
relay_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> Result<RelaySelectorResult, Error> {
match relay_constraints.tunnel_protocol {
Constraint::Only(TunnelType::OpenVpn) => self.get_openvpn_endpoint(
@@ -293,12 +292,9 @@ impl RelaySelector {
&relay_constraints.wireguard_constraints,
retry_attempt,
),
- Constraint::Any => self.get_any_tunnel_endpoint(
- relay_constraints,
- bridge_state,
- retry_attempt,
- wg_key_exists,
- ),
+ Constraint::Any => {
+ self.get_any_tunnel_endpoint(relay_constraints, bridge_state, retry_attempt)
+ }
}
}
@@ -479,14 +475,9 @@ impl RelaySelector {
relay_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> Result<RelaySelectorResult, Error> {
- let preferred_constraints = self.preferred_constraints(
- &relay_constraints,
- bridge_state,
- retry_attempt,
- wg_key_exists,
- );
+ let preferred_constraints =
+ self.preferred_constraints(&relay_constraints, bridge_state, retry_attempt);
let original_matcher: RelayMatcher<_> = relay_constraints.clone().into();
let preferred_tunnel_protocol = preferred_constraints.tunnel_protocol;
@@ -543,14 +534,12 @@ impl RelaySelector {
original_constraints: &RelayConstraints,
bridge_state: BridgeState,
retry_attempt: u32,
- wg_key_exists: bool,
) -> RelayConstraints {
let (preferred_port, preferred_protocol, preferred_tunnel) = self
.preferred_tunnel_constraints(
retry_attempt,
&original_constraints.location,
&original_constraints.providers,
- wg_key_exists,
);
let mut relay_constraints = original_constraints.clone();
@@ -731,7 +720,6 @@ impl RelaySelector {
retry_attempt: u32,
location_constraint: &Constraint<LocationConstraint>,
providers_constraint: &Constraint<Providers>,
- wg_key_exists: bool,
) -> (Constraint<u16>, TransportProtocol, TunnelType) {
#[cfg(target_os = "windows")]
{
@@ -757,7 +745,7 @@ impl RelaySelector {
});
// If location does not support WireGuard, defer to preferred OpenVPN tunnel
// constraints
- if !location_supports_wireguard || !wg_key_exists {
+ if !location_supports_wireguard {
let (preferred_port, preferred_protocol) =
Self::preferred_openvpn_constraints(retry_attempt);
return (preferred_port, preferred_protocol, TunnelType::OpenVpn);
@@ -1159,7 +1147,7 @@ mod test {
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
@@ -1167,7 +1155,7 @@ mod test {
for attempt in 0..10 {
assert!(relay_selector
- .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.is_ok());
}
@@ -1184,7 +1172,7 @@ mod test {
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::Off, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1192,7 +1180,7 @@ mod test {
for attempt in 0..10 {
assert!(relay_selector
- .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_any_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.is_ok());
}
@@ -1205,7 +1193,6 @@ mod test {
&relay_constraints,
BridgeState::Off,
attempt,
- true,
);
assert_eq!(
preferred.tunnel_protocol,
@@ -1215,7 +1202,6 @@ mod test {
&relay_constraints,
BridgeState::Off,
attempt,
- true,
) {
Ok(result) if matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)) => (),
_ => panic!("OpenVPN endpoint was not selected"),
@@ -1250,14 +1236,14 @@ mod test {
// The same host cannot be used for entry and exit
assert!(relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.is_err());
relay_constraints.wireguard_constraints.entry_location = Constraint::Only(location2);
// If the entry and exit differ, this should succeed
assert!(relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.is_ok());
}
@@ -1286,7 +1272,7 @@ mod test {
// The exit must not equal the entry
let exit_relay = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.map_err(|error| error.to_string())?
.exit_relay;
@@ -1301,7 +1287,7 @@ mod test {
endpoint,
..
} = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.map_err(|error| error.to_string())?;
assert_eq!(exit_relay.hostname, specific_hostname);
@@ -1336,7 +1322,7 @@ mod test {
});
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1362,7 +1348,7 @@ mod test {
..RelayConstraints::default()
};
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
@@ -1381,14 +1367,14 @@ mod test {
#[cfg(all(unix, not(target_os = "android")))]
{
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 0);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::Wireguard)
);
}
let preferred =
- relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2, true);
+ relay_selector.preferred_constraints(&relay_constraints, BridgeState::On, 2);
assert_eq!(
preferred.tunnel_protocol,
Constraint::Only(TunnelType::OpenVpn)
@@ -1405,54 +1391,6 @@ mod test {
}
#[test]
- fn test_wg_relay_with_no_key() {
- let mut relay_constraints = RelayConstraints {
- tunnel_protocol: Constraint::Only(TunnelType::Wireguard),
- ..RelayConstraints::default()
- };
-
- let relay_selector = new_relay_selector();
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect("Failed to get WireGuard relay when WireGuard relay was specified as the only tunnel protocol");
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
-
- relay_constraints.tunnel_protocol = Constraint::Any;
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect("Failed to get OpenVPN relay with tunnel protocol constraint set to Any and without a WireGuard key");
-
- assert!(matches!(result.endpoint, MullvadEndpoint::OpenVpn(_)));
-
- let wireguard_specific_location = LocationConstraint::Hostname(
- "se".to_string(),
- "got".to_string(),
- "se9-wireguard".to_string(),
- );
- relay_constraints.location = Constraint::Only(wireguard_specific_location);
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, false)
- .expect(
- "Failed to get a valid WireGuard relay when tunnel constraints are set to any
- tunnel protocol and with a wireguard specific location without a wireguard key",
- );
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
-
- let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
- .expect(
- "Failed to get a valid WireGuard relay when tunnel constraints are set to any
- tunnel protocol and with a wireguard specific location with a wireguard key",
- );
-
- assert!(matches!(result.endpoint, MullvadEndpoint::Wireguard(_)));
- }
-
- #[test]
fn test_selecting_any_relay_will_consider_multihop() {
let relay_constraints = RelayConstraints {
wireguard_constraints: WireguardConstraints {
@@ -1467,7 +1405,7 @@ mod test {
let relay_selector = new_relay_selector();
- let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ let result = relay_selector.get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection");
// Windows will ignore WireGuard until WireGuard is supported well enough
// TODO: Remove this caveat once Windows defaults to using WireGuard
@@ -1502,7 +1440,7 @@ mod test {
fn test_selecting_wireguard_location_will_consider_multihop() {
let relay_selector = new_relay_selector();
- let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0, true)
+ let result = relay_selector.get_tunnel_endpoint(&WIREGUARD_MULTIHOP_CONSTRAINTS, BridgeState::Off, 0)
.expect("Failed to get relay when tunnel constraints are set to Any and retrying the selection");
@@ -1526,7 +1464,7 @@ mod test {
let relay_selector = new_relay_selector();
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(result.entry_relay.is_some());
@@ -1555,7 +1493,7 @@ mod test {
let relay_selector = new_relay_selector();
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, 0)
.expect("Failed to get WireGuard TCP relay");
let endpoint = result.endpoint.unwrap_wireguard();
assert!(matches!(endpoint.peer.protocol, TransportProtocol::Tcp));
@@ -1570,7 +1508,7 @@ mod test {
const INVALID_UDP_PORTS: [u16; 2] = [80, 443];
for attempt in 0..1000 {
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(!INVALID_UDP_PORTS.contains(&result.endpoint.to_endpoint().address.port()));
assert_eq!(
@@ -1587,7 +1525,7 @@ mod test {
const VALID_TCP_PORTS: [u16; 3] = [80, 443, 5001];
for attempt in 0..1000 {
let result = relay_selector
- .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt, true)
+ .get_tunnel_endpoint(&relay_constraints, BridgeState::Off, attempt)
.expect("Failed to get WireGuard TCP multihop relay");
assert!(VALID_TCP_PORTS.contains(&result.endpoint.to_endpoint().address.port()));
assert_eq!(
@@ -1609,7 +1547,7 @@ mod test {
..RelayConstraints::default()
};
relay_selector
- .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&constraints, BridgeState::Off, 0)
.expect_err("Successfully selected a relay that should be filtered");
constraints.location = Constraint::Only(LocationConstraint::Hostname(
@@ -1619,7 +1557,7 @@ mod test {
));
relay_selector
- .get_tunnel_endpoint(&constraints, BridgeState::Off, 0, true)
+ .get_tunnel_endpoint(&constraints, BridgeState::Off, 0)
.expect_err("Successfully selected a relay that should be filtered");
}
}
diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs
index ec610f63d4..bf3fe710c8 100644
--- a/mullvad-daemon/src/settings.rs
+++ b/mullvad-daemon/src/settings.rs
@@ -3,7 +3,7 @@ use futures::TryFutureExt;
use mullvad_types::{
relay_constraints::{BridgeSettings, BridgeState, RelaySettingsUpdate},
settings::{DnsOptions, Settings},
- wireguard::{RotationInterval, WireguardData},
+ wireguard::RotationInterval,
};
#[cfg(target_os = "windows")]
use std::collections::HashSet;
@@ -191,21 +191,6 @@ impl SettingsPersister {
settings
}
- /// Changes account number to the one given. Also saves the new settings to disk.
- /// The boolean in the Result indicates if the account token changed or not
- pub async fn set_account_token(
- &mut self,
- account_token: Option<String>,
- ) -> Result<bool, Error> {
- let should_save = self.settings.set_account_token(account_token);
- self.update(should_save).await
- }
-
- pub async fn set_wireguard(&mut self, wireguard: Option<WireguardData>) -> Result<bool, Error> {
- let should_save = self.settings.set_wireguard(wireguard);
- self.update(should_save).await
- }
-
pub async fn update_relay_settings(
&mut self,
update: RelaySettingsUpdate,
diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs
deleted file mode 100644
index eb198b858b..0000000000
--- a/mullvad-daemon/src/wireguard.rs
+++ /dev/null
@@ -1,499 +0,0 @@
-use crate::{DaemonEventSender, InternalDaemonEvent};
-use chrono::offset::Utc;
-use mullvad_rpc::{
- availability::ApiAvailabilityHandle,
- rest::{Error as RestError, MullvadRestHandle},
-};
-use mullvad_types::account::AccountToken;
-pub use mullvad_types::wireguard::*;
-use std::{future::Future, pin::Pin, time::Duration};
-
-use futures::future::{abortable, AbortHandle};
-#[cfg(not(target_os = "android"))]
-use talpid_core::future_retry::constant_interval;
-use talpid_core::{
- future_retry::{retry_future, retry_future_n, ExponentialBackoff, Jittered},
- mpsc::Sender,
-};
-
-pub use talpid_types::net::wireguard::{
- ConnectionConfig, PrivateKey, TunnelConfig, TunnelParameters,
-};
-use talpid_types::ErrorExt;
-
-/// How long to wait before starting key rotation
-const ROTATION_START_DELAY: Duration = Duration::from_secs(60 * 3);
-
-/// How often to check whether the key has expired.
-/// A short interval is used in case the computer is ever suspended.
-const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60);
-
-const RETRY_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
-const RETRY_INTERVAL_FACTOR: u32 = 5;
-const RETRY_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-
-#[cfg(not(target_os = "android"))]
-const SHORT_RETRY_INTERVAL: Duration = Duration::ZERO;
-
-const MAX_KEY_REMOVAL_RETRIES: usize = 2;
-
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- #[error(display = "Unexpected HTTP request error")]
- RestError(#[error(source)] mullvad_rpc::rest::Error),
- #[error(display = "API availability check was interrupted")]
- ApiCheckError(#[error(source)] mullvad_rpc::availability::Error),
- #[error(display = "Account already has maximum number of keys")]
- TooManyKeys,
-}
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-pub struct KeyManager {
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- current_job: Option<AbortHandle>,
-
- abort_scheduler_tx: Option<AbortHandle>,
- auto_rotation_interval: RotationInterval,
-}
-
-impl KeyManager {
- pub(crate) fn new(
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- ) -> Self {
- Self {
- daemon_tx,
- availability_handle,
- http_handle,
- current_job: None,
- abort_scheduler_tx: None,
- auto_rotation_interval: RotationInterval::default(),
- }
- }
-
- /// Reset key rotation, cancelling the current one and starting a new one for the specified
- /// account
- pub async fn reset_rotation(&mut self, current_key: PublicKey, account_token: AccountToken) {
- self.run_automatic_rotation(account_token, current_key)
- .await
- }
-
- /// Update automatic key rotation interval
- /// Passing `None` for the interval will cause the default value to be used.
- pub async fn set_rotation_interval(
- &mut self,
- current_key: PublicKey,
- account_token: AccountToken,
- auto_rotation_interval: Option<RotationInterval>,
- ) {
- self.auto_rotation_interval = auto_rotation_interval.unwrap_or_default();
- self.reset_rotation(current_key, account_token).await;
- }
-
- /// Stop current key generation
- pub fn reset(&mut self) {
- if let Some(job) = self.current_job.take() {
- job.abort()
- }
- }
-
- /// Generate a new private key
- pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> {
- self.reset();
- let private_key = PrivateKey::new_from_random();
-
- self.push_future_generator(account, private_key, None)()
- .await
- .map_err(Self::map_rpc_error)
- }
-
- /// Replace a key for an account synchronously
- pub async fn replace_key(
- &mut self,
- account: AccountToken,
- old_key: PublicKey,
- ) -> Result<WireguardData> {
- self.reset();
-
- let new_key = PrivateKey::new_from_random();
- Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await
- }
-
- /// Verifies whether a key is valid or not.
- pub fn verify_wireguard_key(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<bool>> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- async move {
- match rpc.get_wireguard_key(account, &key).await {
- Ok(_) => Ok(true),
- Err(mullvad_rpc::rest::Error::ApiError(status, _code))
- if status == mullvad_rpc::StatusCode::NOT_FOUND =>
- {
- Ok(false)
- }
- Err(err) => Err(Self::map_rpc_error(err)),
- }
- }
- }
-
- /// Removes a key from an account
- #[cfg(not(target_os = "android"))]
- pub fn remove_key(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<()>> {
- self.remove_key_inner(account, key, constant_interval(SHORT_RETRY_INTERVAL), false)
- }
-
- /// Removes a key from an account
- pub fn remove_key_with_backoff(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- ) -> impl Future<Output = Result<()>> {
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
- self.remove_key_inner(account, key, retry_strategy, true)
- }
-
- fn remove_key_inner<D: Iterator<Item = Duration> + 'static>(
- &self,
- account: AccountToken,
- key: talpid_types::net::wireguard::PublicKey,
- retry_strategy: D,
- offline_check: bool,
- ) -> impl Future<Output = Result<()>> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- let api_handle = self.availability_handle.clone();
- let api_handle_2 = api_handle.clone();
- let future = retry_future_n(
- move || {
- let remove_key = rpc.remove_wireguard_key(account.clone(), key.clone());
- let wait_future = api_handle.wait_online();
- async move {
- if offline_check {
- let _ = wait_future.await;
- }
- remove_key.await
- }
- },
- move |result| match result {
- Ok(_) => false,
- Err(error) => Self::should_retry_removal(error, &api_handle_2),
- },
- retry_strategy,
- MAX_KEY_REMOVAL_RETRIES,
- );
- async move { future.await.map_err(Self::map_rpc_error) }
- }
-
- fn should_retry_removal(error: &RestError, api_handle: &ApiAvailabilityHandle) -> bool {
- error.is_network_error() && !api_handle.get_state().is_offline()
- }
-
- fn should_retry(error: &RestError) -> bool {
- if let RestError::ApiError(_status, code) = &error {
- code != mullvad_rpc::INVALID_ACCOUNT && code != mullvad_rpc::KEY_LIMIT_REACHED
- } else {
- true
- }
- }
-
- /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel.
- pub async fn spawn_key_generation_task(
- &mut self,
- account: AccountToken,
- timeout: Option<Duration>,
- ) {
- self.reset();
- let private_key = PrivateKey::new_from_random();
-
- let error_tx = self.daemon_tx.clone();
- let error_account = account.clone();
-
- let mut inner_future_generator =
- self.push_future_generator(account.clone(), private_key, timeout);
-
- let availability_handle = self.availability_handle.clone();
-
- let future_generator = move || {
- let wait_available = availability_handle.wait_background();
- let fut = inner_future_generator();
- let error_tx = error_tx.clone();
- let error_account = error_account.clone();
- async move {
- let error_account_copy = error_account.clone();
- wait_available.await.map_err(|error| {
- let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent((
- error_account_copy,
- Err(Error::ApiCheckError(error)),
- )));
- false
- })?;
- let response = fut.await;
- match response {
- Ok(addresses) => Ok(addresses),
- Err(err) => {
- let should_retry = Self::should_retry(&err);
- let _ = error_tx.send(InternalDaemonEvent::WgKeyEvent((
- error_account,
- Err(Self::map_rpc_error(err)),
- )));
- Err(should_retry)
- }
- }
- }
- };
-
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
-
- let should_retry = move |result: &std::result::Result<_, bool>| -> bool {
- match result {
- Ok(_) => false,
- Err(should_retry) => *should_retry,
- }
- };
-
- let upload_future = retry_future(future_generator, should_retry, retry_strategy);
-
- let (cancellable_upload, abort_handle) = abortable(Box::pin(upload_future));
- let daemon_tx = self.daemon_tx.clone();
- let future = async move {
- match cancellable_upload.await {
- Ok(Ok(wireguard_data)) => {
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account,
- Ok(wireguard_data),
- )));
- }
- Ok(Err(_)) => {}
- Err(_) => {
- log::error!("Key generation cancelled");
- }
- }
- };
-
- tokio::spawn(Box::pin(future));
- self.current_job = Some(abort_handle);
- }
-
- fn push_future_generator(
- &self,
- account: AccountToken,
- private_key: PrivateKey,
- timeout: Option<Duration>,
- ) -> Box<
- dyn FnMut() -> Pin<
- Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send>,
- > + Send,
- > {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone());
- let public_key = private_key.public_key();
-
- let push_future =
- move || -> std::pin::Pin<Box<dyn Future<Output = std::result::Result<WireguardData, RestError>> + Send >> {
- let key = private_key.clone();
- let address_future = rpc
- .push_wg_key(account.clone(), public_key.clone(), timeout);
- Box::pin(async move {
- let addresses = address_future.await?;
- Ok(WireguardData {
- private_key: key,
- addresses,
- created: Utc::now(),
- })
- })
- };
- Box::new(push_future)
- }
-
- async fn replace_key_rpc(
- http_handle: MullvadRestHandle,
- account: AccountToken,
- old_key: PublicKey,
- new_key: PrivateKey,
- ) -> Result<WireguardData> {
- let mut rpc = mullvad_rpc::WireguardKeyProxy::new(http_handle);
- let new_public_key = new_key.public_key();
- let addresses = rpc
- .replace_wg_key(account, old_key.key, new_public_key)
- .await
- .map_err(Self::map_rpc_error)?;
- Ok(WireguardData {
- private_key: new_key,
- addresses,
- created: Utc::now(),
- })
- }
-
- fn map_rpc_error(err: mullvad_rpc::rest::Error) -> Error {
- match &err {
- // TODO: Consider handling the invalid account case too.
- mullvad_rpc::rest::Error::ApiError(status, message)
- if *status == mullvad_rpc::StatusCode::BAD_REQUEST
- && message == mullvad_rpc::KEY_LIMIT_REACHED =>
- {
- Error::TooManyKeys
- }
- _ => Error::RestError(err),
- }
- }
-
- async fn wait_for_key_expiry(key: &PublicKey, rotation_interval_secs: u64) {
- let mut interval = tokio::time::interval(KEY_CHECK_INTERVAL);
- interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
- loop {
- interval.tick().await;
- if (Utc::now().signed_duration_since(key.created)).num_seconds() as u64
- >= rotation_interval_secs
- {
- return;
- }
- }
- }
-
- async fn create_automatic_rotation(
- daemon_tx: DaemonEventSender,
- availability_handle: ApiAvailabilityHandle,
- http_handle: MullvadRestHandle,
- mut public_key: PublicKey,
- rotation_interval_secs: u64,
- account_token: AccountToken,
- ) {
- tokio::time::sleep(ROTATION_START_DELAY).await;
-
- let rotate_key_for_account =
- move |old_key: &PublicKey| -> Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>> {
- let wait_available = availability_handle.wait_background();
- let rotate = Self::rotate_key(
- daemon_tx.clone(),
- http_handle.clone(),
- account_token.clone(),
- old_key.clone(),
- );
- Box::pin(async move {
- wait_available.await?;
- rotate.await
- })
- };
-
- loop {
- Self::wait_for_key_expiry(&public_key, rotation_interval_secs).await;
-
- let rotate_key_for_account_copy = rotate_key_for_account.clone();
- match Self::rotate_key_with_retries(public_key.clone(), rotate_key_for_account_copy)
- .await
- {
- Ok(new_key) => public_key = new_key,
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg(
- "Stopping automatic key rotation due to an error"
- )
- );
- return;
- }
- }
- }
- }
-
- fn rotate_key(
- daemon_tx: DaemonEventSender,
- http_handle: MullvadRestHandle,
- account_token: AccountToken,
- old_key: PublicKey,
- ) -> impl Future<Output = Result<PublicKey>> {
- let new_key = PrivateKey::new_from_random();
- let rpc_result =
- Self::replace_key_rpc(http_handle, account_token.clone(), old_key, new_key);
-
- async move {
- match rpc_result.await {
- Ok(data) => {
- // Update account data
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account_token,
- Ok(data.clone()),
- )));
- Ok(data.get_public_key())
- }
- Err(Error::TooManyKeys) => {
- let _ = daemon_tx.send(InternalDaemonEvent::WgKeyEvent((
- account_token,
- Err(Error::TooManyKeys),
- )));
- Err(Error::TooManyKeys)
- }
- Err(unknown) => Err(unknown),
- }
- }
- }
-
- async fn rotate_key_with_retries<F>(old_key: PublicKey, rotate_key: F) -> Result<PublicKey>
- where
- F: FnMut(&PublicKey) -> std::pin::Pin<Box<dyn Future<Output = Result<PublicKey>> + Send>>
- + Clone
- + 'static,
- {
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(RETRY_INTERVAL_INITIAL, RETRY_INTERVAL_FACTOR)
- .max_delay(RETRY_INTERVAL_MAX),
- );
- let should_retry = move |result: &Result<PublicKey>| -> bool {
- match result {
- Ok(_) => false,
- Err(error) => match error {
- Error::RestError(error) => Self::should_retry(error),
- _ => false,
- },
- }
- };
-
- retry_future(
- move || rotate_key.clone()(&old_key),
- should_retry,
- retry_strategy,
- )
- .await
- }
-
- async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) {
- self.stop_automatic_rotation();
-
- log::debug!("Starting automatic key rotation job");
- // Schedule cancellable series of repeating rotation tasks
- let fut = Self::create_automatic_rotation(
- self.daemon_tx.clone(),
- self.availability_handle.clone(),
- self.http_handle.clone(),
- public_key,
- self.auto_rotation_interval.as_duration().as_secs(),
- account_token,
- );
- let (request, abort_handle) = abortable(Box::pin(fut));
-
- tokio::spawn(request);
- self.abort_scheduler_tx = Some(abort_handle);
- }
-
- fn stop_automatic_rotation(&mut self) {
- if let Some(abort_handle) = self.abort_scheduler_tx.take() {
- log::info!("Stopping automatic key rotation");
- abort_handle.abort();
- }
- }
-}
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index e690557aae..701a2e267e 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -44,19 +44,24 @@ service ManagementService {
// Account management
rpc CreateNewAccount(google.protobuf.Empty) returns (google.protobuf.StringValue) {}
- rpc SetAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc LoginAccount(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc LogoutAccount(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetAccountData(google.protobuf.StringValue) returns (AccountData) {}
rpc GetAccountHistory(google.protobuf.Empty) returns (AccountHistory) {}
rpc ClearAccountHistory(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetWwwAuthToken(google.protobuf.Empty) returns (google.protobuf.StringValue) {}
rpc SubmitVoucher(google.protobuf.StringValue) returns (VoucherSubmission) {}
+ // Device management
+ rpc GetDevice(google.protobuf.Empty) returns (DeviceConfig) {}
+ rpc ListDevices(google.protobuf.StringValue) returns (DeviceList) {}
+ rpc RemoveDevice(DeviceRemoval) returns (google.protobuf.Empty) {}
+
// WireGuard key management
rpc SetWireguardRotationInterval(google.protobuf.Duration) returns (google.protobuf.Empty) {}
rpc ResetWireguardRotationInterval(google.protobuf.Empty) returns (google.protobuf.Empty) {}
- rpc GenerateWireguardKey(google.protobuf.Empty) returns (KeygenEvent) {}
+ rpc RotateWireguardKey(google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {}
- rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {}
// Split tunneling (Linux)
rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {}
@@ -265,16 +270,15 @@ message BridgeState {
}
message Settings {
- string account_token = 1;
- RelaySettings relay_settings = 2;
- BridgeSettings bridge_settings = 3;
- BridgeState bridge_state = 4;
- bool allow_lan = 5;
- bool block_when_disconnected = 6;
- bool auto_connect = 7;
- TunnelOptions tunnel_options = 8;
- bool show_beta_releases = 9;
- SplitTunnelSettings split_tunnel = 10;
+ RelaySettings relay_settings = 1;
+ BridgeSettings bridge_settings = 2;
+ BridgeState bridge_state = 3;
+ bool allow_lan = 4;
+ bool block_when_disconnected = 5;
+ bool auto_connect = 6;
+ TunnelOptions tunnel_options = 7;
+ bool show_beta_releases = 8;
+ SplitTunnelSettings split_tunnel = 9;
}
message SplitTunnelSettings {
@@ -521,10 +525,33 @@ message DaemonEvent {
Settings settings = 2;
RelayList relay_list = 3;
AppVersionInfo version_info = 4;
- KeygenEvent key_event = 5;
+ DeviceEvent device = 5;
}
}
message RelayList {
repeated RelayListCountry countries = 1;
}
+
+message DeviceConfig {
+ string account_token = 1;
+ Device device = 2;
+}
+
+message Device {
+ string id = 1;
+ string name = 2;
+}
+
+message DeviceList {
+ repeated Device devices = 1;
+}
+
+message DeviceRemoval {
+ string account_token = 1;
+ string device_id = 2;
+}
+
+message DeviceEvent {
+ Device device = 1;
+}
diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs
index 5398927569..eb6dbf6a31 100644
--- a/mullvad-management-interface/src/types.rs
+++ b/mullvad-management-interface/src/types.rs
@@ -197,6 +197,43 @@ impl From<mullvad_types::states::TunnelState> for TunnelState {
}
}
+impl From<mullvad_types::device::Device> for Device {
+ fn from(device: mullvad_types::device::Device) -> Self {
+ Device {
+ id: device.id,
+ name: device.name,
+ }
+ }
+}
+
+impl From<mullvad_types::device::DeviceEvent> for DeviceEvent {
+ fn from(event: mullvad_types::device::DeviceEvent) -> Self {
+ DeviceEvent {
+ device: event.0.map(|device| Device::from(device)),
+ }
+ }
+}
+
+impl From<mullvad_types::device::DeviceData> for DeviceConfig {
+ fn from(device: mullvad_types::device::DeviceData) -> Self {
+ DeviceConfig {
+ account_token: device.token,
+ device: Some(Device::from(device.device)),
+ }
+ }
+}
+
+impl From<Vec<mullvad_types::device::Device>> for DeviceList {
+ fn from(devices: Vec<mullvad_types::device::Device>) -> Self {
+ DeviceList {
+ devices: devices
+ .into_iter()
+ .map(|device| Device::from(device))
+ .collect(),
+ }
+ }
+}
+
impl From<mullvad_types::wireguard::KeygenEvent> for KeygenEvent {
fn from(event: mullvad_types::wireguard::KeygenEvent) -> Self {
use keygen_event::KeygenEvent as Event;
@@ -387,7 +424,6 @@ impl From<&mullvad_types::settings::Settings> for Settings {
let split_tunnel = None;
Self {
- account_token: settings.get_account_token().unwrap_or_default(),
relay_settings: Some(RelaySettings::from(settings.get_relay_settings())),
bridge_settings: Some(BridgeSettings::from(settings.bridge_settings.clone())),
bridge_state: Some(BridgeState::from(settings.get_bridge_state())),
diff --git a/mullvad-rpc/src/access.rs b/mullvad-rpc/src/access.rs
new file mode 100644
index 0000000000..b58ceee809
--- /dev/null
+++ b/mullvad-rpc/src/access.rs
@@ -0,0 +1,108 @@
+use crate::{
+ rest,
+ rest::{RequestFactory, RequestServiceHandle},
+};
+use hyper::StatusCode;
+use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
+use std::{
+ collections::HashMap,
+ sync::{Arc, Mutex},
+};
+use talpid_types::ErrorExt;
+
+pub const AUTH_URL_PREFIX: &str = "auth/v1-beta1";
+
+#[derive(Clone)]
+pub struct AccessTokenProxy {
+ service: RequestServiceHandle,
+ factory: RequestFactory,
+ access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
+}
+
+impl AccessTokenProxy {
+ pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
+ Self {
+ service,
+ factory,
+ access_from_account: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ /// Obtain access token for an account, requesting a new one from the API if necessary.
+ pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
+ let existing_token = {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .get(account.as_str())
+ .cloned()
+ };
+ if let Some(access_token) = existing_token {
+ if access_token.is_expired() {
+ log::debug!("Replacing expired access token");
+ return self.request_new_token(account.clone()).await;
+ }
+ log::trace!("Using stored access token");
+ return Ok(access_token.access_token.clone());
+ }
+ self.request_new_token(account.clone()).await
+ }
+
+ /// Remove an access token if the API response calls for it.
+ pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
+ if let Err(rest::Error::ApiError(_status, code)) = response {
+ if code == crate::INVALID_ACCESS_TOKEN {
+ log::debug!("Dropping invalid access token");
+ self.remove_token(account);
+ }
+ }
+ }
+
+ /// Removes a stored access token.
+ fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .remove(account)
+ .map(|v| v.access_token)
+ }
+
+ async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
+ log::debug!("Fetching access token for an account");
+ let access_token = self
+ .fetch_access_token(account.clone())
+ .await
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to obtain access token")
+ );
+ error
+ })?;
+ self.access_from_account
+ .lock()
+ .unwrap()
+ .insert(account, access_token.clone());
+ Ok(access_token.access_token)
+ }
+
+ async fn fetch_access_token(
+ &self,
+ account_token: AccountToken,
+ ) -> Result<AccessTokenData, rest::Error> {
+ #[derive(serde::Serialize)]
+ struct AccessTokenRequest {
+ account_token: String,
+ }
+ let request = AccessTokenRequest { account_token };
+
+ let service = self.service.clone();
+
+ let rest_request = self
+ .factory
+ .post_json(&format!("{}/token", AUTH_URL_PREFIX), &request)?;
+ let response = service.request(rest_request).await?;
+ let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
+ rest::deserialize_body(response).await
+ }
+}
diff --git a/mullvad-rpc/src/lib.rs b/mullvad-rpc/src/lib.rs
index 614aa3bdb6..a49e392320 100644
--- a/mullvad-rpc/src/lib.rs
+++ b/mullvad-rpc/src/lib.rs
@@ -7,6 +7,7 @@ use futures::Stream;
use hyper::Method;
use mullvad_types::{
account::{AccountToken, VoucherSubmission},
+ device::{Device, DeviceId, DeviceName},
version::AppVersion,
};
use proxy::ApiConnectionMode;
@@ -29,6 +30,7 @@ mod tls_stream;
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
+mod access;
mod address_cache;
mod relay_list;
pub use address_cache::AddressCache;
@@ -44,11 +46,17 @@ pub const INVALID_VOUCHER: &str = "INVALID_VOUCHER";
/// Error code returned by the Mullvad API if the account token is invalid.
pub const INVALID_ACCOUNT: &str = "INVALID_ACCOUNT";
-/// Error code returned by the Mullvad API if the account token is missing or invalid.
-pub const INVALID_AUTH: &str = "INVALID_AUTH";
+/// Error code returned by the Mullvad API if the access token is invalid.
+pub const INVALID_ACCESS_TOKEN: &str = "INVALID_ACCESS_TOKEN";
+
+pub const MAX_DEVICES_REACHED: &str = "MAX_DEVICES_REACHED";
+pub const PUBKEY_IN_USE: &str = "PUBKEY_IN_USE";
pub const API_IP_CACHE_FILENAME: &str = "api-ip-address.txt";
+const ACCOUNTS_URL_PREFIX: &str = "accounts/v1-beta1";
+const APP_URL_PREFIX: &str = "app/v1";
+
lazy_static::lazy_static! {
static ref API: ApiEndpoint = ApiEndpoint::get();
}
@@ -257,7 +265,7 @@ impl MullvadRpcRuntime {
self.socket_bypass_tx.clone(),
)
.await;
- let factory = rest::RequestFactory::new(API.host.clone(), Some("app".to_owned()));
+ let factory = rest::RequestFactory::new(API.host.clone(), None);
rest::MullvadRestHandle::new(
service,
@@ -296,7 +304,7 @@ pub struct AccountsProxy {
#[derive(serde::Deserialize)]
struct AccountResponse {
token: AccountToken,
- expires: DateTime<Utc>,
+ expiry: DateTime<Utc>,
}
impl AccountsProxy {
@@ -309,18 +317,21 @@ impl AccountsProxy {
account: AccountToken,
) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> {
let service = self.handle.service.clone();
-
- let response = rest::send_request(
- &self.handle.factory,
- service,
- "/v1/me",
- Method::GET,
- Some(account),
- &[StatusCode::OK],
- );
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
async move {
- let account: AccountResponse = rest::deserialize_body(response.await?).await?;
- Ok(account.expires)
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/accounts/me", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let account: AccountResponse = rest::deserialize_body(response?).await?;
+ Ok(account.expiry)
}
}
@@ -329,7 +340,7 @@ impl AccountsProxy {
let response = rest::send_request(
&self.handle.factory,
service,
- "/v1/accounts",
+ &format!("{}/accounts", ACCOUNTS_URL_PREFIX),
Method::POST,
None,
&[StatusCode::CREATED],
@@ -352,18 +363,23 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
let submission = VoucherSubmission { voucher_code };
- let response = rest::post_request_with_json(
- &self.handle.factory,
- service,
- "/v1/submit-voucher",
- &submission,
- Some(account_token),
- &[StatusCode::OK],
- );
-
- async move { rest::deserialize_body(response.await?).await }
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/submit-voucher", APP_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account_token)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
}
pub fn get_www_auth_token(
@@ -376,22 +392,206 @@ impl AccountsProxy {
}
let service = self.handle.service.clone();
- let response = rest::send_request(
- &self.handle.factory,
- service,
- "/v1/www-auth-token",
- Method::POST,
- Some(account),
- &[StatusCode::OK],
- );
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
async move {
- let response: AuthTokenResponse = rest::deserialize_body(response.await?).await?;
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/www-auth-token", APP_URL_PREFIX),
+ Method::POST,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ let response: AuthTokenResponse = rest::deserialize_body(response?).await?;
Ok(response.auth_token)
}
}
}
+#[derive(Clone)]
+pub struct DevicesProxy {
+ handle: rest::MullvadRestHandle,
+}
+
+#[derive(serde::Deserialize)]
+struct DeviceResponse {
+ id: DeviceId,
+ name: DeviceName,
+ ipv4_address: ipnetwork::Ipv4Network,
+ ipv6_address: ipnetwork::Ipv6Network,
+}
+
+impl DevicesProxy {
+ pub fn new(handle: rest::MullvadRestHandle) -> Self {
+ Self { handle }
+ }
+
+ pub fn create(
+ &self,
+ account: AccountToken,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct DeviceSubmission {
+ pubkey: wireguard::PublicKey,
+ kind: String,
+ }
+
+ let submission = DeviceSubmission {
+ pubkey,
+ // TODO: constant
+ kind: "App".to_string(),
+ };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::POST,
+ &submission,
+ Some((access_proxy, account)),
+ &[StatusCode::CREATED],
+ )
+ .await;
+
+ let response: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ id,
+ name,
+ ipv4_address,
+ ipv6_address,
+ ..
+ } = response;
+
+ Ok((
+ Device { id, name },
+ mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ },
+ ))
+ }
+ }
+
+ pub fn get(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<Device, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn list(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices", ACCOUNTS_URL_PREFIX),
+ Method::GET,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+ rest::deserialize_body(response?).await
+ }
+ }
+
+ pub fn remove(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ ) -> impl Future<Output = Result<(), rest::Error>> {
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+ async move {
+ let response = rest::send_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}", ACCOUNTS_URL_PREFIX, id),
+ Method::DELETE,
+ Some((access_proxy, account)),
+ &[StatusCode::NO_CONTENT],
+ )
+ .await;
+
+ response?;
+ Ok(())
+ }
+ }
+
+ pub fn replace_wg_key(
+ &self,
+ account: AccountToken,
+ id: DeviceId,
+ pubkey: wireguard::PublicKey,
+ ) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>>
+ {
+ #[derive(serde::Serialize)]
+ struct RotateDevicePubkey {
+ pubkey: wireguard::PublicKey,
+ }
+ let req_body = RotateDevicePubkey { pubkey };
+
+ let service = self.handle.service.clone();
+ let factory = self.handle.factory.clone();
+ let access_proxy = self.handle.token_store.clone();
+
+ async move {
+ let response = rest::send_json_request(
+ &factory,
+ service,
+ &format!("{}/devices/{}/pubkey", ACCOUNTS_URL_PREFIX, id),
+ Method::PUT,
+ &req_body,
+ Some((access_proxy, account)),
+ &[StatusCode::OK],
+ )
+ .await;
+
+ let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
+ let DeviceResponse {
+ ipv4_address,
+ ipv6_address,
+ ..
+ } = updated_device;
+ Ok(mullvad_types::wireguard::AssociatedAddresses {
+ ipv4_address,
+ ipv6_address,
+ })
+ }
+ }
+}
+
pub struct ProblemReportProxy {
handle: rest::MullvadRestHandle,
}
@@ -425,10 +625,11 @@ impl ProblemReportProxy {
let service = self.handle.service.clone();
- let request = rest::post_request_with_json(
+ let request = rest::send_json_request(
&self.handle.factory,
service,
- "/v1/problem-report",
+ &format!("{}/problem-report", APP_URL_PREFIX),
+ Method::POST,
&report,
None,
&[StatusCode::NO_CONTENT],
@@ -467,7 +668,7 @@ impl AppVersionProxy {
) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> {
let service = self.handle.service.clone();
- let path = format!("/v1/releases/{}/{}", platform, app_version);
+ let path = format!("{}/releases/{}/{}", APP_URL_PREFIX, platform, app_version);
let request = self.handle.factory.request(&path, Method::GET);
async move {
@@ -508,7 +709,10 @@ impl WireguardKeyProxy {
let service = self.handle.service.clone();
let body = PublishRequest { pubkey: public_key };
- let request = self.handle.factory.post_json(&"/v1/wireguard-keys", &body);
+ let request = self
+ .handle
+ .factory
+ .post_json(&"app/v1/wireguard-keys", &body);
async move {
let mut request = request?;
if let Some(timeout) = timeout {
@@ -538,10 +742,11 @@ impl WireguardKeyProxy {
let service = self.handle.service.clone();
let body = ReplacementRequest { old, new };
- let response = rest::post_request_with_json(
+ let response = rest::send_json_request(
&self.handle.factory,
service,
- &"/v1/replace-wireguard-key",
+ &"app/v1/replace-wireguard-key",
+ Method::POST,
&body,
Some(account_token),
[StatusCode::CREATED, StatusCode::OK].as_slice(),
@@ -562,7 +767,7 @@ impl WireguardKeyProxy {
&self.handle.factory,
service,
&format!(
- "/v1/wireguard-keys/{}",
+ "app/v1/wireguard-keys/{}",
urlencoding::encode(&key.to_base64())
),
Method::GET,
@@ -584,7 +789,7 @@ impl WireguardKeyProxy {
&self.handle.factory,
service,
&format!(
- "/v1/wireguard-keys/{}",
+ "app/v1/wireguard-keys/{}",
urlencoding::encode(&key.to_base64())
),
Method::DELETE,
@@ -614,7 +819,7 @@ impl ApiProxy {
let response = rest::send_request(
&self.handle.factory,
service,
- "/v1/api-addrs",
+ &format!("{}/api-addrs", APP_URL_PREFIX),
Method::GET,
None,
&[StatusCode::OK],
diff --git a/mullvad-rpc/src/relay_list.rs b/mullvad-rpc/src/relay_list.rs
index f1ed2217fd..5a8a01836f 100644
--- a/mullvad-rpc/src/relay_list.rs
+++ b/mullvad-rpc/src/relay_list.rs
@@ -13,7 +13,7 @@ use std::{
time::Duration,
};
-/// Fetches relay list from <https://api.mullvad.net/v1/relays>
+/// Fetches relay list from https://api.mullvad.net/app/v1/relays
#[derive(Clone)]
pub struct RelayListProxy {
handle: rest::MullvadRestHandle,
@@ -33,7 +33,7 @@ impl RelayListProxy {
etag: Option<String>,
) -> impl Future<Output = Result<Option<relay_list::RelayList>, rest::Error>> {
let service = self.handle.service.clone();
- let request = self.handle.factory.request("/v1/relays", Method::GET);
+ let request = self.handle.factory.request("app/v1/relays", Method::GET);
let future = async move {
let mut request = request?;
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 17362cce05..c7e5d02fb1 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -1,6 +1,7 @@
#[cfg(target_os = "android")]
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
+ access::AccessTokenProxy,
address_cache::AddressCache,
availability::ApiAvailabilityHandle,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
@@ -17,6 +18,7 @@ use hyper::{
header::{self, HeaderValue},
Method, Uri,
};
+use mullvad_types::account::AccountToken;
use std::{
future::Future,
str::FromStr,
@@ -302,11 +304,11 @@ impl RestRequest {
})
}
- /// Set the auth header with the following format: `Token $auth`.
+ /// Set the auth header with the following format: `Bearer $auth`.
pub fn set_auth(&mut self, auth: Option<String>) -> Result<()> {
let header = match auth {
Some(auth) => Some(
- HeaderValue::from_str(&format!("Token {}", auth))
+ HeaderValue::from_str(&format!("Bearer {}", auth))
.map_err(Error::InvalidHeaderError)?,
),
None => None,
@@ -399,7 +401,16 @@ impl RequestFactory {
}
pub fn post_json<S: serde::Serialize>(&self, path: &str, body: &S) -> Result<RestRequest> {
- let mut request = self.hyper_request(path, Method::POST)?;
+ self.json_request(Method::POST, path, body)
+ }
+
+ fn json_request<S: serde::Serialize>(
+ &self,
+ method: Method,
+ path: &str,
+ body: &S,
+ ) -> Result<RestRequest> {
+ let mut request = self.hyper_request(path, method)?;
let json_body = serde_json::to_string(&body)?;
let body_length = json_body.as_bytes().len() as u64;
@@ -468,33 +479,52 @@ pub fn send_request(
service: RequestServiceHandle,
uri: &str,
method: Method,
- auth: Option<String>,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
let request = factory.request(uri, method);
async move {
let mut request = request?;
- request.set_auth(auth)?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
}
}
-pub fn post_request_with_json<B: serde::Serialize>(
+pub fn send_json_request<B: serde::Serialize>(
factory: &RequestFactory,
service: RequestServiceHandle,
uri: &str,
+ method: Method,
body: &B,
- auth: Option<String>,
+ auth: Option<(AccessTokenProxy, AccountToken)>,
expected_statuses: &'static [hyper::StatusCode],
) -> impl Future<Output = Result<Response>> {
- let request = factory.post_json(uri, body);
+ let request = factory.json_request(method, uri, body);
async move {
let mut request = request?;
- request.set_auth(auth)?;
+ if let Some((store, account)) = &auth {
+ let access_token = store.get_token(&account).await?;
+ request.set_auth(Some(access_token))?;
+ }
let response = service.request(request).await?;
- parse_rest_response(response, expected_statuses).await
+ let result = parse_rest_response(response, expected_statuses).await;
+
+ if let Some((store, account)) = &auth {
+ store.check_response(&account, &result);
+ }
+
+ result
}
}
@@ -554,6 +584,7 @@ pub struct MullvadRestHandle {
pub(crate) service: RequestServiceHandle,
pub factory: RequestFactory,
availability: ApiAvailabilityHandle,
+ pub token_store: AccessTokenProxy,
}
impl MullvadRestHandle {
@@ -563,10 +594,13 @@ impl MullvadRestHandle {
address_cache: AddressCache,
availability: ApiAvailabilityHandle,
) -> Self {
+ let token_store = AccessTokenProxy::new(service.clone(), factory.clone());
+
let handle = Self {
service,
factory,
availability,
+ token_store,
};
if !super::API.disable_address_cache {
handle.spawn_api_address_fetcher(address_cache);
diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs
index e65b1278f8..37061c8854 100644
--- a/mullvad-setup/src/main.rs
+++ b/mullvad-setup/src/main.rs
@@ -72,8 +72,11 @@ pub enum Error {
#[error(display = "Failed to obtain cache directory path")]
CachePathError(#[error(source)] mullvad_paths::Error),
- #[error(display = "Failed to update the settings")]
- SettingsError(#[error(source)] mullvad_daemon::settings::Error),
+ #[error(display = "Failed to read the device cache")]
+ ReadDeviceCacheError(#[error(source)] mullvad_daemon::device::Error),
+
+ #[error(display = "Failed to write the device cache")]
+ WriteDeviceCacheError(#[error(source)] mullvad_daemon::device::Error),
#[error(display = "Cannot parse the version string")]
ParseVersionStringError,
@@ -161,41 +164,40 @@ async fn reset_firewall() -> Result<(), Error> {
async fn remove_wireguard_key() -> Result<(), Error> {
let (cache_path, settings_path) = get_paths()?;
- let mut settings = mullvad_daemon::settings::SettingsPersister::load(&settings_path).await;
+ let (mut cacher, data) = mullvad_daemon::device::DeviceCacher::new(&settings_path)
+ .await
+ .map_err(Error::ReadDeviceCacheError)?;
+ if let Some(device) = data {
+ let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false)
+ .await
+ .map_err(Error::RpcInitializationError)?;
- if let Some(token) = settings.get_account_token() {
- if let Some(wg_data) = settings.get_wireguard() {
- let rpc_runtime = MullvadRpcRuntime::with_cache(&cache_path, false)
- .await
- .map_err(Error::RpcInitializationError)?;
- let mut key_proxy = mullvad_rpc::WireguardKeyProxy::new(
- rpc_runtime
- .mullvad_rest_handle(
- ApiConnectionMode::try_from_cache(&cache_path)
- .await
- .into_repeat(),
- |_| async { true },
- )
- .await,
- );
- retry_future_n(
- move || {
- key_proxy.remove_wireguard_key(token.clone(), wg_data.private_key.public_key())
- },
- move |result| match result {
- Err(error) => error.is_network_error(),
- _ => false,
- },
- constant_interval(KEY_RETRY_INTERVAL),
- KEY_RETRY_MAX_RETRIES,
- )
+ let proxy = mullvad_rpc::DevicesProxy::new(
+ rpc_runtime
+ .mullvad_rest_handle(
+ ApiConnectionMode::try_from_cache(&cache_path)
+ .await
+ .into_repeat(),
+ |_| async { true },
+ )
+ .await,
+ );
+ retry_future_n(
+ move || proxy.remove(device.token.clone(), device.device.id.clone()),
+ move |result| match result {
+ Err(error) => error.is_network_error(),
+ _ => false,
+ },
+ constant_interval(KEY_RETRY_INTERVAL),
+ KEY_RETRY_MAX_RETRIES,
+ )
+ .await
+ .map_err(Error::RemoveKeyError)?;
+
+ cacher
+ .write(None)
.await
- .map_err(Error::RemoveKeyError)?;
- settings
- .set_wireguard(None)
- .await
- .map_err(Error::SettingsError)?;
- }
+ .map_err(Error::WriteDeviceCacheError)?;
}
Ok(())
diff --git a/mullvad-types/src/account.rs b/mullvad-types/src/account.rs
index b5479640e6..16f6a963f2 100644
--- a/mullvad-types/src/account.rs
+++ b/mullvad-types/src/account.rs
@@ -3,9 +3,12 @@ use chrono::{offset::Utc, DateTime};
use jnix::IntoJava;
use serde::{Deserialize, Serialize};
-/// Identifier used to authenticate or identify a Mullvad account.
+/// Identifier used to identify a Mullvad account.
pub type AccountToken = String;
+/// Identifier used to authenticate a Mullvad account.
+pub type AccessToken = String;
+
/// Account expiration info returned by the API via `/v1/me`.
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[cfg_attr(target_os = "android", derive(IntoJava))]
@@ -18,7 +21,7 @@ pub struct AccountData {
impl AccountData {
/// Return true if the account has no time left.
pub fn is_expired(&self) -> bool {
- self.expiry >= Utc::now()
+ Utc::now() >= self.expiry
}
}
@@ -35,3 +38,17 @@ pub struct VoucherSubmission {
#[cfg_attr(target_os = "android", jnix(map = "|expiry| expiry.to_string()"))]
pub new_expiry: DateTime<Utc>,
}
+
+/// Token used for authentication in the API.
+#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
+pub struct AccessTokenData {
+ pub access_token: AccessToken,
+ pub expiry: DateTime<Utc>,
+}
+
+impl AccessTokenData {
+ /// Return true if the token is no longer valid.
+ pub fn is_expired(&self) -> bool {
+ Utc::now() >= self.expiry
+ }
+}
diff --git a/mullvad-types/src/device.rs b/mullvad-types/src/device.rs
new file mode 100644
index 0000000000..e40a3d7080
--- /dev/null
+++ b/mullvad-types/src/device.rs
@@ -0,0 +1,37 @@
+use crate::{account::AccountToken, wireguard};
+use serde::{Deserialize, Serialize};
+use talpid_types::net::wireguard::PublicKey;
+
+/// UUID for a device.
+pub type DeviceId = String;
+
+/// Human-readable device identifier.
+pub type DeviceName = String;
+
+/// Contains data for a device returned by the API.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+pub struct Device {
+ pub id: DeviceId,
+ pub name: DeviceName,
+ pub pubkey: PublicKey,
+}
+
+impl Eq for Device {}
+
+/// A complete device configuration.
+#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
+pub struct DeviceData {
+ pub token: AccountToken,
+ pub device: Device,
+ pub wg_data: wireguard::WireguardData,
+}
+
+impl From<DeviceData> for Device {
+ fn from(data: DeviceData) -> Device {
+ data.device
+ }
+}
+
+/// Emitted when logging in or out of an account, or when the device changes.
+#[derive(Clone, Debug)]
+pub struct DeviceEvent(pub Option<Device>);
diff --git a/mullvad-types/src/lib.rs b/mullvad-types/src/lib.rs
index e93ab2f606..6d636aceb5 100644
--- a/mullvad-types/src/lib.rs
+++ b/mullvad-types/src/lib.rs
@@ -2,6 +2,7 @@
pub mod account;
pub mod auth_failed;
+pub mod device;
pub mod endpoint;
pub mod location;
pub mod relay_constraints;
diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs
index 26a24202a5..63ccb480a2 100644
--- a/mullvad-types/src/settings/mod.rs
+++ b/mullvad-types/src/settings/mod.rs
@@ -61,9 +61,6 @@ impl Serialize for SettingsVersion {
#[cfg_attr(target_os = "android", derive(IntoJava))]
#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))]
pub struct Settings {
- account_token: Option<String>,
- #[cfg_attr(target_os = "android", jnix(skip))]
- wireguard: Option<wireguard::WireguardData>,
relay_settings: RelaySettings,
#[cfg_attr(target_os = "android", jnix(skip))]
pub bridge_settings: BridgeSettings,
@@ -102,8 +99,6 @@ pub struct SplitTunnelSettings {
impl Default for Settings {
fn default() -> Self {
Settings {
- account_token: None,
- wireguard: None,
relay_settings: RelaySettings::Normal(RelayConstraints {
location: Constraint::Only(LocationConstraint::Country("se".to_owned())),
..Default::default()
@@ -123,45 +118,6 @@ impl Default for Settings {
}
impl Settings {
- pub fn get_account_token(&self) -> Option<String> {
- self.account_token.clone()
- }
-
- /// Changes account number to the one given. Also saves the new settings to disk.
- /// The boolean in the Result indicates if the account token changed or not
- pub fn set_account_token(&mut self, mut account_token: Option<String>) -> bool {
- if account_token.as_ref().map(String::len) == Some(0) {
- log::debug!("Setting empty account token is treated as unsetting it");
- account_token = None;
- }
- if account_token != self.account_token {
- if account_token.is_none() {
- log::info!("Unsetting account token");
- } else if self.account_token.is_none() {
- log::info!("Setting account token");
- } else {
- log::info!("Changing account token")
- }
- self.account_token = account_token;
- true
- } else {
- false
- }
- }
-
- pub fn get_wireguard(&self) -> Option<wireguard::WireguardData> {
- self.wireguard.clone()
- }
-
- pub fn set_wireguard(&mut self, wireguard: Option<wireguard::WireguardData>) -> bool {
- if wireguard != self.wireguard {
- self.wireguard = wireguard;
- true
- } else {
- false
- }
- }
-
pub fn get_relay_settings(&self) -> RelaySettings {
self.relay_settings.clone()
}