diff options
| -rw-r--r-- | mullvad-daemon/src/account_history.rs | 29 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 441 | ||||
| -rw-r--r-- | mullvad-daemon/src/main.rs | 2 | ||||
| -rw-r--r-- | mullvad-daemon/src/system_service.rs | 29 | ||||
| -rw-r--r-- | mullvad-daemon/src/wireguard.rs | 35 | ||||
| -rw-r--r-- | mullvad-rpc/src/rest.rs | 13 |
6 files changed, 279 insertions, 270 deletions
diff --git a/mullvad-daemon/src/account_history.rs b/mullvad-daemon/src/account_history.rs index 05ea1c7d3b..084ea271ee 100644 --- a/mullvad-daemon/src/account_history.rs +++ b/mullvad-daemon/src/account_history.rs @@ -123,7 +123,7 @@ impl AccountHistory { /// Gets account data for a certain account id and bumps it's entry to the top of the list if /// it isn't there already. Returns None if the account entry is not available. - pub fn get(&mut self, account: &AccountToken) -> Result<Option<AccountEntry>> { + pub async fn get(&mut self, account: &AccountToken) -> Result<Option<AccountEntry>> { let (idx, entry) = match self .accounts .iter() @@ -139,19 +139,19 @@ impl AccountHistory { if idx == 0 { return Ok(Some(entry)); } - self.insert(entry.clone())?; + self.insert(entry.clone()).await?; Ok(Some(entry)) } /// Bumps history of an account token. If the account token is not in history, it will be /// added. - pub fn bump_history(&mut self, account: &AccountToken) -> Result<()> { - if self.get(account)?.is_none() { + pub async fn bump_history(&mut self, account: &AccountToken) -> Result<()> { + if self.get(account).await?.is_none() { let new_entry = AccountEntry { account: account.to_string(), wireguard: None, }; - self.insert(new_entry)?; + self.insert(new_entry).await?; } Ok(()) } @@ -173,7 +173,7 @@ impl AccountHistory { } /// Always inserts a new entry at the start of the list - pub fn insert(&mut self, new_entry: AccountEntry) -> Result<()> { + pub async fn insert(&mut self, new_entry: AccountEntry) -> Result<()> { self.accounts .retain(|entry| entry.account != new_entry.account); @@ -182,9 +182,7 @@ impl AccountHistory { if self.accounts.len() > ACCOUNT_HISTORY_LIMIT { let last_entry = self.accounts.pop_back().unwrap(); if let Some(wg_data) = last_entry.wireguard { - self.rpc_handle - .service() - .spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data)); + tokio::spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data)); } } @@ -200,17 +198,15 @@ impl AccountHistory { } /// Remove account data - pub fn remove_account(&mut self, account: &str) -> Result<()> { - let entry = self.get(&String::from(account))?; + pub async fn remove_account(&mut self, account: &str) -> Result<()> { + let entry = self.get(&String::from(account)).await?; let entry = match entry { Some(entry) => entry, None => return Ok(()), }; if let Some(wg_data) = entry.wireguard { - self.rpc_handle - .service() - .spawn(self.create_remove_wg_key_rpc(account, &wg_data)) + tokio::spawn(self.create_remove_wg_key_rpc(account, &wg_data)); } let _ = self.accounts.pop_front(); @@ -218,7 +214,7 @@ impl AccountHistory { } /// Remove account history - pub fn clear(&mut self) -> Result<()> { + pub async fn clear(&mut self) -> Result<()> { log::debug!("account_history::clear"); let rpc = WireguardKeyProxy::new(self.rpc_handle.clone()); @@ -241,8 +237,7 @@ impl AccountHistory { .collect(); - let joined_futs = futures::future::join_all(removal); - self.rpc_handle.service().block_on(joined_futs); + futures::future::join_all(removal).await; self.accounts = VecDeque::new(); self.save_to_disk() diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index abb51942a6..412147c2b5 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -21,8 +21,8 @@ mod version_check; use futures::{ channel::{mpsc, oneshot}, compat::Future01CompatExt, - executor::BlockingStream, future::{abortable, AbortHandle, Future}, + StreamExt, }; use futures01::Future as Future01; use log::{debug, error, info, warn}; @@ -456,7 +456,7 @@ pub struct Daemon<L: EventListener> { state: DaemonExecutionState, #[cfg(target_os = "linux")] exclude_pids: split_tunnel::PidManager, - rx: BlockingStream<mpsc::UnboundedReceiver<InternalDaemonEvent>>, + rx: mpsc::UnboundedReceiver<InternalDaemonEvent>, tx: DaemonEventSender, reconnection_job: Option<AbortHandle>, event_listener: L, @@ -602,7 +602,7 @@ where state: DaemonExecutionState::Running, #[cfg(target_os = "linux")] exclude_pids: split_tunnel::PidManager::new().map_err(Error::InitSplitTunneling)?, - rx: futures::executor::block_on_stream(internal_event_rx), + rx: internal_event_rx, tx: internal_event_tx, reconnection_job: None, event_listener, @@ -622,19 +622,22 @@ where cache_dir, }; - daemon.ensure_wireguard_keys_for_current_account(); + daemon.ensure_wireguard_keys_for_current_account().await; if let Some(token) = daemon.settings.get_account_token() { - daemon.wireguard_key_manager.set_rotation_interval( - &mut daemon.account_history, - token, - daemon - .settings - .tunnel_options - .wireguard - .automatic_rotation - .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)), - ); + daemon + .wireguard_key_manager + .set_rotation_interval( + &mut daemon.account_history, + token, + daemon + .settings + .tunnel_options + .wireguard + .automatic_rotation + .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)), + ) + .await; } Ok(daemon) @@ -642,43 +645,38 @@ where /// Consume the `Daemon` and run the main event loop. Blocks until an error happens or a /// shutdown event is received. - pub fn run(mut self) -> Result<(), Error> { + pub async fn run(mut self) -> Result<(), Error> { if self.target_state == TargetState::Secured { self.connect_tunnel(); } - while let Some(event) = self.rx.next() { - self.handle_event(event); + while let Some(event) = self.rx.next().await { + self.handle_event(event).await; if self.state == DaemonExecutionState::Finished { break; } } - self.finalize(); + self.finalize().await; Ok(()) } - fn finalize(self) { - let ( - event_listener, - shutdown_callbacks, - mut rpc_runtime, - tunnel_state_machine_shutdown_signal, - ) = self.shutdown(); + async fn finalize(self) { + let (event_listener, shutdown_callbacks, rpc_runtime, tunnel_state_machine_shutdown_signal) = + self.shutdown(); for cb in shutdown_callbacks { cb(); } - rpc_runtime.handle().block_on(async { - let shutdown_signal = tokio::time::timeout( - TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, - tunnel_state_machine_shutdown_signal, - ); - match shutdown_signal.await { - Ok(_) => log::info!("Tunnel state machine shut down"), - Err(_) => log::error!("Tunnel state machine did not shut down gracefully"), - } - }); + let shutdown_signal = tokio::time::timeout( + TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT, + tunnel_state_machine_shutdown_signal, + ); + match shutdown_signal.await { + Ok(_) => log::info!("Tunnel state machine shut down"), + Err(_) => log::error!("Tunnel state machine did not shut down gracefully"), + } + mem::drop(event_listener); mem::drop(rpc_runtime); } @@ -709,31 +707,39 @@ where } - fn handle_event(&mut self, event: InternalDaemonEvent) { + async fn handle_event(&mut self, event: InternalDaemonEvent) { use self::InternalDaemonEvent::*; match event { - TunnelStateTransition(transition) => self.handle_tunnel_state_transition(transition), + TunnelStateTransition(transition) => { + self.handle_tunnel_state_transition(transition).await + } GenerateTunnelParameters(tunnel_parameters_tx, retry_attempt) => { self.handle_generate_tunnel_parameters(&tunnel_parameters_tx, retry_attempt) + .await } - Command(command) => self.handle_command(command), + Command(command) => self.handle_command(command).await, TriggerShutdown => self.trigger_shutdown_event(), - WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event), - NewAccountEvent(account_token, tx) => self.handle_new_account_event(account_token, tx), + WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await, + NewAccountEvent(account_token, tx) => { + self.handle_new_account_event(account_token, tx).await + } NewAppVersionInfo(app_version_info) => { self.handle_new_app_version_info(app_version_info) } } } - fn handle_tunnel_state_transition(&mut self, tunnel_state_transition: TunnelStateTransition) { + async fn handle_tunnel_state_transition( + &mut self, + tunnel_state_transition: TunnelStateTransition, + ) { match &tunnel_state_transition { TunnelStateTransition::Disconnected | TunnelStateTransition::Connected(_) | TunnelStateTransition::Error(_) => { // Reset the RPCs so that they fail immediately after the underlying socket gets // invalidated due to the tunnel either coming up or breaking. - self.rpc_handle.service().reset(); + self.rpc_handle.service().reset().await; } _ => (), }; @@ -774,7 +780,7 @@ where } if let ErrorStateCause::AuthFailed(_) = error_state.cause() { - self.schedule_reconnect(Duration::from_secs(60)) + self.schedule_reconnect(Duration::from_secs(60)).await } } _ => {} @@ -784,7 +790,7 @@ where self.event_listener.notify_new_state(tunnel_state); } - fn handle_generate_tunnel_parameters( + async fn handle_generate_tunnel_parameters( &mut self, tunnel_parameters_tx: &sync_mpsc::Sender< Result<TunnelParameters, ParameterGenerationError>, @@ -803,26 +809,30 @@ where ParameterGenerationError::CustomTunnelHostResultionError }) } - RelaySettings::Normal(constraints) => self - .relay_selector - .get_tunnel_endpoint( - &constraints, - self.settings.get_bridge_state(), - retry_attempt, - self.account_history - .get(&account_token) - .unwrap_or(None) - .and_then(|entry| entry.wireguard) - .is_some(), - ) - .map_err(|_| ParameterGenerationError::NoMatchingRelay) - .and_then(|(relay, endpoint)| { - let result = self.create_tunnel_parameters( - &relay, - endpoint, - account_token, + RelaySettings::Normal(constraints) => { + let endpoint = self + .relay_selector + .get_tunnel_endpoint( + &constraints, + self.settings.get_bridge_state(), retry_attempt, - ); + self.account_history + .get(&account_token) + .await + .unwrap_or(None) + .and_then(|entry| entry.wireguard) + .is_some(), + ) + .ok(); + if let Some((relay, endpoint)) = endpoint { + let result = self + .create_tunnel_parameters( + &relay, + endpoint, + account_token, + retry_attempt, + ) + .await; self.last_generated_relay = Some(relay); match result { Ok(result) => Ok(result), @@ -842,7 +852,10 @@ where Err(ParameterGenerationError::NoMatchingRelay) } } - }), + } else { + Err(ParameterGenerationError::NoMatchingRelay) + } + } }; if tunnel_parameters_tx.send(result).is_err() { log::error!("Failed to send tunnel parameters"); @@ -852,7 +865,7 @@ where } } - fn create_tunnel_parameters( + async fn create_tunnel_parameters( &mut self, relay: &Relay, endpoint: MullvadEndpoint, @@ -933,6 +946,7 @@ where let wg_data = self .account_history .get(&account_token) + .await .map_err(Error::AccountHistory)? .and_then(|entry| entry.wireguard) .ok_or(Error::NoKeyAvailable)?; @@ -958,7 +972,7 @@ where } } - fn schedule_reconnect(&mut self, delay: Duration) { + async fn schedule_reconnect(&mut self, delay: Duration) { let tunnel_command_tx = self.tx.to_specialized_sender(); let (future, abort_handle) = abortable(Box::pin(async move { tokio::time::delay_for(delay).await; @@ -966,7 +980,7 @@ where let _ = tunnel_command_tx.send(DaemonCommand::Reconnect); })); - self.spawn_future(future); + tokio::spawn(future); self.reconnection_job = Some(abort_handle); } @@ -976,23 +990,8 @@ where } } - fn spawn_future<F>(&mut self, fut: F) - where - F: std::future::Future + Send + 'static, - F::Output: Send, - { - self.rpc_runtime.handle().spawn(fut); - } - - fn block_on_future<F>(&mut self, fut: F) -> F::Output - where - F: std::future::Future, - { - self.rpc_runtime.handle().block_on(fut) - } - - fn handle_command(&mut self, command: DaemonCommand) { + async fn handle_command(&mut self, command: DaemonCommand) { use self::DaemonCommand::*; if !self.state.is_running() { log::trace!("Dropping daemon command because the daemon is shutting down",); @@ -1002,22 +1001,22 @@ where SetTargetState(tx, state) => self.on_set_target_state(tx, state), Reconnect => self.on_reconnect(), GetState(tx) => self.on_get_state(tx), - GetCurrentLocation(tx) => self.on_get_current_location(tx), - CreateNewAccount(tx) => self.on_create_new_account(tx), - GetAccountData(tx, account_token) => self.on_get_account_data(tx, account_token), - GetWwwAuthToken(tx) => self.on_get_www_auth_token(tx), - SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher), + GetCurrentLocation(tx) => self.on_get_current_location(tx).await, + CreateNewAccount(tx) => self.on_create_new_account(tx).await, + GetAccountData(tx, account_token) => self.on_get_account_data(tx, account_token).await, + GetWwwAuthToken(tx) => self.on_get_www_auth_token(tx).await, + SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await, GetRelayLocations(tx) => self.on_get_relay_locations(tx), - UpdateRelayLocations => self.on_update_relay_locations(), - SetAccount(tx, account_token) => self.on_set_account(tx, account_token), + UpdateRelayLocations => self.on_update_relay_locations().await, + SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await, GetAccountHistory(tx) => self.on_get_account_history(tx), RemoveAccountFromHistory(tx, account_token) => { - self.on_remove_account_from_history(tx, account_token) + self.on_remove_account_from_history(tx, account_token).await } - ClearAccountHistory(tx) => self.on_clear_account_history(tx), + ClearAccountHistory(tx) => self.on_clear_account_history(tx).await, UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update), SetAllowLan(tx, allow_lan) => self.on_set_allow_lan(tx, allow_lan), - SetShowBetaReleases(tx, enabled) => self.on_set_show_beta_releases(tx, enabled), + SetShowBetaReleases(tx, enabled) => self.on_set_show_beta_releases(tx, enabled).await, SetBlockWhenDisconnected(tx, block_when_disconnected) => { self.on_set_block_when_disconnected(tx, block_when_disconnected) } @@ -1030,16 +1029,16 @@ where SetEnableIpv6(tx, enable_ipv6) => self.on_set_enable_ipv6(tx, enable_ipv6), SetWireguardMtu(tx, mtu) => self.on_set_wireguard_mtu(tx, mtu), SetWireguardRotationInterval(tx, interval) => { - self.on_set_wireguard_rotation_interval(tx, interval) + self.on_set_wireguard_rotation_interval(tx, interval).await } GetSettings(tx) => self.on_get_settings(tx), - GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx), - GetWireguardKey(tx) => self.on_get_wireguard_key(tx), - VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx), + GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await, + GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await, + VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await, GetVersionInfo(tx) => self.on_get_version_info(tx), GetCurrentVersion(tx) => self.on_get_current_version(tx), #[cfg(not(target_os = "android"))] - FactoryReset(tx) => self.on_factory_reset(tx), + FactoryReset(tx) => self.on_factory_reset(tx).await, #[cfg(target_os = "linux")] GetSplitTunnelProcesses(tx) => self.on_get_split_tunnel_processes(tx), #[cfg(target_os = "linux")] @@ -1053,7 +1052,7 @@ where } } - fn handle_wireguard_key_event( + async fn handle_wireguard_key_event( &mut self, event: ( AccountToken, @@ -1079,6 +1078,7 @@ where let mut account_entry = self .account_history .get(&account) + .await .ok() .and_then(|entry| entry) .unwrap_or_else(|| account_history::AccountEntry { @@ -1086,10 +1086,10 @@ where wireguard: None, }); account_entry.wireguard = Some(data); - match self.account_history.insert(account_entry) { + match self.account_history.insert(account_entry).await { Ok(_) => { if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { - self.schedule_reconnect(WG_RECONNECT_DELAY); + self.schedule_reconnect(WG_RECONNECT_DELAY).await; } self.event_listener .notify_key_event(KeygenEvent::NewKey(public_key)) @@ -1121,12 +1121,12 @@ where } } - fn handle_new_account_event( + async fn handle_new_account_event( &mut self, new_token: AccountToken, tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>, ) { - match self.set_account(Some(new_token.clone())) { + match self.set_account(Some(new_token.clone())).await { Ok(_) => { self.set_target_state(TargetState::Unsecured); let _ = tx.send(Ok(new_token)); @@ -1167,13 +1167,13 @@ where Self::oneshot_send(tx, self.tunnel_state.clone(), "current state"); } - fn on_get_current_location(&mut self, tx: oneshot::Sender<Option<GeoIpLocation>>) { + async fn on_get_current_location(&mut self, tx: oneshot::Sender<Option<GeoIpLocation>>) { use self::TunnelState::*; match &self.tunnel_state { Disconnected => { let location = self.get_geo_location(); - self.rpc_runtime.handle().spawn(async { + tokio::spawn(async { Self::oneshot_send(tx, location.await.ok(), "current location"); }); } @@ -1185,11 +1185,12 @@ where } Connected { location, .. } => { let relay_location = location.clone(); - let location = self.get_geo_location(); - self.rpc_runtime.handle().spawn(async { + let location_future = self.get_geo_location(); + tokio::spawn(async { + let location = location_future.await; Self::oneshot_send( tx, - location.await.ok().map(|fetched_location| GeoIpLocation { + location.ok().map(|fetched_location| GeoIpLocation { ipv4: fetched_location.ipv4, ipv6: fetched_location.ipv6, ..relay_location.unwrap_or(fetched_location) @@ -1240,7 +1241,7 @@ where }) } - fn on_create_new_account( + async fn on_create_new_account( &mut self, tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>, ) { @@ -1261,14 +1262,14 @@ where Ok(()) }); - self.rpc_runtime.handle().spawn(async { + tokio::spawn(async { if future.compat().await.is_err() { log::error!("Failed to spawn future for creating a new account"); } }); } - fn on_get_account_data( + async fn on_get_account_data( &mut self, tx: oneshot::Sender<Result<AccountData, mullvad_rpc::rest::Error>>, account_token: AccountToken, @@ -1281,10 +1282,10 @@ where .map(|expiry| AccountData { expiry }); Self::oneshot_send(tx, result, "account data"); }; - self.rpc_runtime.handle().spawn(rpc_call); + tokio::spawn(rpc_call); } - fn on_get_www_auth_token( + async fn on_get_www_auth_token( &mut self, tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>, ) { @@ -1294,11 +1295,11 @@ where let result = old_future.compat().await; Self::oneshot_send(tx, result, "get_www_auth_token response"); }; - self.rpc_runtime.handle().spawn(rpc_call); + tokio::spawn(rpc_call); } } - fn on_submit_voucher( + async fn on_submit_voucher( &mut self, tx: oneshot::Sender<Result<VoucherSubmission, mullvad_rpc::rest::Error>>, voucher: String, @@ -1309,7 +1310,7 @@ where let result = old_future.compat().await; Self::oneshot_send(tx, result, "submit_voucher response"); }; - self.rpc_runtime.handle().spawn(rpc_call); + tokio::spawn(rpc_call); } } @@ -1317,13 +1318,12 @@ where Self::oneshot_send(tx, self.relay_selector.get_locations(), "relay locations"); } - fn on_update_relay_locations(&mut self) { - let update_future = self.relay_selector.update(); - self.block_on_future(update_future); + async fn on_update_relay_locations(&mut self) { + self.relay_selector.update().await; } - fn on_set_account(&mut self, tx: oneshot::Sender<()>, account_token: Option<String>) { - match self.set_account(account_token.clone()) { + async fn on_set_account(&mut self, tx: oneshot::Sender<()>, account_token: Option<String>) { + match self.set_account(account_token.clone()).await { Ok(account_changed) => { if account_changed { match account_token { @@ -1345,7 +1345,10 @@ where } } - fn set_account(&mut self, account_token: Option<String>) -> Result<bool, settings::Error> { + async fn set_account( + &mut self, + account_token: Option<String>, + ) -> Result<bool, settings::Error> { let account_changed = self.settings.set_account_token(account_token.clone())?; if account_changed { self.event_listener @@ -1353,17 +1356,18 @@ where // Bump account history if a token was set if let Some(token) = account_token.clone() { - if let Err(e) = self.account_history.bump_history(&token) { + if let Err(e) = self.account_history.bump_history(&token).await { log::error!("Failed to bump account history: {}", e); } } - self.ensure_wireguard_keys_for_current_account(); + self.ensure_wireguard_keys_for_current_account().await; if let Some(token) = account_token { // update automatic rotation self.wireguard_key_manager - .reset_rotation(&mut self.account_history, token); + .reset_rotation(&mut self.account_history, token) + .await; } } Ok(account_changed) @@ -1377,18 +1381,23 @@ where ); } - fn on_remove_account_from_history( + async fn on_remove_account_from_history( &mut self, tx: oneshot::Sender<()>, account_token: AccountToken, ) { - if self.account_history.remove_account(&account_token).is_ok() { + if self + .account_history + .remove_account(&account_token) + .await + .is_ok() + { Self::oneshot_send(tx, (), "remove_account_from_history response"); } } - fn on_clear_account_history(&mut self, tx: oneshot::Sender<()>) { - match self.account_history.clear() { + async fn on_clear_account_history(&mut self, tx: oneshot::Sender<()>) { + match self.account_history.clear().await { Ok(_) => { self.set_target_state(TargetState::Unsecured); Self::oneshot_send(tx, (), "clear_account_history response"); @@ -1417,7 +1426,7 @@ where } #[cfg(not(target_os = "android"))] - fn on_factory_reset(&mut self, tx: oneshot::Sender<()>) { + async fn on_factory_reset(&mut self, tx: oneshot::Sender<()>) { let mut failed = false; @@ -1426,7 +1435,7 @@ where failed = true; } - if let Err(e) = self.account_history.clear() { + if let Err(e) = self.account_history.clear().await { log::error!("Failed to clear account history - {}", e); failed = true; } @@ -1519,7 +1528,7 @@ where } } - fn on_set_show_beta_releases(&mut self, tx: oneshot::Sender<()>, enabled: bool) { + async fn on_set_show_beta_releases(&mut self, tx: oneshot::Sender<()>, enabled: bool) { let save_result = self.settings.set_show_beta_releases(enabled); match save_result { Ok(settings_changed) => { @@ -1527,9 +1536,8 @@ where if settings_changed { self.event_listener .notify_settings(self.settings.to_settings()); - let runtime = self.rpc_runtime.handle(); let mut handle = self.version_updater_handle.clone(); - runtime.block_on(handle.set_show_beta_releases(enabled)); + handle.set_show_beta_releases(enabled).await; } } Err(e) => error!("{}", e.display_chain_with_msg("Unable to save settings")), @@ -1681,7 +1689,7 @@ where } } - fn on_set_wireguard_rotation_interval( + async fn on_set_wireguard_rotation_interval( &mut self, tx: oneshot::Sender<()>, interval: Option<u32>, @@ -1694,11 +1702,14 @@ where let account_token = self.settings.get_account_token(); if let Some(token) = account_token { - self.wireguard_key_manager.set_rotation_interval( - &mut self.account_history, - token, - interval.map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)), - ); + self.wireguard_key_manager + .set_rotation_interval( + &mut self.account_history, + token, + interval + .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)), + ) + .await; } self.event_listener @@ -1709,68 +1720,89 @@ where } } - fn ensure_wireguard_keys_for_current_account(&mut self) { + async fn ensure_wireguard_keys_for_current_account(&mut self) { if let Some(account) = self.settings.get_account_token() { if self .account_history .get(&account) + .await .map(|entry| entry.map(|e| e.wireguard.is_none()).unwrap_or(true)) .unwrap_or(true) { log::info!("Automatically generating new wireguard key for account"); self.wireguard_key_manager - .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT)); + .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT)) + .await; } else { log::info!("Account already has wireguard key"); } } } - fn on_generate_wireguard_key(&mut self, tx: oneshot::Sender<KeygenEvent>) { - let mut result = || -> Result<KeygenEvent, String> { - let account_token = self - .settings - .get_account_token() - .ok_or_else(|| "No account token set".to_owned())?; + async fn on_generate_wireguard_key(&mut self, tx: oneshot::Sender<KeygenEvent>) { + match self.on_generate_wireguard_key_inner().await { + Ok(key_event) => { + Self::oneshot_send(tx, key_event, "generate_wireguard_key response"); + } + Err(e) => { + log::error!("Failed to generate new wireguard key - {}", e); + } + } + } + + async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, String> { + let account_token = self + .settings + .get_account_token() + .ok_or_else(|| "No account token set".to_owned())?; - let mut account_entry = self - .account_history - .get(&account_token) - .map_err(|e| format!("Failed to read account entry from history: {}", e)) - .map(|data| { - data.unwrap_or_else(|| { - log::error!("Account token set in settings but not in account history"); - account_history::AccountEntry { - account: account_token.clone(), - wireguard: None, - } - }) - })?; + let mut account_entry = self + .account_history + .get(&account_token) + .await + .map_err(|e| format!("Failed to read account entry from history: {}", e)) + .map(|data| { + data.unwrap_or_else(|| { + log::error!("Account token set in settings but not in account history"); + account_history::AccountEntry { + account: account_token.clone(), + wireguard: None, + } + }) + })?; - let gen_result = match &account_entry.wireguard { - Some(wireguard_data) => self - .wireguard_key_manager - .replace_key(account_token.clone(), wireguard_data.get_public_key()), - None => self - .wireguard_key_manager - .generate_key_sync(account_token.clone()), - }; + let gen_result = match &account_entry.wireguard { + Some(wireguard_data) => { + self.wireguard_key_manager + .replace_key(account_token.clone(), wireguard_data.get_public_key()) + .await + } + None => { + self.wireguard_key_manager + .generate_key_sync(account_token.clone()) + .await + } + }; - match gen_result { - Ok(new_data) => { - let public_key = new_data.get_public_key(); - account_entry.wireguard = Some(new_data); - self.account_history.insert(account_entry).map_err(|e| { + match gen_result { + Ok(new_data) => { + let public_key = new_data.get_public_key(); + account_entry.wireguard = Some(new_data); + self.account_history + .insert(account_entry) + .await + .map_err(|e| { format!("Failed to add new wireguard key to account data: {}", e) })?; - if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { - self.reconnect_tunnel(); - } - let keygen_event = KeygenEvent::NewKey(public_key); - self.event_listener.notify_key_event(keygen_event.clone()); + if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() { + self.reconnect_tunnel(); + } + let keygen_event = KeygenEvent::NewKey(public_key); + self.event_listener.notify_key_event(keygen_event.clone()); - // update automatic rotation - self.wireguard_key_manager.set_rotation_interval( + // update automatic rotation + self.wireguard_key_manager + .set_rotation_interval( &mut self.account_history, account_token, self.settings @@ -1778,39 +1810,31 @@ where .wireguard .automatic_rotation .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)), - ); + ) + .await; - Ok(keygen_event) - } - Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys), - Err(e) => Err(format!( - "Failed to generate new key - {}", - e.display_chain_with_msg("Failed to generate new wireguard key:") - )), - } - }; - - match result() { - Ok(key_event) => { - Self::oneshot_send(tx, key_event, "generate_wireguard_key response"); - } - Err(e) => { - log::error!("Failed to generate new wireguard key - {}", e); + Ok(keygen_event) } + Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys), + Err(e) => Err(format!( + "Failed to generate new key - {}", + e.display_chain_with_msg("Failed to generate new wireguard key:") + )), } } - fn on_get_wireguard_key(&mut self, tx: oneshot::Sender<Option<wireguard::PublicKey>>) { - let key = self - .settings - .get_account_token() - .and_then(|account| self.account_history.get(&account).ok()?) - .and_then(|account_entry| account_entry.wireguard.map(|wg| wg.get_public_key())); - - Self::oneshot_send(tx, key, "get_wireguard_key response"); + async fn on_get_wireguard_key(&mut self, tx: oneshot::Sender<Option<wireguard::PublicKey>>) { + let token = self.settings.get_account_token(); + if let Some(token) = token { + let entry = self.account_history.get(&token).await; + if let Ok(Some(entry)) = entry { + let key = entry.wireguard.map(|wg| wg.get_public_key()); + Self::oneshot_send(tx, key, "get_wireguard_key response"); + } + } } - fn on_verify_wireguard_key(&mut self, tx: oneshot::Sender<bool>) { + async fn on_verify_wireguard_key(&mut self, tx: oneshot::Sender<bool>) { let account = match self.settings.get_account_token() { Some(account) => account, None => { @@ -1822,6 +1846,7 @@ where let key = self .account_history .get(&account) + .await .map(|entry| entry.and_then(|e| e.wireguard.map(|wg| wg.private_key.public_key()))); let public_key = match key { @@ -1840,7 +1865,7 @@ where .wireguard_key_manager .verify_wireguard_key(account, public_key); - self.rpc_handle.service().spawn(async move { + tokio::spawn(async move { match verification_rpc.await { Ok(is_valid) => { Self::oneshot_send(tx, is_valid, "verify_wireguard_key response"); diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 1f58db5caf..b7e96d682c 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -110,7 +110,7 @@ async fn run_standalone(log_dir: Option<PathBuf>) -> Result<(), String> { shutdown::set_shutdown_signal_handler(move || shutdown_handle.shutdown()) .map_err(|e| e.display_chain())?; - daemon.run().map_err(|e| e.display_chain())?; + daemon.run().await.map_err(|e| e.display_chain())?; info!("Mullvad daemon is quitting"); thread::sleep(Duration::from_millis(500)); diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs index ee7b680de8..ef44af49a1 100644 --- a/mullvad-daemon/src/system_service.rs +++ b/mullvad-daemon/src/system_service.rs @@ -117,23 +117,24 @@ fn run_service() -> Result<(), String> { Ok(runtime) => runtime, }; - let result = runtime - .block_on(crate::create_daemon(log_dir)) - .and_then(|daemon| { - let shutdown_handle = daemon.shutdown_handle(); + let result = runtime.block_on(crate::create_daemon(log_dir)); + if let Ok(daemon) = result { + let shutdown_handle = daemon.shutdown_handle(); - // Register monitor that translates `ServiceControl` to Daemon events - start_event_monitor( - persistent_service_status.clone(), - shutdown_handle, - event_rx, - clean_shutdown.clone(), - ); + // Register monitor that translates `ServiceControl` to Daemon events + start_event_monitor( + persistent_service_status.clone(), + shutdown_handle, + event_rx, + clean_shutdown.clone(), + ); - persistent_service_status.set_running().unwrap(); + persistent_service_status.set_running().unwrap(); - daemon.run().map_err(|e| e.display_chain()) - }); + runtime + .block_on(daemon.run()) + .map_err(|e| e.display_chain()) + } let exit_code = match result { Ok(()) => { diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs index 8ad580dd55..189d1f5a6a 100644 --- a/mullvad-daemon/src/wireguard.rs +++ b/mullvad-daemon/src/wireguard.rs @@ -60,16 +60,19 @@ impl KeyManager { /// Reset key rotation, cancelling the current one and starting a new one for the specified /// account - pub fn reset_rotation( + pub async fn reset_rotation( &mut self, account_history: &mut AccountHistory, account_token: AccountToken, ) { match account_history .get(&account_token) + .await .map(|entry| entry.map(|entry| entry.wireguard.map(|wg| wg.get_public_key()))) { - Ok(Some(Some(public_key))) => self.run_automatic_rotation(account_token, public_key), + Ok(Some(Some(public_key))) => { + self.run_automatic_rotation(account_token, public_key).await + } Ok(Some(None)) => { log::error!("reset_rotation: failed to obtain public key for account entry.") } @@ -81,7 +84,7 @@ impl KeyManager { /// Update automatic key rotation interval /// Passing `None` for the interval will use the default value. /// A duration of `0` disables automatic key rotation. - pub fn set_rotation_interval( + pub async fn set_rotation_interval( &mut self, account_history: &mut AccountHistory, account_token: AccountToken, @@ -90,7 +93,7 @@ impl KeyManager { self.auto_rotation_interval = auto_rotation_interval.unwrap_or(DEFAULT_AUTOMATIC_KEY_ROTATION); - self.reset_rotation(account_history, account_token); + self.reset_rotation(account_history, account_token).await; } /// Stop current key generation @@ -101,19 +104,18 @@ impl KeyManager { } /// Generate a new private key - pub fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> { + pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> { self.reset(); let private_key = PrivateKey::new_from_random(); - self.http_handle - .service() - .block_on(self.push_future_generator(account, private_key, None)()) + self.push_future_generator(account, private_key, None)() + .await .map_err(Self::map_rpc_error) } /// Replace a key for an account synchronously - pub fn replace_key( + pub async fn replace_key( &mut self, account: AccountToken, old_key: PublicKey, @@ -121,12 +123,7 @@ impl KeyManager { self.reset(); let new_key = PrivateKey::new_from_random(); - self.http_handle.service().block_on(Self::replace_key_rpc( - self.http_handle.clone(), - account, - old_key, - new_key, - )) + Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await } /// Verifies whether a key is valid or not. @@ -151,7 +148,7 @@ impl KeyManager { /// Generate a new private key asynchronously. The new keys will be sent to the daemon channel. - pub fn generate_key_async(&mut self, account: AccountToken, timeout: Option<Duration>) { + pub async fn generate_key_async(&mut self, account: AccountToken, timeout: Option<Duration>) { self.reset(); let private_key = PrivateKey::new_from_random(); @@ -219,7 +216,7 @@ impl KeyManager { }; - self.http_handle.service().spawn(Box::pin(future)); + tokio::spawn(Box::pin(future)); self.current_job = Some(abort_handle); } @@ -372,7 +369,7 @@ impl KeyManager { } } - fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) { + async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) { self.stop_automatic_rotation(); if self.auto_rotation_interval == Duration::new(0, 0) { @@ -391,7 +388,7 @@ impl KeyManager { ); let (request, abort_handle) = abortable(Box::pin(fut)); - self.http_handle.service().spawn(request); + tokio::spawn(request); self.abort_scheduler_tx = Some(abort_handle); } diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs index 205f350ea5..0e8994b1c3 100644 --- a/mullvad-rpc/src/rest.rs +++ b/mullvad-rpc/src/rest.rs @@ -176,12 +176,10 @@ pub struct RequestServiceHandle { impl RequestServiceHandle { /// Resets the corresponding RequestService, dropping all in-flight requests. - pub fn reset(&self) { + pub async fn reset(&self) { let mut tx = self.tx.clone(); - self.handle.block_on(async move { - let _ = tx.send(RequestCommand::Reset).await; - }); + let _ = tx.send(RequestCommand::Reset).await; } /// Submits a `RestRequest` for exectuion to the request service. @@ -216,13 +214,6 @@ impl RequestServiceHandle { pub fn spawn<T: Send + 'static>(&self, future: impl Future<Output = T> + Send + 'static) { let _ = self.handle.spawn(future); } - - pub fn block_on<T: Send + 'static>( - &self, - future: impl Future<Output = T> + Send + 'static, - ) -> T { - self.handle.block_on(future) - } } #[derive(Debug)] |
