diff options
| -rw-r--r-- | mullvad-daemon/src/cli.rs | 11 | ||||
| -rw-r--r-- | mullvad-daemon/src/main.rs | 6 | ||||
| -rw-r--r-- | mullvad-daemon/src/system_service.rs | 55 |
3 files changed, 70 insertions, 2 deletions
diff --git a/mullvad-daemon/src/cli.rs b/mullvad-daemon/src/cli.rs index 4abf8eb44c..2693c37a15 100644 --- a/mullvad-daemon/src/cli.rs +++ b/mullvad-daemon/src/cli.rs @@ -9,6 +9,7 @@ pub struct Config { pub log_stdout_timestamps: bool, pub run_as_service: bool, pub register_service: bool, + pub restart_service: bool, } pub fn get_config() -> &'static Config { @@ -32,6 +33,7 @@ pub fn create_config() -> Config { let run_as_service = cfg!(windows) && matches.is_present("run_as_service"); let register_service = cfg!(windows) && matches.is_present("register_service"); + let restart_service = cfg!(windows) && matches.is_present("restart_service"); Config { log_level, @@ -39,6 +41,7 @@ pub fn create_config() -> Config { log_stdout_timestamps, run_as_service, register_service, + restart_service, } } @@ -91,11 +94,17 @@ fn create_app() -> App<'static, 'static> { Arg::with_name("run_as_service") .long("run-as-service") .help("Run as a system service. On Windows this option must be used when running a system service"), - ).arg( + ) + .arg( Arg::with_name("register_service") .long("register-service") .help("Register itself as a system service"), ) + .arg( + Arg::with_name("restart_service") + .long("restart-service") + .help("Restarts the existing system service"), + ) } app } diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 68fc2bd5e8..a355f94e44 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -75,6 +75,12 @@ fn get_log_dir(config: &cli::Config) -> Result<Option<PathBuf>, String> { async fn run_platform(config: &cli::Config, log_dir: Option<PathBuf>) -> Result<(), String> { if config.run_as_service { system_service::run() + } else if config.restart_service { + let restart_result = system_service::restart_service().map_err(|e| e.display_chain()); + if restart_result.is_ok() { + log::info!("Restarted the service."); + } + restart_result } else { if config.register_service { let install_result = system_service::install_service().map_err(|e| e.display_chain()); diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs index 281357c0c2..a87e3aec44 100644 --- a/mullvad-daemon/src/system_service.rs +++ b/mullvad-daemon/src/system_service.rs @@ -2,7 +2,7 @@ use crate::cli; use mullvad_daemon::{runtime::new_runtime_builder, DaemonShutdownHandle}; use std::{ env, - ffi::OsString, + ffi::{OsStr, OsString}, ptr, slice, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -46,6 +46,8 @@ static SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const SERVICE_RECOVERY_LAST_RESTART_DELAY: Duration = Duration::from_secs(60 * 10); const SERVICE_FAILURE_RESET_PERIOD: Duration = Duration::from_secs(60 * 15); +const SERVICE_RESTART_TIMEOUT: Duration = Duration::from_secs(60 * 2); + lazy_static::lazy_static! { static ref SERVICE_ACCESS: ServiceAccess = ServiceAccess::QUERY_CONFIG | ServiceAccess::CHANGE_CONFIG @@ -391,6 +393,57 @@ fn get_service_info() -> ServiceInfo { } } +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum RestartError { + #[error(display = "Unable to connect to service manager")] + ConnectServiceManager(#[error(source)] windows_service::Error), + + #[error(display = "Unable to open service")] + OpenService(#[error(source)] windows_service::Error), + + #[error(display = "Failed to query service status")] + QueryStatus(#[error(source)] windows_service::Error), + + #[error(display = "Failed to stop service")] + StopService(#[error(source)] windows_service::Error), + + #[error(display = "Failed to start service")] + StartService(#[error(source)] windows_service::Error), + + #[error(display = "Timed out while stopping service")] + Timeout, +} + +pub fn restart_service() -> Result<(), RestartError> { + let manager_access = ServiceManagerAccess::CONNECT; + let service_manager = ServiceManager::local_computer(None::<&str>, manager_access) + .map_err(RestartError::ConnectServiceManager)?; + + let service_access = ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::STOP; + let service = service_manager + .open_service(SERVICE_NAME, service_access) + .map_err(RestartError::OpenService)?; + + service.stop().map_err(RestartError::StopService)?; + + let start_time = Instant::now(); + + loop { + let status = service.query_status().map_err(RestartError::QueryStatus)?; + if status.current_state == ServiceState::Stopped { + let args: [&OsStr; 0] = []; + break service.start(&args).map_err(RestartError::StartService); + } + + if start_time.elapsed() > SERVICE_RESTART_TIMEOUT { + break Err(RestartError::Timeout); + } + + std::thread::sleep(Duration::from_secs(1)); + } +} + /// Used to track events that taken together would mean the machine is heading towards being /// hibernated. Typically, the user's session if first terminated. Moments later we should receive a /// suspension event corresponding to the hibernation of session 0 (kernel and services). |
