diff options
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 150 | ||||
| -rw-r--r-- | mullvad-daemon/src/target_state.rs | 151 |
2 files changed, 186 insertions, 115 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 84b92e910e..b2b3468a54 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -19,9 +19,11 @@ mod relays; pub mod rpc_uniqueness_check; pub mod runtime; pub mod settings; +mod target_state; pub mod version; mod version_check; +use crate::target_state::PersistentTargetState; use futures::{ channel::{mpsc, oneshot}, future::{abortable, AbortHandle, Future}, @@ -45,13 +47,15 @@ use mullvad_types::{ use settings::SettingsPersister; #[cfg(target_os = "android")] use std::os::unix::io::RawFd; +#[cfg(not(target_os = "android"))] +use std::path::Path; #[cfg(target_os = "windows")] use std::{collections::HashSet, ffi::OsString}; use std::{ marker::PhantomData, mem, net::{IpAddr, Ipv4Addr}, - path::{Path, PathBuf}, + path::PathBuf, pin::Pin, sync::{mpsc as sync_mpsc, Arc, Weak}, time::Duration, @@ -72,13 +76,13 @@ use talpid_types::{ tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, ErrorExt, }; -use tokio::{fs, io}; +#[cfg(not(target_os = "android"))] +use tokio::fs; +use tokio::io; #[path = "wireguard.rs"] mod wireguard; -const TARGET_START_STATE_FILE: &str = "target-start-state.json"; - const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); /// Timeout for first WireGuard key pushing @@ -186,12 +190,6 @@ pub enum Error { #[error(display = "Failed to read dir entries")] ReadDirError(#[error(source)] io::Error), - #[error(display = "Failed to read cached target tunnel state")] - ReadCachedTargetState(#[error(source)] serde_json::Error), - - #[error(display = "Failed to open cached target tunnel state")] - OpenCachedTargetState(#[error(source)] io::Error), - #[cfg(target_os = "macos")] #[error(display = "Failed to set exclusion group")] GroupIdError(#[error(source)] io::Error), @@ -525,8 +523,7 @@ pub trait EventListener { pub struct Daemon<L: EventListener> { tunnel_command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>, tunnel_state: TunnelState, - target_state: TargetState, - lock_target_cache: bool, + target_state: PersistentTargetState, state: DaemonExecutionState, #[cfg(target_os = "linux")] exclude_pids: split_tunnel::PidManager, @@ -549,7 +546,6 @@ pub struct Daemon<L: EventListener> { shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>, /// oneshot channel that completes once the tunnel state machine has been shut down tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>, - cache_dir: PathBuf, } impl<L> Daemon<L> @@ -589,51 +585,19 @@ where let _ = settings.set_show_beta_releases(true).await; } - // Restore the tunnel to a previous state - let target_cache = cache_dir.join(TARGET_START_STATE_FILE); - let cached_target_state: Option<TargetState> = - match fs::read_to_string(&target_cache).await { - Ok(content) => serde_json::from_str(&content) - .map(Some) - .map_err(Error::ReadCachedTargetState), - Err(e) => { - if e.kind() == io::ErrorKind::NotFound { - log::debug!("No cached target state to load"); - Ok(None) - } else { - Err(Error::OpenCachedTargetState(e)) - } - } - } - .unwrap_or_else(|error| { - log::error!("{}", error.display_chain()); - Some(TargetState::Secured) - }); - if let Some(cached_target_state) = &cached_target_state { - log::info!( - "Loaded cached target state \"{}\" from {}", - cached_target_state, - target_cache.display() - ); - } + let target_state = if settings.get_account_token().is_none() { + PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await + } else if settings.auto_connect { + log::info!("Automatically connecting since auto-connect is turned on"); + PersistentTargetState::force(&cache_dir, TargetState::Secured).await + } else { + PersistentTargetState::new(&cache_dir).await + }; let tunnel_parameters_generator = MullvadTunnelParametersGenerator { tx: internal_event_tx.clone(), }; - let initial_target_state = if settings.get_account_token().is_some() { - if settings.auto_connect { - // Note: Auto-connect overrides the cached target state - log::info!("Automatically connecting since auto-connect is turned on"); - TargetState::Secured - } else { - cached_target_state.unwrap_or(TargetState::Unsecured) - } - } else { - TargetState::Unsecured - }; - Self::cache_target_state(&cache_dir, initial_target_state).await; - #[cfg(windows)] let exclude_paths = if settings.split_tunnel.enable_exclusions { settings @@ -669,7 +633,7 @@ where block_when_disconnected: settings.block_when_disconnected, dns_servers: Self::get_dns_resolvers(&settings.tunnel_options.dns_options), allowed_endpoint: initial_api_endpoint, - reset_firewall: initial_target_state != TargetState::Secured, + reset_firewall: *target_state != TargetState::Secured, #[cfg(windows)] exclude_paths, }, @@ -756,8 +720,7 @@ where let mut daemon = Daemon { tunnel_command_tx, tunnel_state: TunnelState::Disconnected, - target_state: initial_target_state, - lock_target_cache: false, + target_state, state: DaemonExecutionState::Running, #[cfg(target_os = "linux")] exclude_pids: split_tunnel::PidManager::new().map_err(Error::InitSplitTunneling)?, @@ -779,7 +742,6 @@ where app_version_info, shutdown_tasks: vec![], tunnel_state_machine_shutdown_signal, - cache_dir, }; daemon.ensure_wireguard_keys_for_current_account().await; @@ -850,7 +812,7 @@ where /// Consume the `Daemon` and run the main event loop. Blocks until an error happens or a /// shutdown event is received. pub async fn run(mut self) -> Result<(), Error> { - if self.target_state == TargetState::Secured { + if *self.target_state == TargetState::Secured { self.connect_tunnel(); } @@ -873,14 +835,8 @@ where } async fn finalize(self) { - let ( - event_listener, - shutdown_tasks, - rpc_runtime, - tunnel_state_machine_shutdown_signal, - cache_dir, - lock_target_cache, - ) = self.shutdown(); + let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal) = + self.shutdown(); for future in shutdown_tasks { future.await; } @@ -903,13 +859,6 @@ where log::error!("Failed to remove old RPC socket: {}", err); } } - - if !lock_target_cache { - let target_cache = cache_dir.join(TARGET_START_STATE_FILE); - let _ = fs::remove_file(target_cache).await.map_err(|e| { - log::error!("Cannot delete target tunnel state cache: {}", e); - }); - } } /// Shuts down the daemon without shutting down the underlying event listener and the shutdown @@ -921,25 +870,23 @@ where Vec<Pin<Box<dyn Future<Output = ()>>>>, mullvad_rpc::MullvadRpcRuntime, oneshot::Receiver<()>, - PathBuf, - bool, ) { let Daemon { event_listener, - shutdown_tasks, + mut shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal, - cache_dir, - lock_target_cache, + target_state, .. } = self; + + shutdown_tasks.push(Box::pin(target_state.finalize())); + ( event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal, - cache_dir, - lock_target_cache, ) } @@ -1446,7 +1393,7 @@ where } fn on_reconnect(&mut self, tx: oneshot::Sender<bool>) { - if self.target_state == TargetState::Secured || self.tunnel_state.is_in_error_state() { + if *self.target_state == TargetState::Secured || self.tunnel_state.is_in_error_state() { self.connect_tunnel(); Self::oneshot_send(tx, true, "reconnect issued"); } else { @@ -2458,11 +2405,10 @@ where // TODO: See if this can be made to also shut down the daemon // without causing the service to be restarted. - if self.target_state == TargetState::Secured { + if *self.target_state == TargetState::Secured { self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true)); } - - self.lock_target_cache = true; + self.target_state.lock(); } #[cfg(target_os = "android")] @@ -2513,17 +2459,12 @@ where /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. async fn set_target_state(&mut self, new_state: TargetState) -> bool { - if new_state != self.target_state || self.tunnel_state.is_in_error_state() { - log::debug!("Target state {:?} => {:?}", self.target_state, new_state); + if new_state != *self.target_state || self.tunnel_state.is_in_error_state() { + log::debug!("Target state {:?} => {:?}", *self.target_state, new_state); - if new_state != self.target_state { - self.target_state = new_state; - if !self.lock_target_cache { - Self::cache_target_state(&self.cache_dir, self.target_state).await; - } - } + self.target_state.set(new_state).await; - match self.target_state { + match *self.target_state { TargetState::Secured => self.connect_tunnel(), TargetState::Unsecured => self.disconnect_tunnel(), } @@ -2533,27 +2474,6 @@ where } } - async fn cache_target_state(cache_dir: &Path, target_state: TargetState) { - let cache_file = cache_dir.join(TARGET_START_STATE_FILE); - log::trace!("Saving tunnel target state to {}", cache_file.display()); - match serde_json::to_string(&target_state) { - Ok(data) => { - if let Err(error) = fs::write(cache_file, data).await { - log::error!( - "{}", - error.display_chain_with_msg("Failed to write cache target state") - ); - } - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to serialize cache target state") - ) - } - } - } - fn connect_tunnel(&mut self) { self.send_tunnel_command(TunnelCommand::Connect); } @@ -2563,7 +2483,7 @@ where } fn reconnect_tunnel(&mut self) { - if self.target_state == TargetState::Secured { + if *self.target_state == TargetState::Secured { self.connect_tunnel(); } } diff --git a/mullvad-daemon/src/target_state.rs b/mullvad-daemon/src/target_state.rs new file mode 100644 index 0000000000..fcd4eacde6 --- /dev/null +++ b/mullvad-daemon/src/target_state.rs @@ -0,0 +1,151 @@ +use mullvad_types::states::TargetState; +use std::{ + ops::Deref, + path::{Path, PathBuf}, +}; +use talpid_types::ErrorExt; +use tokio::{fs, io}; + +/// State to use by default if there is no cache. +const DEFAULT_TARGET_STATE: TargetState = TargetState::Unsecured; +const TARGET_START_STATE_FILE: &str = "target-start-state.json"; + +/// Persists the target state to a file, which is only removed if the instance is dropped cleanly. +pub struct PersistentTargetState { + state: TargetState, + cache_path: PathBuf, + locked: bool, +} + +impl PersistentTargetState { + /// Initialize using the current target state (if there is one) + pub async fn new(cache_dir: &Path) -> Self { + let cache_path = cache_dir.join(TARGET_START_STATE_FILE); + let mut update_cache = false; + let state = match fs::read_to_string(&cache_path).await { + Ok(content) => serde_json::from_str(&content) + .map(|state| { + log::info!( + "Loaded cached target state \"{}\" from {}", + state, + cache_path.display() + ); + state + }) + .unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to parse cached target tunnel state") + ); + update_cache = true; + TargetState::Secured + }), + Err(error) => { + if error.kind() == io::ErrorKind::NotFound { + log::debug!("No cached target state to load"); + DEFAULT_TARGET_STATE + } else { + log::error!( + "{}", + error.display_chain_with_msg("Failed to read cached target tunnel state") + ); + update_cache = true; + TargetState::Secured + } + } + }; + let state = PersistentTargetState { + state, + cache_path, + locked: false, + }; + if update_cache { + state.save().await; + } + state + } + + /// Override the current target state, if there is one + pub async fn force(cache_dir: &Path, state: TargetState) -> Self { + let cache_path = cache_dir.join(TARGET_START_STATE_FILE); + let state = PersistentTargetState { + state, + cache_path, + locked: false, + }; + state.save().await; + state + } + + pub async fn set(&mut self, new_state: TargetState) { + if new_state != self.state { + self.state = new_state; + self.save().await; + } + } + + /// Prevent the file from being removed when the instance is dropped. + pub fn lock(&mut self) { + self.locked = true; + } + + /// Async destructor + pub async fn finalize(mut self) { + if self.locked { + return; + } + let _ = fs::remove_file(&self.cache_path).await.map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Cannot delete target tunnel state cache") + ); + }); + // prevent the sync destructor from running + self.locked = true; + } + + async fn save(&self) { + log::trace!( + "Saving tunnel target state to {}", + self.cache_path.display() + ); + match serde_json::to_string(&self.state) { + Ok(data) => { + if let Err(error) = fs::write(&self.cache_path, data).await { + log::error!( + "{}", + error.display_chain_with_msg("Failed to write cache target state") + ); + } + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to serialize cache target state") + ) + } + } + } +} + +impl Drop for PersistentTargetState { + fn drop(&mut self) { + if self.locked { + return; + } + let _ = std::fs::remove_file(&self.cache_path).map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Cannot delete target tunnel state cache") + ); + }); + } +} + +impl Deref for PersistentTargetState { + type Target = TargetState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} |
