summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-03-30 17:14:38 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-03-31 15:45:47 +0200
commitd4092cd05fb1199419094cfb3eeab015023a5434 (patch)
tree9f7f4bfe23deaad07a61ad3db4d4d03d801b65d4
parentad1070d5252ff06691667fecebf6a31451683616 (diff)
downloadmullvadvpn-d4092cd05fb1199419094cfb3eeab015023a5434.tar.xz
mullvadvpn-d4092cd05fb1199419094cfb3eeab015023a5434.zip
Use async file I/O in main daemon module
-rw-r--r--mullvad-daemon/src/lib.rs156
1 files changed, 86 insertions, 70 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 866bf7131c..53206249a6 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -44,12 +44,11 @@ use settings::SettingsPersister;
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
use std::{
- fs::{self, File},
- io,
marker::PhantomData,
mem,
net::IpAddr,
path::{Path, PathBuf},
+ pin::Pin,
sync::{mpsc as sync_mpsc, Arc, Weak},
time::Duration,
};
@@ -66,6 +65,7 @@ use talpid_types::{
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
};
+use tokio::{fs, io};
#[path = "wireguard.rs"]
mod wireguard;
@@ -490,7 +490,7 @@ pub struct Daemon<L: EventListener> {
last_generated_relay: Option<Relay>,
last_generated_bridge_relay: Option<Relay>,
app_version_info: Option<AppVersionInfo>,
- shutdown_callbacks: Vec<Box<dyn FnOnce()>>,
+ shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
/// oneshot channel that completes once the tunnel state machine has been shut down
tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>,
cache_dir: PathBuf,
@@ -577,23 +577,24 @@ where
// Restore the tunnel to a previous state
let target_cache = cache_dir.join(TARGET_START_STATE_FILE);
- let cached_target_state: Option<TargetState> = match File::open(&target_cache) {
- Ok(handle) => serde_json::from_reader(io::BufReader::new(handle))
- .map(Some)
- .map_err(Error::ReadCachedTargetState),
- Err(e) => {
- if e.kind() == io::ErrorKind::NotFound {
- debug!("No cached target state to load");
- Ok(None)
- } else {
- Err(Error::OpenCachedTargetState(e))
+ let cached_target_state: Option<TargetState> =
+ match fs::read_to_string(&target_cache).await {
+ Ok(content) => serde_json::from_str(&content)
+ .map(Some)
+ .map_err(Error::ReadCachedTargetState),
+ Err(e) => {
+ if e.kind() == io::ErrorKind::NotFound {
+ debug!("No cached target state to load");
+ Ok(None)
+ } else {
+ Err(Error::OpenCachedTargetState(e))
+ }
}
}
- }
- .unwrap_or_else(|error| {
- error!("{}", error.display_chain());
- Some(TargetState::Secured)
- });
+ .unwrap_or_else(|error| {
+ error!("{}", error.display_chain());
+ Some(TargetState::Secured)
+ });
if let Some(cached_target_state) = &cached_target_state {
info!(
"Loaded cached target state \"{}\" from {}",
@@ -618,7 +619,7 @@ where
} else {
TargetState::Unsecured
};
- Self::cache_target_state(&cache_dir, initial_target_state);
+ Self::cache_target_state(&cache_dir, initial_target_state).await;
let initial_api_endpoint = Endpoint::from_socket_address(
rpc_runtime.address_cache.peek_address(),
@@ -683,7 +684,7 @@ where
last_generated_relay: None,
last_generated_bridge_relay: None,
app_version_info,
- shutdown_callbacks: vec![],
+ shutdown_tasks: vec![],
tunnel_state_machine_shutdown_signal,
cache_dir,
};
@@ -729,14 +730,14 @@ where
async fn finalize(self) {
let (
event_listener,
- shutdown_callbacks,
+ shutdown_tasks,
rpc_runtime,
tunnel_state_machine_shutdown_signal,
cache_dir,
lock_target_cache,
) = self.shutdown();
- for cb in shutdown_callbacks {
- cb();
+ for future in shutdown_tasks {
+ future.await;
}
let shutdown_signal = tokio::time::timeout(
@@ -752,7 +753,7 @@ where
mem::drop(rpc_runtime);
#[cfg(any(target_os = "macos", target_os = "linux"))]
- if let Err(err) = fs::remove_file(mullvad_paths::get_rpc_socket_path()) {
+ if let Err(err) = fs::remove_file(mullvad_paths::get_rpc_socket_path()).await {
if err.kind() != std::io::ErrorKind::NotFound {
log::error!("Failed to remove old RPC socket: {}", err);
}
@@ -760,7 +761,7 @@ where
if !lock_target_cache {
let target_cache = cache_dir.join(TARGET_START_STATE_FILE);
- let _ = fs::remove_file(target_cache).map_err(|e| {
+ let _ = fs::remove_file(target_cache).await.map_err(|e| {
error!("Cannot delete target tunnel state cache: {}", e);
});
}
@@ -772,7 +773,7 @@ where
self,
) -> (
L,
- Vec<Box<dyn FnOnce()>>,
+ Vec<Pin<Box<dyn Future<Output = ()>>>>,
mullvad_rpc::MullvadRpcRuntime,
oneshot::Receiver<()>,
PathBuf,
@@ -780,7 +781,7 @@ where
) {
let Daemon {
event_listener,
- shutdown_callbacks,
+ shutdown_tasks,
rpc_runtime,
tunnel_state_machine_shutdown_signal,
cache_dir,
@@ -789,7 +790,7 @@ where
} = self;
(
event_listener,
- shutdown_callbacks,
+ shutdown_tasks,
rpc_runtime,
tunnel_state_machine_shutdown_signal,
cache_dir,
@@ -1094,7 +1095,7 @@ where
return;
}
match command {
- SetTargetState(tx, state) => self.on_set_target_state(tx, state),
+ SetTargetState(tx, state) => self.on_set_target_state(tx, state).await,
Reconnect(tx) => self.on_reconnect(tx),
GetState(tx) => self.on_get_state(tx),
GetCurrentLocation(tx) => self.on_get_current_location(tx).await,
@@ -1244,7 +1245,7 @@ where
) {
match self.set_account(Some(new_token.clone())).await {
Ok(_) => {
- self.set_target_state(TargetState::Unsecured);
+ self.set_target_state(TargetState::Unsecured).await;
let _ = tx.send(Ok(new_token));
}
Err(err) => {
@@ -1262,9 +1263,13 @@ where
self.event_listener.notify_app_version(app_version_info);
}
- fn on_set_target_state(&mut self, tx: oneshot::Sender<bool>, new_target_state: TargetState) {
+ async fn on_set_target_state(
+ &mut self,
+ tx: oneshot::Sender<bool>,
+ new_target_state: TargetState,
+ ) {
if self.state.is_running() {
- let state_change_initated = self.set_target_state(new_target_state);
+ let state_change_initated = self.set_target_state(new_target_state).await;
Self::oneshot_send(tx, state_change_initated, "state change initiated");
} else {
warn!("Ignoring target state change request due to shutdown");
@@ -1449,7 +1454,7 @@ where
}
None => {
info!("Disconnecting because account token was cleared");
- self.set_target_state(TargetState::Unsecured);
+ self.set_target_state(TargetState::Unsecured).await;
}
};
}
@@ -1507,7 +1512,7 @@ where
async fn on_clear_account_history(&mut self, tx: ResponseTx<(), Error>) {
match self.account_history.clear().await {
Ok(_) => {
- self.set_target_state(TargetState::Unsecured);
+ self.set_target_state(TargetState::Unsecured).await;
Self::oneshot_send(tx, Ok(()), "clear_account_history response");
}
Err(err) => {
@@ -1564,8 +1569,8 @@ where
// Shut the daemon down.
self.trigger_shutdown_event();
- self.shutdown_callbacks.push(Box::new(move || {
- if let Err(e) = Self::clear_cache_directory() {
+ self.shutdown_tasks.push(Box::pin(async move {
+ if let Err(e) = Self::clear_cache_directory().await {
log::error!(
"{}",
e.display_chain_with_msg("Failed to clear cache directory")
@@ -1573,7 +1578,7 @@ where
last_error = Err(Error::ClearCacheError);
}
- if let Err(e) = Self::clear_log_directory() {
+ if let Err(e) = Self::clear_log_directory().await {
log::error!(
"{}",
e.display_chain_with_msg("Failed to clear log directory")
@@ -2114,14 +2119,14 @@ where
/// Set the target state of the client. If it changed trigger the operations needed to
/// progress towards that state.
/// Returns a bool representing whether or not a state change was initiated.
- fn set_target_state(&mut self, new_state: TargetState) -> bool {
+ async fn set_target_state(&mut self, new_state: TargetState) -> bool {
if new_state != self.target_state || self.tunnel_state.is_in_error_state() {
debug!("Target state {:?} => {:?}", self.target_state, new_state);
if new_state != self.target_state {
self.target_state = new_state;
if !self.lock_target_cache {
- Self::cache_target_state(&self.cache_dir, self.target_state);
+ Self::cache_target_state(&self.cache_dir, self.target_state).await;
}
}
@@ -2135,17 +2140,23 @@ where
}
}
- fn cache_target_state(cache_dir: &Path, target_state: TargetState) {
+ async fn cache_target_state(cache_dir: &Path, target_state: TargetState) {
let cache_file = cache_dir.join(TARGET_START_STATE_FILE);
log::trace!("Saving tunnel target state to {}", cache_file.display());
- match File::create(&cache_file) {
- Ok(handle) => {
- if let Err(e) = serde_json::to_writer(io::BufWriter::new(handle), &target_state) {
- log::error!("Failed to cache target state: {}", e);
+ match serde_json::to_string(&target_state) {
+ Ok(data) => {
+ if let Err(error) = fs::write(cache_file, data).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to write cache target state")
+ );
}
}
- Err(e) => {
- log::error!("Failed to cache target state: {}", e);
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to serialize cache target state")
+ )
}
}
}
@@ -2197,49 +2208,54 @@ where
}
#[cfg(not(target_os = "android"))]
- fn clear_log_directory() -> Result<(), Error> {
+ async fn clear_log_directory() -> Result<(), Error> {
let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?;
- Self::clear_directory(&log_dir)
+ Self::clear_directory(&log_dir).await
}
#[cfg(not(target_os = "android"))]
- fn clear_cache_directory() -> Result<(), Error> {
+ async fn clear_cache_directory() -> Result<(), Error> {
let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?;
- Self::clear_directory(&cache_dir)
+ Self::clear_directory(&cache_dir).await
}
#[cfg(not(target_os = "android"))]
- fn clear_directory(path: &Path) -> Result<(), Error> {
+ async fn clear_directory(path: &Path) -> Result<(), Error> {
#[cfg(not(target_os = "windows"))]
{
fs::remove_dir_all(path)
+ .await
.map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?;
fs::create_dir_all(path)
+ .await
.map_err(|e| Error::CreateDirError(path.display().to_string(), e))
}
#[cfg(target_os = "windows")]
{
- fs::read_dir(&path)
- .map_err(Error::ReadDirError)
- .and_then(|dir_entries| {
- dir_entries
- .into_iter()
- .map(|entry| {
- let entry = entry.map_err(Error::FileEntryError)?;
- let entry_type = entry.file_type().map_err(Error::FileTypeError)?;
+ let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?;
+ let mut result = Ok(());
- let removal = if entry_type.is_file() || entry_type.is_symlink() {
- fs::remove_file(entry.path())
- } else {
- fs::remove_dir_all(entry.path())
- };
- removal.map_err(|e| {
- Error::RemoveDirError(entry.path().display().to_string(), e)
- })
- })
- .collect::<Result<(), Error>>()
- })
+ while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? {
+ let entry_type = match entry.file_type().await {
+ Ok(entry_type) => entry_type,
+ Err(error) => {
+ result = result.and(Err(Error::FileTypeError(error)));
+ continue;
+ }
+ };
+
+ let removal = if entry_type.is_file() || entry_type.is_symlink() {
+ fs::remove_file(entry.path()).await
+ } else {
+ fs::remove_dir_all(entry.path()).await
+ };
+ result = result.and(
+ removal
+ .map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)),
+ );
+ }
+ result
}
}