diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-03-03 12:57:40 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-03-14 12:08:54 +0100 |
| commit | 0985a987a10d19078efa30eba7fdde0821dcf2be (patch) | |
| tree | b2c2a2f0e21d5af66d0a725dafbe998af0811137 | |
| parent | e7bcd784a7f25be8b3b50e75b8adfad446125bc7 (diff) | |
| download | mullvadvpn-0985a987a10d19078efa30eba7fdde0821dcf2be.tar.xz mullvadvpn-0985a987a10d19078efa30eba7fdde0821dcf2be.zip | |
Refactor account manager into actor
| -rw-r--r-- | mullvad-cli/src/cmds/account.rs | 1 | ||||
| -rw-r--r-- | mullvad-daemon/src/device.rs | 715 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 198 |
3 files changed, 515 insertions, 399 deletions
diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs index 09c268f3c0..b4ef7c7f14 100644 --- a/mullvad-cli/src/cmds/account.rs +++ b/mullvad-cli/src/cmds/account.rs @@ -202,7 +202,6 @@ impl Account { async fn revoke_device(&self, matches: &clap::ArgMatches) -> Result<()> { let mut rpc = new_rpc_client().await?; - let token = self.parse_account_else_current(&mut rpc, matches).await?; let device_to_revoke = parse_device_name(matches); diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs index f4eba59d2d..0a078c19f4 100644 --- a/mullvad-daemon/src/device.rs +++ b/mullvad-daemon/src/device.rs @@ -11,7 +11,7 @@ use mullvad_rpc::{ }; use mullvad_types::{ account::{AccountToken, VoucherSubmission}, - device::{Device, DeviceData, DeviceId}, + device::{Device, DeviceData, DeviceEvent, DeviceId}, wireguard::{RotationInterval, WireguardData}, }; use std::{ @@ -31,7 +31,7 @@ use tokio::{ /// How often to check whether the key has expired. /// A short interval is used in case the computer is ever suspended. -const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60); +const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(5 * 60); /// File that used to store account and device data. const DEVICE_CACHE_FILENAME: &str = "device.json"; @@ -43,10 +43,9 @@ const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4); const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5; const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); -/// How long to keep the known status for [AccountManager::validate_device_cached]. -const DEVICE_VALIDITY_CACHE_DURATION: Duration = Duration::from_secs(30); +const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10); -pub struct DeviceKeyEvent(pub DeviceData); +const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2); #[derive(err_derive::Error, Debug)] pub enum Error { @@ -66,6 +65,54 @@ pub enum Error { OtherRestError(#[error(source)] rest::Error), #[error(display = "The device update task is not running")] DeviceUpdaterCancelled(#[error(source)] oneshot::Canceled), + #[error(display = "The account manager is down")] + AccountManagerDown, +} + +#[derive(Clone)] +pub(crate) enum InnerDeviceEvent { + /// The device was removed due to user (or daemon) action. + Logout, + /// Logged in to a new device. + Login(DeviceData), + /// The device was updated remotely, but not its key. + Updated(DeviceData), + /// The key was rotated. + RotatedKey(DeviceData), + /// Device was removed because it was not found remotely. + Revoked, +} + +impl From<InnerDeviceEvent> for DeviceEvent { + fn from(event: InnerDeviceEvent) -> DeviceEvent { + match event { + InnerDeviceEvent::Logout => DeviceEvent::revoke(false), + InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false), + InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true), + InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false), + InnerDeviceEvent::Revoked => DeviceEvent::revoke(true), + } + } +} + +impl InnerDeviceEvent { + fn data(&self) -> Option<&DeviceData> { + match self { + InnerDeviceEvent::Login(data) => Some(&data), + InnerDeviceEvent::Updated(data) => Some(&data), + InnerDeviceEvent::RotatedKey(data) => Some(&data), + InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None, + } + } + + fn into_data(self) -> Option<DeviceData> { + match self { + InnerDeviceEvent::Login(data) => Some(data), + InnerDeviceEvent::Updated(data) => Some(data), + InnerDeviceEvent::RotatedKey(data) => Some(data), + InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None, + } + } } impl Error { @@ -89,244 +136,291 @@ pub enum ValidationResult { Removed, } -pub(crate) struct AccountManager { - account_service: AccountService, - device_service: DeviceService, - inner: Arc<Mutex<AccountManagerInner>>, - cache_update_tx: - mpsc::UnboundedSender<(Option<DeviceData>, oneshot::Sender<Result<(), Error>>)>, - cache_task_join_handle: Option<tokio::task::JoinHandle<()>>, - key_update_tx: DaemonEventSender<DeviceKeyEvent>, - rotation_abort_handle: Option<AbortHandle>, +type ResponseTx<T> = oneshot::Sender<Result<T, Error>>; + +enum AccountManagerCommand { + Login(AccountToken, ResponseTx<()>), + Logout(ResponseTx<()>), + SetData(DeviceData, ResponseTx<()>), + GetData(ResponseTx<Option<DeviceData>>), + RotateKey(ResponseTx<()>), + SetRotationInterval(RotationInterval, ResponseTx<()>), + GetRotationInterval(ResponseTx<RotationInterval>), + ValidateDevice(ResponseTx<ValidationResult>), + ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>), + Shutdown(oneshot::Sender<()>), } -struct AccountManagerInner { +#[derive(Clone)] +pub(crate) struct AccountManagerHandle { + cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>, + pub account_service: AccountService, + pub device_service: DeviceService, +} + +impl AccountManagerHandle { + pub async fn login(&self, token: AccountToken) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::Login(token, tx)) + .await + } + + pub async fn logout(&self) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::Logout(tx)) + .await + } + + pub async fn set(&self, data: DeviceData) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::SetData(data, tx)) + .await + } + + pub async fn data(&self) -> Result<Option<DeviceData>, Error> { + self.send_command(|tx| AccountManagerCommand::GetData(tx)) + .await + } + + pub async fn rotate_key(&self) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::RotateKey(tx)) + .await + } + + pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> { + self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx)) + .await + } + + pub async fn rotation_interval(&self) -> Result<RotationInterval, Error> { + self.send_command(|tx| AccountManagerCommand::GetRotationInterval(tx)) + .await + } + + pub async fn validate_device(&self) -> Result<ValidationResult, Error> { + self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx)) + .await + } + + pub async fn receive_events( + &self, + events_tx: impl Sender<InnerDeviceEvent> + Send + 'static, + ) -> Result<(), Error> { + self.send_command(|tx| { + AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx) + }) + .await + } + + pub async fn shutdown(self) { + let (tx, rx) = oneshot::channel(); + let _ = self + .cmd_tx + .unbounded_send(AccountManagerCommand::Shutdown(tx)); + let _ = rx.await; + } + + async fn send_command<T>( + &self, + make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand, + ) -> Result<T, Error> { + let (tx, rx) = oneshot::channel(); + self.cmd_tx + .unbounded_send(make_cmd(tx)) + .map_err(|_| Error::AccountManagerDown)?; + rx.await.map_err(|_| Error::AccountManagerDown)? + } +} + +pub(crate) struct AccountManager { + cacher: DeviceCacher, + device_service: DeviceService, data: Option<DeviceData>, rotation_interval: RotationInterval, + listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>, last_validation: Option<SystemTime>, } impl AccountManager { - pub async fn new( + pub async fn spawn( rest_handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle, settings_dir: &Path, - key_update_tx: DaemonEventSender<DeviceKeyEvent>, - ) -> Result<AccountManager, Error> { - let (mut cacher, device_data) = DeviceCacher::new(settings_dir).await?; - let token = device_data.as_ref().map(|state| state.token.clone()); + initial_rotation_interval: RotationInterval, + ) -> Result<AccountManagerHandle, Error> { + let (cacher, data) = DeviceCacher::new(settings_dir).await?; + let token = data.as_ref().map(|state| state.token.clone()); let account_service = spawn_account_service(rest_handle.clone(), token, api_availability.clone()); - let should_start_rotation = device_data.is_some(); - let inner = Arc::new(Mutex::new(AccountManagerInner { - data: device_data, - rotation_interval: RotationInterval::default(), - last_validation: None, - })); - let (cache_update_tx, mut cache_update_rx): ( - _, - mpsc::UnboundedReceiver<(_, oneshot::Sender<Result<(), Error>>)>, - ) = mpsc::unbounded(); - let cache_task_join_handle = tokio::spawn(async move { - while let Some((new_device, result_tx)) = cache_update_rx.next().await { - let result = cacher.write(new_device).await; - if let Err(error) = &result { - log::error!( - "{}", - error.display_chain_with_msg("Failed to update device cache") - ); - } - let _ = result_tx.send(result); - } - }); + let (cmd_tx, cmd_rx) = mpsc::unbounded(); - let mut manager = AccountManager { - account_service, - device_service: DeviceService::new(rest_handle, api_availability), - inner, - cache_update_tx, - cache_task_join_handle: Some(cache_task_join_handle), - key_update_tx, - rotation_abort_handle: None, + let device_service = DeviceService::new(rest_handle, api_availability); + let manager = AccountManager { + cacher, + device_service: device_service.clone(), + data, + rotation_interval: initial_rotation_interval, + listeners: vec![], + last_validation: None, }; - if should_start_rotation { - manager.start_key_rotation(); - } - - Ok(manager) + tokio::spawn(manager.run(cmd_rx)); + let handle = AccountManagerHandle { + cmd_tx, + account_service, + device_service, + }; + KeyUpdater::spawn(handle.clone()).await?; + Ok(handle) } - pub fn account_service(&self) -> AccountService { - self.account_service.clone() + async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) { + let mut shutdown_tx = None; + while let Some(cmd) = cmd_rx.next().await { + match cmd { + AccountManagerCommand::Shutdown(tx) => { + shutdown_tx = Some(tx); + break; + } + other => self.service_command(other).await, + } + } + self.shutdown().await; + if let Some(tx) = shutdown_tx { + let _ = tx.send(()); + } + log::debug!("Account manager has stopped"); } - pub fn device_service(&self) -> DeviceService { - self.device_service.clone() + async fn service_command(&mut self, cmd: AccountManagerCommand) { + match cmd { + AccountManagerCommand::Login(token, tx) => { + let _ = tx.send(self.login(token).await); + } + AccountManagerCommand::Logout(tx) => { + let _ = tx.send(self.logout().await); + } + AccountManagerCommand::SetData(data, tx) => { + let _ = tx.send(self.set(data).await); + } + AccountManagerCommand::GetData(tx) => { + let _ = tx.send(Ok(self.data.clone())); + } + AccountManagerCommand::RotateKey(tx) => { + let _ = tx.send(self.rotate_key().await); + } + AccountManagerCommand::SetRotationInterval(interval, tx) => { + self.rotation_interval = interval; + let _ = tx.send(Ok(())); + } + AccountManagerCommand::GetRotationInterval(tx) => { + let _ = tx.send(Ok(self.rotation_interval)); + } + AccountManagerCommand::ValidateDevice(tx) => { + let _ = tx.send(self.validate_device().await); + } + AccountManagerCommand::ReceiveEvents(events_tx, tx) => { + let _ = tx.send(Ok(self.listeners.push(events_tx))); + } + AccountManagerCommand::Shutdown(_) => unreachable!("shutdown is handled earlier"), + } } - pub async fn login(&mut self, token: AccountToken) -> Result<DeviceData, Error> { + async fn login(&mut self, token: AccountToken) -> Result<(), Error> { let data = self.device_service.generate_for_account(token).await?; - self.logout(); - let (result_tx, result_rx) = oneshot::channel(); - let _ = self - .cache_update_tx - .unbounded_send((Some(data.clone()), result_tx)); - { - let mut inner = self.inner.lock().unwrap(); - inner.data.replace(data.clone()); - } - if let Err(error) = flatten_result(result_rx.await.map_err(Error::DeviceUpdaterCancelled)) { - // Delete the device if an I/O error occurred - self.logout(); - return Err(error); - } - self.start_key_rotation(); - - Ok(data) + self.set(data).await?; + Ok(()) } - pub async fn set(&mut self, data: DeviceData) -> Result<(), Error> { - self.stop_key_rotation(); + async fn logout(&mut self) -> Result<(), Error> { + if self.data.is_some() { + self.cacher.write(None).await?; + let _ = tokio::time::timeout(LOGOUT_TIMEOUT, self.logout_inner()).await; - let (result_tx, result_rx) = oneshot::channel(); - let _ = self - .cache_update_tx - .unbounded_send((Some(data.clone()), result_tx)); - - let old_data = { - let mut inner = self.inner.lock().unwrap(); - inner.data.replace(data.clone()) - }; - - if let Err(error) = flatten_result(result_rx.await.map_err(Error::DeviceUpdaterCancelled)) { - // Delete the device if an I/O error occurred - self.logout(); - return Err(error); + let event = InnerDeviceEvent::Logout; + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); } - - if let Some(old_data) = old_data { - // Log out the previous device if the id differs - if !old_data.device.eq_id(&data.device) { - let service = self.device_service.clone(); - tokio::spawn(async move { - if let Err(error) = service - .remove_device_with_backoff(old_data.token, old_data.device.id) - .await - { - log::error!( - "{}", - error.display_chain_with_msg("Failed to remove a previous device") - ); - } - }); - } - } - self.start_key_rotation(); Ok(()) } - /// Log out without waiting for the result. - pub fn logout(&mut self) { - let fut = self.logout_inner(true); + async fn logout_inner(&mut self) -> tokio::task::JoinHandle<()> { + let prev_data = self.data.take(); + let service = self.device_service.clone(); + tokio::spawn(async move { - let result = fut.await; - if let Err(error) = result { - log::error!( - "{}", - error.display_chain_with_msg("Failed to remove a previous device") - ); + if let Some(data) = prev_data { + if let Err(error) = service + .remove_device_with_backoff(data.token, data.device.id) + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to remove a previous device") + ); + } } - }); + }) } - /// Log out, and wait until the API has removed the device. - #[cfg(not(target_os = "android"))] - pub fn logout_wait(&mut self) -> impl Future<Output = Result<(), Error>> { - self.logout_inner(false) + #[inline] + async fn set(&mut self, new_data: DeviceData) -> Result<(), Error> { + self.set_inner(InnerDeviceEvent::Login(new_data)).await } - fn logout_inner(&mut self, use_backoff: bool) -> impl Future<Output = Result<(), Error>> { - self.stop_key_rotation(); - let data = { - let mut inner = self.inner.lock().unwrap(); - let (result_tx, _result_rx) = oneshot::channel(); - let _ = self.cache_update_tx.unbounded_send((None, result_tx)); - // NOTE: No need to wait on cache update - inner.data.take() - }; - let service = self.device_service.clone(); - async move { - if let Some(data) = data { - if use_backoff { - return service - .remove_device_with_backoff(data.token, data.device.id) - .await; - } else { - return service.remove_device(data.token, data.device.id).await; - } - } - Ok(()) + async fn set_inner(&mut self, event: InnerDeviceEvent) -> Result<(), Error> { + let data = event.data(); + if data == self.data.as_ref() { + return Ok(()); } - } - pub async fn rotate_key(&mut self) -> Result<WireguardData, Error> { - let mut data = { - let inner = self.inner.lock().unwrap(); - inner.data.as_ref().ok_or(Error::NoDevice)?.clone() - }; - self.stop_key_rotation(); - let result = self - .device_service - .rotate_key(data.token.clone(), data.device.id.clone()) - .await; - if let Ok(ref wg_data) = result { - data.wg_data = wg_data.clone(); - data.device.pubkey = wg_data.private_key.public_key(); - let mut inner = self.inner.lock().unwrap(); - inner.data.replace(data.clone()); - let (result_tx, _result_rx) = oneshot::channel(); - let _ = self - .cache_update_tx - .unbounded_send((Some(data.clone()), result_tx)); - // NOTE: No need to wait on cache update - let _ = self.key_update_tx.send(DeviceKeyEvent(data)); + self.cacher.write(data).await?; + + if self + .data + .as_ref() + .map(|current| data.as_ref().map(|d| &d.device.id) != Some(¤t.device.id)) + .unwrap_or(false) + { + // Remove the existing device if its ID differs. Otherwise, only update + // the data. + self.logout_inner().await; } - self.start_key_rotation(); - result - } - pub fn data(&self) -> Option<DeviceData> { - self.inner.lock().unwrap().data.clone() - } + self.data = data.cloned(); - pub fn has_data(&self) -> bool { - self.inner.lock().unwrap().data.is_some() + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); + + Ok(()) } - pub async fn set_rotation_interval(&mut self, interval: RotationInterval) { - self.stop_key_rotation(); - let restart_rotation = { - let mut inner = self.inner.lock().unwrap(); - inner.rotation_interval = interval; - inner.data.is_some() - }; - if restart_rotation { - self.start_key_rotation(); - } + async fn rotate_key(&mut self) -> Result<(), Error> { + // TODO: Update all data opportunistically? + let data = self.data.as_ref().ok_or(Error::NoDevice)?; + + let wg_data = self + .device_service + .rotate_key(data.token.clone(), data.device.id.clone()) + .await?; + + // Copy the data to keep a predictable state if an error occurs. + let mut new_data = data.clone(); + new_data.device.pubkey = wg_data.private_key.public_key(); + new_data.wg_data = wg_data; + self.set_inner(InnerDeviceEvent::RotatedKey(new_data)).await } /// Check if the device is valid for the account, and yank it if it no longer exists. /// This also updates any associated data and returns whether it changed. - pub async fn validate_device(&mut self) -> Result<ValidationResult, Error> { - let mut data = { - let inner = self.inner.lock().unwrap(); - inner.data.as_ref().ok_or(Error::NoDevice)?.clone() - }; - + async fn validate_device(&mut self) -> Result<ValidationResult, Error> { log::debug!("Checking whether the device is still valid"); + if let Some(result) = self.cached_validation() { + log::debug!("The current device is still valid"); + return Ok(result); + } + + let data = self.data.as_ref().ok_or(Error::NoDevice)?; + match self .device_service .get(data.token.clone(), data.device.id.clone()) @@ -339,13 +433,12 @@ impl AccountManager { Ok(ValidationResult::Valid) } else { log::debug!("Updating data for the current device"); - data.device = device; - { - let mut inner = self.inner.lock().unwrap(); - inner.data.replace(data.clone()); - let (result_tx, _result_rx) = oneshot::channel(); - let _ = self.cache_update_tx.unbounded_send((Some(data), result_tx)); - } + // Copy the data to keep a predictable state if an error occurs. + let new_data = DeviceData { + device, + ..data.clone() + }; + self.set_inner(InnerDeviceEvent::Updated(new_data)).await?; Ok(ValidationResult::Updated) } } else { @@ -356,123 +449,153 @@ impl AccountManager { } Err(Error::InvalidAccount) | Err(Error::InvalidDevice) => { log::debug!("The current device is no longer valid for this account"); - self.stop_key_rotation(); - { - self.inner.lock().unwrap().data.take(); - let (result_tx, _result_rx) = oneshot::channel(); - let _ = self.cache_update_tx.unbounded_send((None, result_tx)); - // NOTE: No need to wait on cache update - } + + self.cacher.write(None).await?; + self.data = None; + + let event = InnerDeviceEvent::Revoked; + self.listeners + .retain(|listener| listener.send(event.clone()).is_ok()); + Ok(ValidationResult::Removed) } Err(error) => Err(error), } } - /// Same as [Self::validate_device] but returns [ValidationResult::Valid] (or [Error::NoDevice]) - /// if the last check was recent. - pub async fn validate_device_cached(&mut self) -> Result<ValidationResult, Error> { - let last_validation = { - let inner = self.inner.lock().unwrap(); - if inner.data.is_none() { - return Err(Error::NoDevice); - } - inner.last_validation.clone() - }; + fn cached_validation(&mut self) -> Option<ValidationResult> { + if self.data.is_none() { + return None; + } - if last_validation - .and_then(|last_check| SystemTime::now().duration_since(last_check).ok()) - .map(|elapsed| elapsed < DEVICE_VALIDITY_CACHE_DURATION) - .unwrap_or(false) - { - return Ok(ValidationResult::Valid); + let now = SystemTime::now(); + + let elapsed = self + .last_validation + .and_then(|last_check| now.duration_since(last_check).ok()) + .unwrap_or(VALIDITY_CACHE_TIMEOUT); + + if elapsed >= VALIDITY_CACHE_TIMEOUT { + self.last_validation = Some(now); + return None; } - let result = self.validate_device().await; - let mut inner = self.inner.lock().unwrap(); - inner.last_validation = Some(SystemTime::now()); - result + Some(ValidationResult::Valid) + } + + async fn shutdown(self) { + self.cacher.finalize().await; } +} - fn start_key_rotation(&mut self) { - self.stop_key_rotation(); +struct KeyUpdater { + handle: AccountManagerHandle, + rx: mpsc::UnboundedReceiver<InnerDeviceEvent>, + data: Option<DeviceData>, +} - let service = self.device_service.clone(); - let inner = self.inner.clone(); - let cache_update_tx = self.cache_update_tx.clone(); - let key_update_tx = self.key_update_tx.clone(); +impl KeyUpdater { + async fn spawn(handle: AccountManagerHandle) -> Result<(), Error> { + let (tx, rx) = mpsc::unbounded(); + handle.receive_events(tx).await?; + let data = handle.data().await?; + let mut key_rotator = KeyUpdater { handle, rx, data }; - let (task, abort_handle) = abortable(async move { + tokio::spawn(async move { loop { tokio::time::sleep(KEY_CHECK_INTERVAL).await; - let rotation_interval = { inner.lock().unwrap().rotation_interval.clone() }; - - let mut state = { - match inner.lock().unwrap().data.as_ref() { - Some(device_config) => device_config.clone(), - None => continue, - } - }; - - if (chrono::Utc::now() - .signed_duration_since(state.wg_data.created) - .num_seconds() as u64) - < rotation_interval.as_duration().as_secs() - { - continue; - } - - match service - .rotate_key_with_backoff(state.token.clone(), state.device.id.clone()) - .await - { - Ok(wg_data) => { - state.device.pubkey = wg_data.private_key.public_key(); - state.wg_data = wg_data; - { - let mut inner = inner.lock().unwrap(); - inner.data.replace(state.clone()); - let (result_tx, _result_rx) = oneshot::channel(); - let _ = - cache_update_tx.unbounded_send((Some(state.clone()), result_tx)); - // NOTE: No need to wait on cache update - } - let _ = key_update_tx.send(DeviceKeyEvent(state)); - } - Err(error) => { - log::debug!("{}", error.display_chain_with_msg("Stopping key rotation")); + if let Err(error) = key_rotator.check_key_validity().await { + if let Error::AccountManagerDown = error { + break; } + log::error!( + "{}", + error.display_chain_with_msg("Stopping key rotation task due to an error") + ); + break; } } + log::debug!("Stopping key updater"); }); - tokio::spawn(task); - self.rotation_abort_handle = Some(abort_handle); + + Ok(()) } - fn stop_key_rotation(&mut self) { - if let Some(abort_handle) = self.rotation_abort_handle.take() { - abort_handle.abort(); + async fn check_key_validity(&mut self) -> Result<(), Error> { + let rotation_interval = self.handle.rotation_interval().await?; + let data = self.wait_for_data().await?; + + if (chrono::Utc::now() + .signed_duration_since(data.wg_data.created) + .num_seconds() as u64) + < rotation_interval.as_duration().as_secs() + { + return Ok(()); } - } - /// Consumes the object and completes when there is nothing left to write to - /// the cache file. - pub fn finalize(mut self) -> impl Future<Output = ()> { - let join_handle = self.cache_task_join_handle.take(); - drop(self); + let mut data = data.clone(); + + let rotation_fut = self + .handle + .device_service + .rotate_key_with_backoff(data.token.clone(), data.device.id.clone()); + + match futures::future::select(Box::pin(rotation_fut), self.rx.next()).await { + futures::future::Either::Left((Ok(wg_data), _)) => { + log::debug!("Rotating WireGuard key"); + data.device.pubkey = wg_data.private_key.public_key(); + data.wg_data = wg_data; + self.handle.set(data).await?; + } + futures::future::Either::Left((Err(error), _)) => { + log::error!( + "{}", + error.display_chain_with_msg("Stopping key rotation due to an error") + ); - async move { - if let Some(join_handle) = join_handle { - let _ = join_handle.await; + // Forget the current device. Key rotation will restart when + // it is updated in any way. + self.data = None; + } + futures::future::Either::Right((event, _)) => { + // Abort key rotation if the device changed + if let Some(event) = event { + self.data = event.into_data(); + } else { + return Err(Error::AccountManagerDown); + } } } + + Ok(()) } -} -impl Drop for AccountManager { - fn drop(&mut self) { - self.stop_key_rotation(); + async fn wait_for_data(&mut self) -> Result<&DeviceData, Error> { + while let Ok(item) = self.rx.try_next() { + match item { + Some(event) => { + self.data = event.into_data(); + } + None => return Err(Error::AccountManagerDown), + } + } + + match self.data { + Some(ref data) => Ok(data), + None => loop { + let event = self.rx.next().await; + match event { + Some(event) => { + if let Some(data) = event.into_data() { + self.data = Some(data); + break Ok(self.data.as_ref().unwrap()); + } + } + None => break Err(Error::AccountManagerDown), + } + }, + } } } @@ -741,7 +864,7 @@ impl DeviceCacher { )) } - pub async fn write(&mut self, device: Option<DeviceData>) -> Result<(), Error> { + pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> { let data = serde_json::to_vec_pretty(&device).unwrap(); self.file.get_mut().set_len(0).await?; @@ -763,6 +886,11 @@ impl DeviceCacher { tokio::fs::remove_file(path).await?; Ok(()) } + + async fn finalize(self) { + let std_file = self.file.into_inner().into_std().await; + let _ = tokio::task::spawn_blocking(move || drop(std_file)).await; + } } #[derive(Clone)] @@ -944,12 +1072,3 @@ fn retry_strategy() -> Jittered<ExponentialBackoff> { .max_delay(RETRY_BACKOFF_INTERVAL_MAX), ) } - -fn flatten_result<T, E>( - result: std::result::Result<std::result::Result<T, E>, E>, -) -> std::result::Result<T, E> { - match result { - Ok(value) => value, - Err(err) => Err(err), - } -} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 9277d45335..2be83a3ed9 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -25,6 +25,7 @@ pub mod version; mod version_check; use crate::target_state::PersistentTargetState; +use device::InnerDeviceEvent; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Future}, @@ -62,7 +63,10 @@ use std::{ net::{IpAddr, Ipv4Addr}, path::PathBuf, pin::Pin, - sync::{mpsc as sync_mpsc, Arc, Weak}, + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc as sync_mpsc, Arc, Weak, + }, time::Duration, }; #[cfg(any(target_os = "linux", windows))] @@ -76,8 +80,7 @@ use talpid_types::android::AndroidContext; use talpid_types::{ net::{ openvpn::{self, ProxySettings}, - wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, - TunnelType, + wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType, }, tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, @@ -349,8 +352,8 @@ pub(crate) enum InternalDaemonEvent { NewAppVersionInfo(AppVersionInfo), /// Request from REST client to use a different API endpoint. GenerateApiConnectionMode(api::ApiConnectionModeRequest), - /// Sent when a device key is rotated. - DeviceKeyEvent(device::DeviceKeyEvent), + /// Sent when a device is updated in any way (key rotation, login, logout, etc.). + DeviceEvent(InnerDeviceEvent), /// Handles updates from versions without devices. DeviceMigrationEvent(DeviceData), /// The split tunnel paths or state were updated. @@ -388,9 +391,9 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent { } } -impl From<device::DeviceKeyEvent> for InternalDaemonEvent { - fn from(event: device::DeviceKeyEvent) -> Self { - InternalDaemonEvent::DeviceKeyEvent(event) +impl From<InnerDeviceEvent> for InternalDaemonEvent { + fn from(event: InnerDeviceEvent) -> Self { + InternalDaemonEvent::DeviceEvent(event) } } @@ -575,9 +578,9 @@ pub struct Daemon<L: EventListener> { event_listener: L, settings: SettingsPersister, account_history: account_history::AccountHistory, - account_manager: device::AccountManager, + account_manager: device::AccountManagerHandle, wg_retry_attempt: usize, - wg_check_validity: bool, + wg_check_validity: Arc<AtomicBool>, rpc_runtime: mullvad_rpc::MullvadRpcRuntime, rpc_handle: mullvad_rpc::rest::MullvadRestHandle, version_updater_handle: version_check::VersionUpdaterHandle, @@ -657,28 +660,35 @@ where tx: internal_event_tx.clone(), }; - let mut account_manager = device::AccountManager::new( + let account_manager = device::AccountManager::spawn( rpc_handle.clone(), api_availability.clone(), &settings_dir, - internal_event_tx.to_specialized_sender(), + settings + .tunnel_options + .wireguard + .rotation_interval + .unwrap_or_default(), ) .await .map_err(Error::LoadAccountManager)?; - if let Some(rotation_interval) = settings.tunnel_options.wireguard.rotation_interval { - account_manager - .set_rotation_interval(rotation_interval) - .await; - } + account_manager + .receive_events(internal_event_tx.to_specialized_sender()) + .await + .map_err(Error::LoadAccountManager)?; + let data = account_manager + .data() + .await + .map_err(Error::LoadAccountManager)?; let account_history = account_history::AccountHistory::new( &settings_dir, - account_manager.data().map(|device| device.token), + data.as_ref().map(|device| device.token.clone()), ) .await .map_err(Error::LoadAccountHistory)?; - let target_state = if !account_manager.has_data() { + let target_state = if data.is_none() { PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await } else if settings.auto_connect { log::info!("Automatically connecting since auto-connect is turned on"); @@ -776,7 +786,7 @@ where account_history, account_manager, wg_retry_attempt: 0, - wg_check_validity: false, + wg_check_validity: Arc::new(AtomicBool::new(true)), rpc_runtime, rpc_handle, version_updater_handle, @@ -892,15 +902,15 @@ where let Daemon { event_listener, mut shutdown_tasks, - account_manager, rpc_runtime, tunnel_state_machine_handle, target_state, + account_manager, .. } = self; shutdown_tasks.push(Box::pin(target_state.finalize())); - shutdown_tasks.insert(0, Box::pin(account_manager.finalize())); + shutdown_tasks.push(Box::pin(account_manager.shutdown())); ( event_listener, @@ -931,7 +941,7 @@ where GenerateApiConnectionMode(request) => { self.handle_generate_api_connection_mode(request).await } - DeviceKeyEvent(event) => self.handle_device_key_event(event).await, + DeviceEvent(event) => self.handle_device_event(event).await, DeviceMigrationEvent(event) => self.handle_device_migration_event(event).await, #[cfg(windows)] ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await, @@ -961,6 +971,8 @@ where TunnelStateTransition::Error(error_state) => TunnelState::Error(error_state), }; + self.maybe_validate_device(&tunnel_state); + if !tunnel_state.is_connected() { // Cancel reconnects except when entering the connected state. // Exempt the latter because a reconnect scheduled while connecting should not be @@ -992,10 +1004,7 @@ where } self.tunnel_state = tunnel_state.clone(); - self.event_listener.notify_new_state(tunnel_state.clone()); - - // Check device validity last so that the broadcast is not delayed. - self.maybe_validate_device(&tunnel_state).await; + self.event_listener.notify_new_state(tunnel_state); } async fn reset_rpc_sockets_on_tunnel_state_transition( @@ -1012,34 +1021,34 @@ where } /// Check whether the device is valid after a number of failed connection attempts. - async fn maybe_validate_device(&mut self, tunnel_state: &TunnelState) { + fn maybe_validate_device(&mut self, tunnel_state: &TunnelState) { match tunnel_state { TunnelState::Connecting { endpoint, .. } => { if endpoint.tunnel_type != TunnelType::Wireguard { return; } self.wg_retry_attempt += 1; - if self.wg_check_validity && self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 - { - match self.account_manager.validate_device_cached().await { - Ok(status) => { - self.handle_validation_result(status); - self.wg_check_validity = false; + if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 { + let handle = self.account_manager.clone(); + let check_validity = self.wg_check_validity.clone(); + tokio::spawn(async move { + if !check_validity.swap(false, Ordering::SeqCst) { + return; } - Err(error) => { + if let Err(error) = handle.validate_device().await { log::error!( "{}", error.display_chain_with_msg("Failed to check device validity") ); - if !error.is_network_error() { - self.wg_check_validity = false; + if error.is_network_error() { + check_validity.store(true, Ordering::SeqCst); } } - } + }); } } TunnelState::Connected { .. } | TunnelState::Disconnected => { - self.wg_check_validity = true; + self.wg_check_validity.store(true, Ordering::SeqCst); self.wg_retry_attempt = 0; } _ => (), @@ -1053,7 +1062,7 @@ where >, retry_attempt: u32, ) { - if let Some(device) = self.account_manager.data() { + if let Ok(Some(device)) = self.account_manager.data().await { let result = match self.settings.get_relay_settings() { RelaySettings::CustomTunnelEndpoint(custom_relay) => { self.last_generated_relay = None; @@ -1199,6 +1208,8 @@ where let wg_data = self .account_manager .data() + .await + .map_err(|_| Error::NoKeyAvailable)? .map(|device| device.wg_data) .ok_or(Error::NoKeyAvailable)?; let tunnel = wireguard::TunnelConfig { @@ -1224,21 +1235,6 @@ where } } - // Emit the appropriate events for an updated device. - fn handle_validation_result(&mut self, result: device::ValidationResult) { - match result { - device::ValidationResult::RotatedKey | device::ValidationResult::Valid => (), - device::ValidationResult::Removed => { - self.event_listener - .notify_device_event(DeviceEvent::revoke(true)); - } - device::ValidationResult::Updated => { - self.event_listener - .notify_device_event(DeviceEvent::new(self.account_manager.data(), true)); - } - } - } - fn schedule_reconnect(&mut self, delay: Duration) { self.unschedule_reconnect(); @@ -1438,31 +1434,21 @@ where let _ = request.response_tx.send(config); } - async fn handle_device_key_event(&mut self, event: device::DeviceKeyEvent) { - let device_id = &event.0.device.id; - if Some(device_id) - != self - .account_manager - .data() - .map(|device| device.device.id) - .as_ref() - { - // Stale config - return; - } - if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY); + async fn handle_device_event(&mut self, event: InnerDeviceEvent) { + if let InnerDeviceEvent::RotatedKey(_) = &event { + if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { + self.schedule_reconnect(WG_RECONNECT_DELAY); + } } self.event_listener - .notify_device_event(DeviceEvent::from_device(event.0, false)); + .notify_device_event(DeviceEvent::from(event)); } async fn handle_device_migration_event(&mut self, data: DeviceData) { - if self.account_manager.has_data() { + if let Ok(Some(_)) = self.account_manager.data().await { // Discard stale device return; } - let event = DeviceEvent::from_device(data.clone(), false); if let Err(error) = self.account_manager.set(data).await { log::error!( "{}", @@ -1470,7 +1456,6 @@ where ); } self.reconnect_tunnel(); - self.event_listener.notify_device_event(event); } #[cfg(windows)] @@ -1604,12 +1589,12 @@ where } async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) { - if self.account_manager.has_data() { + if let Ok(Some(_)) = self.account_manager.data().await { let _ = tx.send(Err(Error::AlreadyLoggedIn)); return; } let daemon_tx = self.tx.clone(); - let future = self.account_manager.account_service().create_account(); + let future = self.account_manager.account_service.create_account(); tokio::spawn(async move { match future.await { Ok(account_token) => { @@ -1627,7 +1612,7 @@ where tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>, account_token: AccountToken, ) { - let account = self.account_manager.account_service(); + let account = self.account_manager.account_service.clone(); tokio::spawn(async move { let result = account.check_expiry(account_token).await; Self::oneshot_send( @@ -1639,19 +1624,18 @@ where } async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) { - if let Some(device) = self.account_manager.data() { + if let Ok(Some(device)) = self.account_manager.data().await { let future = self .account_manager - .account_service() + .account_service .get_www_auth_token(device.token); - let rpc_call = async { + tokio::spawn(async { Self::oneshot_send( tx, future.await.map_err(Error::RestError), "get_www_auth_token response", ); - }; - tokio::spawn(rpc_call); + }); } else { Self::oneshot_send( tx, @@ -1666,8 +1650,8 @@ where tx: ResponseTx<VoucherSubmission, Error>, voucher: String, ) { - if let Some(device) = self.account_manager.data() { - let mut account = self.account_manager.account_service(); + if let Ok(Some(device)) = self.account_manager.data().await { + let mut account = self.account_manager.account_service.clone(); tokio::spawn(async move { Self::oneshot_send( tx, @@ -1724,25 +1708,28 @@ where } async fn set_account(&mut self, account_token: Option<String>) -> Result<bool, Error> { - let previous_token = self.account_manager.data().map(|device| device.token); + let previous_token = self + .account_manager + .data() + .await + .unwrap_or(None) + .map(|device| device.token); if previous_token == account_token { return Ok(false); } match account_token.clone() { Some(token) => { - let device_data = self - .account_manager + self.account_manager .login(token) .await .map_err(Error::LoginError)?; - self.event_listener - .notify_device_event(DeviceEvent::from_device(device_data, false)); } None => { - self.account_manager.logout(); - self.event_listener - .notify_device_event(DeviceEvent::revoke(false)); + self.account_manager + .logout() + .await + .map_err(Error::LogoutError)?; } } @@ -1761,8 +1748,7 @@ where async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceConfig>, Error>) { // Make sure the device is updated match self.account_manager.validate_device().await { - Ok(status) => self.handle_validation_result(status), - Err(device::Error::NoDevice) => (), + Ok(_) | Err(device::Error::NoDevice) => (), Err(error) => { log::error!( "{}", @@ -1773,7 +1759,12 @@ where Self::oneshot_send( tx, - Ok(self.account_manager.data().map(DeviceConfig::from)), + Ok(self + .account_manager + .data() + .await + .unwrap_or(None) + .map(DeviceConfig::from)), "get_device response", ); } @@ -1782,7 +1773,7 @@ where Self::oneshot_send( tx, self.account_manager - .device_service() + .device_service .list_devices(token) .await .map_err(Error::ListDevicesError), @@ -1796,7 +1787,7 @@ where token: AccountToken, device_id: DeviceId, ) { - let device_service = self.account_manager.device_service(); + let device_service = self.account_manager.device_service.clone(); let event_listener = self.event_listener.clone(); tokio::spawn(async move { @@ -1898,7 +1889,7 @@ where async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) { let mut last_error = Ok(()); - if let Err(error) = self.account_manager.logout_wait().await { + if let Err(error) = self.account_manager.logout().await { log::error!( "{}", error.display_chain_with_msg("Failed to clear device cache") @@ -2414,9 +2405,16 @@ where Ok(settings_changed) => { Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response"); if settings_changed { - self.account_manager + if let Err(error) = self + .account_manager .set_rotation_interval(interval.unwrap_or_default()) - .await; + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update rotation interval") + ); + } self.event_listener .notify_settings(self.settings.to_settings()); } @@ -2434,7 +2432,7 @@ where } async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<PublicKey>, Error>) { - let result = if let Some(device) = self.account_manager.data() { + let result = if let Ok(Some(device)) = self.account_manager.data().await { Ok(Some(device.wg_data.get_public_key())) } else { Err(Error::NoAccountToken) |
