summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-03-03 12:57:40 +0100
committerDavid Lönnhager <david.l@mullvad.net>2022-03-14 12:08:54 +0100
commit0985a987a10d19078efa30eba7fdde0821dcf2be (patch)
treeb2c2a2f0e21d5af66d0a725dafbe998af0811137
parente7bcd784a7f25be8b3b50e75b8adfad446125bc7 (diff)
downloadmullvadvpn-0985a987a10d19078efa30eba7fdde0821dcf2be.tar.xz
mullvadvpn-0985a987a10d19078efa30eba7fdde0821dcf2be.zip
Refactor account manager into actor
-rw-r--r--mullvad-cli/src/cmds/account.rs1
-rw-r--r--mullvad-daemon/src/device.rs715
-rw-r--r--mullvad-daemon/src/lib.rs198
3 files changed, 515 insertions, 399 deletions
diff --git a/mullvad-cli/src/cmds/account.rs b/mullvad-cli/src/cmds/account.rs
index 09c268f3c0..b4ef7c7f14 100644
--- a/mullvad-cli/src/cmds/account.rs
+++ b/mullvad-cli/src/cmds/account.rs
@@ -202,7 +202,6 @@ impl Account {
async fn revoke_device(&self, matches: &clap::ArgMatches) -> Result<()> {
let mut rpc = new_rpc_client().await?;
-
let token = self.parse_account_else_current(&mut rpc, matches).await?;
let device_to_revoke = parse_device_name(matches);
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs
index f4eba59d2d..0a078c19f4 100644
--- a/mullvad-daemon/src/device.rs
+++ b/mullvad-daemon/src/device.rs
@@ -11,7 +11,7 @@ use mullvad_rpc::{
};
use mullvad_types::{
account::{AccountToken, VoucherSubmission},
- device::{Device, DeviceData, DeviceId},
+ device::{Device, DeviceData, DeviceEvent, DeviceId},
wireguard::{RotationInterval, WireguardData},
};
use std::{
@@ -31,7 +31,7 @@ use tokio::{
/// How often to check whether the key has expired.
/// A short interval is used in case the computer is ever suspended.
-const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(60);
+const KEY_CHECK_INTERVAL: Duration = Duration::from_secs(5 * 60);
/// File that used to store account and device data.
const DEVICE_CACHE_FILENAME: &str = "device.json";
@@ -43,10 +43,9 @@ const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4);
const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5;
const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60);
-/// How long to keep the known status for [AccountManager::validate_device_cached].
-const DEVICE_VALIDITY_CACHE_DURATION: Duration = Duration::from_secs(30);
+const VALIDITY_CACHE_TIMEOUT: Duration = Duration::from_secs(10);
-pub struct DeviceKeyEvent(pub DeviceData);
+const LOGOUT_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(err_derive::Error, Debug)]
pub enum Error {
@@ -66,6 +65,54 @@ pub enum Error {
OtherRestError(#[error(source)] rest::Error),
#[error(display = "The device update task is not running")]
DeviceUpdaterCancelled(#[error(source)] oneshot::Canceled),
+ #[error(display = "The account manager is down")]
+ AccountManagerDown,
+}
+
+#[derive(Clone)]
+pub(crate) enum InnerDeviceEvent {
+ /// The device was removed due to user (or daemon) action.
+ Logout,
+ /// Logged in to a new device.
+ Login(DeviceData),
+ /// The device was updated remotely, but not its key.
+ Updated(DeviceData),
+ /// The key was rotated.
+ RotatedKey(DeviceData),
+ /// Device was removed because it was not found remotely.
+ Revoked,
+}
+
+impl From<InnerDeviceEvent> for DeviceEvent {
+ fn from(event: InnerDeviceEvent) -> DeviceEvent {
+ match event {
+ InnerDeviceEvent::Logout => DeviceEvent::revoke(false),
+ InnerDeviceEvent::Login(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Updated(data) => DeviceEvent::from_device(data, true),
+ InnerDeviceEvent::RotatedKey(data) => DeviceEvent::from_device(data, false),
+ InnerDeviceEvent::Revoked => DeviceEvent::revoke(true),
+ }
+ }
+}
+
+impl InnerDeviceEvent {
+ fn data(&self) -> Option<&DeviceData> {
+ match self {
+ InnerDeviceEvent::Login(data) => Some(&data),
+ InnerDeviceEvent::Updated(data) => Some(&data),
+ InnerDeviceEvent::RotatedKey(data) => Some(&data),
+ InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
+ }
+ }
+
+ fn into_data(self) -> Option<DeviceData> {
+ match self {
+ InnerDeviceEvent::Login(data) => Some(data),
+ InnerDeviceEvent::Updated(data) => Some(data),
+ InnerDeviceEvent::RotatedKey(data) => Some(data),
+ InnerDeviceEvent::Logout | InnerDeviceEvent::Revoked => None,
+ }
+ }
}
impl Error {
@@ -89,244 +136,291 @@ pub enum ValidationResult {
Removed,
}
-pub(crate) struct AccountManager {
- account_service: AccountService,
- device_service: DeviceService,
- inner: Arc<Mutex<AccountManagerInner>>,
- cache_update_tx:
- mpsc::UnboundedSender<(Option<DeviceData>, oneshot::Sender<Result<(), Error>>)>,
- cache_task_join_handle: Option<tokio::task::JoinHandle<()>>,
- key_update_tx: DaemonEventSender<DeviceKeyEvent>,
- rotation_abort_handle: Option<AbortHandle>,
+type ResponseTx<T> = oneshot::Sender<Result<T, Error>>;
+
+enum AccountManagerCommand {
+ Login(AccountToken, ResponseTx<()>),
+ Logout(ResponseTx<()>),
+ SetData(DeviceData, ResponseTx<()>),
+ GetData(ResponseTx<Option<DeviceData>>),
+ RotateKey(ResponseTx<()>),
+ SetRotationInterval(RotationInterval, ResponseTx<()>),
+ GetRotationInterval(ResponseTx<RotationInterval>),
+ ValidateDevice(ResponseTx<ValidationResult>),
+ ReceiveEvents(Box<dyn Sender<InnerDeviceEvent> + Send>, ResponseTx<()>),
+ Shutdown(oneshot::Sender<()>),
}
-struct AccountManagerInner {
+#[derive(Clone)]
+pub(crate) struct AccountManagerHandle {
+ cmd_tx: mpsc::UnboundedSender<AccountManagerCommand>,
+ pub account_service: AccountService,
+ pub device_service: DeviceService,
+}
+
+impl AccountManagerHandle {
+ pub async fn login(&self, token: AccountToken) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Login(token, tx))
+ .await
+ }
+
+ pub async fn logout(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::Logout(tx))
+ .await
+ }
+
+ pub async fn set(&self, data: DeviceData) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetData(data, tx))
+ .await
+ }
+
+ pub async fn data(&self) -> Result<Option<DeviceData>, Error> {
+ self.send_command(|tx| AccountManagerCommand::GetData(tx))
+ .await
+ }
+
+ pub async fn rotate_key(&self) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::RotateKey(tx))
+ .await
+ }
+
+ pub async fn set_rotation_interval(&self, interval: RotationInterval) -> Result<(), Error> {
+ self.send_command(|tx| AccountManagerCommand::SetRotationInterval(interval, tx))
+ .await
+ }
+
+ pub async fn rotation_interval(&self) -> Result<RotationInterval, Error> {
+ self.send_command(|tx| AccountManagerCommand::GetRotationInterval(tx))
+ .await
+ }
+
+ pub async fn validate_device(&self) -> Result<ValidationResult, Error> {
+ self.send_command(|tx| AccountManagerCommand::ValidateDevice(tx))
+ .await
+ }
+
+ pub async fn receive_events(
+ &self,
+ events_tx: impl Sender<InnerDeviceEvent> + Send + 'static,
+ ) -> Result<(), Error> {
+ self.send_command(|tx| {
+ AccountManagerCommand::ReceiveEvents(Box::new(events_tx) as Box<_>, tx)
+ })
+ .await
+ }
+
+ pub async fn shutdown(self) {
+ let (tx, rx) = oneshot::channel();
+ let _ = self
+ .cmd_tx
+ .unbounded_send(AccountManagerCommand::Shutdown(tx));
+ let _ = rx.await;
+ }
+
+ async fn send_command<T>(
+ &self,
+ make_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> AccountManagerCommand,
+ ) -> Result<T, Error> {
+ let (tx, rx) = oneshot::channel();
+ self.cmd_tx
+ .unbounded_send(make_cmd(tx))
+ .map_err(|_| Error::AccountManagerDown)?;
+ rx.await.map_err(|_| Error::AccountManagerDown)?
+ }
+}
+
+pub(crate) struct AccountManager {
+ cacher: DeviceCacher,
+ device_service: DeviceService,
data: Option<DeviceData>,
rotation_interval: RotationInterval,
+ listeners: Vec<Box<dyn Sender<InnerDeviceEvent> + Send>>,
last_validation: Option<SystemTime>,
}
impl AccountManager {
- pub async fn new(
+ pub async fn spawn(
rest_handle: rest::MullvadRestHandle,
api_availability: ApiAvailabilityHandle,
settings_dir: &Path,
- key_update_tx: DaemonEventSender<DeviceKeyEvent>,
- ) -> Result<AccountManager, Error> {
- let (mut cacher, device_data) = DeviceCacher::new(settings_dir).await?;
- let token = device_data.as_ref().map(|state| state.token.clone());
+ initial_rotation_interval: RotationInterval,
+ ) -> Result<AccountManagerHandle, Error> {
+ let (cacher, data) = DeviceCacher::new(settings_dir).await?;
+ let token = data.as_ref().map(|state| state.token.clone());
let account_service =
spawn_account_service(rest_handle.clone(), token, api_availability.clone());
- let should_start_rotation = device_data.is_some();
- let inner = Arc::new(Mutex::new(AccountManagerInner {
- data: device_data,
- rotation_interval: RotationInterval::default(),
- last_validation: None,
- }));
- let (cache_update_tx, mut cache_update_rx): (
- _,
- mpsc::UnboundedReceiver<(_, oneshot::Sender<Result<(), Error>>)>,
- ) = mpsc::unbounded();
- let cache_task_join_handle = tokio::spawn(async move {
- while let Some((new_device, result_tx)) = cache_update_rx.next().await {
- let result = cacher.write(new_device).await;
- if let Err(error) = &result {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to update device cache")
- );
- }
- let _ = result_tx.send(result);
- }
- });
+ let (cmd_tx, cmd_rx) = mpsc::unbounded();
- let mut manager = AccountManager {
- account_service,
- device_service: DeviceService::new(rest_handle, api_availability),
- inner,
- cache_update_tx,
- cache_task_join_handle: Some(cache_task_join_handle),
- key_update_tx,
- rotation_abort_handle: None,
+ let device_service = DeviceService::new(rest_handle, api_availability);
+ let manager = AccountManager {
+ cacher,
+ device_service: device_service.clone(),
+ data,
+ rotation_interval: initial_rotation_interval,
+ listeners: vec![],
+ last_validation: None,
};
- if should_start_rotation {
- manager.start_key_rotation();
- }
-
- Ok(manager)
+ tokio::spawn(manager.run(cmd_rx));
+ let handle = AccountManagerHandle {
+ cmd_tx,
+ account_service,
+ device_service,
+ };
+ KeyUpdater::spawn(handle.clone()).await?;
+ Ok(handle)
}
- pub fn account_service(&self) -> AccountService {
- self.account_service.clone()
+ async fn run(mut self, mut cmd_rx: mpsc::UnboundedReceiver<AccountManagerCommand>) {
+ let mut shutdown_tx = None;
+ while let Some(cmd) = cmd_rx.next().await {
+ match cmd {
+ AccountManagerCommand::Shutdown(tx) => {
+ shutdown_tx = Some(tx);
+ break;
+ }
+ other => self.service_command(other).await,
+ }
+ }
+ self.shutdown().await;
+ if let Some(tx) = shutdown_tx {
+ let _ = tx.send(());
+ }
+ log::debug!("Account manager has stopped");
}
- pub fn device_service(&self) -> DeviceService {
- self.device_service.clone()
+ async fn service_command(&mut self, cmd: AccountManagerCommand) {
+ match cmd {
+ AccountManagerCommand::Login(token, tx) => {
+ let _ = tx.send(self.login(token).await);
+ }
+ AccountManagerCommand::Logout(tx) => {
+ let _ = tx.send(self.logout().await);
+ }
+ AccountManagerCommand::SetData(data, tx) => {
+ let _ = tx.send(self.set(data).await);
+ }
+ AccountManagerCommand::GetData(tx) => {
+ let _ = tx.send(Ok(self.data.clone()));
+ }
+ AccountManagerCommand::RotateKey(tx) => {
+ let _ = tx.send(self.rotate_key().await);
+ }
+ AccountManagerCommand::SetRotationInterval(interval, tx) => {
+ self.rotation_interval = interval;
+ let _ = tx.send(Ok(()));
+ }
+ AccountManagerCommand::GetRotationInterval(tx) => {
+ let _ = tx.send(Ok(self.rotation_interval));
+ }
+ AccountManagerCommand::ValidateDevice(tx) => {
+ let _ = tx.send(self.validate_device().await);
+ }
+ AccountManagerCommand::ReceiveEvents(events_tx, tx) => {
+ let _ = tx.send(Ok(self.listeners.push(events_tx)));
+ }
+ AccountManagerCommand::Shutdown(_) => unreachable!("shutdown is handled earlier"),
+ }
}
- pub async fn login(&mut self, token: AccountToken) -> Result<DeviceData, Error> {
+ async fn login(&mut self, token: AccountToken) -> Result<(), Error> {
let data = self.device_service.generate_for_account(token).await?;
- self.logout();
- let (result_tx, result_rx) = oneshot::channel();
- let _ = self
- .cache_update_tx
- .unbounded_send((Some(data.clone()), result_tx));
- {
- let mut inner = self.inner.lock().unwrap();
- inner.data.replace(data.clone());
- }
- if let Err(error) = flatten_result(result_rx.await.map_err(Error::DeviceUpdaterCancelled)) {
- // Delete the device if an I/O error occurred
- self.logout();
- return Err(error);
- }
- self.start_key_rotation();
-
- Ok(data)
+ self.set(data).await?;
+ Ok(())
}
- pub async fn set(&mut self, data: DeviceData) -> Result<(), Error> {
- self.stop_key_rotation();
+ async fn logout(&mut self) -> Result<(), Error> {
+ if self.data.is_some() {
+ self.cacher.write(None).await?;
+ let _ = tokio::time::timeout(LOGOUT_TIMEOUT, self.logout_inner()).await;
- let (result_tx, result_rx) = oneshot::channel();
- let _ = self
- .cache_update_tx
- .unbounded_send((Some(data.clone()), result_tx));
-
- let old_data = {
- let mut inner = self.inner.lock().unwrap();
- inner.data.replace(data.clone())
- };
-
- if let Err(error) = flatten_result(result_rx.await.map_err(Error::DeviceUpdaterCancelled)) {
- // Delete the device if an I/O error occurred
- self.logout();
- return Err(error);
+ let event = InnerDeviceEvent::Logout;
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
}
-
- if let Some(old_data) = old_data {
- // Log out the previous device if the id differs
- if !old_data.device.eq_id(&data.device) {
- let service = self.device_service.clone();
- tokio::spawn(async move {
- if let Err(error) = service
- .remove_device_with_backoff(old_data.token, old_data.device.id)
- .await
- {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to remove a previous device")
- );
- }
- });
- }
- }
- self.start_key_rotation();
Ok(())
}
- /// Log out without waiting for the result.
- pub fn logout(&mut self) {
- let fut = self.logout_inner(true);
+ async fn logout_inner(&mut self) -> tokio::task::JoinHandle<()> {
+ let prev_data = self.data.take();
+ let service = self.device_service.clone();
+
tokio::spawn(async move {
- let result = fut.await;
- if let Err(error) = result {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to remove a previous device")
- );
+ if let Some(data) = prev_data {
+ if let Err(error) = service
+ .remove_device_with_backoff(data.token, data.device.id)
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to remove a previous device")
+ );
+ }
}
- });
+ })
}
- /// Log out, and wait until the API has removed the device.
- #[cfg(not(target_os = "android"))]
- pub fn logout_wait(&mut self) -> impl Future<Output = Result<(), Error>> {
- self.logout_inner(false)
+ #[inline]
+ async fn set(&mut self, new_data: DeviceData) -> Result<(), Error> {
+ self.set_inner(InnerDeviceEvent::Login(new_data)).await
}
- fn logout_inner(&mut self, use_backoff: bool) -> impl Future<Output = Result<(), Error>> {
- self.stop_key_rotation();
- let data = {
- let mut inner = self.inner.lock().unwrap();
- let (result_tx, _result_rx) = oneshot::channel();
- let _ = self.cache_update_tx.unbounded_send((None, result_tx));
- // NOTE: No need to wait on cache update
- inner.data.take()
- };
- let service = self.device_service.clone();
- async move {
- if let Some(data) = data {
- if use_backoff {
- return service
- .remove_device_with_backoff(data.token, data.device.id)
- .await;
- } else {
- return service.remove_device(data.token, data.device.id).await;
- }
- }
- Ok(())
+ async fn set_inner(&mut self, event: InnerDeviceEvent) -> Result<(), Error> {
+ let data = event.data();
+ if data == self.data.as_ref() {
+ return Ok(());
}
- }
- pub async fn rotate_key(&mut self) -> Result<WireguardData, Error> {
- let mut data = {
- let inner = self.inner.lock().unwrap();
- inner.data.as_ref().ok_or(Error::NoDevice)?.clone()
- };
- self.stop_key_rotation();
- let result = self
- .device_service
- .rotate_key(data.token.clone(), data.device.id.clone())
- .await;
- if let Ok(ref wg_data) = result {
- data.wg_data = wg_data.clone();
- data.device.pubkey = wg_data.private_key.public_key();
- let mut inner = self.inner.lock().unwrap();
- inner.data.replace(data.clone());
- let (result_tx, _result_rx) = oneshot::channel();
- let _ = self
- .cache_update_tx
- .unbounded_send((Some(data.clone()), result_tx));
- // NOTE: No need to wait on cache update
- let _ = self.key_update_tx.send(DeviceKeyEvent(data));
+ self.cacher.write(data).await?;
+
+ if self
+ .data
+ .as_ref()
+ .map(|current| data.as_ref().map(|d| &d.device.id) != Some(&current.device.id))
+ .unwrap_or(false)
+ {
+ // Remove the existing device if its ID differs. Otherwise, only update
+ // the data.
+ self.logout_inner().await;
}
- self.start_key_rotation();
- result
- }
- pub fn data(&self) -> Option<DeviceData> {
- self.inner.lock().unwrap().data.clone()
- }
+ self.data = data.cloned();
- pub fn has_data(&self) -> bool {
- self.inner.lock().unwrap().data.is_some()
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+
+ Ok(())
}
- pub async fn set_rotation_interval(&mut self, interval: RotationInterval) {
- self.stop_key_rotation();
- let restart_rotation = {
- let mut inner = self.inner.lock().unwrap();
- inner.rotation_interval = interval;
- inner.data.is_some()
- };
- if restart_rotation {
- self.start_key_rotation();
- }
+ async fn rotate_key(&mut self) -> Result<(), Error> {
+ // TODO: Update all data opportunistically?
+ let data = self.data.as_ref().ok_or(Error::NoDevice)?;
+
+ let wg_data = self
+ .device_service
+ .rotate_key(data.token.clone(), data.device.id.clone())
+ .await?;
+
+ // Copy the data to keep a predictable state if an error occurs.
+ let mut new_data = data.clone();
+ new_data.device.pubkey = wg_data.private_key.public_key();
+ new_data.wg_data = wg_data;
+ self.set_inner(InnerDeviceEvent::RotatedKey(new_data)).await
}
/// Check if the device is valid for the account, and yank it if it no longer exists.
/// This also updates any associated data and returns whether it changed.
- pub async fn validate_device(&mut self) -> Result<ValidationResult, Error> {
- let mut data = {
- let inner = self.inner.lock().unwrap();
- inner.data.as_ref().ok_or(Error::NoDevice)?.clone()
- };
-
+ async fn validate_device(&mut self) -> Result<ValidationResult, Error> {
log::debug!("Checking whether the device is still valid");
+ if let Some(result) = self.cached_validation() {
+ log::debug!("The current device is still valid");
+ return Ok(result);
+ }
+
+ let data = self.data.as_ref().ok_or(Error::NoDevice)?;
+
match self
.device_service
.get(data.token.clone(), data.device.id.clone())
@@ -339,13 +433,12 @@ impl AccountManager {
Ok(ValidationResult::Valid)
} else {
log::debug!("Updating data for the current device");
- data.device = device;
- {
- let mut inner = self.inner.lock().unwrap();
- inner.data.replace(data.clone());
- let (result_tx, _result_rx) = oneshot::channel();
- let _ = self.cache_update_tx.unbounded_send((Some(data), result_tx));
- }
+ // Copy the data to keep a predictable state if an error occurs.
+ let new_data = DeviceData {
+ device,
+ ..data.clone()
+ };
+ self.set_inner(InnerDeviceEvent::Updated(new_data)).await?;
Ok(ValidationResult::Updated)
}
} else {
@@ -356,123 +449,153 @@ impl AccountManager {
}
Err(Error::InvalidAccount) | Err(Error::InvalidDevice) => {
log::debug!("The current device is no longer valid for this account");
- self.stop_key_rotation();
- {
- self.inner.lock().unwrap().data.take();
- let (result_tx, _result_rx) = oneshot::channel();
- let _ = self.cache_update_tx.unbounded_send((None, result_tx));
- // NOTE: No need to wait on cache update
- }
+
+ self.cacher.write(None).await?;
+ self.data = None;
+
+ let event = InnerDeviceEvent::Revoked;
+ self.listeners
+ .retain(|listener| listener.send(event.clone()).is_ok());
+
Ok(ValidationResult::Removed)
}
Err(error) => Err(error),
}
}
- /// Same as [Self::validate_device] but returns [ValidationResult::Valid] (or [Error::NoDevice])
- /// if the last check was recent.
- pub async fn validate_device_cached(&mut self) -> Result<ValidationResult, Error> {
- let last_validation = {
- let inner = self.inner.lock().unwrap();
- if inner.data.is_none() {
- return Err(Error::NoDevice);
- }
- inner.last_validation.clone()
- };
+ fn cached_validation(&mut self) -> Option<ValidationResult> {
+ if self.data.is_none() {
+ return None;
+ }
- if last_validation
- .and_then(|last_check| SystemTime::now().duration_since(last_check).ok())
- .map(|elapsed| elapsed < DEVICE_VALIDITY_CACHE_DURATION)
- .unwrap_or(false)
- {
- return Ok(ValidationResult::Valid);
+ let now = SystemTime::now();
+
+ let elapsed = self
+ .last_validation
+ .and_then(|last_check| now.duration_since(last_check).ok())
+ .unwrap_or(VALIDITY_CACHE_TIMEOUT);
+
+ if elapsed >= VALIDITY_CACHE_TIMEOUT {
+ self.last_validation = Some(now);
+ return None;
}
- let result = self.validate_device().await;
- let mut inner = self.inner.lock().unwrap();
- inner.last_validation = Some(SystemTime::now());
- result
+ Some(ValidationResult::Valid)
+ }
+
+ async fn shutdown(self) {
+ self.cacher.finalize().await;
}
+}
- fn start_key_rotation(&mut self) {
- self.stop_key_rotation();
+struct KeyUpdater {
+ handle: AccountManagerHandle,
+ rx: mpsc::UnboundedReceiver<InnerDeviceEvent>,
+ data: Option<DeviceData>,
+}
- let service = self.device_service.clone();
- let inner = self.inner.clone();
- let cache_update_tx = self.cache_update_tx.clone();
- let key_update_tx = self.key_update_tx.clone();
+impl KeyUpdater {
+ async fn spawn(handle: AccountManagerHandle) -> Result<(), Error> {
+ let (tx, rx) = mpsc::unbounded();
+ handle.receive_events(tx).await?;
+ let data = handle.data().await?;
+ let mut key_rotator = KeyUpdater { handle, rx, data };
- let (task, abort_handle) = abortable(async move {
+ tokio::spawn(async move {
loop {
tokio::time::sleep(KEY_CHECK_INTERVAL).await;
- let rotation_interval = { inner.lock().unwrap().rotation_interval.clone() };
-
- let mut state = {
- match inner.lock().unwrap().data.as_ref() {
- Some(device_config) => device_config.clone(),
- None => continue,
- }
- };
-
- if (chrono::Utc::now()
- .signed_duration_since(state.wg_data.created)
- .num_seconds() as u64)
- < rotation_interval.as_duration().as_secs()
- {
- continue;
- }
-
- match service
- .rotate_key_with_backoff(state.token.clone(), state.device.id.clone())
- .await
- {
- Ok(wg_data) => {
- state.device.pubkey = wg_data.private_key.public_key();
- state.wg_data = wg_data;
- {
- let mut inner = inner.lock().unwrap();
- inner.data.replace(state.clone());
- let (result_tx, _result_rx) = oneshot::channel();
- let _ =
- cache_update_tx.unbounded_send((Some(state.clone()), result_tx));
- // NOTE: No need to wait on cache update
- }
- let _ = key_update_tx.send(DeviceKeyEvent(state));
- }
- Err(error) => {
- log::debug!("{}", error.display_chain_with_msg("Stopping key rotation"));
+ if let Err(error) = key_rotator.check_key_validity().await {
+ if let Error::AccountManagerDown = error {
+ break;
}
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Stopping key rotation task due to an error")
+ );
+ break;
}
}
+ log::debug!("Stopping key updater");
});
- tokio::spawn(task);
- self.rotation_abort_handle = Some(abort_handle);
+
+ Ok(())
}
- fn stop_key_rotation(&mut self) {
- if let Some(abort_handle) = self.rotation_abort_handle.take() {
- abort_handle.abort();
+ async fn check_key_validity(&mut self) -> Result<(), Error> {
+ let rotation_interval = self.handle.rotation_interval().await?;
+ let data = self.wait_for_data().await?;
+
+ if (chrono::Utc::now()
+ .signed_duration_since(data.wg_data.created)
+ .num_seconds() as u64)
+ < rotation_interval.as_duration().as_secs()
+ {
+ return Ok(());
}
- }
- /// Consumes the object and completes when there is nothing left to write to
- /// the cache file.
- pub fn finalize(mut self) -> impl Future<Output = ()> {
- let join_handle = self.cache_task_join_handle.take();
- drop(self);
+ let mut data = data.clone();
+
+ let rotation_fut = self
+ .handle
+ .device_service
+ .rotate_key_with_backoff(data.token.clone(), data.device.id.clone());
+
+ match futures::future::select(Box::pin(rotation_fut), self.rx.next()).await {
+ futures::future::Either::Left((Ok(wg_data), _)) => {
+ log::debug!("Rotating WireGuard key");
+ data.device.pubkey = wg_data.private_key.public_key();
+ data.wg_data = wg_data;
+ self.handle.set(data).await?;
+ }
+ futures::future::Either::Left((Err(error), _)) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Stopping key rotation due to an error")
+ );
- async move {
- if let Some(join_handle) = join_handle {
- let _ = join_handle.await;
+ // Forget the current device. Key rotation will restart when
+ // it is updated in any way.
+ self.data = None;
+ }
+ futures::future::Either::Right((event, _)) => {
+ // Abort key rotation if the device changed
+ if let Some(event) = event {
+ self.data = event.into_data();
+ } else {
+ return Err(Error::AccountManagerDown);
+ }
}
}
+
+ Ok(())
}
-}
-impl Drop for AccountManager {
- fn drop(&mut self) {
- self.stop_key_rotation();
+ async fn wait_for_data(&mut self) -> Result<&DeviceData, Error> {
+ while let Ok(item) = self.rx.try_next() {
+ match item {
+ Some(event) => {
+ self.data = event.into_data();
+ }
+ None => return Err(Error::AccountManagerDown),
+ }
+ }
+
+ match self.data {
+ Some(ref data) => Ok(data),
+ None => loop {
+ let event = self.rx.next().await;
+ match event {
+ Some(event) => {
+ if let Some(data) = event.into_data() {
+ self.data = Some(data);
+ break Ok(self.data.as_ref().unwrap());
+ }
+ }
+ None => break Err(Error::AccountManagerDown),
+ }
+ },
+ }
}
}
@@ -741,7 +864,7 @@ impl DeviceCacher {
))
}
- pub async fn write(&mut self, device: Option<DeviceData>) -> Result<(), Error> {
+ pub async fn write(&mut self, device: Option<&DeviceData>) -> Result<(), Error> {
let data = serde_json::to_vec_pretty(&device).unwrap();
self.file.get_mut().set_len(0).await?;
@@ -763,6 +886,11 @@ impl DeviceCacher {
tokio::fs::remove_file(path).await?;
Ok(())
}
+
+ async fn finalize(self) {
+ let std_file = self.file.into_inner().into_std().await;
+ let _ = tokio::task::spawn_blocking(move || drop(std_file)).await;
+ }
}
#[derive(Clone)]
@@ -944,12 +1072,3 @@ fn retry_strategy() -> Jittered<ExponentialBackoff> {
.max_delay(RETRY_BACKOFF_INTERVAL_MAX),
)
}
-
-fn flatten_result<T, E>(
- result: std::result::Result<std::result::Result<T, E>, E>,
-) -> std::result::Result<T, E> {
- match result {
- Ok(value) => value,
- Err(err) => Err(err),
- }
-}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 9277d45335..2be83a3ed9 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -25,6 +25,7 @@ pub mod version;
mod version_check;
use crate::target_state::PersistentTargetState;
+use device::InnerDeviceEvent;
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Future},
@@ -62,7 +63,10 @@ use std::{
net::{IpAddr, Ipv4Addr},
path::PathBuf,
pin::Pin,
- sync::{mpsc as sync_mpsc, Arc, Weak},
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ mpsc as sync_mpsc, Arc, Weak,
+ },
time::Duration,
};
#[cfg(any(target_os = "linux", windows))]
@@ -76,8 +80,7 @@ use talpid_types::android::AndroidContext;
use talpid_types::{
net::{
openvpn::{self, ProxySettings},
- wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters,
- TunnelType,
+ wireguard, TransportProtocol, TunnelEndpoint, TunnelParameters, TunnelType,
},
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
@@ -349,8 +352,8 @@ pub(crate) enum InternalDaemonEvent {
NewAppVersionInfo(AppVersionInfo),
/// Request from REST client to use a different API endpoint.
GenerateApiConnectionMode(api::ApiConnectionModeRequest),
- /// Sent when a device key is rotated.
- DeviceKeyEvent(device::DeviceKeyEvent),
+ /// Sent when a device is updated in any way (key rotation, login, logout, etc.).
+ DeviceEvent(InnerDeviceEvent),
/// Handles updates from versions without devices.
DeviceMigrationEvent(DeviceData),
/// The split tunnel paths or state were updated.
@@ -388,9 +391,9 @@ impl From<api::ApiConnectionModeRequest> for InternalDaemonEvent {
}
}
-impl From<device::DeviceKeyEvent> for InternalDaemonEvent {
- fn from(event: device::DeviceKeyEvent) -> Self {
- InternalDaemonEvent::DeviceKeyEvent(event)
+impl From<InnerDeviceEvent> for InternalDaemonEvent {
+ fn from(event: InnerDeviceEvent) -> Self {
+ InternalDaemonEvent::DeviceEvent(event)
}
}
@@ -575,9 +578,9 @@ pub struct Daemon<L: EventListener> {
event_listener: L,
settings: SettingsPersister,
account_history: account_history::AccountHistory,
- account_manager: device::AccountManager,
+ account_manager: device::AccountManagerHandle,
wg_retry_attempt: usize,
- wg_check_validity: bool,
+ wg_check_validity: Arc<AtomicBool>,
rpc_runtime: mullvad_rpc::MullvadRpcRuntime,
rpc_handle: mullvad_rpc::rest::MullvadRestHandle,
version_updater_handle: version_check::VersionUpdaterHandle,
@@ -657,28 +660,35 @@ where
tx: internal_event_tx.clone(),
};
- let mut account_manager = device::AccountManager::new(
+ let account_manager = device::AccountManager::spawn(
rpc_handle.clone(),
api_availability.clone(),
&settings_dir,
- internal_event_tx.to_specialized_sender(),
+ settings
+ .tunnel_options
+ .wireguard
+ .rotation_interval
+ .unwrap_or_default(),
)
.await
.map_err(Error::LoadAccountManager)?;
- if let Some(rotation_interval) = settings.tunnel_options.wireguard.rotation_interval {
- account_manager
- .set_rotation_interval(rotation_interval)
- .await;
- }
+ account_manager
+ .receive_events(internal_event_tx.to_specialized_sender())
+ .await
+ .map_err(Error::LoadAccountManager)?;
+ let data = account_manager
+ .data()
+ .await
+ .map_err(Error::LoadAccountManager)?;
let account_history = account_history::AccountHistory::new(
&settings_dir,
- account_manager.data().map(|device| device.token),
+ data.as_ref().map(|device| device.token.clone()),
)
.await
.map_err(Error::LoadAccountHistory)?;
- let target_state = if !account_manager.has_data() {
+ let target_state = if data.is_none() {
PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await
} else if settings.auto_connect {
log::info!("Automatically connecting since auto-connect is turned on");
@@ -776,7 +786,7 @@ where
account_history,
account_manager,
wg_retry_attempt: 0,
- wg_check_validity: false,
+ wg_check_validity: Arc::new(AtomicBool::new(true)),
rpc_runtime,
rpc_handle,
version_updater_handle,
@@ -892,15 +902,15 @@ where
let Daemon {
event_listener,
mut shutdown_tasks,
- account_manager,
rpc_runtime,
tunnel_state_machine_handle,
target_state,
+ account_manager,
..
} = self;
shutdown_tasks.push(Box::pin(target_state.finalize()));
- shutdown_tasks.insert(0, Box::pin(account_manager.finalize()));
+ shutdown_tasks.push(Box::pin(account_manager.shutdown()));
(
event_listener,
@@ -931,7 +941,7 @@ where
GenerateApiConnectionMode(request) => {
self.handle_generate_api_connection_mode(request).await
}
- DeviceKeyEvent(event) => self.handle_device_key_event(event).await,
+ DeviceEvent(event) => self.handle_device_event(event).await,
DeviceMigrationEvent(event) => self.handle_device_migration_event(event).await,
#[cfg(windows)]
ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await,
@@ -961,6 +971,8 @@ where
TunnelStateTransition::Error(error_state) => TunnelState::Error(error_state),
};
+ self.maybe_validate_device(&tunnel_state);
+
if !tunnel_state.is_connected() {
// Cancel reconnects except when entering the connected state.
// Exempt the latter because a reconnect scheduled while connecting should not be
@@ -992,10 +1004,7 @@ where
}
self.tunnel_state = tunnel_state.clone();
- self.event_listener.notify_new_state(tunnel_state.clone());
-
- // Check device validity last so that the broadcast is not delayed.
- self.maybe_validate_device(&tunnel_state).await;
+ self.event_listener.notify_new_state(tunnel_state);
}
async fn reset_rpc_sockets_on_tunnel_state_transition(
@@ -1012,34 +1021,34 @@ where
}
/// Check whether the device is valid after a number of failed connection attempts.
- async fn maybe_validate_device(&mut self, tunnel_state: &TunnelState) {
+ fn maybe_validate_device(&mut self, tunnel_state: &TunnelState) {
match tunnel_state {
TunnelState::Connecting { endpoint, .. } => {
if endpoint.tunnel_type != TunnelType::Wireguard {
return;
}
self.wg_retry_attempt += 1;
- if self.wg_check_validity && self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0
- {
- match self.account_manager.validate_device_cached().await {
- Ok(status) => {
- self.handle_validation_result(status);
- self.wg_check_validity = false;
+ if self.wg_retry_attempt % WG_DEVICE_CHECK_THRESHOLD == 0 {
+ let handle = self.account_manager.clone();
+ let check_validity = self.wg_check_validity.clone();
+ tokio::spawn(async move {
+ if !check_validity.swap(false, Ordering::SeqCst) {
+ return;
}
- Err(error) => {
+ if let Err(error) = handle.validate_device().await {
log::error!(
"{}",
error.display_chain_with_msg("Failed to check device validity")
);
- if !error.is_network_error() {
- self.wg_check_validity = false;
+ if error.is_network_error() {
+ check_validity.store(true, Ordering::SeqCst);
}
}
- }
+ });
}
}
TunnelState::Connected { .. } | TunnelState::Disconnected => {
- self.wg_check_validity = true;
+ self.wg_check_validity.store(true, Ordering::SeqCst);
self.wg_retry_attempt = 0;
}
_ => (),
@@ -1053,7 +1062,7 @@ where
>,
retry_attempt: u32,
) {
- if let Some(device) = self.account_manager.data() {
+ if let Ok(Some(device)) = self.account_manager.data().await {
let result = match self.settings.get_relay_settings() {
RelaySettings::CustomTunnelEndpoint(custom_relay) => {
self.last_generated_relay = None;
@@ -1199,6 +1208,8 @@ where
let wg_data = self
.account_manager
.data()
+ .await
+ .map_err(|_| Error::NoKeyAvailable)?
.map(|device| device.wg_data)
.ok_or(Error::NoKeyAvailable)?;
let tunnel = wireguard::TunnelConfig {
@@ -1224,21 +1235,6 @@ where
}
}
- // Emit the appropriate events for an updated device.
- fn handle_validation_result(&mut self, result: device::ValidationResult) {
- match result {
- device::ValidationResult::RotatedKey | device::ValidationResult::Valid => (),
- device::ValidationResult::Removed => {
- self.event_listener
- .notify_device_event(DeviceEvent::revoke(true));
- }
- device::ValidationResult::Updated => {
- self.event_listener
- .notify_device_event(DeviceEvent::new(self.account_manager.data(), true));
- }
- }
- }
-
fn schedule_reconnect(&mut self, delay: Duration) {
self.unschedule_reconnect();
@@ -1438,31 +1434,21 @@ where
let _ = request.response_tx.send(config);
}
- async fn handle_device_key_event(&mut self, event: device::DeviceKeyEvent) {
- let device_id = &event.0.device.id;
- if Some(device_id)
- != self
- .account_manager
- .data()
- .map(|device| device.device.id)
- .as_ref()
- {
- // Stale config
- return;
- }
- if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY);
+ async fn handle_device_event(&mut self, event: InnerDeviceEvent) {
+ if let InnerDeviceEvent::RotatedKey(_) = &event {
+ if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() {
+ self.schedule_reconnect(WG_RECONNECT_DELAY);
+ }
}
self.event_listener
- .notify_device_event(DeviceEvent::from_device(event.0, false));
+ .notify_device_event(DeviceEvent::from(event));
}
async fn handle_device_migration_event(&mut self, data: DeviceData) {
- if self.account_manager.has_data() {
+ if let Ok(Some(_)) = self.account_manager.data().await {
// Discard stale device
return;
}
- let event = DeviceEvent::from_device(data.clone(), false);
if let Err(error) = self.account_manager.set(data).await {
log::error!(
"{}",
@@ -1470,7 +1456,6 @@ where
);
}
self.reconnect_tunnel();
- self.event_listener.notify_device_event(event);
}
#[cfg(windows)]
@@ -1604,12 +1589,12 @@ where
}
async fn on_create_new_account(&mut self, tx: ResponseTx<String, Error>) {
- if self.account_manager.has_data() {
+ if let Ok(Some(_)) = self.account_manager.data().await {
let _ = tx.send(Err(Error::AlreadyLoggedIn));
return;
}
let daemon_tx = self.tx.clone();
- let future = self.account_manager.account_service().create_account();
+ let future = self.account_manager.account_service.create_account();
tokio::spawn(async move {
match future.await {
Ok(account_token) => {
@@ -1627,7 +1612,7 @@ where
tx: ResponseTx<AccountData, mullvad_rpc::rest::Error>,
account_token: AccountToken,
) {
- let account = self.account_manager.account_service();
+ let account = self.account_manager.account_service.clone();
tokio::spawn(async move {
let result = account.check_expiry(account_token).await;
Self::oneshot_send(
@@ -1639,19 +1624,18 @@ where
}
async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) {
- if let Some(device) = self.account_manager.data() {
+ if let Ok(Some(device)) = self.account_manager.data().await {
let future = self
.account_manager
- .account_service()
+ .account_service
.get_www_auth_token(device.token);
- let rpc_call = async {
+ tokio::spawn(async {
Self::oneshot_send(
tx,
future.await.map_err(Error::RestError),
"get_www_auth_token response",
);
- };
- tokio::spawn(rpc_call);
+ });
} else {
Self::oneshot_send(
tx,
@@ -1666,8 +1650,8 @@ where
tx: ResponseTx<VoucherSubmission, Error>,
voucher: String,
) {
- if let Some(device) = self.account_manager.data() {
- let mut account = self.account_manager.account_service();
+ if let Ok(Some(device)) = self.account_manager.data().await {
+ let mut account = self.account_manager.account_service.clone();
tokio::spawn(async move {
Self::oneshot_send(
tx,
@@ -1724,25 +1708,28 @@ where
}
async fn set_account(&mut self, account_token: Option<String>) -> Result<bool, Error> {
- let previous_token = self.account_manager.data().map(|device| device.token);
+ let previous_token = self
+ .account_manager
+ .data()
+ .await
+ .unwrap_or(None)
+ .map(|device| device.token);
if previous_token == account_token {
return Ok(false);
}
match account_token.clone() {
Some(token) => {
- let device_data = self
- .account_manager
+ self.account_manager
.login(token)
.await
.map_err(Error::LoginError)?;
- self.event_listener
- .notify_device_event(DeviceEvent::from_device(device_data, false));
}
None => {
- self.account_manager.logout();
- self.event_listener
- .notify_device_event(DeviceEvent::revoke(false));
+ self.account_manager
+ .logout()
+ .await
+ .map_err(Error::LogoutError)?;
}
}
@@ -1761,8 +1748,7 @@ where
async fn on_get_device(&mut self, tx: ResponseTx<Option<DeviceConfig>, Error>) {
// Make sure the device is updated
match self.account_manager.validate_device().await {
- Ok(status) => self.handle_validation_result(status),
- Err(device::Error::NoDevice) => (),
+ Ok(_) | Err(device::Error::NoDevice) => (),
Err(error) => {
log::error!(
"{}",
@@ -1773,7 +1759,12 @@ where
Self::oneshot_send(
tx,
- Ok(self.account_manager.data().map(DeviceConfig::from)),
+ Ok(self
+ .account_manager
+ .data()
+ .await
+ .unwrap_or(None)
+ .map(DeviceConfig::from)),
"get_device response",
);
}
@@ -1782,7 +1773,7 @@ where
Self::oneshot_send(
tx,
self.account_manager
- .device_service()
+ .device_service
.list_devices(token)
.await
.map_err(Error::ListDevicesError),
@@ -1796,7 +1787,7 @@ where
token: AccountToken,
device_id: DeviceId,
) {
- let device_service = self.account_manager.device_service();
+ let device_service = self.account_manager.device_service.clone();
let event_listener = self.event_listener.clone();
tokio::spawn(async move {
@@ -1898,7 +1889,7 @@ where
async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) {
let mut last_error = Ok(());
- if let Err(error) = self.account_manager.logout_wait().await {
+ if let Err(error) = self.account_manager.logout().await {
log::error!(
"{}",
error.display_chain_with_msg("Failed to clear device cache")
@@ -2414,9 +2405,16 @@ where
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_wireguard_rotation_interval response");
if settings_changed {
- self.account_manager
+ if let Err(error) = self
+ .account_manager
.set_rotation_interval(interval.unwrap_or_default())
- .await;
+ .await
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to update rotation interval")
+ );
+ }
self.event_listener
.notify_settings(self.settings.to_settings());
}
@@ -2434,7 +2432,7 @@ where
}
async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<PublicKey>, Error>) {
- let result = if let Some(device) = self.account_manager.data() {
+ let result = if let Ok(Some(device)) = self.account_manager.data().await {
Ok(Some(device.wg_data.get_public_key()))
} else {
Err(Error::NoAccountToken)