summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mullvad-daemon/src/cli.rs11
-rw-r--r--mullvad-daemon/src/main.rs6
-rw-r--r--mullvad-daemon/src/system_service.rs55
3 files changed, 70 insertions, 2 deletions
diff --git a/mullvad-daemon/src/cli.rs b/mullvad-daemon/src/cli.rs
index 4abf8eb44c..2693c37a15 100644
--- a/mullvad-daemon/src/cli.rs
+++ b/mullvad-daemon/src/cli.rs
@@ -9,6 +9,7 @@ pub struct Config {
pub log_stdout_timestamps: bool,
pub run_as_service: bool,
pub register_service: bool,
+ pub restart_service: bool,
}
pub fn get_config() -> &'static Config {
@@ -32,6 +33,7 @@ pub fn create_config() -> Config {
let run_as_service = cfg!(windows) && matches.is_present("run_as_service");
let register_service = cfg!(windows) && matches.is_present("register_service");
+ let restart_service = cfg!(windows) && matches.is_present("restart_service");
Config {
log_level,
@@ -39,6 +41,7 @@ pub fn create_config() -> Config {
log_stdout_timestamps,
run_as_service,
register_service,
+ restart_service,
}
}
@@ -91,11 +94,17 @@ fn create_app() -> App<'static, 'static> {
Arg::with_name("run_as_service")
.long("run-as-service")
.help("Run as a system service. On Windows this option must be used when running a system service"),
- ).arg(
+ )
+ .arg(
Arg::with_name("register_service")
.long("register-service")
.help("Register itself as a system service"),
)
+ .arg(
+ Arg::with_name("restart_service")
+ .long("restart-service")
+ .help("Restarts the existing system service"),
+ )
}
app
}
diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs
index 68fc2bd5e8..a355f94e44 100644
--- a/mullvad-daemon/src/main.rs
+++ b/mullvad-daemon/src/main.rs
@@ -75,6 +75,12 @@ fn get_log_dir(config: &cli::Config) -> Result<Option<PathBuf>, String> {
async fn run_platform(config: &cli::Config, log_dir: Option<PathBuf>) -> Result<(), String> {
if config.run_as_service {
system_service::run()
+ } else if config.restart_service {
+ let restart_result = system_service::restart_service().map_err(|e| e.display_chain());
+ if restart_result.is_ok() {
+ log::info!("Restarted the service.");
+ }
+ restart_result
} else {
if config.register_service {
let install_result = system_service::install_service().map_err(|e| e.display_chain());
diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs
index 281357c0c2..a87e3aec44 100644
--- a/mullvad-daemon/src/system_service.rs
+++ b/mullvad-daemon/src/system_service.rs
@@ -2,7 +2,7 @@ use crate::cli;
use mullvad_daemon::{runtime::new_runtime_builder, DaemonShutdownHandle};
use std::{
env,
- ffi::OsString,
+ ffi::{OsStr, OsString},
ptr, slice,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
@@ -46,6 +46,8 @@ static SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS;
const SERVICE_RECOVERY_LAST_RESTART_DELAY: Duration = Duration::from_secs(60 * 10);
const SERVICE_FAILURE_RESET_PERIOD: Duration = Duration::from_secs(60 * 15);
+const SERVICE_RESTART_TIMEOUT: Duration = Duration::from_secs(60 * 2);
+
lazy_static::lazy_static! {
static ref SERVICE_ACCESS: ServiceAccess = ServiceAccess::QUERY_CONFIG
| ServiceAccess::CHANGE_CONFIG
@@ -391,6 +393,57 @@ fn get_service_info() -> ServiceInfo {
}
}
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum RestartError {
+ #[error(display = "Unable to connect to service manager")]
+ ConnectServiceManager(#[error(source)] windows_service::Error),
+
+ #[error(display = "Unable to open service")]
+ OpenService(#[error(source)] windows_service::Error),
+
+ #[error(display = "Failed to query service status")]
+ QueryStatus(#[error(source)] windows_service::Error),
+
+ #[error(display = "Failed to stop service")]
+ StopService(#[error(source)] windows_service::Error),
+
+ #[error(display = "Failed to start service")]
+ StartService(#[error(source)] windows_service::Error),
+
+ #[error(display = "Timed out while stopping service")]
+ Timeout,
+}
+
+pub fn restart_service() -> Result<(), RestartError> {
+ let manager_access = ServiceManagerAccess::CONNECT;
+ let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)
+ .map_err(RestartError::ConnectServiceManager)?;
+
+ let service_access = ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::STOP;
+ let service = service_manager
+ .open_service(SERVICE_NAME, service_access)
+ .map_err(RestartError::OpenService)?;
+
+ service.stop().map_err(RestartError::StopService)?;
+
+ let start_time = Instant::now();
+
+ loop {
+ let status = service.query_status().map_err(RestartError::QueryStatus)?;
+ if status.current_state == ServiceState::Stopped {
+ let args: [&OsStr; 0] = [];
+ break service.start(&args).map_err(RestartError::StartService);
+ }
+
+ if start_time.elapsed() > SERVICE_RESTART_TIMEOUT {
+ break Err(RestartError::Timeout);
+ }
+
+ std::thread::sleep(Duration::from_secs(1));
+ }
+}
+
/// Used to track events that taken together would mean the machine is heading towards being
/// hibernated. Typically, the user's session if first terminated. Moments later we should receive a
/// suspension event corresponding to the hibernation of session 0 (kernel and services).