diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-06-16 10:51:14 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-06-18 10:28:49 +0200 |
| commit | 336809f41866b2caf8e9a3bc344819e6be0757d2 (patch) | |
| tree | daf49445131d1648393b50af9d95052d0f63b46a | |
| parent | 7a36083a1156dcb26654e92a08762eeb5d75a8dc (diff) | |
| download | mullvadvpn-336809f41866b2caf8e9a3bc344819e6be0757d2.tar.xz mullvadvpn-336809f41866b2caf8e9a3bc344819e6be0757d2.zip | |
Store the WireGuard key in the settings and store a single token in the account history
| -rw-r--r-- | mullvad-daemon/src/account_history.rs | 222 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 276 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 39 | ||||
| -rw-r--r-- | mullvad-setup/src/main.rs | 67 |
4 files changed, 258 insertions, 346 deletions
diff --git a/mullvad-daemon/src/account_history.rs b/mullvad-daemon/src/account_history.rs index 1c5095321f..57d1330e3a 100644 --- a/mullvad-daemon/src/account_history.rs +++ b/mullvad-daemon/src/account_history.rs @@ -1,10 +1,8 @@ -use mullvad_rpc::{rest::MullvadRestHandle, WireguardKeyProxy}; use mullvad_types::{account::AccountToken, wireguard::WireguardData}; +use regex::Regex; use std::{ - collections::VecDeque, fs, - future::Future, - io::{self, Seek, Write}, + io::{self, Read, Seek, Write}, path::Path, sync::{Arc, Mutex}, }; @@ -29,22 +27,19 @@ pub enum Error { } static ACCOUNT_HISTORY_FILE: &str = "account-history.json"; -static ACCOUNT_HISTORY_LIMIT: usize = 1; -/// A trivial MRU cache of account data pub struct AccountHistory { file: Arc<Mutex<io::BufWriter<fs::File>>>, - accounts: Arc<Mutex<VecDeque<AccountEntry>>>, - rpc_handle: MullvadRestHandle, + token: Option<AccountToken>, +} + +lazy_static::lazy_static! { + static ref ACCOUNT_REGEX: Regex = Regex::new(r"^[0-9]+$").unwrap(); } impl AccountHistory { - pub async fn new( - cache_dir: &Path, - settings_dir: &Path, - rpc_handle: MullvadRestHandle, - ) -> Result<AccountHistory> { + pub async fn new(cache_dir: &Path, settings_dir: &Path) -> Result<AccountHistory> { Self::migrate_from_old_file_location(cache_dir, settings_dir).await; let mut options = fs::OpenOptions::new(); @@ -60,7 +55,7 @@ impl AccountHistory { options.share_mode(0); } let path = settings_dir.join(ACCOUNT_HISTORY_FILE); - let (file, accounts) = if path.is_file() { + let (file, token) = if path.is_file() { log::info!("Opening account history file in {}", path.display()); let mut reader = options .write(true) @@ -69,24 +64,20 @@ impl AccountHistory { .map(io::BufReader::new) .map_err(Error::Read)?; - let accounts: VecDeque<AccountEntry> = match serde_json::from_reader(&mut reader) { - Err(e) => { - log::warn!( - "{}", - e.display_chain_with_msg("Failed to read+deserialize account history") - ); - Self::try_old_format(&mut reader)? - .into_iter() - .map(|account| AccountEntry { - account, - wireguard: None, - }) - .collect() + let mut buffer = String::new(); + let token: Option<AccountToken> = match reader.read_to_string(&mut buffer) { + Ok(0) => None, + Ok(_) if ACCOUNT_REGEX.is_match(&buffer) => Some(buffer), + Ok(_) | Err(_) => { + log::warn!("Failed to parse account history. Trying old formats",); + match Self::try_format_v2(&mut reader)? { + Some(token) => Some(token), + None => Self::try_format_v1(&mut reader)?, + } } - Ok(accounts) => accounts, }; - (reader.into_inner(), accounts) + (reader.into_inner(), token) } else { log::info!("Creating account history file in {}", path.display()); ( @@ -95,14 +86,13 @@ impl AccountHistory { .create(true) .open(path) .map_err(Error::Read)?, - VecDeque::new(), + None, ) }; let file = io::BufWriter::new(file); let mut history = AccountHistory { file: Arc::new(Mutex::new(file)), - accounts: Arc::new(Mutex::new(accounts)), - rpc_handle, + token, }; if let Err(e) = history.save_to_disk().await { log::error!("Failed to save account cache after opening it: {}", e); @@ -129,175 +119,61 @@ impl AccountHistory { } } - fn try_old_format(reader: &mut io::BufReader<fs::File>) -> Result<Vec<AccountToken>> { + fn try_format_v1(reader: &mut io::BufReader<fs::File>) -> Result<Option<AccountToken>> { #[derive(Deserialize)] struct OldFormat { accounts: Vec<AccountToken>, } reader.seek(io::SeekFrom::Start(0)).map_err(Error::Read)?; Ok(serde_json::from_reader(reader) - .map(|old_format: OldFormat| old_format.accounts) - .unwrap_or_else(|_| Vec::new())) + .map(|old_format: OldFormat| old_format.accounts.first().cloned()) + .unwrap_or_else(|_| None)) } - /// Gets account data for a certain account id and bumps it's entry to the top of the list if - /// it isn't there already. Returns None if the account entry is not available. - pub async fn get(&mut self, account: &AccountToken) -> Result<Option<AccountEntry>> { - let (idx, entry) = match self - .accounts - .lock() - .unwrap() - .iter() - .enumerate() - .find(|(_idx, entry)| &entry.account == account) - { - Some((idx, entry)) => (idx, entry.clone()), - None => { - return Ok(None); - } - }; - // this account is already on top - if idx == 0 { - return Ok(Some(entry)); + fn try_format_v2(reader: &mut io::BufReader<fs::File>) -> Result<Option<AccountToken>> { + #[derive(Serialize, Deserialize, Clone, Debug)] + pub struct AccountEntry { + pub account: AccountToken, + pub wireguard: Option<WireguardData>, } - self.insert(entry.clone()).await?; - Ok(Some(entry)) - } - - /// Bumps history of an account token. If the account token is not in history, it will be - /// added. - pub async fn bump_history(&mut self, account: &AccountToken) -> Result<()> { - if self.get(account).await?.is_none() { - let new_entry = AccountEntry { - account: account.to_string(), - wireguard: None, - }; - self.insert(new_entry).await?; - } - Ok(()) - } - - fn create_remove_wg_key_rpc( - &self, - account: &str, - wg_data: &WireguardData, - ) -> impl Future<Output = ()> + 'static { - let mut rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); - let pub_key = wg_data.private_key.public_key(); - let account = String::from(account); - - async move { - if let Err(err) = rpc.remove_wireguard_key(account, &pub_key).await { - log::error!("Failed to remove WireGuard key: {}", err); - } - } - } - - /// Always inserts a new entry at the start of the list - pub async fn insert(&mut self, new_entry: AccountEntry) -> Result<()> { - let mut accounts = self.accounts.lock().unwrap(); - accounts.retain(|entry| entry.account != new_entry.account); - accounts.push_front(new_entry); - - while accounts.len() > ACCOUNT_HISTORY_LIMIT { - let last_entry = accounts.pop_back().unwrap(); - if let Some(wg_data) = last_entry.wireguard { - tokio::spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data)); - } - } - - std::mem::drop(accounts); - self.save_to_disk().await + reader.seek(io::SeekFrom::Start(0)).map_err(Error::Read)?; + Ok(serde_json::from_reader(reader) + .map(|entries: Vec<AccountEntry>| entries.first().map(|entry| entry.account.clone())) + .unwrap_or_else(|_| None)) } - /// Retrieve account history. - pub fn get_account_history(&self) -> Vec<AccountToken> { - self.accounts - .lock() - .unwrap() - .iter() - .map(|entry| entry.account.clone()) - .collect() + /// Gets the account token in the history + pub fn get(&self) -> Option<AccountToken> { + self.token.clone() } - /// Remove account data - pub async fn remove_account(&mut self, account: &str) -> Result<()> { - let entry = self.get(&String::from(account)).await?; - let entry = match entry { - Some(entry) => entry, - None => return Ok(()), - }; - - if let Some(wg_data) = entry.wireguard { - tokio::spawn(self.create_remove_wg_key_rpc(account, &wg_data)); - } - - let _ = self.accounts.lock().unwrap().pop_front(); + /// Replace the account token in the history + pub async fn set(&mut self, new_entry: AccountToken) -> Result<()> { + self.token = Some(new_entry); self.save_to_disk().await } /// Remove account history pub async fn clear(&mut self) -> Result<()> { - log::debug!("account_history::clear"); - - let rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); - - let removal: Vec<_> = self - .accounts - .lock() - .unwrap() - .drain(0..) - .filter_map(move |entry| { - let account = entry.account.clone(); - let mut rpc = rpc.clone(); - entry.wireguard.map(move |wg_data| { - let public_key = wg_data.private_key.public_key(); - async move { - if let Err(err) = rpc.remove_wireguard_key(account, &public_key).await { - log::error!("Failed to remove WireGuard key: {}", err); - } - } - }) - }) - .collect(); - - - futures::future::join_all(removal).await; - - { - let mut accounts = self.accounts.lock().unwrap(); - *accounts = VecDeque::new(); - } + self.token = None; self.save_to_disk().await } async fn save_to_disk(&mut self) -> Result<()> { let file = self.file.clone(); - let accounts = self.accounts.clone(); + let token = self.token.clone(); tokio::task::spawn_blocking(move || { let mut file = file.lock().unwrap(); - let accounts = accounts.lock().unwrap(); - Self::save_to_disk_inner(&mut *file, &*accounts) + file.get_mut().set_len(0).map_err(Error::Write)?; + file.seek(io::SeekFrom::Start(0)).map_err(Error::Write)?; + if let Some(token) = token { + write!(&mut file, "{}", token).map_err(Error::Write)?; + } + file.flush().map_err(Error::Write)?; + file.get_mut().sync_all().map_err(Error::Write) }) .await .map_err(Error::WriteCancelled)? } - - fn save_to_disk_inner( - mut file: &mut io::BufWriter<fs::File>, - accounts: &VecDeque<AccountEntry>, - ) -> Result<()> { - file.get_mut().set_len(0).map_err(Error::Write)?; - file.seek(io::SeekFrom::Start(0)).map_err(Error::Write)?; - serde_json::to_writer_pretty(&mut file, accounts).map_err(Error::Serialize)?; - file.flush().map_err(Error::Write)?; - file.get_mut().sync_all().map_err(Error::Write) - } -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct AccountEntry { - pub account: AccountToken, - pub wireguard: Option<WireguardData>, } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 546f06647b..bccd1dea77 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -103,13 +103,16 @@ pub enum Error { #[error(display = "REST request failed")] RestError(#[error(source)] mullvad_rpc::rest::Error), - #[error(display = "Unable to load account history with wireguard key cache")] + #[error(display = "Unable to load account history")] LoadAccountHistory(#[error(source)] account_history::Error), #[cfg(target_os = "linux")] #[error(display = "Unable to initialize split tunneling")] InitSplitTunneling(#[error(source)] split_tunnel::Error), + #[error(display = "The account has too many wireguard keys")] + TooManyKeys, + #[error(display = "No wireguard private key available")] NoKeyAvailable, @@ -579,10 +582,9 @@ where settings.show_beta_releases, ); tokio::spawn(version_updater.run()); - let account_history = - account_history::AccountHistory::new(&cache_dir, &settings_dir, rpc_handle.clone()) - .await - .map_err(Error::LoadAccountHistory)?; + let account_history = account_history::AccountHistory::new(&cache_dir, &settings_dir) + .await + .map_err(Error::LoadAccountHistory)?; // Restore the tunnel to a previous state let target_cache = cache_dir.join(TARGET_START_STATE_FILE); @@ -938,12 +940,7 @@ where &constraints, self.settings.get_bridge_state(), retry_attempt, - self.account_history - .get(&account_token) - .await - .unwrap_or(None) - .and_then(|entry| entry.wireguard) - .is_some(), + self.settings.get_wireguard().is_some(), ) .ok(); if let Some((relay, endpoint)) = endpoint { @@ -1087,13 +1084,7 @@ where let exit_peer = entry_peer.as_ref().map(|_| peer.clone()); let entry_peer = entry_peer.unwrap_or(peer); - let wg_data = self - .account_history - .get(&account_token) - .await - .map_err(Error::AccountHistory)? - .and_then(|entry| entry.wireguard) - .ok_or(Error::NoKeyAvailable)?; + let wg_data = self.settings.get_wireguard().ok_or(Error::NoKeyAvailable)?; let tunnel = wireguard::TunnelConfig { private_key: wg_data.private_key, addresses: vec![ @@ -1225,27 +1216,15 @@ where match result { Ok(data) => { let public_key = data.get_public_key(); - let mut account_entry = self - .account_history - .get(&account) - .await - .ok() - .and_then(|entry| entry) - .unwrap_or_else(|| account_history::AccountEntry { - account: account.clone(), - wireguard: None, - }); - // if no key existed before - let first_key_for_account_on_host = account_entry.wireguard.is_none(); - account_entry.wireguard = Some(data); - match self.account_history.insert(account_entry).await { + let is_first_key = self.settings.get_wireguard().is_none(); + match self.settings.set_wireguard(Some(data)).await { Ok(_) => { if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { self.schedule_reconnect(WG_RECONNECT_DELAY).await; } self.event_listener .notify_key_event(KeygenEvent::NewKey(public_key)); - if first_key_for_account_on_host { + if is_first_key { self.ensure_key_rotation().await; } } @@ -1277,15 +1256,21 @@ where } async fn ensure_key_rotation(&mut self) { - if let Some(token) = self.settings.get_account_token() { - self.wireguard_key_manager - .set_rotation_interval( - &mut self.account_history, - token, - self.settings.tunnel_options.wireguard.rotation_interval, - ) - .await; - } + let token = match self.settings.get_account_token() { + Some(token) => token, + None => return, + }; + let public_key = match self.settings.get_wireguard() { + Some(data) => data.get_public_key(), + None => return, + }; + self.wireguard_key_manager + .set_rotation_interval( + public_key, + token, + self.settings.tunnel_options.wireguard.rotation_interval, + ) + .await; } async fn handle_new_account_event( @@ -1521,6 +1506,7 @@ where &mut self, account_token: Option<String>, ) -> Result<bool, settings::Error> { + let previous_token = self.settings.get_account_token(); let account_changed = self .settings .set_account_token(account_token.clone()) @@ -1529,13 +1515,44 @@ where self.event_listener .notify_settings(self.settings.to_settings()); - // Bump account history if a token was set - if let Some(token) = account_token.clone() { - if let Err(e) = self.account_history.bump_history(&token).await { - log::error!("Failed to bump account history: {}", e); - } + let history_token = match account_token { + Some(token) => token, + None => previous_token.clone().unwrap_or("".to_string()), + }; + if let Err(error) = self.account_history.set(history_token).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update account history") + ); } + if let Some(previous_token) = previous_token { + if let Some(previous_key) = self + .settings + .get_wireguard() + .map(|data| data.private_key.public_key()) + { + let remove_key = self + .wireguard_key_manager + .remove_key(previous_token, previous_key); + tokio::spawn(async move { + if let Err(error) = remove_key.await { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to remove WireGuard key for previous account" + ) + ); + } + }); + } + } + if let Err(error) = self.settings.set_wireguard(None).await { + log::error!( + "{}", + error.display_chain_with_msg("Error resetting WireGuard key") + ); + } self.ensure_wireguard_keys_for_current_account().await; } Ok(account_changed) @@ -1544,7 +1561,10 @@ where fn on_get_account_history(&mut self, tx: oneshot::Sender<Vec<AccountToken>>) { Self::oneshot_send( tx, - self.account_history.get_account_history(), + self.account_history + .get() + .map(|token| vec![token]) + .unwrap_or(vec![]), "get_account_history response", ); } @@ -1554,16 +1574,63 @@ where tx: ResponseTx<(), Error>, account_token: AccountToken, ) { - let result = self - .account_history - .remove_account(&account_token) - .await - .map_err(Error::AccountHistory); + let result = if self.account_history.get() == Some(account_token) { + self.account_history + .clear() + .await + .map_err(Error::AccountHistory) + } else { + Ok(()) + }; Self::oneshot_send(tx, result, "remove_account_from_history response"); } + // Remove the key associated with the current account, if there is one. + // This does not modify settings or account history. + #[cfg(not(target_os = "android"))] + fn remove_current_key_rpc(&self) -> impl std::future::Future<Output = Result<(), Error>> { + let remove_key = if let Some(token) = self.settings.get_account_token() { + if let Some(wg_data) = self.settings.get_wireguard() { + Some( + self.wireguard_key_manager + .remove_key(token, wg_data.private_key.public_key()), + ) + } else { + None + } + } else { + None + }; + + async move { + if let Some(task) = remove_key { + match task.await { + Err(wireguard::Error::RestError(error)) => Err(Error::RestError(error)), + // This result should never occur + Err(wireguard::Error::TooManyKeys) => Err(Error::TooManyKeys), + _ => Ok(()), + } + } else { + Ok(()) + } + } + } + async fn on_clear_account_history(&mut self, tx: ResponseTx<(), Error>) { - match self.account_history.clear().await { + if let Err(error) = self.remove_current_key_rpc().await { + Self::oneshot_send(tx, Err(error), "clear_account_history response"); + return; + } + if let Err(error) = self.account_history.clear().await { + Self::oneshot_send( + tx, + Err(Error::ClearAccountHistoryError(error)), + "clear_account_history response", + ); + return; + } + + match self.settings.set_wireguard(None).await { Ok(_) => { self.set_target_state(TargetState::Unsecured).await; Self::oneshot_send(tx, Ok(()), "clear_account_history response"); @@ -1575,7 +1642,7 @@ where ); Self::oneshot_send( tx, - Err(Error::AccountHistory(err)), + Err(Error::SettingsError(err)), "clear_account_history response", ); } @@ -1608,17 +1675,31 @@ where async fn on_factory_reset(&mut self, tx: ResponseTx<(), Error>) { let mut last_error = Ok(()); + let remove_key = self.remove_current_key_rpc(); + tokio::spawn(async move { + if let Err(error) = remove_key.await { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to remove WireGuard key for previous account" + ) + ); + } + }); + + if let Err(error) = self.account_history.clear().await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to clear account history") + ); + last_error = Err(Error::ClearAccountHistoryError(error)); + } if let Err(e) = self.settings.reset().await { log::error!("Failed to reset settings - {}", e); last_error = Err(Error::ClearSettingsError(e)); } - if let Err(e) = self.account_history.clear().await { - log::error!("Failed to clear account history - {}", e); - last_error = Err(Error::ClearAccountHistoryError(e)); - } - // Shut the daemon down. self.trigger_shutdown_event(); @@ -1969,13 +2050,7 @@ where async fn ensure_wireguard_keys_for_current_account(&mut self) { if let Some(account) = self.settings.get_account_token() { - if self - .account_history - .get(&account) - .await - .map(|entry| entry.map(|e| e.wireguard.is_none()).unwrap_or(true)) - .unwrap_or(true) - { + if self.settings.get_wireguard().is_none() { log::info!("Generating new WireGuard key for account"); self.wireguard_key_manager .spawn_key_generation_task(account, Some(FIRST_KEY_PUSH_TIMEOUT)) @@ -2007,23 +2082,9 @@ where .settings .get_account_token() .ok_or(Error::NoAccountToken)?; + let wireguard_data = self.settings.get_wireguard(); - let mut account_entry = self - .account_history - .get(&account_token) - .await - .map_err(Error::AccountHistory) - .map(|data| { - data.unwrap_or_else(|| { - log::error!("Account token set in settings but not in account history"); - account_history::AccountEntry { - account: account_token.clone(), - wireguard: None, - } - }) - })?; - - let gen_result = match &account_entry.wireguard { + let gen_result = match &wireguard_data { Some(wireguard_data) => { self.wireguard_key_manager .replace_key(account_token.clone(), wireguard_data.get_public_key()) @@ -2039,21 +2100,20 @@ where match gen_result { Ok(new_data) => { let public_key = new_data.get_public_key(); - account_entry.wireguard = Some(new_data); - self.account_history - .insert(account_entry) + self.settings + .set_wireguard(Some(new_data)) .await - .map_err(Error::AccountHistory)?; + .map_err(Error::SettingsError)?; if let Some(TunnelType::Wireguard) = self.get_target_tunnel_type() { self.schedule_reconnect(WG_RECONNECT_DELAY).await; } - let keygen_event = KeygenEvent::NewKey(public_key); + let keygen_event = KeygenEvent::NewKey(public_key.clone()); self.event_listener.notify_key_event(keygen_event.clone()); // update automatic rotation self.wireguard_key_manager .set_rotation_interval( - &mut self.account_history, + public_key, account_token, self.settings.tunnel_options.wireguard.rotation_interval, ) @@ -2067,17 +2127,11 @@ where } async fn on_get_wireguard_key(&mut self, tx: ResponseTx<Option<wireguard::PublicKey>, Error>) { - let token = self.settings.get_account_token(); - let result = if let Some(token) = token { - let entry = self.account_history.get(&token).await; - match entry { - Ok(Some(entry)) => { - let key = entry.wireguard.map(|wg| wg.get_public_key()); - Ok(key) - } - Ok(None) => Err(Error::NoAccountTokenHistory), - Err(error) => Err(Error::AccountHistory(error)), - } + let result = if self.settings.get_account_token().is_some() { + Ok(self + .settings + .get_wireguard() + .map(|data| data.get_public_key())) } else { Err(Error::NoAccountToken) }; @@ -2092,28 +2146,12 @@ where return; } }; - - let key = self - .account_history - .get(&account) - .await - .map(|entry| entry.and_then(|e| e.wireguard.map(|wg| wg.private_key.public_key()))); - - let public_key = match key { - Ok(Some(public_key)) => public_key, - Ok(None) => { + let public_key = match self.settings.get_wireguard() { + Some(wg_data) => wg_data.private_key.public_key(), + None => { Self::oneshot_send(tx, Ok(false), "verify_wireguard_key response"); return; } - Err(e) => { - log::error!("Failed to read key data: {}", e); - Self::oneshot_send( - tx, - Err(Error::AccountHistory(e)), - "verify_wireguard_key response", - ); - return; - } }; let verification_rpc = self diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs index 5c2ebc16d3..575ffa261c 100644 --- a/mullvad-daemon/src/wireguard.rs +++ b/mullvad-daemon/src/wireguard.rs @@ -1,4 +1,4 @@ -use crate::{account_history::AccountHistory, DaemonEventSender, InternalDaemonEvent}; +use crate::{DaemonEventSender, InternalDaemonEvent}; use chrono::offset::Utc; use mullvad_rpc::rest::{Error as RestError, MullvadRestHandle}; use mullvad_types::account::AccountToken; @@ -59,38 +59,21 @@ impl KeyManager { /// Reset key rotation, cancelling the current one and starting a new one for the specified /// account - pub async fn reset_rotation( - &mut self, - account_history: &mut AccountHistory, - account_token: AccountToken, - ) { - match account_history - .get(&account_token) + pub async fn reset_rotation(&mut self, current_key: PublicKey, account_token: AccountToken) { + self.run_automatic_rotation(account_token, current_key) .await - .map(|entry| entry.map(|entry| entry.wireguard.map(|wg| wg.get_public_key()))) - { - Ok(Some(Some(public_key))) => { - self.run_automatic_rotation(account_token, public_key).await - } - Ok(Some(None)) => { - log::error!("reset_rotation: failed to obtain public key for account entry.") - } - Ok(None) => log::error!("reset_rotation: account entry not found."), - Err(e) => log::error!("reset_rotation: failed to obtain account entry. {}", e), - }; } /// Update automatic key rotation interval /// Passing `None` for the interval will cause the default value to be used. pub async fn set_rotation_interval( &mut self, - account_history: &mut AccountHistory, + current_key: PublicKey, account_token: AccountToken, auto_rotation_interval: Option<RotationInterval>, ) { self.auto_rotation_interval = auto_rotation_interval.unwrap_or_default(); - - self.reset_rotation(account_history, account_token).await; + self.reset_rotation(current_key, account_token).await; } /// Stop current key generation @@ -144,15 +127,17 @@ impl KeyManager { } /// Removes a key from an account - pub async fn remove_key( + pub fn remove_key( &self, account: AccountToken, key: talpid_types::net::wireguard::PublicKey, - ) -> Result<()> { + ) -> impl Future<Output = Result<()>> { let mut rpc = mullvad_rpc::WireguardKeyProxy::new(self.http_handle.clone()); - rpc.remove_wireguard_key(account, &key) - .await - .map_err(Self::map_rpc_error) + async move { + rpc.remove_wireguard_key(account, &key) + .await + .map_err(Self::map_rpc_error) + } } fn should_retry(error: &RestError) -> bool { diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 45186dd1a7..4c7fde8e4d 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -1,11 +1,15 @@ use clap::{crate_authors, crate_description, crate_name, SubCommand}; -use mullvad_daemon::account_history; use mullvad_management_interface::new_rpc_client; use mullvad_rpc::MullvadRpcRuntime; -use mullvad_types::version::ParsedAppVersion; -use std::{path::PathBuf, process}; +use mullvad_types::{settings, version::ParsedAppVersion}; +use std::{ + io, + path::{Path, PathBuf}, + process, +}; use talpid_core::firewall::{self, Firewall, FirewallArguments}; use talpid_types::ErrorExt; +use tokio::fs; pub const PRODUCT_VERSION: &str = include_str!(concat!(env!("OUT_DIR"), "/product-version.txt")); @@ -58,17 +62,20 @@ pub enum Error { #[error(display = "Failed to initialize mullvad RPC runtime")] RpcInitializationError(#[error(source)] mullvad_rpc::Error), + #[error(display = "Failed to remove WireGuard key for account")] + RemoveKeyError(#[error(source)] mullvad_rpc::rest::Error), + #[error(display = "Failed to obtain settings directory path")] SettingsPathError(#[error(source)] SettingsPathErrorType), #[error(display = "Failed to obtain cache directory path")] CachePathError(#[error(source)] mullvad_paths::Error), - #[error(display = "Failed to initialize account history")] - InitializeAccountHistoryError(#[error(source)] account_history::Error), + #[error(display = "Failed to load settings")] + LoadSettingsError(#[error(source)] io::Error), - #[error(display = "Failed to initialize account history")] - ClearAccountHistoryError(#[error(source)] account_history::Error), + #[error(display = "Failed to parse settings")] + ParseSettingsError(#[error(source)] settings::Error), #[error(display = "Cannot parse the version string")] ParseVersionStringError, @@ -163,31 +170,37 @@ async fn reset_firewall() -> Result<(), Error> { async fn clear_history() -> Result<(), Error> { let (cache_path, settings_path) = get_paths()?; + let settings = load_settings(&settings_path).await?; - let mut rpc_runtime = MullvadRpcRuntime::with_cache( - tokio::runtime::Handle::current(), - None, - &cache_path, - false, - |_| Ok(()), - ) - .await - .map_err(Error::RpcInitializationError)?; + if let Some(token) = settings.get_account_token() { + if let Some(wg_data) = settings.get_wireguard() { + let mut rpc_runtime = MullvadRpcRuntime::with_cache( + tokio::runtime::Handle::current(), + None, + &cache_path, + false, + |_| Ok(()), + ) + .await + .map_err(Error::RpcInitializationError)?; + let mut key_proxy = + mullvad_rpc::WireguardKeyProxy::new(rpc_runtime.mullvad_rest_handle()); + key_proxy + .remove_wireguard_key(token, &wg_data.private_key.public_key()) + .await + .map_err(Error::RemoveKeyError)?; + } + } - let mut account_history = account_history::AccountHistory::new( - &cache_path, - &settings_path, - rpc_runtime.mullvad_rest_handle(), - ) - .await - .map_err(Error::InitializeAccountHistoryError)?; - account_history - .clear() - .await - .map_err(Error::ClearAccountHistoryError)?; Ok(()) } +async fn load_settings(settings_dir: &Path) -> Result<settings::Settings, Error> { + let path = settings_dir.join("settings.json"); + let settings_bytes = fs::read(path).await.map_err(Error::LoadSettingsError)?; + settings::Settings::load_from_bytes(&settings_bytes).map_err(Error::ParseSettingsError) +} + #[cfg(not(windows))] fn get_paths() -> Result<(PathBuf, PathBuf), Error> { let cache_path = mullvad_paths::cache_dir().map_err(Error::CachePathError)?; |
