diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-03-30 17:14:38 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-03-31 15:45:47 +0200 |
| commit | d4092cd05fb1199419094cfb3eeab015023a5434 (patch) | |
| tree | 9f7f4bfe23deaad07a61ad3db4d4d03d801b65d4 | |
| parent | ad1070d5252ff06691667fecebf6a31451683616 (diff) | |
| download | mullvadvpn-d4092cd05fb1199419094cfb3eeab015023a5434.tar.xz mullvadvpn-d4092cd05fb1199419094cfb3eeab015023a5434.zip | |
Use async file I/O in main daemon module
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 156 |
1 files changed, 86 insertions, 70 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 866bf7131c..53206249a6 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -44,12 +44,11 @@ use settings::SettingsPersister; #[cfg(target_os = "android")] use std::os::unix::io::RawFd; use std::{ - fs::{self, File}, - io, marker::PhantomData, mem, net::IpAddr, path::{Path, PathBuf}, + pin::Pin, sync::{mpsc as sync_mpsc, Arc, Weak}, time::Duration, }; @@ -66,6 +65,7 @@ use talpid_types::{ tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, }; +use tokio::{fs, io}; #[path = "wireguard.rs"] mod wireguard; @@ -490,7 +490,7 @@ pub struct Daemon<L: EventListener> { last_generated_relay: Option<Relay>, last_generated_bridge_relay: Option<Relay>, app_version_info: Option<AppVersionInfo>, - shutdown_callbacks: Vec<Box<dyn FnOnce()>>, + shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>, /// oneshot channel that completes once the tunnel state machine has been shut down tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>, cache_dir: PathBuf, @@ -577,23 +577,24 @@ where // Restore the tunnel to a previous state let target_cache = cache_dir.join(TARGET_START_STATE_FILE); - let cached_target_state: Option<TargetState> = match File::open(&target_cache) { - Ok(handle) => serde_json::from_reader(io::BufReader::new(handle)) - .map(Some) - .map_err(Error::ReadCachedTargetState), - Err(e) => { - if e.kind() == io::ErrorKind::NotFound { - debug!("No cached target state to load"); - Ok(None) - } else { - Err(Error::OpenCachedTargetState(e)) + let cached_target_state: Option<TargetState> = + match fs::read_to_string(&target_cache).await { + Ok(content) => serde_json::from_str(&content) + .map(Some) + .map_err(Error::ReadCachedTargetState), + Err(e) => { + if e.kind() == io::ErrorKind::NotFound { + debug!("No cached target state to load"); + Ok(None) + } else { + Err(Error::OpenCachedTargetState(e)) + } } } - } - .unwrap_or_else(|error| { - error!("{}", error.display_chain()); - Some(TargetState::Secured) - }); + .unwrap_or_else(|error| { + error!("{}", error.display_chain()); + Some(TargetState::Secured) + }); if let Some(cached_target_state) = &cached_target_state { info!( "Loaded cached target state \"{}\" from {}", @@ -618,7 +619,7 @@ where } else { TargetState::Unsecured }; - Self::cache_target_state(&cache_dir, initial_target_state); + Self::cache_target_state(&cache_dir, initial_target_state).await; let initial_api_endpoint = Endpoint::from_socket_address( rpc_runtime.address_cache.peek_address(), @@ -683,7 +684,7 @@ where last_generated_relay: None, last_generated_bridge_relay: None, app_version_info, - shutdown_callbacks: vec![], + shutdown_tasks: vec![], tunnel_state_machine_shutdown_signal, cache_dir, }; @@ -729,14 +730,14 @@ where async fn finalize(self) { let ( event_listener, - shutdown_callbacks, + shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal, cache_dir, lock_target_cache, ) = self.shutdown(); - for cb in shutdown_callbacks { - cb(); + for future in shutdown_tasks { + future.await; } let shutdown_signal = tokio::time::timeout( @@ -752,7 +753,7 @@ where mem::drop(rpc_runtime); #[cfg(any(target_os = "macos", target_os = "linux"))] - if let Err(err) = fs::remove_file(mullvad_paths::get_rpc_socket_path()) { + if let Err(err) = fs::remove_file(mullvad_paths::get_rpc_socket_path()).await { if err.kind() != std::io::ErrorKind::NotFound { log::error!("Failed to remove old RPC socket: {}", err); } @@ -760,7 +761,7 @@ where if !lock_target_cache { let target_cache = cache_dir.join(TARGET_START_STATE_FILE); - let _ = fs::remove_file(target_cache).map_err(|e| { + let _ = fs::remove_file(target_cache).await.map_err(|e| { error!("Cannot delete target tunnel state cache: {}", e); }); } @@ -772,7 +773,7 @@ where self, ) -> ( L, - Vec<Box<dyn FnOnce()>>, + Vec<Pin<Box<dyn Future<Output = ()>>>>, mullvad_rpc::MullvadRpcRuntime, oneshot::Receiver<()>, PathBuf, @@ -780,7 +781,7 @@ where ) { let Daemon { event_listener, - shutdown_callbacks, + shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal, cache_dir, @@ -789,7 +790,7 @@ where } = self; ( event_listener, - shutdown_callbacks, + shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal, cache_dir, @@ -1094,7 +1095,7 @@ where return; } match command { - SetTargetState(tx, state) => self.on_set_target_state(tx, state), + SetTargetState(tx, state) => self.on_set_target_state(tx, state).await, Reconnect(tx) => self.on_reconnect(tx), GetState(tx) => self.on_get_state(tx), GetCurrentLocation(tx) => self.on_get_current_location(tx).await, @@ -1244,7 +1245,7 @@ where ) { match self.set_account(Some(new_token.clone())).await { Ok(_) => { - self.set_target_state(TargetState::Unsecured); + self.set_target_state(TargetState::Unsecured).await; let _ = tx.send(Ok(new_token)); } Err(err) => { @@ -1262,9 +1263,13 @@ where self.event_listener.notify_app_version(app_version_info); } - fn on_set_target_state(&mut self, tx: oneshot::Sender<bool>, new_target_state: TargetState) { + async fn on_set_target_state( + &mut self, + tx: oneshot::Sender<bool>, + new_target_state: TargetState, + ) { if self.state.is_running() { - let state_change_initated = self.set_target_state(new_target_state); + let state_change_initated = self.set_target_state(new_target_state).await; Self::oneshot_send(tx, state_change_initated, "state change initiated"); } else { warn!("Ignoring target state change request due to shutdown"); @@ -1449,7 +1454,7 @@ where } None => { info!("Disconnecting because account token was cleared"); - self.set_target_state(TargetState::Unsecured); + self.set_target_state(TargetState::Unsecured).await; } }; } @@ -1507,7 +1512,7 @@ where async fn on_clear_account_history(&mut self, tx: ResponseTx<(), Error>) { match self.account_history.clear().await { Ok(_) => { - self.set_target_state(TargetState::Unsecured); + self.set_target_state(TargetState::Unsecured).await; Self::oneshot_send(tx, Ok(()), "clear_account_history response"); } Err(err) => { @@ -1564,8 +1569,8 @@ where // Shut the daemon down. self.trigger_shutdown_event(); - self.shutdown_callbacks.push(Box::new(move || { - if let Err(e) = Self::clear_cache_directory() { + self.shutdown_tasks.push(Box::pin(async move { + if let Err(e) = Self::clear_cache_directory().await { log::error!( "{}", e.display_chain_with_msg("Failed to clear cache directory") @@ -1573,7 +1578,7 @@ where last_error = Err(Error::ClearCacheError); } - if let Err(e) = Self::clear_log_directory() { + if let Err(e) = Self::clear_log_directory().await { log::error!( "{}", e.display_chain_with_msg("Failed to clear log directory") @@ -2114,14 +2119,14 @@ where /// Set the target state of the client. If it changed trigger the operations needed to /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. - fn set_target_state(&mut self, new_state: TargetState) -> bool { + async fn set_target_state(&mut self, new_state: TargetState) -> bool { if new_state != self.target_state || self.tunnel_state.is_in_error_state() { debug!("Target state {:?} => {:?}", self.target_state, new_state); if new_state != self.target_state { self.target_state = new_state; if !self.lock_target_cache { - Self::cache_target_state(&self.cache_dir, self.target_state); + Self::cache_target_state(&self.cache_dir, self.target_state).await; } } @@ -2135,17 +2140,23 @@ where } } - fn cache_target_state(cache_dir: &Path, target_state: TargetState) { + async fn cache_target_state(cache_dir: &Path, target_state: TargetState) { let cache_file = cache_dir.join(TARGET_START_STATE_FILE); log::trace!("Saving tunnel target state to {}", cache_file.display()); - match File::create(&cache_file) { - Ok(handle) => { - if let Err(e) = serde_json::to_writer(io::BufWriter::new(handle), &target_state) { - log::error!("Failed to cache target state: {}", e); + match serde_json::to_string(&target_state) { + Ok(data) => { + if let Err(error) = fs::write(cache_file, data).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to write cache target state") + ); } } - Err(e) => { - log::error!("Failed to cache target state: {}", e); + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to serialize cache target state") + ) } } } @@ -2197,49 +2208,54 @@ where } #[cfg(not(target_os = "android"))] - fn clear_log_directory() -> Result<(), Error> { + async fn clear_log_directory() -> Result<(), Error> { let log_dir = mullvad_paths::get_log_dir().map_err(Error::PathError)?; - Self::clear_directory(&log_dir) + Self::clear_directory(&log_dir).await } #[cfg(not(target_os = "android"))] - fn clear_cache_directory() -> Result<(), Error> { + async fn clear_cache_directory() -> Result<(), Error> { let cache_dir = mullvad_paths::cache_dir().map_err(Error::PathError)?; - Self::clear_directory(&cache_dir) + Self::clear_directory(&cache_dir).await } #[cfg(not(target_os = "android"))] - fn clear_directory(path: &Path) -> Result<(), Error> { + async fn clear_directory(path: &Path) -> Result<(), Error> { #[cfg(not(target_os = "windows"))] { fs::remove_dir_all(path) + .await .map_err(|e| Error::RemoveDirError(path.display().to_string(), e))?; fs::create_dir_all(path) + .await .map_err(|e| Error::CreateDirError(path.display().to_string(), e)) } #[cfg(target_os = "windows")] { - fs::read_dir(&path) - .map_err(Error::ReadDirError) - .and_then(|dir_entries| { - dir_entries - .into_iter() - .map(|entry| { - let entry = entry.map_err(Error::FileEntryError)?; - let entry_type = entry.file_type().map_err(Error::FileTypeError)?; + let mut dir = fs::read_dir(&path).await.map_err(Error::ReadDirError)?; + let mut result = Ok(()); - let removal = if entry_type.is_file() || entry_type.is_symlink() { - fs::remove_file(entry.path()) - } else { - fs::remove_dir_all(entry.path()) - }; - removal.map_err(|e| { - Error::RemoveDirError(entry.path().display().to_string(), e) - }) - }) - .collect::<Result<(), Error>>() - }) + while let Some(entry) = dir.next_entry().await.map_err(Error::FileEntryError)? { + let entry_type = match entry.file_type().await { + Ok(entry_type) => entry_type, + Err(error) => { + result = result.and(Err(Error::FileTypeError(error))); + continue; + } + }; + + let removal = if entry_type.is_file() || entry_type.is_symlink() { + fs::remove_file(entry.path()).await + } else { + fs::remove_dir_all(entry.path()).await + }; + result = result.and( + removal + .map_err(|e| Error::RemoveDirError(entry.path().display().to_string(), e)), + ); + } + result } } |
