diff options
| -rw-r--r-- | mullvad-daemon/src/device/mod.rs | 153 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 30 | ||||
| -rw-r--r-- | mullvad-daemon/src/tunnel.rs | 1 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 4 |
4 files changed, 125 insertions, 63 deletions
diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs index 5b2c21ae71..cdee636638 100644 --- a/mullvad-daemon/src/device/mod.rs +++ b/mullvad-daemon/src/device/mod.rs @@ -70,6 +70,57 @@ pub enum Error { AccountManagerDown, } +/// Contains the current device state. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum PrivateDeviceState { + LoggedIn(PrivateAccountAndDevice), + LoggedOut, + Revoked, +} + +impl PrivateDeviceState { + /// Returns whether the device is in the logged in state. + pub fn logged_in(&self) -> bool { + matches!(self, PrivateDeviceState::LoggedIn(_)) + } + + /// Returns whether the state is logged out, as opposed to + /// logged in or revoked. + pub fn logged_out(&self) -> bool { + matches!(self, PrivateDeviceState::LoggedOut) + } + + /// Returns the logged in device config. + pub fn device(&self) -> Option<&PrivateAccountAndDevice> { + match self { + PrivateDeviceState::LoggedIn(device) => Some(device), + _ => None, + } + } + + /// Returns the logged in device config. + pub fn into_device(self) -> Option<PrivateAccountAndDevice> { + match self { + PrivateDeviceState::LoggedIn(device) => Some(device), + _ => None, + } + } + + /// Sets the state to `Revoked`. + fn revoke(&mut self) { + *self = PrivateDeviceState::Revoked; + } + + /// Sets the state to `LoggedOut` and returns the logged-in device, if one exists. + fn logout(&mut self) -> Option<PrivateAccountAndDevice> { + match std::mem::replace(self, PrivateDeviceState::LoggedOut) { + PrivateDeviceState::LoggedIn(data) => Some(data), + _ => None, + } + } +} + /// Same as [PrivateDevice] but also contains the associated account token. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct PrivateAccountAndDevice { @@ -170,12 +221,13 @@ impl From<PrivateDeviceEvent> for DeviceEvent { } impl PrivateDeviceEvent { - pub fn data(&self) -> Option<&PrivateAccountAndDevice> { + pub fn state(self) -> PrivateDeviceState { match self { - PrivateDeviceEvent::Login(config) => Some(config), - PrivateDeviceEvent::Updated(config) => Some(config), - PrivateDeviceEvent::RotatedKey(config) => Some(config), - PrivateDeviceEvent::Logout | PrivateDeviceEvent::Revoked => None, + PrivateDeviceEvent::Login(config) => PrivateDeviceState::LoggedIn(config), + PrivateDeviceEvent::Updated(config) => PrivateDeviceState::LoggedIn(config), + PrivateDeviceEvent::RotatedKey(config) => PrivateDeviceState::LoggedIn(config), + PrivateDeviceEvent::Logout => PrivateDeviceState::LoggedOut, + PrivateDeviceEvent::Revoked => PrivateDeviceState::Revoked, } } } @@ -212,8 +264,8 @@ enum AccountManagerCommand { Login(AccountToken, ResponseTx<()>), Logout(ResponseTx<()>), SetData(PrivateAccountAndDevice, ResponseTx<()>), - GetData(ResponseTx<Option<PrivateAccountAndDevice>>), - GetDataAfterLogin(ResponseTx<Option<PrivateAccountAndDevice>>), + GetData(ResponseTx<PrivateDeviceState>), + GetDataAfterLogin(ResponseTx<PrivateDeviceState>), RotateKey(ResponseTx<()>), SetRotationInterval(RotationInterval, ResponseTx<()>), ValidateDevice(ResponseTx<()>), @@ -243,12 +295,12 @@ impl AccountManagerHandle { .await } - pub async fn data(&self) -> Result<Option<PrivateAccountAndDevice>, Error> { + pub async fn data(&self) -> Result<PrivateDeviceState, Error> { self.send_command(|tx| AccountManagerCommand::GetData(tx)) .await } - pub async fn data_after_login(&self) -> Result<Option<PrivateAccountAndDevice>, Error> { + pub async fn data_after_login(&self) -> Result<PrivateDeviceState, Error> { self.send_command(|tx| AccountManagerCommand::GetDataAfterLogin(tx)) .await } @@ -291,13 +343,13 @@ impl AccountManagerHandle { pub(crate) struct AccountManager { cacher: DeviceCacher, device_service: DeviceService, - data: Option<PrivateAccountAndDevice>, + data: PrivateDeviceState, rotation_interval: RotationInterval, listeners: Vec<Box<dyn Sender<PrivateDeviceEvent> + Send>>, last_validation: Option<SystemTime>, validation_requests: Vec<ResponseTx<()>>, rotation_requests: Vec<ResponseTx<()>>, - data_requests: Vec<ResponseTx<Option<PrivateAccountAndDevice>>>, + data_requests: Vec<ResponseTx<PrivateDeviceState>>, } impl AccountManager { @@ -308,9 +360,9 @@ impl AccountManager { settings_dir: &Path, initial_rotation_interval: RotationInterval, listener_tx: impl Sender<PrivateDeviceEvent> + Send + 'static, - ) -> Result<(AccountManagerHandle, Option<PrivateAccountAndDevice>), Error> { + ) -> Result<(AccountManagerHandle, PrivateDeviceState), Error> { let (cacher, data) = DeviceCacher::new(settings_dir).await?; - let token = data.as_ref().map(|state| state.account_token.clone()); + let token = data.device().map(|state| state.account_token.clone()); let api_availability = rest_handle.availability.clone(); let account_service = service::spawn_account_service(rest_handle.clone(), token, api_availability.clone()); @@ -485,12 +537,10 @@ impl AccountManager { response: Result<Device, Error>, api_call: &mut api::CurrentApiCall, ) { - let current_config = match self.data.as_ref() { - Some(data) => data, - None => { - panic!("Received a validation response whilst having no device data"); - } - }; + let current_config = self + .data + .device() + .expect("Received a validation response whilst having no device data"); match response { Ok(new_device) => { @@ -502,7 +552,7 @@ impl AccountManager { .update(new_device) .expect("pubkey must match privkey"); - if Some(&new_data) != self.data.as_ref() { + if Some(&new_data) != self.data.device() { log::debug!("Updating data for the current device"); } else { log::debug!("The current device is still valid"); @@ -540,7 +590,7 @@ impl AccountManager { } if !self.rotation_requests.is_empty() || !self.validation_requests.is_empty() { - if let Some(updated_config) = self.data.as_ref() { + if let Some(updated_config) = self.data.device() { let device_service = self.device_service.clone(); let token = updated_config.account_token.clone(); let device_id = updated_config.device.id.clone(); @@ -554,7 +604,8 @@ impl AccountManager { async fn consume_rotation_result(&mut self, api_result: Result<WireguardData, Error>) { let mut config = self .data - .clone() + .device() + .cloned() .expect("Received a key rotation result whilst having no data"); match api_result { @@ -603,7 +654,7 @@ impl AccountManager { fn spawn_timed_key_rotation( &self, ) -> Option<impl Future<Output = Result<WireguardData, Error>> + Send + 'static> { - let config = self.data.as_ref()?; + let config = self.data.device()?; let key_rotation_timer = self.key_rotation_timer(config.device.wg_data.created); let device_service = self.device_service.clone(); @@ -621,13 +672,13 @@ impl AccountManager { async fn invalidate_current_data(&mut self, err_constructor: impl Fn() -> Error) { log::debug!("Invalidating the current device"); - if let Err(err) = self.cacher.write(None).await { + if let Err(err) = self.cacher.write(&PrivateDeviceState::Revoked).await { log::error!( "{}", err.display_chain_with_msg("Failed to save device data to disk") ); } - self.data = None; + self.data.revoke(); Self::drain_requests(&mut self.validation_requests, || Err(err_constructor())); Self::drain_requests(&mut self.rotation_requests, || Err(err_constructor())); @@ -638,27 +689,31 @@ impl AccountManager { async fn logout(&mut self, tx: ResponseTx<()>) { Self::drain_requests(&mut self.data_requests, || Err(Error::AccountChange)); - if self.data.is_none() { + if self.data.logged_out() { let _ = tx.send(Ok(())); return; } - if let Err(err) = self.cacher.write(None).await { + if let Err(err) = self.cacher.write(&PrivateDeviceState::LoggedOut).await { let _ = tx.send(Err(err)); return; } - // Cannot panic: `data.is_none() == false`. - let old_config = self.data.take().unwrap(); + let old_config = self.data.logout(); self.listeners .retain(|listener| listener.send(PrivateDeviceEvent::Logout).is_ok()); - let logout_call = tokio::spawn(Box::pin(self.logout_api_call(old_config))); + if let Some(old_config) = old_config { + let logout_call = tokio::spawn(Box::pin(self.logout_api_call(old_config))); - tokio::spawn(async move { - let _response = tokio::time::timeout(LOGOUT_TIMEOUT, logout_call).await; + tokio::spawn(async move { + let _response = tokio::time::timeout(LOGOUT_TIMEOUT, logout_call).await; + let _ = tx.send(Ok(())); + }); + } else { + // The state was `revoked`. let _ = tx.send(Ok(())); - }); + } } fn logout_api_call(&self, data: PrivateAccountAndDevice) -> impl Future<Output = ()> + 'static { @@ -678,21 +733,21 @@ impl AccountManager { } async fn set(&mut self, event: PrivateDeviceEvent) -> Result<(), Error> { - let data = event.data(); - if data == self.data.as_ref() { + let device_state = event.clone().state(); + if device_state == self.data { return Ok(()); } - self.cacher.write(data).await?; + self.cacher.write(&device_state).await?; self.last_validation = None; - if let Some(old_config) = self.data.take() { - if data.as_ref().map(|d| &d.device.id) != Some(&old_config.device.id) { + if let Some(old_config) = self.data.logout() { + if device_state.device().map(|d| &d.device.id) != Some(&old_config.device.id) { tokio::spawn(self.logout_api_call(old_config)); } } - self.data = data.cloned(); + self.data = device_state; self.listeners .retain(|listener| listener.send(event.clone()).is_ok()); @@ -703,7 +758,7 @@ impl AccountManager { fn initiate_key_rotation( &self, ) -> Result<impl Future<Output = Result<WireguardData, Error>>, Error> { - let data = self.data.clone().ok_or(Error::NoDevice)?; + let data = self.data.device().cloned().ok_or(Error::NoDevice)?; let device_service = self.device_service.clone(); Ok(async move { device_service @@ -750,13 +805,13 @@ impl AccountManager { } fn validation_call(&self) -> Result<impl Future<Output = Result<Device, Error>>, Error> { - let old_config = self.data.as_ref().ok_or(Error::NoDevice)?; + let old_config = self.data.device().ok_or(Error::NoDevice)?; let device_request = self.fetch_device_config(old_config); Ok(async move { device_request.await }) } fn needs_validation(&mut self) -> bool { - if self.data.is_none() { + if !self.data.logged_in() { return true; } @@ -785,9 +840,7 @@ pub struct DeviceCacher { } impl DeviceCacher { - pub async fn new( - settings_dir: &Path, - ) -> Result<(DeviceCacher, Option<PrivateAccountAndDevice>), Error> { + pub async fn new(settings_dir: &Path) -> Result<(DeviceCacher, PrivateDeviceState), Error> { let path = settings_dir.join(DEVICE_CACHE_FILENAME); let cache_exists = path.is_file(); @@ -798,7 +851,7 @@ impl DeviceCacher { .open(&path) .await?; - let device: Option<PrivateAccountAndDevice> = if cache_exists { + let device: PrivateDeviceState = if cache_exists { let mut reader = io::BufReader::new(&mut file); let mut buffer = String::new(); reader.read_to_string(&mut buffer).await?; @@ -808,13 +861,13 @@ impl DeviceCacher { "{}", error.display_chain_with_msg("Wiping device config due to an error") ); - None + PrivateDeviceState::LoggedOut }) } else { - None + PrivateDeviceState::LoggedOut } } else { - None + PrivateDeviceState::LoggedOut }; Ok(( @@ -842,7 +895,7 @@ impl DeviceCacher { options } - pub async fn write(&mut self, device: Option<&PrivateAccountAndDevice>) -> Result<(), Error> { + pub async fn write(&mut self, device: &PrivateDeviceState) -> Result<(), Error> { let data = serde_json::to_vec_pretty(&device).unwrap(); self.file.get_mut().set_len(0).await?; diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d950f20663..a7eee39a70 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -611,7 +611,7 @@ where let account_history = account_history::AccountHistory::new( &settings_dir, - data.as_ref().map(|device| device.account_token.clone()), + data.device().map(|device| device.account_token.clone()), ) .await .map_err(Error::LoadAccountHistory)?; @@ -1070,7 +1070,11 @@ where let account_manager = self.account_manager.clone(); let event_listener = self.event_listener.clone(); tokio::spawn(async move { - if let Ok(Some(_)) = account_manager.data_after_login().await { + if let Ok(Some(_)) = account_manager + .data_after_login() + .await + .map(|s| s.into_device()) + { // Discard stale device return; } @@ -1202,8 +1206,10 @@ where let account_manager = self.account_manager.clone(); tokio::spawn(async move { let result = async { - if let Ok(Some(_)) = account_manager.data().await { - return Err(Error::AlreadyLoggedIn); + if let Ok(data) = account_manager.data().await { + if data.logged_in() { + return Err(Error::AlreadyLoggedIn); + } } let token = account_manager .account_service @@ -1243,7 +1249,7 @@ where } async fn on_get_www_auth_token(&mut self, tx: ResponseTx<String, Error>) { - if let Ok(Some(device)) = self.account_manager.data().await { + if let Ok(Some(device)) = self.account_manager.data().await.map(|s| s.into_device()) { let future = self .account_manager .account_service @@ -1269,7 +1275,7 @@ where tx: ResponseTx<VoucherSubmission, Error>, voucher: String, ) { - if let Ok(Some(device)) = self.account_manager.data().await { + if let Ok(Some(device)) = self.account_manager.data().await.map(|s| s.into_device()) { let mut account = self.account_manager.account_service.clone(); tokio::spawn(async move { Self::oneshot_send( @@ -1328,6 +1334,7 @@ where Ok(account_manager .data() .await + .map(|s| s.into_device()) .unwrap_or(None) .map(AccountAndDevice::from)), "get_device response", @@ -2041,11 +2048,12 @@ where } async fn on_get_wireguard_key(&self, tx: ResponseTx<Option<PublicKey>, Error>) { - let result = if let Ok(Some(config)) = self.account_manager.data().await { - Ok(Some(config.device.wg_data.get_public_key())) - } else { - Err(Error::NoAccountToken) - }; + let result = + if let Ok(Some(config)) = self.account_manager.data().await.map(|s| s.into_device()) { + Ok(Some(config.device.wg_data.get_public_key())) + } else { + Err(Error::NoAccountToken) + }; Self::oneshot_send(tx, result, "get_wireguard_key response"); } diff --git a/mullvad-daemon/src/tunnel.rs b/mullvad-daemon/src/tunnel.rs index 9e8919a41b..e95850c4f4 100644 --- a/mullvad-daemon/src/tunnel.rs +++ b/mullvad-daemon/src/tunnel.rs @@ -232,6 +232,7 @@ impl InnerParametersGenerator { self.account_manager .data() .await + .map(|s| s.into_device()) .ok() .flatten() .ok_or(Error::NoAuthDetails) diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 15a0c7dd6d..b257d92080 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -164,10 +164,10 @@ async fn reset_firewall() -> Result<(), Error> { async fn remove_device() -> Result<(), Error> { let (cache_path, settings_path) = get_paths()?; - let (cacher, data) = mullvad_daemon::device::DeviceCacher::new(&settings_path) + let (cacher, state) = mullvad_daemon::device::DeviceCacher::new(&settings_path) .await .map_err(Error::ReadDeviceCacheError)?; - if let Some(device) = data { + if let Some(device) = state.into_device() { let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false) .await .map_err(Error::RpcInitializationError)?; |
