summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-api/src/rest.rs2
-rw-r--r--mullvad-daemon/src/device.rs1136
-rw-r--r--mullvad-daemon/src/device/api.rs126
-rw-r--r--mullvad-daemon/src/device/mod.rs807
-rw-r--r--mullvad-daemon/src/device/service.rs424
-rw-r--r--mullvad-daemon/src/lib.rs10
6 files changed, 1363 insertions, 1142 deletions
diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs
index c3f48b3e8c..4297bd92ed 100644
--- a/mullvad-api/src/rest.rs
+++ b/mullvad-api/src/rest.rs
@@ -244,7 +244,7 @@ pub struct RequestServiceHandle {
impl RequestServiceHandle {
/// Resets the corresponding RequestService, dropping all in-flight requests.
- pub async fn reset(&self) {
+ pub fn reset(&self) {
let _ = self.tx.unbounded_send(RequestCommand::Reset);
}
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs
deleted file mode 100644
index 23ce331502..0000000000
--- a/mullvad-daemon/src/device.rs
+++ /dev/null
@@ -1,1136 +0,0 @@
-use chrono::{DateTime, Utc};
-use futures::{
- channel::{mpsc, oneshot},
- future::{abortable, AbortHandle},
- stream::StreamExt,
-};
-use mullvad_api::{
- availability::ApiAvailabilityHandle,
- rest::{self, Error as RestError, MullvadRestHandle},
- AccountsProxy, DevicesProxy,
-};
-use mullvad_types::{
- account::{AccountToken, VoucherSubmission},
- device::{Device, DeviceData, DeviceEvent, DeviceId},
- wireguard::{RotationInterval, WireguardData},
-};
-use std::{
- future::Future,
- path::Path,
- sync::{
- atomic::{AtomicBool, Ordering},
- Arc,
- },
- time::{Duration, SystemTime},
-};
-use talpid_core::{
- future_retry::{constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered},
- mpsc::Sender,
-};
-use talpid_types::{
- net::{wireguard::PrivateKey, TunnelType},
- tunnel::TunnelStateTransition,
- ErrorExt,
-};
-use tokio::{
- fs,
- io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
-};
-
-/// How often to check whether the key has expired.
-/// A short interval is used in case the computer is ever suspended.
-const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(5 * 60);
-
-/// File that used to store account and device data.
-const DEVICE_CACHE_FILENAME: &str = "device.json";
-
-const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
-const RETRY_ACTION_MAX_RETRIES: usize = 2;
-
-const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
-const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5;
-const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-
-/// How long to keep the known status for [AccountManagerHandle::validate_device].
-const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10);
-
-/// How long to wait on logout (device removal) before letting it continue as a background task.
-const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2);
-
-/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts
-/// to set up a WireGuard tunnel.
-const WG_DEVICE_CHECK_THRESHOLD: usize = 3;
-
-#[derive(err_derive::Error, Debug)]
-pub enum Error {
- #[error(display = "The account already has a maximum number of devices")]
- MaxDevicesReached,
- #[error(display = "No device is set")]
- NoDevice,
- #[error(display = "Device not found")]
- InvalidDevice,
- #[error(display = "Invalid account")]
- InvalidAccount,
- #[error(display = "Failed to read or write device cache")]
- DeviceIoError(#[error(source)] io::Error),
- #[error(display = "Failed parse device cache")]
- ParseDeviceCache(#[error(source)] serde_json::Error),
- #[error(display = "Unexpected HTTP request error")]
- OtherRestError(#[error(source)] rest::Error),
- #[error(display = "The task was aborted")]
- Cancelled,
- #[error(display = "The account manager is down")]
- AccountManagerDown,
-}
-
-#[derive(Clone)]
-pub(crate) enum InnerDeviceEvent {
- /// The device was removed due to user (or daemon) action.
- Logout,
- /// Logged in to a new device.
- Login(DeviceData),
- /// The device was updated remotely, but not its key.
- Updated(DeviceData),
- /// The key was rotated.
- RotatedKey(DeviceData),
- /// Device was removed because it was not found remotely.
- Revoked,
-}
-
-impl From<InnerDeviceEvent> for DeviceEvent {
- fn from(event: InnerDeviceEvent) -> DeviceEvent {
- match event {
- InnerDeviceEvent::Logout => DeviceEvent::revoke(false),
- InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false),
- InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true),
- InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false),
- InnerDeviceEvent::Revoked => DeviceEvent::revoke(true),
- }
- }
-}
-
-impl InnerDeviceEvent {
- fn data(&self) -> Option<&DeviceData> {
- match self {
- InnerDeviceEvent::Login(data) => Some(&data),
- InnerDeviceEvent::Updated(data) => Some(&data),
- InnerDeviceEvent::RotatedKey(data) => Some(&data),
- InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
- }
- }
-
- fn into_data(self) -> Option<DeviceData> {
- match self {
- InnerDeviceEvent::Login(data) => Some(data),
- InnerDeviceEvent::Updated(data) => Some(data),
- InnerDeviceEvent::RotatedKey(data) => Some(data),
- InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
- }
- }
-}
-
-impl Error {
- pub fn is_network_error(&self) -> bool {
- if let Error::OtherRestError(error) = self {
- error.is_network_error()
- } else {
- false
- }
- }
-}
-
-pub enum ValidationResult {
- /// The device and key were valid.
- Valid,
- /// The device was valid but the key was replaced
- RotatedKey,
- /// The device was valid but one or more fields, such as ports, were replaced
- Updated,
- /// The device was not found remotely and was removed from the cache.
- Removed,
-}
-
-type ResponseTx<T> = oneshot::Sender<Result<T, Error>>;
-
-enum AccountManagerCommand {
- Login(AccountToken, ResponseTx<()>),
- Logout(ResponseTx<()>),
- SetData(DeviceData, ResponseTx<()>),
- GetData(ResponseTx<Option<DeviceData>>),
- RotateKey(ResponseTx<()>),
- SetRotationInterval(RotationInterval, ResponseTx<()>),
- GetRotationInterval(ResponseTx<RotationInterval>),
- ValidateDevice(ResponseTx<ValidationResult>),
- ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>),
- Shutdown(oneshot::Sender<()>),
-}
-
-#[derive(Clone)]
-pub(crate) struct AccountManagerHandle {
- cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>,
- pub account_service: AccountService,
- pub device_service: DeviceService,
-}
-
-impl AccountManagerHandle {
- pub async fn login(&self, token: AccountToken) -> Result<(), Error> {
- self.send_command(|tx| AccountManagerCommand::Login(token, tx))
- .await
- }
-
- pub async fn logout(&self) -> Result<(), Error> {
- self.send_command(|tx| AccountManagerCommand::Logout(tx))
- .await
- }
-
- pub async fn set(&self, data: DeviceData) -> Result<(), Error> {
- self.send_command(|tx| AccountManagerCommand::SetData(data, tx))
- .await
- }
-
- pub async fn data(&self) -> Result<Option<DeviceData>, Error> {
- self.send_command(|tx| AccountManagerCommand::GetData(tx))
- .await
- }
-
- pub async fn rotate_key(&self) -> Result<(), Error> {
- self.send_command(|tx| AccountManagerCommand::RotateKey(tx))
- .await
- }
-
- pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> {
- self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx))
- .await
- }
-
- pub async fn rotation_interval(&self) -> Result<RotationInterval, Error> {
- self.send_command(|tx| AccountManagerCommand::GetRotationInterval(tx))
- .await
- }
-
- pub async fn validate_device(&self) -> Result<ValidationResult, Error> {
- self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx))
- .await
- }
-
- pub async fn receive_events(
- &self,
- events_tx: impl Sender<InnerDeviceEvent> + Send + 'static,
- ) -> Result<(), Error> {
- self.send_command(|tx| {
- AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx)
- })
- .await
- }
-
- pub async fn shutdown(self) {
- let (tx, rx) = oneshot::channel();
- let _ = self
- .cmd_tx
- .unbounded_send(AccountManagerCommand::Shutdown(tx));
- let _ = rx.await;
- }
-
- async fn send_command<T>(
- &self,
- make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand,
- ) -> Result<T, Error> {
- let (tx, rx) = oneshot::channel();
- self.cmd_tx
- .unbounded_send(make_cmd(tx))
- .map_err(|_| Error::AccountManagerDown)?;
- rx.await.map_err(|_| Error::AccountManagerDown)?
- }
-}
-
-pub(crate) struct AccountManager {
- cacher: DeviceCacher,
- device_service: DeviceService,
- data: Option<DeviceData>,
- rotation_interval: RotationInterval,
- listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>,
- last_validation: Option<SystemTime>,
-}
-
-impl AccountManager {
- pub async fn spawn(
- rest_handle: rest::MullvadRestHandle,
- api_availability: ApiAvailabilityHandle,
- settings_dir: &Path,
- initial_rotation_interval: RotationInterval,
- ) -> Result<AccountManagerHandle, Error> {
- let (cacher, data) = DeviceCacher::new(settings_dir).await?;
- let token = data.as_ref().map(|state| state.token.clone());
- let account_service =
- spawn_account_service(rest_handle.clone(), token, api_availability.clone());
-
- let (cmd_tx, cmd_rx) = mpsc::unbounded();
-
- let device_service = DeviceService::new(rest_handle, api_availability);
- let manager = AccountManager {
- cacher,
- device_service: device_service.clone(),
- data,
- rotation_interval: initial_rotation_interval,
- listeners: vec![],
- last_validation: None,
- };
-
- tokio::spawn(manager.run(cmd_rx));
- let handle = AccountManagerHandle {
- cmd_tx,
- account_service,
- device_service,
- };
- KeyUpdater::spawn(handle.clone()).await?;
- Ok(handle)
- }
-
- async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) {
- let mut shutdown_tx = None;
- while let Some(cmd) = cmd_rx.next().await {
- match cmd {
- AccountManagerCommand::Shutdown(tx) => {
- shutdown_tx = Some(tx);
- break;
- }
- other => self.service_command(other).await,
- }
- }
- self.shutdown().await;
- if let Some(tx) = shutdown_tx {
- let _ = tx.send(());
- }
- log::debug!("Account manager has stopped");
- }
-
- async fn service_command(&mut self, cmd: AccountManagerCommand) {
- match cmd {
- AccountManagerCommand::Login(token, tx) => {
- let _ = tx.send(self.login(token).await);
- }
- AccountManagerCommand::Logout(tx) => {
- let _ = tx.send(self.logout().await);
- }
- AccountManagerCommand::SetData(data, tx) => {
- let _ = tx.send(self.set(InnerDeviceEvent::Login(data)).await);
- }
- AccountManagerCommand::GetData(tx) => {
- let _ = tx.send(Ok(self.data.clone()));
- }
- AccountManagerCommand::RotateKey(tx) => {
- let _ = tx.send(self.rotate_key().await);
- }
- AccountManagerCommand::SetRotationInterval(interval, tx) => {
- self.rotation_interval = interval;
- let _ = tx.send(Ok(()));
- }
- AccountManagerCommand::GetRotationInterval(tx) => {
- let _ = tx.send(Ok(self.rotation_interval));
- }
- AccountManagerCommand::ValidateDevice(tx) => {
- let _ = tx.send(self.validate_device().await);
- }
- AccountManagerCommand::ReceiveEvents(events_tx, tx) => {
- let _ = tx.send(Ok(self.listeners.push(events_tx)));
- }
- AccountManagerCommand::Shutdown(_) => unreachable!("shutdown is handled earlier"),
- }
- }
-
- async fn login(&mut self, token: AccountToken) -> Result<(), Error> {
- let data = self.device_service.generate_for_account(token).await?;
- self.set(InnerDeviceEvent::Login(data)).await?;
- Ok(())
- }
-
- async fn logout(&mut self) -> Result<(), Error> {
- if self.data.is_some() {
- self.cacher.write(None).await?;
- let _ = tokio::time::timeout(LOGOUT_TIMEOUT, self.logout_inner()).await;
-
- let event = InnerDeviceEvent::Logout;
- self.listeners
- .retain(|listener| listener.send(event.clone()).is_ok());
- }
- Ok(())
- }
-
- fn logout_inner(&mut self) -> tokio::task::JoinHandle<()> {
- let prev_data = self.data.take();
- let service = self.device_service.clone();
-
- tokio::spawn(async move {
- if let Some(data) = prev_data {
- if let Err(error) = service
- .remove_device_with_backoff(data.token, data.device.id)
- .await
- {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to remove a previous device")
- );
- }
- }
- })
- }
-
- async fn set(&mut self, event: InnerDeviceEvent) -> Result<(), Error> {
- let data = event.data();
- if data == self.data.as_ref() {
- return Ok(());
- }
-
- self.cacher.write(data).await?;
- self.last_validation = None;
-
- if self
- .data
- .as_ref()
- .map(|current| data.as_ref().map(|d| &d.device.id) != Some(&current.device.id))
- .unwrap_or(false)
- {
- // Remove the existing device if its ID differs. Otherwise, only update
- // the data.
- self.logout_inner();
- }
-
- self.data = data.cloned();
-
- self.listeners
- .retain(|listener| listener.send(event.clone()).is_ok());
-
- Ok(())
- }
-
- async fn rotate_key(&mut self) -> Result<(), Error> {
- // TODO: Update all data opportunistically?
- let data = self.data.as_ref().ok_or(Error::NoDevice)?;
-
- let wg_data = self
- .device_service
- .rotate_key(data.token.clone(), data.device.id.clone())
- .await?;
-
- // Copy the data to keep a predictable state if an error occurs.
- let mut new_data = data.clone();
- new_data.device.pubkey = wg_data.private_key.public_key();
- new_data.wg_data = wg_data;
- self.set(InnerDeviceEvent::RotatedKey(new_data)).await
- }
-
- /// Check if the device is valid for the account, and yank it if it no longer exists.
- /// This also updates any associated data and returns whether it changed.
- async fn validate_device(&mut self) -> Result<ValidationResult, Error> {
- log::debug!("Checking whether the device is still valid");
-
- if let Some(result) = self.cached_validation() {
- log::debug!("The current device is still valid");
- return Ok(result);
- }
-
- let data = self.data.as_ref().ok_or(Error::NoDevice)?;
-
- match self
- .device_service
- .get(data.token.clone(), data.device.id.clone())
- .await
- {
- Ok(device) => {
- if device.pubkey == data.device.pubkey {
- if device == data.device {
- log::debug!("The current device is still valid");
- Ok(ValidationResult::Valid)
- } else {
- log::debug!("Updating data for the current device");
- // Copy the data to keep a predictable state if an error occurs.
- let new_data = DeviceData {
- device,
- ..data.clone()
- };
- self.set(InnerDeviceEvent::Updated(new_data)).await?;
- Ok(ValidationResult::Updated)
- }
- } else {
- log::debug!("Rotating invalid WireGuard key");
- self.rotate_key().await?;
- Ok(ValidationResult::RotatedKey)
- }
- }
- Err(Error::InvalidAccount) | Err(Error::InvalidDevice) => {
- log::debug!("The current device is no longer valid for this account");
-
- self.cacher.write(None).await?;
- self.data = None;
-
- let event = InnerDeviceEvent::Revoked;
- self.listeners
- .retain(|listener| listener.send(event.clone()).is_ok());
-
- Ok(ValidationResult::Removed)
- }
- Err(error) => Err(error),
- }
- }
-
- fn cached_validation(&mut self) -> Option<ValidationResult> {
- if self.data.is_none() {
- return None;
- }
-
- let now = SystemTime::now();
-
- let elapsed = self
- .last_validation
- .and_then(|last_check| now.duration_since(last_check).ok())
- .unwrap_or(VALIDITY_CACHE_TIMEOUT);
-
- if elapsed >= VALIDITY_CACHE_TIMEOUT {
- self.last_validation = Some(now);
- return None;
- }
-
- Some(ValidationResult::Valid)
- }
-
- async fn shutdown(self) {
- self.cacher.finalize().await;
- }
-}
-
-struct KeyUpdater {
- handle: AccountManagerHandle,
- rx: mpsc::UnboundedReceiver<InnerDeviceEvent>,
- data: Option<DeviceData>,
-}
-
-impl KeyUpdater {
- async fn spawn(handle: AccountManagerHandle) -> Result<(), Error> {
- let (tx, rx) = mpsc::unbounded();
- handle.receive_events(tx).await?;
- let data = handle.data().await?;
- let mut key_rotator = KeyUpdater { handle, rx, data };
-
- tokio::spawn(async move {
- loop {
- tokio::time::sleep(KEY_CHECK_INTERVAL).await;
-
- if let Err(error) = key_rotator.check_key_validity().await {
- if let Error::AccountManagerDown = error {
- break;
- }
- log::error!(
- "{}",
- error.display_chain_with_msg("Stopping key rotation task due to an error")
- );
- break;
- }
- }
- log::debug!("Stopping key updater");
- });
-
- Ok(())
- }
-
- async fn check_key_validity(&mut self) -> Result<(), Error> {
- let rotation_interval = self.handle.rotation_interval().await?;
- let data = self.wait_for_data().await?;
-
- if (chrono::Utc::now()
- .signed_duration_since(data.wg_data.created)
- .num_seconds() as u64)
- < rotation_interval.as_duration().as_secs()
- {
- return Ok(());
- }
-
- let mut data = data.clone();
-
- let rotation_fut = self
- .handle
- .device_service
- .rotate_key_with_backoff(data.token.clone(), data.device.id.clone());
-
- match futures::future::select(Box::pin(rotation_fut), self.rx.next()).await {
- futures::future::Either::Left((Ok(wg_data), _)) => {
- log::debug!("Rotating WireGuard key");
- data.device.pubkey = wg_data.private_key.public_key();
- data.wg_data = wg_data;
- self.handle.set(data).await?;
- }
- futures::future::Either::Left((Err(error), _)) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Stopping key rotation due to an error")
- );
-
- // Forget the current device. Key rotation will restart when
- // it is updated in any way.
- self.data = None;
- }
- futures::future::Either::Right((event, _)) => {
- // Abort key rotation if the device changed
- if let Some(event) = event {
- self.data = event.into_data();
- } else {
- return Err(Error::AccountManagerDown);
- }
- }
- }
-
- Ok(())
- }
-
- async fn wait_for_data(&mut self) -> Result<&DeviceData, Error> {
- while let Ok(item) = self.rx.try_next() {
- match item {
- Some(event) => {
- self.data = event.into_data();
- }
- None => return Err(Error::AccountManagerDown),
- }
- }
-
- match self.data {
- Some(ref data) => Ok(data),
- None => loop {
- let event = self.rx.next().await;
- match event {
- Some(event) => {
- if let Some(data) = event.into_data() {
- self.data = Some(data);
- break Ok(self.data.as_ref().unwrap());
- }
- }
- None => break Err(Error::AccountManagerDown),
- }
- },
- }
- }
-}
-
-#[derive(Clone)]
-pub struct DeviceService {
- api_availability: ApiAvailabilityHandle,
- proxy: DevicesProxy,
-}
-
-impl DeviceService {
- pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self {
- Self {
- proxy: DevicesProxy::new(handle),
- api_availability,
- }
- }
-
- /// Generate a new device for a given token
- pub async fn generate_for_account(&self, token: AccountToken) -> Result<DeviceData, Error> {
- let private_key = PrivateKey::new_from_random();
- let pubkey = private_key.public_key();
-
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let token_copy = token.clone();
- let (device, addresses) = retry_future_n(
- move || proxy.create(token_copy.clone(), pubkey.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await
- .map_err(map_rest_error)?;
-
- Ok(DeviceData {
- token,
- device,
- wg_data: WireguardData {
- private_key,
- addresses,
- created: Utc::now(),
- },
- })
- }
-
- pub async fn generate_for_account_with_backoff(
- &self,
- token: AccountToken,
- ) -> Result<DeviceData, Error> {
- let private_key = PrivateKey::new_from_random();
- let pubkey = private_key.public_key();
-
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let token_copy = token.clone();
- let (device, addresses) = retry_future(
- move || api_handle.when_online(proxy.create(token_copy.clone(), pubkey.clone())),
- should_retry_backoff,
- retry_strategy(),
- )
- .await
- .map_err(map_rest_error)?;
-
- Ok(DeviceData {
- token,
- device,
- wg_data: WireguardData {
- private_key,
- addresses,
- created: Utc::now(),
- },
- })
- }
-
- pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.remove(token.clone(), device.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await
- .map_err(map_rest_error)?;
- Ok(())
- }
-
- pub async fn remove_device_with_backoff(
- &self,
- token: AccountToken,
- device: DeviceId,
- ) -> Result<(), Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
-
- let retry_strategy = Jittered::jitter(
- ExponentialBackoff::new(
- RETRY_BACKOFF_INTERVAL_INITIAL,
- RETRY_BACKOFF_INTERVAL_FACTOR,
- ), // Not setting a maximum interval
- );
-
- retry_future(
- // NOTE: Not honoring "paused" state, because the account may have no time on it.
- move || api_handle.when_online(proxy.remove(token.clone(), device.clone())),
- should_retry_backoff,
- retry_strategy,
- )
- .await
- .map_err(map_rest_error)?;
-
- Ok(())
- }
-
- pub async fn rotate_key(
- &self,
- token: AccountToken,
- device: DeviceId,
- ) -> Result<WireguardData, Error> {
- let private_key = PrivateKey::new_from_random();
-
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let pubkey = private_key.public_key();
- let addresses = retry_future_n(
- move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await
- .map_err(map_rest_error)?;
-
- Ok(WireguardData {
- private_key,
- addresses,
- created: Utc::now(),
- })
- }
-
- pub async fn rotate_key_with_backoff(
- &self,
- token: AccountToken,
- device: DeviceId,
- ) -> Result<WireguardData, Error> {
- let private_key = PrivateKey::new_from_random();
-
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let pubkey = private_key.public_key();
-
- let addresses = retry_future(
- move || {
- api_handle.when_bg_resumes(proxy.replace_wg_key(
- token.clone(),
- device.clone(),
- pubkey.clone(),
- ))
- },
- should_retry_backoff,
- retry_strategy(),
- )
- .await
- .map_err(map_rest_error)?;
-
- Ok(WireguardData {
- private_key,
- addresses,
- created: Utc::now(),
- })
- }
-
- pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.list(token.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await
- .map_err(map_rest_error)
- }
-
- pub async fn list_devices_with_backoff(
- &self,
- token: AccountToken,
- ) -> Result<Vec<Device>, Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
-
- retry_future(
- move || api_handle.when_online(proxy.list(token.clone())),
- should_retry_backoff,
- retry_strategy(),
- )
- .await
- .map_err(map_rest_error)
- }
-
- pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result<Device, Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.get(token.clone(), device.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await
- .map_err(map_rest_error)
- }
-}
-
-pub struct DeviceCacher {
- file: io::BufWriter<fs::File>,
- path: std::path::PathBuf,
-}
-
-impl DeviceCacher {
- pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> {
- let mut options = std::fs::OpenOptions::new();
- #[cfg(unix)]
- {
- use std::os::unix::fs::OpenOptionsExt;
- options.mode(0o600);
- }
- #[cfg(windows)]
- {
- use std::os::windows::fs::OpenOptionsExt;
- // exclusive access
- options.share_mode(0);
- }
-
- let path = settings_dir.join(DEVICE_CACHE_FILENAME);
- let cache_exists = path.is_file();
-
- let mut file = fs::OpenOptions::from(options)
- .write(true)
- .read(true)
- .create(true)
- .open(&path)
- .await?;
-
- let device: Option<DeviceData> = if cache_exists {
- let mut reader = io::BufReader::new(&mut file);
- let mut buffer = String::new();
- reader.read_to_string(&mut buffer).await?;
- if !buffer.is_empty() {
- serde_json::from_str(&buffer)?
- } else {
- None
- }
- } else {
- None
- };
-
- Ok((
- DeviceCacher {
- file: io::BufWriter::new(file),
- path,
- },
- device,
- ))
- }
-
- pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> {
- let data = serde_json::to_vec_pretty(&device).unwrap();
-
- self.file.get_mut().set_len(0).await?;
- self.file.seek(io::SeekFrom::Start(0)).await?;
- self.file.write_all(&data).await?;
- self.file.flush().await?;
- self.file.get_mut().sync_data().await?;
-
- Ok(())
- }
-
- pub async fn remove(self) -> Result<(), Error> {
- let path = {
- let DeviceCacher { path, file } = self;
- let std_file = file.into_inner().into_std().await;
- let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
- path
- };
- tokio::fs::remove_file(path).await?;
- Ok(())
- }
-
- async fn finalize(self) {
- let std_file = self.file.into_inner().into_std().await;
- let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
- }
-}
-
-#[derive(Clone)]
-pub struct AccountService {
- api_availability: ApiAvailabilityHandle,
- initial_check_abort_handle: AbortHandle,
- proxy: AccountsProxy,
-}
-
-impl AccountService {
- pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.create_account(),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub fn get_www_auth_token(
- &self,
- account: AccountToken,
- ) -> impl Future<Output = Result<String, rest::Error>> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- retry_future_n(
- move || proxy.get_www_auth_token(account.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- }
-
- pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
- let proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.get_expiry(token.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if handle_expiry_result_inner(&result, &self.api_availability) {
- self.initial_check_abort_handle.abort();
- }
- result
- }
-
- pub async fn submit_voucher(
- &mut self,
- account_token: AccountToken,
- voucher: String,
- ) -> Result<VoucherSubmission, rest::Error> {
- let mut proxy = self.proxy.clone();
- let api_handle = self.api_availability.clone();
- let result = retry_future_n(
- move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
- move |result| should_retry(result, &api_handle),
- constant_interval(RETRY_ACTION_INTERVAL),
- RETRY_ACTION_MAX_RETRIES,
- )
- .await;
- if result.is_ok() {
- self.initial_check_abort_handle.abort();
- self.api_availability.resume_background();
- }
- result
- }
-}
-
-pub fn spawn_account_service(
- api_handle: MullvadRestHandle,
- token: Option<String>,
- api_availability: ApiAvailabilityHandle,
-) -> AccountService {
- let accounts_proxy = AccountsProxy::new(api_handle);
- api_availability.pause_background();
-
- let api_availability_copy = api_availability.clone();
- let accounts_proxy_copy = accounts_proxy.clone();
-
- let (future, initial_check_abort_handle) = abortable(async move {
- let token = if let Some(token) = token {
- token
- } else {
- api_availability.pause_background();
- return;
- };
-
- let future_generator = move || {
- let expiry_fut = api_availability.when_online(accounts_proxy.get_expiry(token.clone()));
- let api_availability_copy = api_availability.clone();
- async move { handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) }
- };
- let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
- retry_future(future_generator, should_retry, retry_strategy()).await;
- });
- tokio::spawn(future);
-
- AccountService {
- api_availability: api_availability_copy,
- initial_check_abort_handle,
- proxy: accounts_proxy_copy,
- }
-}
-
-fn handle_expiry_result_inner(
- result: &Result<chrono::DateTime<chrono::Utc>, mullvad_api::rest::Error>,
- api_availability: &ApiAvailabilityHandle,
-) -> bool {
- match result {
- Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
- api_availability.resume_background();
- true
- }
- Ok(_expiry) => {
- api_availability.pause_background();
- true
- }
- Err(mullvad_api::rest::Error::ApiError(_status, code)) => {
- if code == mullvad_api::INVALID_ACCOUNT {
- api_availability.pause_background();
- return true;
- }
- false
- }
- Err(_) => false,
- }
-}
-
-fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
- match result {
- Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
- _ => false,
- }
-}
-
-fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
- match result {
- Ok(_) => false,
- Err(error) => {
- if let RestError::ApiError(status, code) = error {
- *status != rest::StatusCode::NOT_FOUND
- && code != mullvad_api::INVALID_ACCOUNT
- && code != mullvad_api::MAX_DEVICES_REACHED
- && code != mullvad_api::PUBKEY_IN_USE
- } else {
- true
- }
- }
- }
-}
-
-fn map_rest_error(error: rest::Error) -> Error {
- match error {
- RestError::ApiError(status, ref code) => {
- if status == rest::StatusCode::NOT_FOUND {
- return Error::InvalidDevice;
- }
- match code.as_str() {
- mullvad_api::INVALID_ACCOUNT => Error::InvalidAccount,
- mullvad_api::MAX_DEVICES_REACHED => Error::MaxDevicesReached,
- _ => Error::OtherRestError(error),
- }
- }
- error => Error::OtherRestError(error),
- }
-}
-
-fn retry_strategy() -> Jittered<ExponentialBackoff> {
- Jittered::jitter(
- ExponentialBackoff::new(
- RETRY_BACKOFF_INTERVAL_INITIAL,
- RETRY_BACKOFF_INTERVAL_FACTOR,
- )
- .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
- )
-}
-
-/// Checks if the current device is valid if a WireGuard tunnel cannot be set up
-/// after multiple attempts.
-pub(crate) struct TunnelStateChangeHandler {
- manager: AccountManagerHandle,
- check_validity: Arc<AtomicBool>,
- wg_retry_attempt: usize,
-}
-
-impl TunnelStateChangeHandler {
- pub fn new(manager: AccountManagerHandle) -> Self {
- Self {
- manager,
- check_validity: Arc::new(AtomicBool::new(true)),
- wg_retry_attempt: 0,
- }
- }
-
- pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) {
- match new_state {
- TunnelStateTransition::Connecting(endpoint) => {
- if endpoint.tunnel_type != TunnelType::Wireguard {
- return;
- }
- self.wg_retry_attempt += 1;
- if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 {
- let handle = self.manager.clone();
- let check_validity = self.check_validity.clone();
- tokio::spawn(async move {
- if !check_validity.swap(false, Ordering::SeqCst) {
- return;
- }
- if let Err(error) = handle.validate_device().await {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to check device validity")
- );
- if error.is_network_error() {
- check_validity.store(true, Ordering::SeqCst);
- }
- }
- });
- }
- }
- TunnelStateTransition::Connected(_) | TunnelStateTransition::Disconnected => {
- self.check_validity.store(true, Ordering::SeqCst);
- self.wg_retry_attempt = 0;
- }
- _ => (),
- }
- }
-}
diff --git a/mullvad-daemon/src/device/api.rs b/mullvad-daemon/src/device/api.rs
new file mode 100644
index 0000000000..9005c41364
--- /dev/null
+++ b/mullvad-daemon/src/device/api.rs
@@ -0,0 +1,126 @@
+use std::pin::Pin;
+
+use futures::{future::FusedFuture, Future};
+use mullvad_types::{
+ device::{Device, DeviceData},
+ wireguard::WireguardData,
+};
+
+use super::{Error, ResponseTx};
+pub struct CurrentApiCall {
+ current_call: Option<Call>,
+}
+
+impl CurrentApiCall {
+ pub fn new() -> Self {
+ Self { current_call: None }
+ }
+
+ pub fn clear(&mut self) {
+ self.current_call = None;
+ }
+
+ pub fn set_login(&mut self, login: ApiCall<DeviceData>, tx: ResponseTx<()>) {
+ self.current_call = Some(Call::Login(login, Some(tx)));
+ }
+
+ pub fn set_oneshot_rotation(&mut self, rotation: ApiCall<WireguardData>) {
+ self.current_call = Some(Call::OneshotKeyRotation(rotation));
+ }
+
+ pub fn set_timed_rotation(&mut self, rotation: ApiCall<WireguardData>) {
+ self.current_call = Some(Call::TimerKeyRotation(rotation));
+ }
+
+ pub fn set_validation(&mut self, validation: ApiCall<Device>) {
+ self.current_call = Some(Call::Validation(validation));
+ }
+
+ pub fn is_validating(&self) -> bool {
+ match &self.current_call {
+ Some(Call::Validation(_)) | Some(Call::OneshotKeyRotation(_)) => true,
+ _ => false,
+ }
+ }
+
+ pub fn is_running_timed_totation(&self) -> bool {
+ matches!(&self.current_call, Some(Call::TimerKeyRotation(_)))
+ }
+
+ pub fn is_idle(&self) -> bool {
+ self.current_call.is_none()
+ }
+
+ pub fn is_logging_in(&self) -> bool {
+ use Call::*;
+ match &self.current_call {
+ Some(Login(..)) => true,
+ _ => false,
+ }
+ }
+}
+
+impl Future for CurrentApiCall {
+ type Output = ApiResult;
+
+ fn poll(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Self::Output> {
+ match self.current_call.as_mut() {
+ Some(call) => {
+ let result = Pin::new(call).poll(cx);
+ if result.is_ready() {
+ self.current_call = None;
+ }
+ result
+ }
+ None => panic!("Polled an unfinished future"),
+ }
+ }
+}
+
+impl FusedFuture for CurrentApiCall {
+ fn is_terminated(&self) -> bool {
+ self.current_call.is_none()
+ }
+}
+
+type ApiCall<T> = Pin<Box<dyn Future<Output = Result<T, Error>> + Send>>;
+
+enum Call {
+ Login(ApiCall<DeviceData>, Option<ResponseTx<()>>),
+ TimerKeyRotation(ApiCall<WireguardData>),
+ OneshotKeyRotation(ApiCall<WireguardData>),
+ Validation(ApiCall<Device>),
+}
+
+impl futures::Future for Call {
+ type Output = ApiResult;
+
+ fn poll(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Self::Output> {
+ use Call::*;
+ match &mut *self {
+ Login(call, tx) => {
+ if let std::task::Poll::Ready(response) = Pin::new(call).poll(cx) {
+ std::task::Poll::Ready(ApiResult::Login(response, tx.take().unwrap()))
+ } else {
+ std::task::Poll::Pending
+ }
+ }
+ TimerKeyRotation(call) | OneshotKeyRotation(call) => {
+ Pin::new(call).poll(cx).map(ApiResult::Rotation)
+ }
+ Validation(call) => Pin::new(call).poll(cx).map(ApiResult::Validation),
+ }
+ }
+}
+
+pub enum ApiResult {
+ Login(Result<DeviceData, Error>, ResponseTx<()>),
+ Rotation(Result<WireguardData, Error>),
+ Validation(Result<Device, Error>),
+}
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs
new file mode 100644
index 0000000000..b2bb669eb2
--- /dev/null
+++ b/mullvad-daemon/src/device/mod.rs
@@ -0,0 +1,807 @@
+use chrono::{DateTime, Utc};
+use futures::{
+ channel::{mpsc, oneshot},
+ stream::StreamExt,
+ FutureExt,
+};
+
+use mullvad_api::{availability::ApiAvailabilityHandle, rest};
+use mullvad_types::{
+ account::AccountToken,
+ device::{Device, DeviceData, DeviceEvent},
+ wireguard::{RotationInterval, WireguardData},
+};
+use std::{
+ future::Future,
+ path::Path,
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+ },
+ time::{Duration, SystemTime},
+};
+use talpid_core::mpsc::Sender;
+use talpid_types::{net::TunnelType, tunnel::TunnelStateTransition, ErrorExt};
+use tokio::{
+ fs,
+ io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
+};
+
+mod api;
+mod service;
+pub use service::{spawn_account_service, AccountService, DeviceService};
+
+/// File that used to store account and device data.
+const DEVICE_CACHE_FILENAME: &str = "device.json";
+
+/// How long to keep the known status for [AccountManagerHandle::validate_device].
+const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10);
+
+/// How long to wait on logout (device removal) before letting it continue as a background task.
+const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2);
+
+/// Validate the current device once for every `WG_DEVICE_CHECK_THRESHOLD` failed attempts
+/// to set up a WireGuard tunnel.
+const WG_DEVICE_CHECK_THRESHOLD: usize = 3;
+
+#[derive(err_derive::Error, Debug)]
+pub enum Error {
+ #[error(display = "The account already has a maximum number of devices")]
+ MaxDevicesReached,
+ #[error(display = "No device is set")]
+ NoDevice,
+ #[error(display = "Device not found")]
+ InvalidDevice,
+ #[error(display = "Invalid account")]
+ InvalidAccount,
+ #[error(display = "Failed to read or write device cache")]
+ DeviceIoError(#[error(source)] io::Error),
+ #[error(display = "Failed parse device cache")]
+ ParseDeviceCache(#[error(source)] serde_json::Error),
+ #[error(display = "Unexpected HTTP request error")]
+ OtherRestError(#[error(source)] rest::Error),
+ #[error(display = "The device update task is not running")]
+ Cancelled,
+ /// Intended to be broadcast to requesters
+ #[error(display = "Broadcast error")]
+ ResponseFailure(#[error(source)] Arc<Error>),
+ #[error(display = "Account changed during operation")]
+ AccountChange,
+ #[error(display = "The account manager is down")]
+ AccountManagerDown,
+}
+
+#[derive(Clone)]
+pub(crate) enum InnerDeviceEvent {
+ /// The device was removed due to user (or daemon) action.
+ Logout,
+ /// Logged in to a new device.
+ Login(DeviceData),
+ /// The device was updated remotely, but not its key.
+ Updated(DeviceData),
+ /// The key was rotated.
+ RotatedKey(DeviceData),
+ /// Device was removed because it was not found remotely.
+ Revoked,
+}
+
+impl From<InnerDeviceEvent> for DeviceEvent {
+ fn from(event: InnerDeviceEvent) -> DeviceEvent {
+ match event {
+ InnerDeviceEvent::Logout => DeviceEvent::revoke(false),
+ InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true),
+ InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Revoked => DeviceEvent::revoke(true),
+ }
+ }
+}
+
+impl InnerDeviceEvent {
+ fn data(&self) -> Option<&DeviceData> {
+ match self {
+ InnerDeviceEvent::Login(data) => Some(&data),
+ InnerDeviceEvent::Updated(data) => Some(&data),
+ InnerDeviceEvent::RotatedKey(data) => Some(&data),
+ InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
+ }
+ }
+}
+
+impl Error {
+ pub fn is_network_error(&self) -> bool {
+ if let Error::OtherRestError(error) = self {
+ error.is_network_error()
+ } else {
+ false
+ }
+ }
+}
+
+type ResponseTx<T> = oneshot::Sender<Result<T, Error>>;
+
+enum AccountManagerCommand {
+ Login(AccountToken, ResponseTx<()>),
+ Logout(ResponseTx<()>),
+ SetData(DeviceData, ResponseTx<()>),
+ GetData(ResponseTx<Option<DeviceData>>),
+ RotateKey(ResponseTx<()>),
+ SetRotationInterval(RotationInterval, ResponseTx<()>),
+ ValidateDevice(ResponseTx<()>),
+ ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>),
+ Shutdown(oneshot::Sender<()>),
+}
+
+#[derive(Clone)]
+pub(crate) struct AccountManagerHandle {
+ cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>,
+ pub account_service: AccountService,
+ pub device_service: DeviceService,
+}
+
+impl AccountManagerHandle {
+ pub async fn login(&self, token: AccountToken) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Login(token, tx))
+ .await
+ }
+
+ pub async fn logout(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Logout(tx))
+ .await
+ }
+
+ pub async fn set(&self, data: DeviceData) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetData(data, tx))
+ .await
+ }
+
+ pub async fn data(&self) -> Result<Option<DeviceData>, Error> {
+ self.send_command(|tx| AccountManagerCommand::GetData(tx))
+ .await
+ }
+
+ pub async fn rotate_key(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::RotateKey(tx))
+ .await
+ }
+
+ pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx))
+ .await
+ }
+
+ pub async fn validate_device(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx))
+ .await
+ }
+
+ pub async fn receive_events(
+ &self,
+ events_tx: impl Sender<InnerDeviceEvent> + Send + 'static,
+ ) -> Result<(), Error> {
+ self.send_command(|tx| {
+ AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx)
+ })
+ .await
+ }
+
+ pub async fn shutdown(self) {
+ let (tx, rx) = oneshot::channel();
+ let _ = self
+ .cmd_tx
+ .unbounded_send(AccountManagerCommand::Shutdown(tx));
+ let _ = rx.await;
+ }
+
+ async fn send_command<T>(
+ &self,
+ make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand,
+ ) -> Result<T, Error> {
+ let (tx, rx) = oneshot::channel();
+ self.cmd_tx
+ .unbounded_send(make_cmd(tx))
+ .map_err(|_| Error::AccountManagerDown)?;
+ rx.await.map_err(|_| Error::AccountManagerDown)?
+ }
+}
+
+pub(crate) struct AccountManager {
+ cacher: DeviceCacher,
+ device_service: DeviceService,
+ data: Option<DeviceData>,
+ rotation_interval: RotationInterval,
+ listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>,
+ last_validation: Option<SystemTime>,
+ validation_requests: Vec<ResponseTx<()>>,
+ rotation_requests: Vec<ResponseTx<()>>,
+}
+
+impl AccountManager {
+ pub async fn spawn(
+ rest_handle: rest::MullvadRestHandle,
+ api_availability: ApiAvailabilityHandle,
+ settings_dir: &Path,
+ initial_rotation_interval: RotationInterval,
+ ) -> Result<AccountManagerHandle, Error> {
+ let (cacher, data) = DeviceCacher::new(settings_dir).await?;
+ let token = data.as_ref().map(|state| state.token.clone());
+ let account_service =
+ spawn_account_service(rest_handle.clone(), token, api_availability.clone());
+
+ let (cmd_tx, cmd_rx) = mpsc::unbounded();
+
+ let device_service = DeviceService::new(rest_handle, api_availability);
+ let manager = AccountManager {
+ cacher,
+ device_service: device_service.clone(),
+ data,
+ rotation_interval: initial_rotation_interval,
+ listeners: vec![],
+ last_validation: None,
+ validation_requests: vec![],
+ rotation_requests: vec![],
+ };
+
+ tokio::spawn(manager.run(cmd_rx));
+ let handle = AccountManagerHandle {
+ cmd_tx,
+ account_service,
+ device_service,
+ };
+ Ok(handle)
+ }
+
+ async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) {
+ let mut shutdown_tx = None;
+ let mut current_api_call = api::CurrentApiCall::new();
+
+ loop {
+ futures::select! {
+ api_result = current_api_call => {
+ self.consume_api_result(api_result, &mut current_api_call).await;
+ }
+
+ cmd = cmd_rx.next() => {
+ match cmd {
+ Some(AccountManagerCommand::Shutdown(tx)) => {
+ shutdown_tx = Some(tx);
+ break;
+ }
+ Some(AccountManagerCommand::Login(token, tx)) => {
+ let job = self.device_service
+ .generate_for_account(token);
+ current_api_call.set_login(Box::pin(job), tx);
+ }
+ Some(AccountManagerCommand::Logout(tx)) => {
+ current_api_call.clear();
+ self.logout(tx).await;
+ }
+ Some(AccountManagerCommand::SetData(data, tx)) => {
+ let _ = tx.send(self.set(InnerDeviceEvent::Login(data)).await);
+ }
+ Some(AccountManagerCommand::GetData(tx)) => {
+ let _ = tx.send(Ok(self.data.clone()));
+ }
+ Some(AccountManagerCommand::RotateKey(tx)) => {
+ if current_api_call.is_logging_in() {
+ let _ = tx.send(Err(Error::AccountChange));
+ continue
+ }
+ if current_api_call.is_validating() {
+ self.rotation_requests.push(tx);
+ continue
+ }
+ match self.initiate_key_rotation() {
+ Ok(api_call) => {
+ current_api_call.set_oneshot_rotation(Box::pin(api_call))
+ },
+ Err(err) => {
+ let _ = tx.send(Err(err));
+ }
+ }
+ }
+ Some(AccountManagerCommand::SetRotationInterval(interval, tx)) => {
+ self.rotation_interval = interval;
+ if current_api_call.is_running_timed_totation() {
+ if let Some(timed_rotation) = self.spawn_timed_key_rotation() {
+ current_api_call.set_timed_rotation(Box::pin(timed_rotation))
+ }
+ }
+ let _ = tx.send(Ok(()));
+ }
+ Some(AccountManagerCommand::ValidateDevice(tx)) => {
+ self.handle_validation_request(tx, &mut current_api_call);
+ }
+ Some(AccountManagerCommand::ReceiveEvents(events_tx, tx)) => {
+ let _ = tx.send(Ok(self.listeners.push(events_tx)));
+ },
+ None => {
+ break;
+ }
+ }
+ }
+ }
+
+ if current_api_call.is_idle() {
+ if let Some(timed_rotation) = self.spawn_timed_key_rotation() {
+ current_api_call.set_timed_rotation(Box::pin(timed_rotation))
+ }
+ }
+ }
+ self.shutdown().await;
+ if let Some(tx) = shutdown_tx {
+ let _ = tx.send(());
+ }
+ log::debug!("Account manager has stopped");
+ }
+
+ fn handle_validation_request(
+ &mut self,
+ tx: ResponseTx<()>,
+ current_api_call: &mut api::CurrentApiCall,
+ ) {
+ if current_api_call.is_logging_in() {
+ let _ = tx.send(Err(Error::AccountChange));
+ return;
+ }
+ if current_api_call.is_validating() {
+ self.validation_requests.push(tx);
+ return;
+ }
+ if self.cached_validation() {
+ let _ = tx.send(Ok(()));
+ return;
+ }
+
+ match self.validation_call() {
+ Ok(call) => {
+ current_api_call.set_validation(Box::pin(call));
+ self.validation_requests.push(tx);
+ }
+ Err(err) => {
+ let _ = tx.send(Err(err));
+ }
+ }
+ }
+
+ async fn consume_api_result(
+ &mut self,
+ result: api::ApiResult,
+ api_call: &mut api::CurrentApiCall,
+ ) {
+ use api::ApiResult::*;
+ match result {
+ Login(data, tx) => self.consume_login(data, tx).await,
+ Rotation(rotation_response) => self.consume_rotation_result(rotation_response).await,
+ Validation(data_response) => self.consume_validation(data_response, api_call).await,
+ }
+ }
+
+ async fn consume_login(
+ &mut self,
+ device_response: Result<DeviceData, Error>,
+ tx: ResponseTx<()>,
+ ) {
+ let _ = tx.send(async { self.set(InnerDeviceEvent::Login(device_response?)).await }.await);
+ }
+
+ async fn consume_validation(
+ &mut self,
+ response: Result<Device, Error>,
+ api_call: &mut api::CurrentApiCall,
+ ) {
+ let current_data = match self.data.as_ref() {
+ Some(data) => data,
+ None => {
+ panic!("Received a validation response whilst having no device data");
+ }
+ };
+
+ match response {
+ Ok(new_device_data) => {
+ if new_device_data.pubkey == current_data.device.pubkey {
+ let new_data = DeviceData {
+ device: new_device_data,
+ ..current_data.clone()
+ };
+
+ match self.set(InnerDeviceEvent::Updated(new_data)).await {
+ Ok(_) => {
+ Self::drain_requests(&mut self.validation_requests, || Ok(()));
+ }
+ Err(err) => {
+ log::error!("Failed to save device data to disk");
+ let cloneable_err = Arc::new(err);
+ Self::drain_requests(&mut self.validation_requests, || {
+ Err(Error::ResponseFailure(cloneable_err.clone()))
+ });
+ }
+ }
+ }
+ }
+ Err(Error::InvalidAccount) => {
+ self.invalidate_current_data(|| Error::InvalidAccount).await;
+ }
+ Err(Error::InvalidDevice) => {
+ self.invalidate_current_data(|| Error::InvalidDevice).await;
+ }
+ Err(err) => {
+ log::error!("Failed to validate device: {}", err);
+ let cloneable_err = Arc::new(err);
+ Self::drain_requests(&mut self.validation_requests, || {
+ Err(Error::ResponseFailure(cloneable_err.clone()))
+ });
+ }
+ }
+
+ if !self.rotation_requests.is_empty() || !self.validation_requests.is_empty() {
+ if let Some(updated_data) = self.data.as_ref() {
+ let device_service = self.device_service.clone();
+ let token = updated_data.token.clone();
+ let device_id = updated_data.device.id.clone();
+ api_call.set_oneshot_rotation(Box::pin(async move {
+ device_service.rotate_key(token, device_id).await
+ }));
+ }
+ }
+ }
+
+ async fn consume_rotation_result(&mut self, api_result: Result<WireguardData, Error>) {
+ let mut device_data = match self.data.clone() {
+ Some(data) => data,
+ None => {
+ panic!("Received a key rotation result whilst having no data");
+ }
+ };
+
+ match api_result {
+ Ok(wg_data) => {
+ device_data.device.pubkey = wg_data.private_key.public_key();
+ device_data.wg_data = wg_data;
+ match self.set(InnerDeviceEvent::RotatedKey(device_data)).await {
+ Ok(_) => {
+ Self::drain_requests(&mut self.rotation_requests, || Ok(()));
+
+ Self::drain_requests(&mut self.validation_requests, || Ok(()));
+ }
+ Err(err) => {
+ self.drain_requests_with_err(err);
+ }
+ }
+ }
+ Err(Error::InvalidAccount) => {
+ self.invalidate_current_data(|| Error::InvalidAccount).await;
+ }
+ Err(Error::InvalidDevice) => {
+ self.invalidate_current_data(|| Error::InvalidDevice).await;
+ }
+ Err(err) => {
+ self.drain_requests_with_err(err);
+ }
+ }
+ }
+
+ fn drain_requests_with_err(&mut self, err: Error) {
+ let cloneable_err = Arc::new(err);
+ Self::drain_requests(&mut self.rotation_requests, || {
+ Err(Error::ResponseFailure(cloneable_err.clone()))
+ });
+ Self::drain_requests(&mut self.validation_requests, || {
+ Err(Error::ResponseFailure(cloneable_err.clone()))
+ });
+ }
+
+ fn drain_requests<T>(requests: &mut Vec<ResponseTx<T>>, result: impl Fn() -> Result<T, Error>) {
+ for req in requests.drain(0..) {
+ let _ = req.send(result());
+ }
+ }
+
+ fn spawn_timed_key_rotation(
+ &self,
+ ) -> Option<impl Future<Output = Result<WireguardData, Error>> + Send + 'static> {
+ let data = self.data.as_ref()?;
+ let key_rotation_timer = self.key_rotation_timer(data.wg_data.created);
+
+ let device_service = self.device_service.clone();
+ let account_token = data.token.clone();
+ let device_id = data.device.id.clone();
+
+ Some(async move {
+ key_rotation_timer.await;
+ device_service
+ .rotate_key_with_backoff(account_token, device_id)
+ .await
+ })
+ }
+
+ async fn invalidate_current_data(&mut self, err_constructor: impl Fn() -> Error) {
+ if let Err(err) = self.cacher.write(None).await {
+ log::error!(
+ "{}",
+ err.display_chain_with_msg("Failed to save device data to disk")
+ );
+ }
+ self.data = None;
+
+ Self::drain_requests(&mut self.validation_requests, || Err(err_constructor()));
+ Self::drain_requests(&mut self.rotation_requests, || Err(err_constructor()));
+
+ self.listeners
+ .retain(|listener| listener.send(InnerDeviceEvent::Revoked).is_ok());
+ }
+
+ async fn logout(&mut self, tx: ResponseTx<()>) {
+ let data = match self.data.take() {
+ Some(it) => it,
+ _ => return,
+ };
+ if let Err(err) = self.cacher.write(None).await {
+ let _ = tx.send(Err(err));
+ return;
+ }
+
+ self.listeners
+ .retain(|listener| listener.send(InnerDeviceEvent::Logout).is_ok());
+
+ let mut logout_call = Box::pin(self.logout_api_call(data).fuse());
+ tokio::spawn(async move {
+ let timeout = tokio::time::sleep(LOGOUT_TIMEOUT).fuse();
+ futures::pin_mut!(timeout);
+ futures::select! {
+ _timeout = timeout => {
+ let _ = tx.send(Ok(()));
+ logout_call.await
+ },
+ _logout = logout_call => {
+ let _ = tx.send(Ok(()));
+ }
+ }
+ });
+ }
+
+ fn logout_api_call(&self, data: DeviceData) -> impl Future<Output = ()> + 'static {
+ let service = self.device_service.clone();
+
+ async move {
+ if let Err(error) = service
+ .remove_device_with_backoff(data.token, data.device.id)
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to logout device")
+ );
+ }
+ }
+ }
+
+ async fn set(&mut self, event: InnerDeviceEvent) -> Result<(), Error> {
+ let data = event.data();
+ if data == self.data.as_ref() {
+ return Ok(());
+ }
+
+ self.cacher.write(data).await?;
+ self.last_validation = None;
+
+ if let Some(old_data) = self.data.take() {
+ if data.as_ref().map(|d| &d.device.id) == Some(&old_data.device.id) {
+ tokio::spawn(self.logout_api_call(old_data));
+ }
+ }
+
+ self.data = data.cloned();
+
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+
+ Ok(())
+ }
+
+ fn initiate_key_rotation(
+ &self,
+ ) -> Result<impl Future<Output = Result<WireguardData, Error>>, Error> {
+ let data = self.data.clone().ok_or(Error::NoDevice)?;
+ let device_service = self.device_service.clone();
+ Ok(async move { device_service.rotate_key(data.token, data.device.id).await })
+ }
+
+ fn key_rotation_timer(&self, key_created: DateTime<Utc>) -> impl Future<Output = ()> + 'static {
+ let rotation_interval = self.rotation_interval;
+
+ async move {
+ let key_age = Duration::from_secs(
+ chrono::Utc::now()
+ .signed_duration_since(key_created)
+ .num_seconds()
+ .try_into()
+ // This would only fail if the key was created in the future, in which case
+ // the duration would be negative. In this case, I think it's safe to
+ // assume the daemon should wait one whole key rotation interval.
+ .unwrap_or(0u64),
+ );
+ let time_until_next_rotation = std::cmp::max(
+ rotation_interval.as_duration().saturating_sub(key_age),
+ Duration::from_secs(60),
+ );
+
+ log::trace!(
+ "{} seconds to wait until next rotation",
+ time_until_next_rotation.as_secs(),
+ );
+ talpid_time::sleep(time_until_next_rotation).await
+ }
+ }
+
+ fn fetch_device_data(
+ &self,
+ old_data: &DeviceData,
+ ) -> impl Future<Output = Result<Device, Error>> {
+ let device_service = self.device_service.clone();
+ let account_token = old_data.token.clone();
+ let device_id = old_data.device.id.clone();
+ async move { device_service.get(account_token, device_id).await }
+ }
+
+ fn validation_call(&self) -> Result<impl Future<Output = Result<Device, Error>>, Error> {
+ let old_data = self.data.as_ref().ok_or(Error::NoDevice)?;
+ let device_request = self.fetch_device_data(old_data);
+ Ok(async move { device_request.await })
+ }
+
+ fn cached_validation(&mut self) -> bool {
+ if self.data.is_none() {
+ return false;
+ }
+
+ let now = SystemTime::now();
+
+ let elapsed = self
+ .last_validation
+ .and_then(|last_check| now.duration_since(last_check).ok())
+ .unwrap_or(VALIDITY_CACHE_TIMEOUT);
+
+ if elapsed >= VALIDITY_CACHE_TIMEOUT {
+ self.last_validation = None;
+ return false;
+ }
+
+ true
+ }
+
+ async fn shutdown(self) {
+ self.cacher.finalize().await;
+ }
+}
+pub struct DeviceCacher {
+ file: io::BufWriter<fs::File>,
+ path: std::path::PathBuf,
+}
+
+impl DeviceCacher {
+ pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, Option<DeviceData>), Error> {
+ let mut options = std::fs::OpenOptions::new();
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::OpenOptionsExt;
+ options.mode(0o600);
+ }
+ #[cfg(windows)]
+ {
+ use std::os::windows::fs::OpenOptionsExt;
+ // exclusive access
+ options.share_mode(0);
+ }
+
+ let path = settings_dir.join(DEVICE_CACHE_FILENAME);
+ let cache_exists = path.is_file();
+
+ let mut file = fs::OpenOptions::from(options)
+ .write(true)
+ .read(true)
+ .create(true)
+ .open(&path)
+ .await?;
+
+ let device: Option<DeviceData> = if cache_exists {
+ let mut reader = io::BufReader::new(&mut file);
+ let mut buffer = String::new();
+ reader.read_to_string(&mut buffer).await?;
+ if !buffer.is_empty() {
+ serde_json::from_str(&buffer)?
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ Ok((
+ DeviceCacher {
+ file: io::BufWriter::new(file),
+ path,
+ },
+ device,
+ ))
+ }
+
+ pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> {
+ let data = serde_json::to_vec_pretty(&device).unwrap();
+
+ self.file.get_mut().set_len(0).await?;
+ self.file.seek(io::SeekFrom::Start(0)).await?;
+ self.file.write_all(&data).await?;
+ self.file.flush().await?;
+ self.file.get_mut().sync_data().await?;
+
+ Ok(())
+ }
+
+ pub async fn remove(self) -> Result<(), Error> {
+ let path = {
+ let DeviceCacher { path, file } = self;
+ let std_file = file.into_inner().into_std().await;
+ let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
+ path
+ };
+ tokio::fs::remove_file(path).await?;
+ Ok(())
+ }
+
+ async fn finalize(self) {
+ let std_file = self.file.into_inner().into_std().await;
+ let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
+ }
+}
+/// Checks if the current device is valid if a WireGuard tunnel cannot be set up
+/// after multiple attempts.
+pub(crate) struct TunnelStateChangeHandler {
+ manager: AccountManagerHandle,
+ check_validity: Arc<AtomicBool>,
+ wg_retry_attempt: usize,
+}
+
+impl TunnelStateChangeHandler {
+ pub fn new(manager: AccountManagerHandle) -> Self {
+ Self {
+ manager,
+ check_validity: Arc::new(AtomicBool::new(true)),
+ wg_retry_attempt: 0,
+ }
+ }
+
+ pub fn handle_state_transition(&mut self, new_state: &TunnelStateTransition) {
+ match new_state {
+ TunnelStateTransition::Connecting(endpoint) => {
+ if endpoint.tunnel_type != TunnelType::Wireguard {
+ return;
+ }
+ self.wg_retry_attempt += 1;
+ if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 {
+ let handle = self.manager.clone();
+ let check_validity = self.check_validity.clone();
+ tokio::spawn(async move {
+ if !check_validity.swap(false, Ordering::SeqCst) {
+ return;
+ }
+ if let Err(error) = handle.validate_device().await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to check device validity")
+ );
+ if error.is_network_error() {
+ check_validity.store(true, Ordering::SeqCst);
+ }
+ }
+ });
+ }
+ }
+ TunnelStateTransition::Connected(_) | TunnelStateTransition::Disconnected => {
+ self.check_validity.store(true, Ordering::SeqCst);
+ self.wg_retry_attempt = 0;
+ }
+ _ => (),
+ }
+ }
+}
diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs
new file mode 100644
index 0000000000..36ea48ea86
--- /dev/null
+++ b/mullvad-daemon/src/device/service.rs
@@ -0,0 +1,424 @@
+use std::{future::Future, time::Duration};
+
+use chrono::{DateTime, Utc};
+use futures::future::{abortable, AbortHandle};
+use mullvad_types::{
+ account::{AccountToken, VoucherSubmission},
+ device::{Device, DeviceData, DeviceId},
+ wireguard::WireguardData,
+};
+use talpid_types::net::wireguard::PrivateKey;
+
+use super::Error;
+use mullvad_api::{
+ availability::ApiAvailabilityHandle,
+ rest::{self, Error as RestError, MullvadRestHandle},
+ AccountsProxy, DevicesProxy,
+};
+use talpid_core::future_retry::{
+ constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered,
+};
+const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO;
+const RETRY_ACTION_MAX_RETRIES: usize = 2;
+
+const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
+const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5;
+const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
+
+#[derive(Clone)]
+pub struct DeviceService {
+ api_availability: ApiAvailabilityHandle,
+ proxy: DevicesProxy,
+}
+
+impl DeviceService {
+ pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self {
+ Self {
+ proxy: DevicesProxy::new(handle),
+ api_availability,
+ }
+ }
+
+ /// Generate a new device for a given token
+ pub fn generate_for_account(
+ &self,
+ token: AccountToken,
+ ) -> impl Future<Output = Result<DeviceData, Error>> + Send {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ async move {
+ let (device, addresses) = retry_future_n(
+ move || proxy.create(token_copy.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+ }
+
+ pub async fn generate_for_account_with_backoff(
+ &self,
+ token: AccountToken,
+ ) -> Result<DeviceData, Error> {
+ let private_key = PrivateKey::new_from_random();
+ let pubkey = private_key.public_key();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let token_copy = token.clone();
+ let (device, addresses) = retry_future(
+ move || api_handle.when_online(proxy.create(token_copy.clone(), pubkey.clone())),
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(DeviceData {
+ token,
+ device,
+ wg_data: WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ },
+ })
+ }
+
+ pub async fn remove_device(&self, token: AccountToken, device: DeviceId) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.remove(token.clone(), device.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+ Ok(())
+ }
+
+ pub async fn remove_device_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<(), Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+
+ let retry_strategy = Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ ), // Not setting a maximum interval
+ );
+
+ retry_future(
+ // NOTE: Not honoring "paused" state, because the account may have no time on it.
+ move || api_handle.when_online(proxy.remove(token.clone(), device.clone())),
+ should_retry_backoff,
+ retry_strategy,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(())
+ }
+
+ pub async fn rotate_key(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+ let addresses = retry_future_n(
+ move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn rotate_key_with_backoff(
+ &self,
+ token: AccountToken,
+ device: DeviceId,
+ ) -> Result<WireguardData, Error> {
+ let private_key = PrivateKey::new_from_random();
+
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let pubkey = private_key.public_key();
+
+ let addresses = retry_future(
+ move || {
+ api_handle.when_bg_resumes(proxy.replace_wg_key(
+ token.clone(),
+ device.clone(),
+ pubkey.clone(),
+ ))
+ },
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)?;
+
+ Ok(WireguardData {
+ private_key,
+ addresses,
+ created: Utc::now(),
+ })
+ }
+
+ pub async fn list_devices(&self, token: AccountToken) -> Result<Vec<Device>, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.list(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+
+ pub async fn list_devices_with_backoff(
+ &self,
+ token: AccountToken,
+ ) -> Result<Vec<Device>, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+
+ retry_future(
+ move || api_handle.when_online(proxy.list(token.clone())),
+ should_retry_backoff,
+ retry_strategy(),
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+
+ pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result<Device, Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.get(token.clone(), device.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await
+ .map_err(map_rest_error)
+ }
+}
+
+#[derive(Clone)]
+pub struct AccountService {
+ api_availability: ApiAvailabilityHandle,
+ initial_check_abort_handle: AbortHandle,
+ proxy: AccountsProxy,
+}
+
+impl AccountService {
+ pub fn create_account(&self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.create_account(),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub fn get_www_auth_token(
+ &self,
+ account: AccountToken,
+ ) -> impl Future<Output = Result<String, rest::Error>> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ retry_future_n(
+ move || proxy.get_www_auth_token(account.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ }
+
+ pub async fn check_expiry(&self, token: AccountToken) -> Result<DateTime<Utc>, rest::Error> {
+ let proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.get_expiry(token.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if handle_expiry_result_inner(&result, &self.api_availability) {
+ self.initial_check_abort_handle.abort();
+ }
+ result
+ }
+
+ pub async fn submit_voucher(
+ &mut self,
+ account_token: AccountToken,
+ voucher: String,
+ ) -> Result<VoucherSubmission, rest::Error> {
+ let mut proxy = self.proxy.clone();
+ let api_handle = self.api_availability.clone();
+ let result = retry_future_n(
+ move || proxy.submit_voucher(account_token.clone(), voucher.clone()),
+ move |result| should_retry(result, &api_handle),
+ constant_interval(RETRY_ACTION_INTERVAL),
+ RETRY_ACTION_MAX_RETRIES,
+ )
+ .await;
+ if result.is_ok() {
+ self.initial_check_abort_handle.abort();
+ self.api_availability.resume_background();
+ }
+ result
+ }
+}
+
+pub fn spawn_account_service(
+ api_handle: MullvadRestHandle,
+ token: Option<String>,
+ api_availability: ApiAvailabilityHandle,
+) -> AccountService {
+ let accounts_proxy = AccountsProxy::new(api_handle);
+ api_availability.pause_background();
+
+ let api_availability_copy = api_availability.clone();
+ let accounts_proxy_copy = accounts_proxy.clone();
+
+ let (future, initial_check_abort_handle) = abortable(async move {
+ let token = if let Some(token) = token {
+ token
+ } else {
+ api_availability.pause_background();
+ return;
+ };
+
+ let future_generator = move || {
+ let expiry_fut = api_availability.when_online(accounts_proxy.get_expiry(token.clone()));
+ let api_availability_copy = api_availability.clone();
+ async move { handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) }
+ };
+ let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated };
+ retry_future(future_generator, should_retry, retry_strategy()).await;
+ });
+ tokio::spawn(future);
+
+ AccountService {
+ api_availability: api_availability_copy,
+ initial_check_abort_handle,
+ proxy: accounts_proxy_copy,
+ }
+}
+
+fn handle_expiry_result_inner(
+ result: &Result<chrono::DateTime<chrono::Utc>, mullvad_api::rest::Error>,
+ api_availability: &ApiAvailabilityHandle,
+) -> bool {
+ match result {
+ Ok(_expiry) if *_expiry >= chrono::Utc::now() => {
+ api_availability.resume_background();
+ true
+ }
+ Ok(_expiry) => {
+ api_availability.pause_background();
+ true
+ }
+ Err(mullvad_api::rest::Error::ApiError(_status, code)) => {
+ if code == mullvad_api::INVALID_ACCOUNT {
+ api_availability.pause_background();
+ return true;
+ }
+ false
+ }
+ Err(_) => false,
+ }
+}
+
+fn should_retry<T>(result: &Result<T, RestError>, api_handle: &ApiAvailabilityHandle) -> bool {
+ match result {
+ Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(),
+ _ => false,
+ }
+}
+
+fn should_retry_backoff<T>(result: &Result<T, RestError>) -> bool {
+ match result {
+ Ok(_) => false,
+ Err(error) => {
+ if let RestError::ApiError(status, code) = error {
+ *status != rest::StatusCode::NOT_FOUND
+ && code != mullvad_api::INVALID_ACCOUNT
+ && code != mullvad_api::MAX_DEVICES_REACHED
+ && code != mullvad_api::PUBKEY_IN_USE
+ } else {
+ true
+ }
+ }
+ }
+}
+
+fn map_rest_error(error: rest::Error) -> Error {
+ match error {
+ RestError::ApiError(status, ref code) => {
+ if status == rest::StatusCode::NOT_FOUND {
+ return Error::InvalidDevice;
+ }
+ match code.as_str() {
+ mullvad_api::INVALID_ACCOUNT => Error::InvalidAccount,
+ mullvad_api::MAX_DEVICES_REACHED => Error::MaxDevicesReached,
+ _ => Error::OtherRestError(error),
+ }
+ }
+ error => Error::OtherRestError(error),
+ }
+}
+
+fn retry_strategy() -> Jittered<ExponentialBackoff> {
+ Jittered::jitter(
+ ExponentialBackoff::new(
+ RETRY_BACKOFF_INTERVAL_INITIAL,
+ RETRY_BACKOFF_INTERVAL_FACTOR,
+ )
+ .max_delay(RETRY_BACKOFF_INTERVAL_MAX),
+ )
+}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 9aca859d9e..b4e1763b6e 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -1012,7 +1012,7 @@ where
match (&self.tunnel_state, &tunnel_state_transition) {
// only reset the API sockets if when connected or leaving the connected state
(&TunnelState::Connected { .. }, _) | (_, &TunnelStateTransition::Connected(_)) => {
- self.api_handle.service().reset().await;
+ self.api_handle.service().reset();
}
_ => (),
};
@@ -1290,8 +1290,8 @@ where
SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await,
GetRelayLocations(tx) => self.on_get_relay_locations(tx),
UpdateRelayLocations => self.on_update_relay_locations().await,
- LoginAccount(tx, account_token) => self.on_login_account(tx, account_token).await,
- LogoutAccount(tx) => self.on_logout_account(tx).await,
+ LoginAccount(tx, account_token) => self.on_login_account(tx, account_token),
+ LogoutAccount(tx) => self.on_logout_account(tx),
GetDevice(tx) => self.on_get_device(tx).await,
UpdateDevice(tx) => self.on_update_device(tx).await,
ListDevices(tx, account_token) => self.on_list_devices(tx, account_token).await,
@@ -1732,7 +1732,7 @@ where
self.relay_selector.update().await;
}
- async fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) {
+ fn on_login_account(&mut self, tx: ResponseTx<(), Error>, account_token: String) {
let account_manager = self.account_manager.clone();
tokio::spawn(async move {
let result = async {
@@ -1745,7 +1745,7 @@ where
});
}
- async fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) {
+ fn on_logout_account(&mut self, tx: ResponseTx<(), Error>) {
let account_manager = self.account_manager.clone();
tokio::spawn(async move {
let result = async {