summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-daemon/src/lib.rs150
-rw-r--r--mullvad-daemon/src/target_state.rs151
2 files changed, 186 insertions, 115 deletions
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 84b92e910e..b2b3468a54 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -19,9 +19,11 @@ mod relays;
pub mod rpc_uniqueness_check;
pub mod runtime;
pub mod settings;
+mod target_state;
pub mod version;
mod version_check;
+use crate::target_state::PersistentTargetState;
use futures::{
channel::{mpsc, oneshot},
future::{abortable, AbortHandle, Future},
@@ -45,13 +47,15 @@ use mullvad_types::{
use settings::SettingsPersister;
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
+#[cfg(not(target_os = "android"))]
+use std::path::Path;
#[cfg(target_os = "windows")]
use std::{collections::HashSet, ffi::OsString};
use std::{
marker::PhantomData,
mem,
net::{IpAddr, Ipv4Addr},
- path::{Path, PathBuf},
+ path::PathBuf,
pin::Pin,
sync::{mpsc as sync_mpsc, Arc, Weak},
time::Duration,
@@ -72,13 +76,13 @@ use talpid_types::{
tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition},
ErrorExt,
};
-use tokio::{fs, io};
+#[cfg(not(target_os = "android"))]
+use tokio::fs;
+use tokio::io;
#[path = "wireguard.rs"]
mod wireguard;
-const TARGET_START_STATE_FILE: &str = "target-start-state.json";
-
const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
/// Timeout for first WireGuard key pushing
@@ -186,12 +190,6 @@ pub enum Error {
#[error(display = "Failed to read dir entries")]
ReadDirError(#[error(source)] io::Error),
- #[error(display = "Failed to read cached target tunnel state")]
- ReadCachedTargetState(#[error(source)] serde_json::Error),
-
- #[error(display = "Failed to open cached target tunnel state")]
- OpenCachedTargetState(#[error(source)] io::Error),
-
#[cfg(target_os = "macos")]
#[error(display = "Failed to set exclusion group")]
GroupIdError(#[error(source)] io::Error),
@@ -525,8 +523,7 @@ pub trait EventListener {
pub struct Daemon<L: EventListener> {
tunnel_command_tx: Arc<mpsc::UnboundedSender<TunnelCommand>>,
tunnel_state: TunnelState,
- target_state: TargetState,
- lock_target_cache: bool,
+ target_state: PersistentTargetState,
state: DaemonExecutionState,
#[cfg(target_os = "linux")]
exclude_pids: split_tunnel::PidManager,
@@ -549,7 +546,6 @@ pub struct Daemon<L: EventListener> {
shutdown_tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
/// oneshot channel that completes once the tunnel state machine has been shut down
tunnel_state_machine_shutdown_signal: oneshot::Receiver<()>,
- cache_dir: PathBuf,
}
impl<L> Daemon<L>
@@ -589,51 +585,19 @@ where
let _ = settings.set_show_beta_releases(true).await;
}
- // Restore the tunnel to a previous state
- let target_cache = cache_dir.join(TARGET_START_STATE_FILE);
- let cached_target_state: Option<TargetState> =
- match fs::read_to_string(&target_cache).await {
- Ok(content) => serde_json::from_str(&content)
- .map(Some)
- .map_err(Error::ReadCachedTargetState),
- Err(e) => {
- if e.kind() == io::ErrorKind::NotFound {
- log::debug!("No cached target state to load");
- Ok(None)
- } else {
- Err(Error::OpenCachedTargetState(e))
- }
- }
- }
- .unwrap_or_else(|error| {
- log::error!("{}", error.display_chain());
- Some(TargetState::Secured)
- });
- if let Some(cached_target_state) = &cached_target_state {
- log::info!(
- "Loaded cached target state \"{}\" from {}",
- cached_target_state,
- target_cache.display()
- );
- }
+ let target_state = if settings.get_account_token().is_none() {
+ PersistentTargetState::force(&cache_dir, TargetState::Unsecured).await
+ } else if settings.auto_connect {
+ log::info!("Automatically connecting since auto-connect is turned on");
+ PersistentTargetState::force(&cache_dir, TargetState::Secured).await
+ } else {
+ PersistentTargetState::new(&cache_dir).await
+ };
let tunnel_parameters_generator = MullvadTunnelParametersGenerator {
tx: internal_event_tx.clone(),
};
- let initial_target_state = if settings.get_account_token().is_some() {
- if settings.auto_connect {
- // Note: Auto-connect overrides the cached target state
- log::info!("Automatically connecting since auto-connect is turned on");
- TargetState::Secured
- } else {
- cached_target_state.unwrap_or(TargetState::Unsecured)
- }
- } else {
- TargetState::Unsecured
- };
- Self::cache_target_state(&cache_dir, initial_target_state).await;
-
#[cfg(windows)]
let exclude_paths = if settings.split_tunnel.enable_exclusions {
settings
@@ -669,7 +633,7 @@ where
block_when_disconnected: settings.block_when_disconnected,
dns_servers: Self::get_dns_resolvers(&settings.tunnel_options.dns_options),
allowed_endpoint: initial_api_endpoint,
- reset_firewall: initial_target_state != TargetState::Secured,
+ reset_firewall: *target_state != TargetState::Secured,
#[cfg(windows)]
exclude_paths,
},
@@ -756,8 +720,7 @@ where
let mut daemon = Daemon {
tunnel_command_tx,
tunnel_state: TunnelState::Disconnected,
- target_state: initial_target_state,
- lock_target_cache: false,
+ target_state,
state: DaemonExecutionState::Running,
#[cfg(target_os = "linux")]
exclude_pids: split_tunnel::PidManager::new().map_err(Error::InitSplitTunneling)?,
@@ -779,7 +742,6 @@ where
app_version_info,
shutdown_tasks: vec![],
tunnel_state_machine_shutdown_signal,
- cache_dir,
};
daemon.ensure_wireguard_keys_for_current_account().await;
@@ -850,7 +812,7 @@ where
/// Consume the `Daemon` and run the main event loop. Blocks until an error happens or a
/// shutdown event is received.
pub async fn run(mut self) -> Result<(), Error> {
- if self.target_state == TargetState::Secured {
+ if *self.target_state == TargetState::Secured {
self.connect_tunnel();
}
@@ -873,14 +835,8 @@ where
}
async fn finalize(self) {
- let (
- event_listener,
- shutdown_tasks,
- rpc_runtime,
- tunnel_state_machine_shutdown_signal,
- cache_dir,
- lock_target_cache,
- ) = self.shutdown();
+ let (event_listener, shutdown_tasks, rpc_runtime, tunnel_state_machine_shutdown_signal) =
+ self.shutdown();
for future in shutdown_tasks {
future.await;
}
@@ -903,13 +859,6 @@ where
log::error!("Failed to remove old RPC socket: {}", err);
}
}
-
- if !lock_target_cache {
- let target_cache = cache_dir.join(TARGET_START_STATE_FILE);
- let _ = fs::remove_file(target_cache).await.map_err(|e| {
- log::error!("Cannot delete target tunnel state cache: {}", e);
- });
- }
}
/// Shuts down the daemon without shutting down the underlying event listener and the shutdown
@@ -921,25 +870,23 @@ where
Vec<Pin<Box<dyn Future<Output = ()>>>>,
mullvad_rpc::MullvadRpcRuntime,
oneshot::Receiver<()>,
- PathBuf,
- bool,
) {
let Daemon {
event_listener,
- shutdown_tasks,
+ mut shutdown_tasks,
rpc_runtime,
tunnel_state_machine_shutdown_signal,
- cache_dir,
- lock_target_cache,
+ target_state,
..
} = self;
+
+ shutdown_tasks.push(Box::pin(target_state.finalize()));
+
(
event_listener,
shutdown_tasks,
rpc_runtime,
tunnel_state_machine_shutdown_signal,
- cache_dir,
- lock_target_cache,
)
}
@@ -1446,7 +1393,7 @@ where
}
fn on_reconnect(&mut self, tx: oneshot::Sender<bool>) {
- if self.target_state == TargetState::Secured || self.tunnel_state.is_in_error_state() {
+ if *self.target_state == TargetState::Secured || self.tunnel_state.is_in_error_state() {
self.connect_tunnel();
Self::oneshot_send(tx, true, "reconnect issued");
} else {
@@ -2458,11 +2405,10 @@ where
// TODO: See if this can be made to also shut down the daemon
// without causing the service to be restarted.
- if self.target_state == TargetState::Secured {
+ if *self.target_state == TargetState::Secured {
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true));
}
-
- self.lock_target_cache = true;
+ self.target_state.lock();
}
#[cfg(target_os = "android")]
@@ -2513,17 +2459,12 @@ where
/// progress towards that state.
/// Returns a bool representing whether or not a state change was initiated.
async fn set_target_state(&mut self, new_state: TargetState) -> bool {
- if new_state != self.target_state || self.tunnel_state.is_in_error_state() {
- log::debug!("Target state {:?} => {:?}", self.target_state, new_state);
+ if new_state != *self.target_state || self.tunnel_state.is_in_error_state() {
+ log::debug!("Target state {:?} => {:?}", *self.target_state, new_state);
- if new_state != self.target_state {
- self.target_state = new_state;
- if !self.lock_target_cache {
- Self::cache_target_state(&self.cache_dir, self.target_state).await;
- }
- }
+ self.target_state.set(new_state).await;
- match self.target_state {
+ match *self.target_state {
TargetState::Secured => self.connect_tunnel(),
TargetState::Unsecured => self.disconnect_tunnel(),
}
@@ -2533,27 +2474,6 @@ where
}
}
- async fn cache_target_state(cache_dir: &Path, target_state: TargetState) {
- let cache_file = cache_dir.join(TARGET_START_STATE_FILE);
- log::trace!("Saving tunnel target state to {}", cache_file.display());
- match serde_json::to_string(&target_state) {
- Ok(data) => {
- if let Err(error) = fs::write(cache_file, data).await {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to write cache target state")
- );
- }
- }
- Err(error) => {
- log::error!(
- "{}",
- error.display_chain_with_msg("Failed to serialize cache target state")
- )
- }
- }
- }
-
fn connect_tunnel(&mut self) {
self.send_tunnel_command(TunnelCommand::Connect);
}
@@ -2563,7 +2483,7 @@ where
}
fn reconnect_tunnel(&mut self) {
- if self.target_state == TargetState::Secured {
+ if *self.target_state == TargetState::Secured {
self.connect_tunnel();
}
}
diff --git a/mullvad-daemon/src/target_state.rs b/mullvad-daemon/src/target_state.rs
new file mode 100644
index 0000000000..fcd4eacde6
--- /dev/null
+++ b/mullvad-daemon/src/target_state.rs
@@ -0,0 +1,151 @@
+use mullvad_types::states::TargetState;
+use std::{
+ ops::Deref,
+ path::{Path, PathBuf},
+};
+use talpid_types::ErrorExt;
+use tokio::{fs, io};
+
+/// State to use by default if there is no cache.
+const DEFAULT_TARGET_STATE: TargetState = TargetState::Unsecured;
+const TARGET_START_STATE_FILE: &str = "target-start-state.json";
+
+/// Persists the target state to a file, which is only removed if the instance is dropped cleanly.
+pub struct PersistentTargetState {
+ state: TargetState,
+ cache_path: PathBuf,
+ locked: bool,
+}
+
+impl PersistentTargetState {
+ /// Initialize using the current target state (if there is one)
+ pub async fn new(cache_dir: &Path) -> Self {
+ let cache_path = cache_dir.join(TARGET_START_STATE_FILE);
+ let mut update_cache = false;
+ let state = match fs::read_to_string(&cache_path).await {
+ Ok(content) => serde_json::from_str(&content)
+ .map(|state| {
+ log::info!(
+ "Loaded cached target state \"{}\" from {}",
+ state,
+ cache_path.display()
+ );
+ state
+ })
+ .unwrap_or_else(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to parse cached target tunnel state")
+ );
+ update_cache = true;
+ TargetState::Secured
+ }),
+ Err(error) => {
+ if error.kind() == io::ErrorKind::NotFound {
+ log::debug!("No cached target state to load");
+ DEFAULT_TARGET_STATE
+ } else {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to read cached target tunnel state")
+ );
+ update_cache = true;
+ TargetState::Secured
+ }
+ }
+ };
+ let state = PersistentTargetState {
+ state,
+ cache_path,
+ locked: false,
+ };
+ if update_cache {
+ state.save().await;
+ }
+ state
+ }
+
+ /// Override the current target state, if there is one
+ pub async fn force(cache_dir: &Path, state: TargetState) -> Self {
+ let cache_path = cache_dir.join(TARGET_START_STATE_FILE);
+ let state = PersistentTargetState {
+ state,
+ cache_path,
+ locked: false,
+ };
+ state.save().await;
+ state
+ }
+
+ pub async fn set(&mut self, new_state: TargetState) {
+ if new_state != self.state {
+ self.state = new_state;
+ self.save().await;
+ }
+ }
+
+ /// Prevent the file from being removed when the instance is dropped.
+ pub fn lock(&mut self) {
+ self.locked = true;
+ }
+
+ /// Async destructor
+ pub async fn finalize(mut self) {
+ if self.locked {
+ return;
+ }
+ let _ = fs::remove_file(&self.cache_path).await.map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Cannot delete target tunnel state cache")
+ );
+ });
+ // prevent the sync destructor from running
+ self.locked = true;
+ }
+
+ async fn save(&self) {
+ log::trace!(
+ "Saving tunnel target state to {}",
+ self.cache_path.display()
+ );
+ match serde_json::to_string(&self.state) {
+ Ok(data) => {
+ if let Err(error) = fs::write(&self.cache_path, data).await {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to write cache target state")
+ );
+ }
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to serialize cache target state")
+ )
+ }
+ }
+ }
+}
+
+impl Drop for PersistentTargetState {
+ fn drop(&mut self) {
+ if self.locked {
+ return;
+ }
+ let _ = std::fs::remove_file(&self.cache_path).map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Cannot delete target tunnel state cache")
+ );
+ });
+ }
+}
+
+impl Deref for PersistentTargetState {
+ type Target = TargetState;
+
+ fn deref(&self) -> &Self::Target {
+ &self.state
+ }
+}