summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-daemon/src/device.rs63
-rw-r--r--mullvad-daemon/src/lib.rs7
-rw-r--r--mullvad-daemon/src/management_interface.rs3
3 files changed, 60 insertions, 13 deletions
diff --git a/mullvad-daemon/src/device.rs b/mullvad-daemon/src/device.rs
index 111d0df9e1..585855504d 100644
--- a/mullvad-daemon/src/device.rs
+++ b/mullvad-daemon/src/device.rs
@@ -1,7 +1,7 @@
use crate::DaemonEventSender;
use chrono::{DateTime, Utc};
use futures::{
- channel::mpsc,
+ channel::{mpsc, oneshot},
future::{abortable, AbortHandle},
stream::StreamExt,
};
@@ -63,6 +63,8 @@ pub enum Error {
ParseDeviceCache(#[error(source)] serde_json::Error),
#[error(display = "Unexpected HTTP request error")]
OtherRestError(#[error(source)] rest::Error),
+ #[error(display = "The device update task is not running")]
+ DeviceUpdaterCancelled(#[error(source)] oneshot::Canceled),
}
impl Error {
@@ -89,7 +91,8 @@ pub(crate) struct AccountManager {
account_service: AccountService,
device_service: DeviceService,
inner: Arc<Mutex<AccountManagerInner>>,
- cache_update_tx: mpsc::UnboundedSender<Option<DeviceData>>,
+ cache_update_tx:
+ mpsc::UnboundedSender<(Option<DeviceData>, oneshot::Sender<Result<(), Error>>)>,
cache_task_join_handle: Option<tokio::task::JoinHandle<()>>,
key_update_tx: DaemonEventSender<DeviceKeyEvent>,
rotation_abort_handle: Option<AbortHandle>,
@@ -122,15 +125,20 @@ impl AccountManager {
rotation_interval: RotationInterval::default(),
}));
- let (cache_update_tx, mut cache_update_rx) = mpsc::unbounded();
+ let (cache_update_tx, mut cache_update_rx): (
+ _,
+ mpsc::UnboundedReceiver<(_, oneshot::Sender<Result<(), Error>>)>,
+ ) = mpsc::unbounded();
let cache_task_join_handle = runtime.spawn(async move {
- while let Some(new_device) = cache_update_rx.next().await {
- if let Err(error) = cacher.write(new_device).await {
+ while let Some((new_device, result_tx)) = cache_update_rx.next().await {
+ let result = cacher.write(new_device).await;
+ if let Err(error) = &result {
log::error!(
"{}",
error.display_chain_with_msg("Failed to update device cache")
);
}
+ let _ = result_tx.send(result);
}
});
@@ -163,24 +171,37 @@ impl AccountManager {
pub async fn login(&mut self, token: AccountToken) -> Result<DeviceData, Error> {
let data = self.device_service.generate_for_account(token).await?;
self.logout();
+ let (result_tx, result_rx) = oneshot::channel();
+ let _ = self
+ .cache_update_tx
+ .unbounded_send((Some(data.clone()), result_tx));
{
let mut inner = self.inner.lock().unwrap();
inner.data.replace(data.clone());
- let _ = self.cache_update_tx.unbounded_send(Some(data.clone()));
+ }
+ if let Err(error) = flatten_result(result_rx.await.map_err(Error::DeviceUpdaterCancelled)) {
+ // Delete the device if an I/O error occurred
+ self.logout();
+ return Err(error);
}
self.start_key_rotation();
Ok(data)
}
- pub fn set(&mut self, data: DeviceData) {
+ pub async fn set(&mut self, data: DeviceData) -> Result<(), Error> {
self.logout();
+ let (result_tx, result_rx) = oneshot::channel();
+ let _ = self
+ .cache_update_tx
+ .unbounded_send((Some(data.clone()), result_tx));
{
let mut inner = self.inner.lock().unwrap();
inner.data.replace(data.clone());
- let _ = self.cache_update_tx.unbounded_send(Some(data));
}
+ result_rx.await.map_err(Error::DeviceUpdaterCancelled)??;
self.start_key_rotation();
+ Ok(())
}
/// Log out without waiting for the result.
@@ -207,7 +228,9 @@ impl AccountManager {
self.stop_key_rotation();
let data = {
let mut inner = self.inner.lock().unwrap();
- let _ = self.cache_update_tx.unbounded_send(None);
+ let (result_tx, _result_rx) = oneshot::channel();
+ let _ = self.cache_update_tx.unbounded_send((None, result_tx));
+ // NOTE: No need to wait on cache update
inner.data.take()
};
let service = self.device_service.clone();
@@ -240,7 +263,9 @@ impl AccountManager {
data.device.pubkey = wg_data.private_key.public_key();
let mut inner = self.inner.lock().unwrap();
inner.data.replace(data.clone());
- let _ = self.cache_update_tx.unbounded_send(Some(data));
+ let (result_tx, _result_rx) = oneshot::channel();
+ let _ = self.cache_update_tx.unbounded_send((Some(data), result_tx));
+ // NOTE: No need to wait on cache update
}
self.start_key_rotation();
result
@@ -288,7 +313,9 @@ impl AccountManager {
self.stop_key_rotation();
{
self.inner.lock().unwrap().data.take();
- let _ = self.cache_update_tx.unbounded_send(None);
+ let (result_tx, _result_rx) = oneshot::channel();
+ let _ = self.cache_update_tx.unbounded_send((None, result_tx));
+ // NOTE: No need to wait on cache update
}
Ok(ValidationResult::Removed)
}
@@ -335,7 +362,10 @@ impl AccountManager {
{
let mut inner = inner.lock().unwrap();
inner.data.replace(state.clone());
- let _ = cache_update_tx.unbounded_send(Some(state.clone()));
+ let (result_tx, _result_rx) = oneshot::channel();
+ let _ =
+ cache_update_tx.unbounded_send((Some(state.clone()), result_tx));
+ // NOTE: No need to wait on cache update
}
let _ = key_update_tx.send(DeviceKeyEvent(state));
}
@@ -836,3 +866,12 @@ fn retry_strategy() -> Jittered<ExponentialBackoff> {
.max_delay(RETRY_BACKOFF_INTERVAL_MAX),
)
}
+
+fn flatten_result<T, E>(
+ result: std::result::Result<std::result::Result<T, E>, E>,
+) -> std::result::Result<T, E> {
+ match result {
+ Ok(value) => value,
+ Err(err) => Err(err),
+ }
+}
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index b573ca452b..6d82708a88 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -1441,7 +1441,12 @@ where
return;
}
let event = DeviceEvent::from_device(data.clone(), false);
- self.account_manager.set(data);
+ if let Err(error) = self.account_manager.set(data).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to move over account from old settings")
+ );
+ }
self.reconnect_tunnel();
self.event_listener.notify_device_event(event);
}
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 7800f09381..8a15474257 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -983,6 +983,9 @@ fn map_device_error(error: device::Error) -> Status {
device::Error::InvalidDevice | device::Error::NoDevice => {
Status::new(Code::NotFound, error.to_string())
}
+ device::Error::DeviceIoError(ref _error) => {
+ Status::new(Code::Unavailable, error.to_string())
+ }
device::Error::OtherRestError(error) => map_rest_error(error),
_ => Status::new(Code::Unknown, error.to_string()),
}