diff options
| author | David Lönnhager <david.l@mullvad.net> | 2022-09-20 10:01:57 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2022-09-20 16:28:08 +0200 |
| commit | 473b2ab1c22c733acdff0551e3135228bfeb70e5 (patch) | |
| tree | 6a542bf7ce7eab9155cdce48cd1e42ff648911ab | |
| parent | a453fa6c1d577eb73b4fe2badb704277b99f2b30 (diff) | |
| download | mullvadvpn-473b2ab1c22c733acdff0551e3135228bfeb70e5.tar.xz mullvadvpn-473b2ab1c22c733acdff0551e3135228bfeb70e5.zip | |
Add shutdown detection for Windows service
| -rw-r--r-- | mullvad-daemon/src/cli.rs | 8 | ||||
| -rw-r--r-- | mullvad-daemon/src/lib.rs | 28 | ||||
| -rw-r--r-- | mullvad-daemon/src/main.rs | 8 | ||||
| -rw-r--r-- | mullvad-daemon/src/system_service.rs | 139 |
4 files changed, 66 insertions, 117 deletions
diff --git a/mullvad-daemon/src/cli.rs b/mullvad-daemon/src/cli.rs index e7e332c942..9a6e8eec50 100644 --- a/mullvad-daemon/src/cli.rs +++ b/mullvad-daemon/src/cli.rs @@ -9,7 +9,6 @@ pub struct Config { pub log_stdout_timestamps: bool, pub run_as_service: bool, pub register_service: bool, - pub restart_service: bool, #[cfg(target_os = "linux")] pub initialize_firewall_and_exit: bool, } @@ -38,7 +37,6 @@ pub fn create_config() -> Config { cfg!(target_os = "linux") && matches.is_present("initialize-early-boot-firewall"); let run_as_service = cfg!(windows) && matches.is_present("run_as_service"); let register_service = cfg!(windows) && matches.is_present("register_service"); - let restart_service = cfg!(windows) && matches.is_present("restart_service"); Config { #[cfg(target_os = "linux")] @@ -48,7 +46,6 @@ pub fn create_config() -> Config { log_stdout_timestamps, run_as_service, register_service, - restart_service, } } @@ -107,11 +104,6 @@ fn create_app() -> App<'static> { .long("register-service") .help("Register itself as a system service"), ) - .arg( - Arg::new("restart_service") - .long("restart-service") - .help("Restarts the existing system service"), - ) } if cfg!(target_os = "linux") { diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 49f27751cc..88dd9da735 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -300,7 +300,8 @@ pub(crate) enum InternalDaemonEvent { /// A command sent to the daemon. Command(DaemonCommand), /// Daemon shutdown triggered by a signal, ctrl-c or similar. - TriggerShutdown, + /// The boolean should indicate whether the shutdown was user-initiated. + TriggerShutdown(bool), /// The background job fetching new `AppVersionInfo`s got a new info object. NewAppVersionInfo(AppVersionInfo), /// Sent when a device is updated in any way (key rotation, login, logout, etc.). @@ -825,7 +826,7 @@ where self.handle_tunnel_state_transition(transition).await } Command(command) => self.handle_command(command).await, - TriggerShutdown => self.trigger_shutdown_event(), + TriggerShutdown(user_init_shutdown) => self.trigger_shutdown_event(user_init_shutdown), NewAppVersionInfo(app_version_info) => { self.handle_new_app_version_info(app_version_info); } @@ -1030,7 +1031,7 @@ where SetObfuscationSettings(tx, settings) => { self.on_set_obfuscation_settings(tx, settings).await } - Shutdown => self.trigger_shutdown_event(), + Shutdown => self.trigger_shutdown_event(false), PrepareRestart => self.on_prepare_restart(), #[cfg(target_os = "android")] BypassSocket(fd, tx) => self.on_bypass_socket(fd, tx), @@ -1518,7 +1519,7 @@ where } // Shut the daemon down. - self.trigger_shutdown_event(); + self.trigger_shutdown_event(false); self.shutdown_tasks.push(Box::pin(async move { if let Err(e) = cleanup::clear_directories().await { @@ -2158,11 +2159,15 @@ where } } - fn trigger_shutdown_event(&mut self) { - // If auto-connect is enabled, block all traffic before shutting down to ensure - // that no traffic can leak during boot. + #[cfg_attr(not(target_os = "windows"), allow(unused_variables))] + fn trigger_shutdown_event(&mut self, user_init_shutdown: bool) { + // Block all traffic before shutting down to ensure that no traffic can leak on boot or + // shutdown. #[cfg(windows)] - if self.settings.auto_connect { + if !user_init_shutdown + && (*self.target_state == TargetState::Secured || self.settings.auto_connect) + { + log::debug!("Blocking firewall during shutdown since system is going down"); self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true)); } @@ -2268,13 +2273,16 @@ where } } +#[derive(Clone)] pub struct DaemonShutdownHandle { tx: DaemonEventSender, } impl DaemonShutdownHandle { - pub fn shutdown(&self) { - let _ = self.tx.send(InternalDaemonEvent::TriggerShutdown); + pub fn shutdown(&self, user_init_shutdown: bool) { + let _ = self + .tx + .send(InternalDaemonEvent::TriggerShutdown(user_init_shutdown)); } } diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 13b4d8a268..1d929c2d4a 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -91,12 +91,6 @@ fn get_log_dir(config: &cli::Config) -> Result<Option<PathBuf>, String> { async fn run_platform(config: &cli::Config, log_dir: Option<PathBuf>) -> Result<(), String> { if config.run_as_service { system_service::run() - } else if config.restart_service { - let restart_result = system_service::restart_service().map_err(|e| e.display_chain()); - if restart_result.is_ok() { - log::info!("Restarted the service."); - } - restart_result } else { if config.register_service { let install_result = system_service::install_service().map_err(|e| e.display_chain()); @@ -144,7 +138,7 @@ async fn run_standalone(log_dir: Option<PathBuf>) -> Result<(), String> { let daemon = create_daemon(log_dir).await?; let shutdown_handle = daemon.shutdown_handle(); - shutdown::set_shutdown_signal_handler(move || shutdown_handle.shutdown()) + shutdown::set_shutdown_signal_handler(move || shutdown_handle.shutdown(true)) .map_err(|e| e.display_chain())?; daemon.run().await.map_err(|e| e.display_chain())?; diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs index bb22f2d04a..1a7f3ee1fd 100644 --- a/mullvad-daemon/src/system_service.rs +++ b/mullvad-daemon/src/system_service.rs @@ -3,7 +3,7 @@ use libc::c_void; use mullvad_daemon::{runtime::new_runtime_builder, DaemonShutdownHandle}; use std::{ env, - ffi::{OsStr, OsString}, + ffi::OsString, mem, ptr, slice, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -39,8 +39,6 @@ static SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const SERVICE_RECOVERY_LAST_RESTART_DELAY: Duration = Duration::from_secs(60 * 10); const SERVICE_FAILURE_RESET_PERIOD: Duration = Duration::from_secs(60 * 15); -const SERVICE_RESTART_TIMEOUT: Duration = Duration::from_secs(60 * 2); - lazy_static::lazy_static! { static ref SERVICE_ACCESS: ServiceAccess = ServiceAccess::QUERY_CONFIG | ServiceAccess::CHANGE_CONFIG @@ -96,7 +94,7 @@ pub fn handle_service_main(_arguments: Vec<OsString>) { .set_pending_start(Duration::from_secs(1)) .unwrap(); - let clean_shutdown = Arc::new(AtomicBool::new(false)); + let should_restart = Arc::new(AtomicBool::new(true)); let log_dir = crate::get_log_dir(cli::get_config()).expect("Log dir should be available here"); @@ -121,7 +119,7 @@ pub fn handle_service_main(_arguments: Vec<OsString>) { persistent_service_status.clone(), shutdown_handle, event_rx, - clean_shutdown.clone(), + should_restart.clone(), ); persistent_service_status.set_running().unwrap(); @@ -137,7 +135,7 @@ pub fn handle_service_main(_arguments: Vec<OsString>) { Ok(()) => { log::info!("Stopping service"); // check if shutdown signal was sent from the system - if clean_shutdown.load(Ordering::Acquire) { + if !should_restart.load(Ordering::Acquire) { ServiceExitCode::default() } else { // otherwise return a non-zero code so that the daemon gets restarted @@ -156,22 +154,24 @@ pub fn handle_service_main(_arguments: Vec<OsString>) { /// Start event monitor thread that polls for `ServiceControl` and translates them into calls to /// Daemon. fn start_event_monitor( - mut persistent_service_status: PersistentServiceStatus, + persistent_service_status: PersistentServiceStatus, shutdown_handle: DaemonShutdownHandle, event_rx: mpsc::Receiver<ServiceControl>, - clean_shutdown: Arc<AtomicBool>, + should_restart: Arc<AtomicBool>, ) -> thread::JoinHandle<()> { thread::spawn(move || { - let mut hibernation_detector = HibernationDetector::default(); + let mut shutdown_handle = ServiceShutdownHandle { + persistent_service_status, + shutdown_handle, + should_restart, + }; + let mut hibernation_detector = HibernationDetector::new(shutdown_handle.clone()); for event in event_rx { match event { ServiceControl::Stop | ServiceControl::Preshutdown => { - persistent_service_status - .set_pending_stop(Duration::from_secs(10)) - .unwrap(); - - clean_shutdown.store(true, Ordering::Release); - shutdown_handle.shutdown(); + // If the daemon is closing due to the system shutting down, + // keep blocking traffic after the daemon exits. + shutdown_handle.shutdown(false, event == ServiceControl::Preshutdown); } ServiceControl::PowerEvent(details) => match details { PowerEventParam::Suspend => { @@ -193,6 +193,26 @@ fn start_event_monitor( }) } +#[derive(Clone)] +struct ServiceShutdownHandle { + persistent_service_status: PersistentServiceStatus, + shutdown_handle: DaemonShutdownHandle, + /// If true, the service will be restarted by the SCM when + /// the daemon has exited. + should_restart: Arc<AtomicBool>, +} + +impl ServiceShutdownHandle { + fn shutdown(&mut self, should_restart: bool, is_system_shutdown: bool) { + self.persistent_service_status + .set_pending_stop(Duration::from_secs(10)) + .unwrap(); + + self.should_restart.store(should_restart, Ordering::Release); + self.shutdown_handle.shutdown(!is_system_shutdown); + } +} + /// Service status helper with persistent checkpoint counter. #[derive(Debug, Clone)] struct PersistentServiceStatus { @@ -385,69 +405,26 @@ fn get_service_info() -> ServiceInfo { } } -#[derive(err_derive::Error, Debug)] -#[error(no_from)] -pub enum RestartError { - #[error(display = "Unable to connect to service manager")] - ConnectServiceManager(#[error(source)] windows_service::Error), - - #[error(display = "Unable to open service")] - OpenService(#[error(source)] windows_service::Error), - - #[error(display = "Failed to query service status")] - QueryStatus(#[error(source)] windows_service::Error), - - #[error(display = "Failed to stop service")] - StopService(#[error(source)] windows_service::Error), - - #[error(display = "Failed to start service")] - StartService(#[error(source)] windows_service::Error), - - #[error(display = "Timed out while stopping service")] - Timeout, -} - -pub fn restart_service() -> Result<(), RestartError> { - let manager_access = ServiceManagerAccess::CONNECT; - let service_manager = ServiceManager::local_computer(None::<&str>, manager_access) - .map_err(RestartError::ConnectServiceManager)?; - - let service_access = ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::STOP; - let service = service_manager - .open_service(SERVICE_NAME, service_access) - .map_err(RestartError::OpenService)?; - - service.stop().map_err(RestartError::StopService)?; - - let start_time = Instant::now(); - - loop { - let status = service.query_status().map_err(RestartError::QueryStatus)?; - if status.current_state == ServiceState::Stopped { - let args: [&OsStr; 0] = []; - break service.start(&args).map_err(RestartError::StartService); - } - - if start_time.elapsed() > SERVICE_RESTART_TIMEOUT { - break Err(RestartError::Timeout); - } - - std::thread::sleep(Duration::from_secs(1)); - } -} - /// Used to track events that taken together would mean the machine is heading towards being /// hibernated. Typically, the user's session if first terminated. Moments later we should receive a /// suspension event corresponding to the hibernation of session 0 (kernel and services). -#[derive(Default)] struct HibernationDetector { logoff_time: Option<Instant>, should_restart: bool, + shutdown_handle: ServiceShutdownHandle, } const SECURITY_LOGON_TYPE_INTERACTIVE: u32 = 2; impl HibernationDetector { + fn new(shutdown_handle: ServiceShutdownHandle) -> Self { + Self { + logoff_time: None, + should_restart: false, + shutdown_handle, + } + } + /// Register a session logoff. /// The logoff event is discarded unless the session was/is interactive. fn register_logoff(&mut self, session_id: u32) { @@ -507,31 +484,9 @@ impl HibernationDetector { if self.should_restart { self.should_restart = false; log::info!("System is being restored from hibernation. Restarting daemon service"); - if let Err(err) = Self::restart_daemon() { - log::error!("{}", err); - } - } - } - /// Performs a clean shutdown and restart of the daemon. - fn restart_daemon() -> Result<(), String> { - let daemon_path = env::current_exe() - .map_err(|e| e.display_chain_with_msg("Failed to obtain daemon path"))?; - let working_dir = daemon_path - .parent() - .ok_or("Failed to obtain resource directory".to_string())? - .to_path_buf(); - let args = vec![ - "--restart-service".to_string(), - "--disable-log-to-file".to_string(), - ]; - duct::cmd(daemon_path, args) - .dir(working_dir) - .stdin_null() - .stdout_null() - .stderr_null() - .start() - .map(|_| ()) - .map_err(|e| e.display_chain_with_msg("Failed to start helper process")) + // Perform a non-clean shutdown. This will cause the daemon to restart itself. + self.shutdown_handle.shutdown(true, true); + } } } |
