summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-08-20 16:52:00 +0200
committerDavid Lönnhager <david.l@mullvad.net>2020-09-01 14:17:21 +0200
commit760f03178f2976b98b63525dfba7badcde5a55ce (patch)
treea54002f81671e0b87081509dc39d21a7c8da34c1
parenta4f5636ae7f593f6060fc7f993005d3fb5b57b4a (diff)
downloadmullvadvpn-760f03178f2976b98b63525dfba7badcde5a55ce.tar.xz
mullvadvpn-760f03178f2976b98b63525dfba7badcde5a55ce.zip
Replace tokio handle references in `Daemon` with async/await syntax
-rw-r--r--mullvad-daemon/src/account_history.rs29
-rw-r--r--mullvad-daemon/src/lib.rs441
-rw-r--r--mullvad-daemon/src/main.rs2
-rw-r--r--mullvad-daemon/src/system_service.rs29
-rw-r--r--mullvad-daemon/src/wireguard.rs35
-rw-r--r--mullvad-rpc/src/rest.rs13
6 files changed, 279 insertions, 270 deletions
diff --git a/mullvad-daemon/src/account_history.rs b/mullvad-daemon/src/account_history.rs
index 05ea1c7d3b..084ea271ee 100644
--- a/mullvad-daemon/src/account_history.rs
+++ b/mullvad-daemon/src/account_history.rs
@@ -123,7 +123,7 @@ impl AccountHistory {
/// Gets account data for a certain account id and bumps it's entry to the top of the list if
/// it isn't there already. Returns None if the account entry is not available.
- pub fn get(&mut self, account: &AccountToken) -> Result<Option<AccountEntry>> {
+ pub async fn get(&mut self, account: &AccountToken) -> Result<Option<AccountEntry>> {
let (idx, entry) = match self
.accounts
.iter()
@@ -139,19 +139,19 @@ impl AccountHistory {
if idx == 0 {
return Ok(Some(entry));
}
- self.insert(entry.clone())?;
+ self.insert(entry.clone()).await?;
Ok(Some(entry))
}
/// Bumps history of an account token. If the account token is not in history, it will be
/// added.
- pub fn bump_history(&mut self, account: &AccountToken) -> Result<()> {
- if self.get(account)?.is_none() {
+ pub async fn bump_history(&mut self, account: &AccountToken) -> Result<()> {
+ if self.get(account).await?.is_none() {
let new_entry = AccountEntry {
account: account.to_string(),
wireguard: None,
};
- self.insert(new_entry)?;
+ self.insert(new_entry).await?;
}
Ok(())
}
@@ -173,7 +173,7 @@ impl AccountHistory {
}
/// Always inserts a new entry at the start of the list
- pub fn insert(&mut self, new_entry: AccountEntry) -> Result<()> {
+ pub async fn insert(&mut self, new_entry: AccountEntry) -> Result<()> {
self.accounts
.retain(|entry| entry.account != new_entry.account);
@@ -182,9 +182,7 @@ impl AccountHistory {
if self.accounts.len() > ACCOUNT_HISTORY_LIMIT {
let last_entry = self.accounts.pop_back().unwrap();
if let Some(wg_data) = last_entry.wireguard {
- self.rpc_handle
- .service()
- .spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data));
+ tokio::spawn(self.create_remove_wg_key_rpc(&last_entry.account, &wg_data));
}
}
@@ -200,17 +198,15 @@ impl AccountHistory {
}
/// Remove account data
- pub fn remove_account(&mut self, account: &str) -> Result<()> {
- let entry = self.get(&String::from(account))?;
+ pub async fn remove_account(&mut self, account: &str) -> Result<()> {
+ let entry = self.get(&String::from(account)).await?;
let entry = match entry {
Some(entry) => entry,
None => return Ok(()),
};
if let Some(wg_data) = entry.wireguard {
- self.rpc_handle
- .service()
- .spawn(self.create_remove_wg_key_rpc(account, &wg_data))
+ tokio::spawn(self.create_remove_wg_key_rpc(account, &wg_data));
}
let _ = self.accounts.pop_front();
@@ -218,7 +214,7 @@ impl AccountHistory {
}
/// Remove account history
- pub fn clear(&mut self) -> Result<()> {
+ pub async fn clear(&mut self) -> Result<()> {
log::debug!("account_history::clear");
let rpc = WireguardKeyProxy::new(self.rpc_handle.clone());
@@ -241,8 +237,7 @@ impl AccountHistory {
.collect();
- let joined_futs = futures::future::join_all(removal);
- self.rpc_handle.service().block_on(joined_futs);
+ futures::future::join_all(removal).await;
self.accounts = VecDeque::new();
self.save_to_disk()
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index abb51942a6..412147c2b5 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -21,8 +21,8 @@ mod version_check;
use futures::{
channel::{mpsc, oneshot},
compat::Future01CompatExt,
- executor::BlockingStream,
future::{abortable, AbortHandle, Future},
+ StreamExt,
};
use futures01::Future as Future01;
use log::{debug, error, info, warn};
@@ -456,7 +456,7 @@ pub struct Daemon<L: EventListener> {
state: DaemonExecutionState,
#[cfg(target_os = "linux")]
exclude_pids: split_tunnel::PidManager,
- rx: BlockingStream<mpsc::UnboundedReceiver<InternalDaemonEvent>>,
+ rx: mpsc::UnboundedReceiver<InternalDaemonEvent>,
tx: DaemonEventSender,
reconnection_job: Option<AbortHandle>,
event_listener: L,
@@ -602,7 +602,7 @@ where
state: DaemonExecutionState::Running,
#[cfg(target_os = "linux")]
exclude_pids: split_tunnel::PidManager::new().map_err(Error::InitSplitTunneling)?,
- rx: futures::executor::block_on_stream(internal_event_rx),
+ rx: internal_event_rx,
tx: internal_event_tx,
reconnection_job: None,
event_listener,
@@ -622,19 +622,22 @@ where
cache_dir,
};
- daemon.ensure_wireguard_keys_for_current_account();
+ daemon.ensure_wireguard_keys_for_current_account().await;
if let Some(token) = daemon.settings.get_account_token() {
- daemon.wireguard_key_manager.set_rotation_interval(
- &mut daemon.account_history,
- token,
- daemon
- .settings
- .tunnel_options
- .wireguard
- .automatic_rotation
- .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)),
- );
+ daemon
+ .wireguard_key_manager
+ .set_rotation_interval(
+ &mut daemon.account_history,
+ token,
+ daemon
+ .settings
+ .tunnel_options
+ .wireguard
+ .automatic_rotation
+ .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)),
+ )
+ .await;
}
Ok(daemon)
@@ -642,43 +645,38 @@ where
/// Consume the `Daemon` and run the main event loop. Blocks until an error happens or a
/// shutdown event is received.
- pub fn run(mut self) -> Result<(), Error> {
+ pub async fn run(mut self) -> Result<(), Error> {
if self.target_state == TargetState::Secured {
self.connect_tunnel();
}
- while let Some(event) = self.rx.next() {
- self.handle_event(event);
+ while let Some(event) = self.rx.next().await {
+ self.handle_event(event).await;
if self.state == DaemonExecutionState::Finished {
break;
}
}
- self.finalize();
+ self.finalize().await;
Ok(())
}
- fn finalize(self) {
- let (
- event_listener,
- shutdown_callbacks,
- mut rpc_runtime,
- tunnel_state_machine_shutdown_signal,
- ) = self.shutdown();
+ async fn finalize(self) {
+ let (event_listener, shutdown_callbacks, rpc_runtime, tunnel_state_machine_shutdown_signal) =
+ self.shutdown();
for cb in shutdown_callbacks {
cb();
}
- rpc_runtime.handle().block_on(async {
- let shutdown_signal = tokio::time::timeout(
- TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT,
- tunnel_state_machine_shutdown_signal,
- );
- match shutdown_signal.await {
- Ok(_) => log::info!("Tunnel state machine shut down"),
- Err(_) => log::error!("Tunnel state machine did not shut down gracefully"),
- }
- });
+ let shutdown_signal = tokio::time::timeout(
+ TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT,
+ tunnel_state_machine_shutdown_signal,
+ );
+ match shutdown_signal.await {
+ Ok(_) => log::info!("Tunnel state machine shut down"),
+ Err(_) => log::error!("Tunnel state machine did not shut down gracefully"),
+ }
+
mem::drop(event_listener);
mem::drop(rpc_runtime);
}
@@ -709,31 +707,39 @@ where
}
- fn handle_event(&mut self, event: InternalDaemonEvent) {
+ async fn handle_event(&mut self, event: InternalDaemonEvent) {
use self::InternalDaemonEvent::*;
match event {
- TunnelStateTransition(transition) => self.handle_tunnel_state_transition(transition),
+ TunnelStateTransition(transition) => {
+ self.handle_tunnel_state_transition(transition).await
+ }
GenerateTunnelParameters(tunnel_parameters_tx, retry_attempt) => {
self.handle_generate_tunnel_parameters(&tunnel_parameters_tx, retry_attempt)
+ .await
}
- Command(command) => self.handle_command(command),
+ Command(command) => self.handle_command(command).await,
TriggerShutdown => self.trigger_shutdown_event(),
- WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event),
- NewAccountEvent(account_token, tx) => self.handle_new_account_event(account_token, tx),
+ WgKeyEvent(key_event) => self.handle_wireguard_key_event(key_event).await,
+ NewAccountEvent(account_token, tx) => {
+ self.handle_new_account_event(account_token, tx).await
+ }
NewAppVersionInfo(app_version_info) => {
self.handle_new_app_version_info(app_version_info)
}
}
}
- fn handle_tunnel_state_transition(&mut self, tunnel_state_transition: TunnelStateTransition) {
+ async fn handle_tunnel_state_transition(
+ &mut self,
+ tunnel_state_transition: TunnelStateTransition,
+ ) {
match &tunnel_state_transition {
TunnelStateTransition::Disconnected
| TunnelStateTransition::Connected(_)
| TunnelStateTransition::Error(_) => {
// Reset the RPCs so that they fail immediately after the underlying socket gets
// invalidated due to the tunnel either coming up or breaking.
- self.rpc_handle.service().reset();
+ self.rpc_handle.service().reset().await;
}
_ => (),
};
@@ -774,7 +780,7 @@ where
}
if let ErrorStateCause::AuthFailed(_) = error_state.cause() {
- self.schedule_reconnect(Duration::from_secs(60))
+ self.schedule_reconnect(Duration::from_secs(60)).await
}
}
_ => {}
@@ -784,7 +790,7 @@ where
self.event_listener.notify_new_state(tunnel_state);
}
- fn handle_generate_tunnel_parameters(
+ async fn handle_generate_tunnel_parameters(
&mut self,
tunnel_parameters_tx: &sync_mpsc::Sender<
Result<TunnelParameters, ParameterGenerationError>,
@@ -803,26 +809,30 @@ where
ParameterGenerationError::CustomTunnelHostResultionError
})
}
- RelaySettings::Normal(constraints) => self
- .relay_selector
- .get_tunnel_endpoint(
- &constraints,
- self.settings.get_bridge_state(),
- retry_attempt,
- self.account_history
- .get(&account_token)
- .unwrap_or(None)
- .and_then(|entry| entry.wireguard)
- .is_some(),
- )
- .map_err(|_| ParameterGenerationError::NoMatchingRelay)
- .and_then(|(relay, endpoint)| {
- let result = self.create_tunnel_parameters(
- &relay,
- endpoint,
- account_token,
+ RelaySettings::Normal(constraints) => {
+ let endpoint = self
+ .relay_selector
+ .get_tunnel_endpoint(
+ &constraints,
+ self.settings.get_bridge_state(),
retry_attempt,
- );
+ self.account_history
+ .get(&account_token)
+ .await
+ .unwrap_or(None)
+ .and_then(|entry| entry.wireguard)
+ .is_some(),
+ )
+ .ok();
+ if let Some((relay, endpoint)) = endpoint {
+ let result = self
+ .create_tunnel_parameters(
+ &relay,
+ endpoint,
+ account_token,
+ retry_attempt,
+ )
+ .await;
self.last_generated_relay = Some(relay);
match result {
Ok(result) => Ok(result),
@@ -842,7 +852,10 @@ where
Err(ParameterGenerationError::NoMatchingRelay)
}
}
- }),
+ } else {
+ Err(ParameterGenerationError::NoMatchingRelay)
+ }
+ }
};
if tunnel_parameters_tx.send(result).is_err() {
log::error!("Failed to send tunnel parameters");
@@ -852,7 +865,7 @@ where
}
}
- fn create_tunnel_parameters(
+ async fn create_tunnel_parameters(
&mut self,
relay: &Relay,
endpoint: MullvadEndpoint,
@@ -933,6 +946,7 @@ where
let wg_data = self
.account_history
.get(&account_token)
+ .await
.map_err(Error::AccountHistory)?
.and_then(|entry| entry.wireguard)
.ok_or(Error::NoKeyAvailable)?;
@@ -958,7 +972,7 @@ where
}
}
- fn schedule_reconnect(&mut self, delay: Duration) {
+ async fn schedule_reconnect(&mut self, delay: Duration) {
let tunnel_command_tx = self.tx.to_specialized_sender();
let (future, abort_handle) = abortable(Box::pin(async move {
tokio::time::delay_for(delay).await;
@@ -966,7 +980,7 @@ where
let _ = tunnel_command_tx.send(DaemonCommand::Reconnect);
}));
- self.spawn_future(future);
+ tokio::spawn(future);
self.reconnection_job = Some(abort_handle);
}
@@ -976,23 +990,8 @@ where
}
}
- fn spawn_future<F>(&mut self, fut: F)
- where
- F: std::future::Future + Send + 'static,
- F::Output: Send,
- {
- self.rpc_runtime.handle().spawn(fut);
- }
-
- fn block_on_future<F>(&mut self, fut: F) -> F::Output
- where
- F: std::future::Future,
- {
- self.rpc_runtime.handle().block_on(fut)
- }
-
- fn handle_command(&mut self, command: DaemonCommand) {
+ async fn handle_command(&mut self, command: DaemonCommand) {
use self::DaemonCommand::*;
if !self.state.is_running() {
log::trace!("Dropping daemon command because the daemon is shutting down",);
@@ -1002,22 +1001,22 @@ where
SetTargetState(tx, state) => self.on_set_target_state(tx, state),
Reconnect => self.on_reconnect(),
GetState(tx) => self.on_get_state(tx),
- GetCurrentLocation(tx) => self.on_get_current_location(tx),
- CreateNewAccount(tx) => self.on_create_new_account(tx),
- GetAccountData(tx, account_token) => self.on_get_account_data(tx, account_token),
- GetWwwAuthToken(tx) => self.on_get_www_auth_token(tx),
- SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher),
+ GetCurrentLocation(tx) => self.on_get_current_location(tx).await,
+ CreateNewAccount(tx) => self.on_create_new_account(tx).await,
+ GetAccountData(tx, account_token) => self.on_get_account_data(tx, account_token).await,
+ GetWwwAuthToken(tx) => self.on_get_www_auth_token(tx).await,
+ SubmitVoucher(tx, voucher) => self.on_submit_voucher(tx, voucher).await,
GetRelayLocations(tx) => self.on_get_relay_locations(tx),
- UpdateRelayLocations => self.on_update_relay_locations(),
- SetAccount(tx, account_token) => self.on_set_account(tx, account_token),
+ UpdateRelayLocations => self.on_update_relay_locations().await,
+ SetAccount(tx, account_token) => self.on_set_account(tx, account_token).await,
GetAccountHistory(tx) => self.on_get_account_history(tx),
RemoveAccountFromHistory(tx, account_token) => {
- self.on_remove_account_from_history(tx, account_token)
+ self.on_remove_account_from_history(tx, account_token).await
}
- ClearAccountHistory(tx) => self.on_clear_account_history(tx),
+ ClearAccountHistory(tx) => self.on_clear_account_history(tx).await,
UpdateRelaySettings(tx, update) => self.on_update_relay_settings(tx, update),
SetAllowLan(tx, allow_lan) => self.on_set_allow_lan(tx, allow_lan),
- SetShowBetaReleases(tx, enabled) => self.on_set_show_beta_releases(tx, enabled),
+ SetShowBetaReleases(tx, enabled) => self.on_set_show_beta_releases(tx, enabled).await,
SetBlockWhenDisconnected(tx, block_when_disconnected) => {
self.on_set_block_when_disconnected(tx, block_when_disconnected)
}
@@ -1030,16 +1029,16 @@ where
SetEnableIpv6(tx, enable_ipv6) => self.on_set_enable_ipv6(tx, enable_ipv6),
SetWireguardMtu(tx, mtu) => self.on_set_wireguard_mtu(tx, mtu),
SetWireguardRotationInterval(tx, interval) => {
- self.on_set_wireguard_rotation_interval(tx, interval)
+ self.on_set_wireguard_rotation_interval(tx, interval).await
}
GetSettings(tx) => self.on_get_settings(tx),
- GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx),
- GetWireguardKey(tx) => self.on_get_wireguard_key(tx),
- VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx),
+ GenerateWireguardKey(tx) => self.on_generate_wireguard_key(tx).await,
+ GetWireguardKey(tx) => self.on_get_wireguard_key(tx).await,
+ VerifyWireguardKey(tx) => self.on_verify_wireguard_key(tx).await,
GetVersionInfo(tx) => self.on_get_version_info(tx),
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
- FactoryReset(tx) => self.on_factory_reset(tx),
+ FactoryReset(tx) => self.on_factory_reset(tx).await,
#[cfg(target_os = "linux")]
GetSplitTunnelProcesses(tx) => self.on_get_split_tunnel_processes(tx),
#[cfg(target_os = "linux")]
@@ -1053,7 +1052,7 @@ where
}
}
- fn handle_wireguard_key_event(
+ async fn handle_wireguard_key_event(
&mut self,
event: (
AccountToken,
@@ -1079,6 +1078,7 @@ where
let mut account_entry = self
.account_history
.get(&account)
+ .await
.ok()
.and_then(|entry| entry)
.unwrap_or_else(|| account_history::AccountEntry {
@@ -1086,10 +1086,10 @@ where
wireguard: None,
});
account_entry.wireguard = Some(data);
- match self.account_history.insert(account_entry) {
+ match self.account_history.insert(account_entry).await {
Ok(_) => {
if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
- self.schedule_reconnect(WG_RECONNECT_DELAY);
+ self.schedule_reconnect(WG_RECONNECT_DELAY).await;
}
self.event_listener
.notify_key_event(KeygenEvent::NewKey(public_key))
@@ -1121,12 +1121,12 @@ where
}
}
- fn handle_new_account_event(
+ async fn handle_new_account_event(
&mut self,
new_token: AccountToken,
tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>,
) {
- match self.set_account(Some(new_token.clone())) {
+ match self.set_account(Some(new_token.clone())).await {
Ok(_) => {
self.set_target_state(TargetState::Unsecured);
let _ = tx.send(Ok(new_token));
@@ -1167,13 +1167,13 @@ where
Self::oneshot_send(tx, self.tunnel_state.clone(), "current state");
}
- fn on_get_current_location(&mut self, tx: oneshot::Sender<Option<GeoIpLocation>>) {
+ async fn on_get_current_location(&mut self, tx: oneshot::Sender<Option<GeoIpLocation>>) {
use self::TunnelState::*;
match &self.tunnel_state {
Disconnected => {
let location = self.get_geo_location();
- self.rpc_runtime.handle().spawn(async {
+ tokio::spawn(async {
Self::oneshot_send(tx, location.await.ok(), "current location");
});
}
@@ -1185,11 +1185,12 @@ where
}
Connected { location, .. } => {
let relay_location = location.clone();
- let location = self.get_geo_location();
- self.rpc_runtime.handle().spawn(async {
+ let location_future = self.get_geo_location();
+ tokio::spawn(async {
+ let location = location_future.await;
Self::oneshot_send(
tx,
- location.await.ok().map(|fetched_location| GeoIpLocation {
+ location.ok().map(|fetched_location| GeoIpLocation {
ipv4: fetched_location.ipv4,
ipv6: fetched_location.ipv6,
..relay_location.unwrap_or(fetched_location)
@@ -1240,7 +1241,7 @@ where
})
}
- fn on_create_new_account(
+ async fn on_create_new_account(
&mut self,
tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>,
) {
@@ -1261,14 +1262,14 @@ where
Ok(())
});
- self.rpc_runtime.handle().spawn(async {
+ tokio::spawn(async {
if future.compat().await.is_err() {
log::error!("Failed to spawn future for creating a new account");
}
});
}
- fn on_get_account_data(
+ async fn on_get_account_data(
&mut self,
tx: oneshot::Sender<Result<AccountData, mullvad_rpc::rest::Error>>,
account_token: AccountToken,
@@ -1281,10 +1282,10 @@ where
.map(|expiry| AccountData { expiry });
Self::oneshot_send(tx, result, "account data");
};
- self.rpc_runtime.handle().spawn(rpc_call);
+ tokio::spawn(rpc_call);
}
- fn on_get_www_auth_token(
+ async fn on_get_www_auth_token(
&mut self,
tx: oneshot::Sender<Result<String, mullvad_rpc::rest::Error>>,
) {
@@ -1294,11 +1295,11 @@ where
let result = old_future.compat().await;
Self::oneshot_send(tx, result, "get_www_auth_token response");
};
- self.rpc_runtime.handle().spawn(rpc_call);
+ tokio::spawn(rpc_call);
}
}
- fn on_submit_voucher(
+ async fn on_submit_voucher(
&mut self,
tx: oneshot::Sender<Result<VoucherSubmission, mullvad_rpc::rest::Error>>,
voucher: String,
@@ -1309,7 +1310,7 @@ where
let result = old_future.compat().await;
Self::oneshot_send(tx, result, "submit_voucher response");
};
- self.rpc_runtime.handle().spawn(rpc_call);
+ tokio::spawn(rpc_call);
}
}
@@ -1317,13 +1318,12 @@ where
Self::oneshot_send(tx, self.relay_selector.get_locations(), "relay locations");
}
- fn on_update_relay_locations(&mut self) {
- let update_future = self.relay_selector.update();
- self.block_on_future(update_future);
+ async fn on_update_relay_locations(&mut self) {
+ self.relay_selector.update().await;
}
- fn on_set_account(&mut self, tx: oneshot::Sender<()>, account_token: Option<String>) {
- match self.set_account(account_token.clone()) {
+ async fn on_set_account(&mut self, tx: oneshot::Sender<()>, account_token: Option<String>) {
+ match self.set_account(account_token.clone()).await {
Ok(account_changed) => {
if account_changed {
match account_token {
@@ -1345,7 +1345,10 @@ where
}
}
- fn set_account(&mut self, account_token: Option<String>) -> Result<bool, settings::Error> {
+ async fn set_account(
+ &mut self,
+ account_token: Option<String>,
+ ) -> Result<bool, settings::Error> {
let account_changed = self.settings.set_account_token(account_token.clone())?;
if account_changed {
self.event_listener
@@ -1353,17 +1356,18 @@ where
// Bump account history if a token was set
if let Some(token) = account_token.clone() {
- if let Err(e) = self.account_history.bump_history(&token) {
+ if let Err(e) = self.account_history.bump_history(&token).await {
log::error!("Failed to bump account history: {}", e);
}
}
- self.ensure_wireguard_keys_for_current_account();
+ self.ensure_wireguard_keys_for_current_account().await;
if let Some(token) = account_token {
// update automatic rotation
self.wireguard_key_manager
- .reset_rotation(&mut self.account_history, token);
+ .reset_rotation(&mut self.account_history, token)
+ .await;
}
}
Ok(account_changed)
@@ -1377,18 +1381,23 @@ where
);
}
- fn on_remove_account_from_history(
+ async fn on_remove_account_from_history(
&mut self,
tx: oneshot::Sender<()>,
account_token: AccountToken,
) {
- if self.account_history.remove_account(&account_token).is_ok() {
+ if self
+ .account_history
+ .remove_account(&account_token)
+ .await
+ .is_ok()
+ {
Self::oneshot_send(tx, (), "remove_account_from_history response");
}
}
- fn on_clear_account_history(&mut self, tx: oneshot::Sender<()>) {
- match self.account_history.clear() {
+ async fn on_clear_account_history(&mut self, tx: oneshot::Sender<()>) {
+ match self.account_history.clear().await {
Ok(_) => {
self.set_target_state(TargetState::Unsecured);
Self::oneshot_send(tx, (), "clear_account_history response");
@@ -1417,7 +1426,7 @@ where
}
#[cfg(not(target_os = "android"))]
- fn on_factory_reset(&mut self, tx: oneshot::Sender<()>) {
+ async fn on_factory_reset(&mut self, tx: oneshot::Sender<()>) {
let mut failed = false;
@@ -1426,7 +1435,7 @@ where
failed = true;
}
- if let Err(e) = self.account_history.clear() {
+ if let Err(e) = self.account_history.clear().await {
log::error!("Failed to clear account history - {}", e);
failed = true;
}
@@ -1519,7 +1528,7 @@ where
}
}
- fn on_set_show_beta_releases(&mut self, tx: oneshot::Sender<()>, enabled: bool) {
+ async fn on_set_show_beta_releases(&mut self, tx: oneshot::Sender<()>, enabled: bool) {
let save_result = self.settings.set_show_beta_releases(enabled);
match save_result {
Ok(settings_changed) => {
@@ -1527,9 +1536,8 @@ where
if settings_changed {
self.event_listener
.notify_settings(self.settings.to_settings());
- let runtime = self.rpc_runtime.handle();
let mut handle = self.version_updater_handle.clone();
- runtime.block_on(handle.set_show_beta_releases(enabled));
+ handle.set_show_beta_releases(enabled).await;
}
}
Err(e) => error!("{}", e.display_chain_with_msg("Unable to save settings")),
@@ -1681,7 +1689,7 @@ where
}
}
- fn on_set_wireguard_rotation_interval(
+ async fn on_set_wireguard_rotation_interval(
&mut self,
tx: oneshot::Sender<()>,
interval: Option<u32>,
@@ -1694,11 +1702,14 @@ where
let account_token = self.settings.get_account_token();
if let Some(token) = account_token {
- self.wireguard_key_manager.set_rotation_interval(
- &mut self.account_history,
- token,
- interval.map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)),
- );
+ self.wireguard_key_manager
+ .set_rotation_interval(
+ &mut self.account_history,
+ token,
+ interval
+ .map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)),
+ )
+ .await;
}
self.event_listener
@@ -1709,68 +1720,89 @@ where
}
}
- fn ensure_wireguard_keys_for_current_account(&mut self) {
+ async fn ensure_wireguard_keys_for_current_account(&mut self) {
if let Some(account) = self.settings.get_account_token() {
if self
.account_history
.get(&account)
+ .await
.map(|entry| entry.map(|e| e.wireguard.is_none()).unwrap_or(true))
.unwrap_or(true)
{
log::info!("Automatically generating new wireguard key for account");
self.wireguard_key_manager
- .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT));
+ .generate_key_async(account, Some(FIRST_KEY_PUSH_TIMEOUT))
+ .await;
} else {
log::info!("Account already has wireguard key");
}
}
}
- fn on_generate_wireguard_key(&mut self, tx: oneshot::Sender<KeygenEvent>) {
- let mut result = || -> Result<KeygenEvent, String> {
- let account_token = self
- .settings
- .get_account_token()
- .ok_or_else(|| "No account token set".to_owned())?;
+ async fn on_generate_wireguard_key(&mut self, tx: oneshot::Sender<KeygenEvent>) {
+ match self.on_generate_wireguard_key_inner().await {
+ Ok(key_event) => {
+ Self::oneshot_send(tx, key_event, "generate_wireguard_key response");
+ }
+ Err(e) => {
+ log::error!("Failed to generate new wireguard key - {}", e);
+ }
+ }
+ }
+
+ async fn on_generate_wireguard_key_inner(&mut self) -> Result<KeygenEvent, String> {
+ let account_token = self
+ .settings
+ .get_account_token()
+ .ok_or_else(|| "No account token set".to_owned())?;
- let mut account_entry = self
- .account_history
- .get(&account_token)
- .map_err(|e| format!("Failed to read account entry from history: {}", e))
- .map(|data| {
- data.unwrap_or_else(|| {
- log::error!("Account token set in settings but not in account history");
- account_history::AccountEntry {
- account: account_token.clone(),
- wireguard: None,
- }
- })
- })?;
+ let mut account_entry = self
+ .account_history
+ .get(&account_token)
+ .await
+ .map_err(|e| format!("Failed to read account entry from history: {}", e))
+ .map(|data| {
+ data.unwrap_or_else(|| {
+ log::error!("Account token set in settings but not in account history");
+ account_history::AccountEntry {
+ account: account_token.clone(),
+ wireguard: None,
+ }
+ })
+ })?;
- let gen_result = match &account_entry.wireguard {
- Some(wireguard_data) => self
- .wireguard_key_manager
- .replace_key(account_token.clone(), wireguard_data.get_public_key()),
- None => self
- .wireguard_key_manager
- .generate_key_sync(account_token.clone()),
- };
+ let gen_result = match &account_entry.wireguard {
+ Some(wireguard_data) => {
+ self.wireguard_key_manager
+ .replace_key(account_token.clone(), wireguard_data.get_public_key())
+ .await
+ }
+ None => {
+ self.wireguard_key_manager
+ .generate_key_sync(account_token.clone())
+ .await
+ }
+ };
- match gen_result {
- Ok(new_data) => {
- let public_key = new_data.get_public_key();
- account_entry.wireguard = Some(new_data);
- self.account_history.insert(account_entry).map_err(|e| {
+ match gen_result {
+ Ok(new_data) => {
+ let public_key = new_data.get_public_key();
+ account_entry.wireguard = Some(new_data);
+ self.account_history
+ .insert(account_entry)
+ .await
+ .map_err(|e| {
format!("Failed to add new wireguard key to account data: {}", e)
})?;
- if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
- self.reconnect_tunnel();
- }
- let keygen_event = KeygenEvent::NewKey(public_key);
- self.event_listener.notify_key_event(keygen_event.clone());
+ if let Some(TunnelType::Wireguard) = self.get_connected_tunnel_type() {
+ self.reconnect_tunnel();
+ }
+ let keygen_event = KeygenEvent::NewKey(public_key);
+ self.event_listener.notify_key_event(keygen_event.clone());
- // update automatic rotation
- self.wireguard_key_manager.set_rotation_interval(
+ // update automatic rotation
+ self.wireguard_key_manager
+ .set_rotation_interval(
&mut self.account_history,
account_token,
self.settings
@@ -1778,39 +1810,31 @@ where
.wireguard
.automatic_rotation
.map(|hours| Duration::from_secs(60u64 * 60u64 * hours as u64)),
- );
+ )
+ .await;
- Ok(keygen_event)
- }
- Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys),
- Err(e) => Err(format!(
- "Failed to generate new key - {}",
- e.display_chain_with_msg("Failed to generate new wireguard key:")
- )),
- }
- };
-
- match result() {
- Ok(key_event) => {
- Self::oneshot_send(tx, key_event, "generate_wireguard_key response");
- }
- Err(e) => {
- log::error!("Failed to generate new wireguard key - {}", e);
+ Ok(keygen_event)
}
+ Err(wireguard::Error::TooManyKeys) => Ok(KeygenEvent::TooManyKeys),
+ Err(e) => Err(format!(
+ "Failed to generate new key - {}",
+ e.display_chain_with_msg("Failed to generate new wireguard key:")
+ )),
}
}
- fn on_get_wireguard_key(&mut self, tx: oneshot::Sender<Option<wireguard::PublicKey>>) {
- let key = self
- .settings
- .get_account_token()
- .and_then(|account| self.account_history.get(&account).ok()?)
- .and_then(|account_entry| account_entry.wireguard.map(|wg| wg.get_public_key()));
-
- Self::oneshot_send(tx, key, "get_wireguard_key response");
+ async fn on_get_wireguard_key(&mut self, tx: oneshot::Sender<Option<wireguard::PublicKey>>) {
+ let token = self.settings.get_account_token();
+ if let Some(token) = token {
+ let entry = self.account_history.get(&token).await;
+ if let Ok(Some(entry)) = entry {
+ let key = entry.wireguard.map(|wg| wg.get_public_key());
+ Self::oneshot_send(tx, key, "get_wireguard_key response");
+ }
+ }
}
- fn on_verify_wireguard_key(&mut self, tx: oneshot::Sender<bool>) {
+ async fn on_verify_wireguard_key(&mut self, tx: oneshot::Sender<bool>) {
let account = match self.settings.get_account_token() {
Some(account) => account,
None => {
@@ -1822,6 +1846,7 @@ where
let key = self
.account_history
.get(&account)
+ .await
.map(|entry| entry.and_then(|e| e.wireguard.map(|wg| wg.private_key.public_key())));
let public_key = match key {
@@ -1840,7 +1865,7 @@ where
.wireguard_key_manager
.verify_wireguard_key(account, public_key);
- self.rpc_handle.service().spawn(async move {
+ tokio::spawn(async move {
match verification_rpc.await {
Ok(is_valid) => {
Self::oneshot_send(tx, is_valid, "verify_wireguard_key response");
diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs
index 1f58db5caf..b7e96d682c 100644
--- a/mullvad-daemon/src/main.rs
+++ b/mullvad-daemon/src/main.rs
@@ -110,7 +110,7 @@ async fn run_standalone(log_dir: Option<PathBuf>) -> Result<(), String> {
shutdown::set_shutdown_signal_handler(move || shutdown_handle.shutdown())
.map_err(|e| e.display_chain())?;
- daemon.run().map_err(|e| e.display_chain())?;
+ daemon.run().await.map_err(|e| e.display_chain())?;
info!("Mullvad daemon is quitting");
thread::sleep(Duration::from_millis(500));
diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs
index ee7b680de8..ef44af49a1 100644
--- a/mullvad-daemon/src/system_service.rs
+++ b/mullvad-daemon/src/system_service.rs
@@ -117,23 +117,24 @@ fn run_service() -> Result<(), String> {
Ok(runtime) => runtime,
};
- let result = runtime
- .block_on(crate::create_daemon(log_dir))
- .and_then(|daemon| {
- let shutdown_handle = daemon.shutdown_handle();
+ let result = runtime.block_on(crate::create_daemon(log_dir));
+ if let Ok(daemon) = result {
+ let shutdown_handle = daemon.shutdown_handle();
- // Register monitor that translates `ServiceControl` to Daemon events
- start_event_monitor(
- persistent_service_status.clone(),
- shutdown_handle,
- event_rx,
- clean_shutdown.clone(),
- );
+ // Register monitor that translates `ServiceControl` to Daemon events
+ start_event_monitor(
+ persistent_service_status.clone(),
+ shutdown_handle,
+ event_rx,
+ clean_shutdown.clone(),
+ );
- persistent_service_status.set_running().unwrap();
+ persistent_service_status.set_running().unwrap();
- daemon.run().map_err(|e| e.display_chain())
- });
+ runtime
+ .block_on(daemon.run())
+ .map_err(|e| e.display_chain())
+ }
let exit_code = match result {
Ok(()) => {
diff --git a/mullvad-daemon/src/wireguard.rs b/mullvad-daemon/src/wireguard.rs
index 8ad580dd55..189d1f5a6a 100644
--- a/mullvad-daemon/src/wireguard.rs
+++ b/mullvad-daemon/src/wireguard.rs
@@ -60,16 +60,19 @@ impl KeyManager {
/// Reset key rotation, cancelling the current one and starting a new one for the specified
/// account
- pub fn reset_rotation(
+ pub async fn reset_rotation(
&mut self,
account_history: &mut AccountHistory,
account_token: AccountToken,
) {
match account_history
.get(&account_token)
+ .await
.map(|entry| entry.map(|entry| entry.wireguard.map(|wg| wg.get_public_key())))
{
- Ok(Some(Some(public_key))) => self.run_automatic_rotation(account_token, public_key),
+ Ok(Some(Some(public_key))) => {
+ self.run_automatic_rotation(account_token, public_key).await
+ }
Ok(Some(None)) => {
log::error!("reset_rotation: failed to obtain public key for account entry.")
}
@@ -81,7 +84,7 @@ impl KeyManager {
/// Update automatic key rotation interval
/// Passing `None` for the interval will use the default value.
/// A duration of `0` disables automatic key rotation.
- pub fn set_rotation_interval(
+ pub async fn set_rotation_interval(
&mut self,
account_history: &mut AccountHistory,
account_token: AccountToken,
@@ -90,7 +93,7 @@ impl KeyManager {
self.auto_rotation_interval =
auto_rotation_interval.unwrap_or(DEFAULT_AUTOMATIC_KEY_ROTATION);
- self.reset_rotation(account_history, account_token);
+ self.reset_rotation(account_history, account_token).await;
}
/// Stop current key generation
@@ -101,19 +104,18 @@ impl KeyManager {
}
/// Generate a new private key
- pub fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> {
+ pub async fn generate_key_sync(&mut self, account: AccountToken) -> Result<WireguardData> {
self.reset();
let private_key = PrivateKey::new_from_random();
- self.http_handle
- .service()
- .block_on(self.push_future_generator(account, private_key, None)())
+ self.push_future_generator(account, private_key, None)()
+ .await
.map_err(Self::map_rpc_error)
}
/// Replace a key for an account synchronously
- pub fn replace_key(
+ pub async fn replace_key(
&mut self,
account: AccountToken,
old_key: PublicKey,
@@ -121,12 +123,7 @@ impl KeyManager {
self.reset();
let new_key = PrivateKey::new_from_random();
- self.http_handle.service().block_on(Self::replace_key_rpc(
- self.http_handle.clone(),
- account,
- old_key,
- new_key,
- ))
+ Self::replace_key_rpc(self.http_handle.clone(), account, old_key, new_key).await
}
/// Verifies whether a key is valid or not.
@@ -151,7 +148,7 @@ impl KeyManager {
/// Generate a new private key asynchronously. The new keys will be sent to the daemon channel.
- pub fn generate_key_async(&mut self, account: AccountToken, timeout: Option<Duration>) {
+ pub async fn generate_key_async(&mut self, account: AccountToken, timeout: Option<Duration>) {
self.reset();
let private_key = PrivateKey::new_from_random();
@@ -219,7 +216,7 @@ impl KeyManager {
};
- self.http_handle.service().spawn(Box::pin(future));
+ tokio::spawn(Box::pin(future));
self.current_job = Some(abort_handle);
}
@@ -372,7 +369,7 @@ impl KeyManager {
}
}
- fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) {
+ async fn run_automatic_rotation(&mut self, account_token: AccountToken, public_key: PublicKey) {
self.stop_automatic_rotation();
if self.auto_rotation_interval == Duration::new(0, 0) {
@@ -391,7 +388,7 @@ impl KeyManager {
);
let (request, abort_handle) = abortable(Box::pin(fut));
- self.http_handle.service().spawn(request);
+ tokio::spawn(request);
self.abort_scheduler_tx = Some(abort_handle);
}
diff --git a/mullvad-rpc/src/rest.rs b/mullvad-rpc/src/rest.rs
index 205f350ea5..0e8994b1c3 100644
--- a/mullvad-rpc/src/rest.rs
+++ b/mullvad-rpc/src/rest.rs
@@ -176,12 +176,10 @@ pub struct RequestServiceHandle {
impl RequestServiceHandle {
/// Resets the corresponding RequestService, dropping all in-flight requests.
- pub fn reset(&self) {
+ pub async fn reset(&self) {
let mut tx = self.tx.clone();
- self.handle.block_on(async move {
- let _ = tx.send(RequestCommand::Reset).await;
- });
+ let _ = tx.send(RequestCommand::Reset).await;
}
/// Submits a `RestRequest` for exectuion to the request service.
@@ -216,13 +214,6 @@ impl RequestServiceHandle {
pub fn spawn<T: Send + 'static>(&self, future: impl Future<Output = T> + Send + 'static) {
let _ = self.handle.spawn(future);
}
-
- pub fn block_on<T: Send + 'static>(
- &self,
- future: impl Future<Output = T> + Send + 'static,
- ) -> T {
- self.handle.block_on(future)
- }
}
#[derive(Debug)]