summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorOdd Stranne <odd@mullvad.net>2020-01-28 13:00:39 +0100
committerOdd Stranne <odd@mullvad.net>2020-01-28 13:00:39 +0100
commited5a61312692268e774abe39599d6099e47f0cae (patch)
tree70f49902f776cbc3e27fce72d6ff2797ac03b51c
parentfd06f0af58f1c59ab72394ac91dbbecf242407b2 (diff)
parent00f2e5373d7f3fb6757bfa9454236d291866bb7b (diff)
downloadmullvadvpn-ed5a61312692268e774abe39599d6099e47f0cae.tar.xz
mullvadvpn-ed5a61312692268e774abe39599d6099e47f0cae.zip
Merge branch 'win-detect-hibernation'
-rw-r--r--CHANGELOG.md2
-rw-r--r--Cargo.lock1
-rw-r--r--mullvad-daemon/Cargo.toml3
-rw-r--r--mullvad-daemon/src/system_service.rs163
4 files changed, 161 insertions, 8 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f92b886341..970c924031 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -51,6 +51,8 @@ Line wrap the file at 100 chars. Th
- Fix occasional failure to shut down the old daemon process during installation by killing it if
necessary.
- Make WireGuard work with IPv6 enabled even if there is no functioning TAP adapter for OpenVPN.
+- Restart daemon when coming back from system hibernation with terminated user session, since
+ it's perceived as a cold boot from the user's perspective, so the app should act accordingly.
#### Android
- Fix crash when starting the app right after quitting it.
diff --git a/Cargo.lock b/Cargo.lock
index 31d0d67a22..599fcf0c47 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1253,6 +1253,7 @@ dependencies = [
"chrono 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)",
"clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
"ctrlc 3.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
+ "duct 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)",
"err-derive 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"fern 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml
index 8077dd7bcb..6ccd9089fe 100644
--- a/mullvad-daemon/Cargo.toml
+++ b/mullvad-daemon/Cargo.toml
@@ -49,8 +49,9 @@ simple-signal = "1.1"
[target.'cfg(windows)'.dependencies]
ctrlc = "3.0"
+duct = "0.13"
windows-service = { git = "https://github.com/mullvad/windows-service-rs.git", rev = "1d5f9cc65658429414f2d62e4581e5a3e2532b99" }
-winapi = { version = "0.3", features = ["errhandlingapi", "handleapi", "libloaderapi", "synchapi", "tlhelp32", "winbase", "winerror", "winuser"] }
+winapi = { version = "0.3", features = ["errhandlingapi", "handleapi", "libloaderapi", "ntlsa", "synchapi", "tlhelp32", "winbase", "winerror", "winuser"] }
[target.'cfg(windows)'.build-dependencies]
winres = "0.1"
diff --git a/mullvad-daemon/src/system_service.rs b/mullvad-daemon/src/system_service.rs
index 4483756f63..4e4486f1c4 100644
--- a/mullvad-daemon/src/system_service.rs
+++ b/mullvad-daemon/src/system_service.rs
@@ -3,20 +3,36 @@ use mullvad_daemon::DaemonShutdownHandle;
use std::{
env,
ffi::OsString,
+ ptr, slice,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
mpsc, Arc,
},
thread,
- time::Duration,
+ time::{Duration, Instant},
};
use talpid_types::ErrorExt;
+use winapi::{
+ ctypes::c_void,
+ shared::{
+ minwindef::ULONG,
+ ntdef::{LUID, PVOID, WCHAR},
+ ntstatus::STATUS_SUCCESS,
+ },
+ um::{
+ ntlsa::{
+ LsaEnumerateLogonSessions, LsaFreeReturnBuffer, LsaGetLogonSessionData,
+ SECURITY_LOGON_SESSION_DATA,
+ },
+ sysinfoapi::GetSystemDirectoryW,
+ },
+};
use windows_service::{
service::{
- Service, ServiceAccess, ServiceAction, ServiceActionType, ServiceControl,
+ PowerEventParam, Service, ServiceAccess, ServiceAction, ServiceActionType, ServiceControl,
ServiceControlAccept, ServiceDependency, ServiceErrorControl, ServiceExitCode,
ServiceFailureActions, ServiceFailureResetPeriod, ServiceInfo, ServiceSidType,
- ServiceStartType, ServiceState, ServiceStatus, ServiceType,
+ ServiceStartType, ServiceState, ServiceStatus, ServiceType, SessionChangeReason,
},
service_control_handler::{self, ServiceControlHandlerResult, ServiceStatusHandle},
service_dispatcher,
@@ -65,7 +81,10 @@ fn run_service() -> Result<(), String> {
// control manager. Always return NO_ERROR even if not implemented.
ServiceControl::Interrogate => ServiceControlHandlerResult::NoError,
- ServiceControl::Stop | ServiceControl::Preshutdown => {
+ ServiceControl::Stop
+ | ServiceControl::Preshutdown
+ | ServiceControl::PowerEvent(_)
+ | ServiceControl::SessionChange(_) => {
event_tx.send(control_event).unwrap();
ServiceControlHandlerResult::NoError
}
@@ -99,7 +118,6 @@ fn run_service() -> Result<(), String> {
daemon.run().map_err(|e| e.display_chain())
});
-
let exit_code = match result {
Ok(()) => {
// check if shutdown signal was sent from the system
@@ -127,6 +145,7 @@ fn start_event_monitor(
clean_shutdown: Arc<AtomicBool>,
) -> thread::JoinHandle<()> {
thread::spawn(move || {
+ let mut hibernation_detector = HibernationDetector::default();
for event in event_rx {
match event {
ServiceControl::Stop | ServiceControl::Preshutdown => {
@@ -137,6 +156,20 @@ fn start_event_monitor(
clean_shutdown.store(true, Ordering::Release);
shutdown_handle.shutdown();
}
+ ServiceControl::PowerEvent(details) => match details {
+ PowerEventParam::Suspend => {
+ hibernation_detector.register_suspend();
+ }
+ PowerEventParam::ResumeAutomatic | PowerEventParam::ResumeSuspend => {
+ hibernation_detector.register_resume();
+ }
+ _ => (),
+ },
+ ServiceControl::SessionChange(details) => {
+ if details.reason == SessionChangeReason::SessionLogoff {
+ hibernation_detector.register_logoff(details.notification.session_id);
+ }
+ }
_ => (),
}
}
@@ -233,12 +266,17 @@ impl PersistentServiceStatus {
/// Returns the list of accepted service events at each stage of the service lifecycle.
fn accepted_controls_by_state(state: ServiceState) -> ServiceControlAccept {
+ let always_accepted = ServiceControlAccept::POWER_EVENT | ServiceControlAccept::SESSION_CHANGE;
match state {
ServiceState::StartPending | ServiceState::PausePending | ServiceState::ContinuePending => {
ServiceControlAccept::empty()
}
- ServiceState::Running => ServiceControlAccept::STOP | ServiceControlAccept::PRESHUTDOWN,
- ServiceState::Paused => ServiceControlAccept::STOP | ServiceControlAccept::PRESHUTDOWN,
+ ServiceState::Running => {
+ always_accepted | ServiceControlAccept::STOP | ServiceControlAccept::PRESHUTDOWN
+ }
+ ServiceState::Paused => {
+ always_accepted | ServiceControlAccept::STOP | ServiceControlAccept::PRESHUTDOWN
+ }
ServiceState::StopPending | ServiceState::Stopped => ServiceControlAccept::empty(),
}
}
@@ -329,3 +367,114 @@ fn get_service_info() -> ServiceInfo {
account_password: None,
}
}
+
+/// Used to track events that taken together would mean the machine is heading towards being
+/// hibernated. Typically, the user's session if first terminated. Moments later we should receive a
+/// suspension event corresponding to the hibernation of session 0 (kernel and services).
+#[derive(Default)]
+struct HibernationDetector {
+ logoff_time: Option<Instant>,
+ should_restart: bool,
+}
+
+const SECURITY_LOGON_TYPE_INTERACTIVE: u32 = 2;
+
+impl HibernationDetector {
+ /// Register a session logoff.
+ /// The logoff event is discarded unless the session was/is interactive.
+ fn register_logoff(&mut self, session_id: u32) {
+ if unsafe { Self::interactive_session(session_id) } {
+ self.logoff_time = Some(Instant::now());
+ }
+ }
+
+ unsafe fn interactive_session(session_id: u32) -> bool {
+ let mut logon_session_count: ULONG = 0;
+ let mut logon_session_list: *mut LUID = ptr::null_mut();
+ let status = LsaEnumerateLogonSessions(&mut logon_session_count, &mut logon_session_list);
+ if status != STATUS_SUCCESS {
+ log::warn!("LsaEnumerateLogonSessions() failed, error code: {}", status);
+ return false;
+ }
+ let logons = slice::from_raw_parts(logon_session_list, logon_session_count as usize);
+ let mut interactive = false;
+ for logon in logons {
+ let mut session_data: *mut SECURITY_LOGON_SESSION_DATA = ptr::null_mut();
+ let status = LsaGetLogonSessionData(logon as *const _ as *mut LUID, &mut session_data);
+ if status != STATUS_SUCCESS {
+ log::warn!("LsaGetLogonSessionData() failed, error code: {}", status);
+ continue;
+ }
+ let candidate_correct_session = (*session_data).Session == session_id;
+ let candidate_interactive =
+ (*session_data).LogonType == SECURITY_LOGON_TYPE_INTERACTIVE;
+ LsaFreeReturnBuffer(session_data as *mut c_void as PVOID);
+ if candidate_correct_session {
+ interactive = candidate_interactive;
+ break;
+ }
+ }
+ LsaFreeReturnBuffer(logon_session_list as *mut c_void as PVOID);
+ interactive
+ }
+
+ /// Register a machine suspend event.
+ fn register_suspend(&mut self) {
+ if let Some(logoff_time) = &self.logoff_time {
+ if logoff_time.elapsed() < Duration::from_secs(5) {
+ log::info!("Pending hibernation detected");
+ self.should_restart = true;
+ }
+ }
+ }
+
+ /// Register a machine resume event.
+ /// This will restart the service if we are coming back from hibernation.
+ fn register_resume(&mut self) {
+ if self.should_restart {
+ self.should_restart = false;
+ log::info!("System is being restored from hibernation. Restarting daemon service");
+ if let Err(err) = Self::restart_daemon() {
+ log::error!("{}", err);
+ }
+ }
+ }
+
+ /// Performs a clean shutdown and restart of the daemon.
+ fn restart_daemon() -> Result<(), String> {
+ let sysdir = unsafe { Self::get_system_directory() }?;
+ let cmd_path = format!("{}cmd.exe", sysdir);
+ let commands = vec!["net stop", SERVICE_NAME, "& net start", SERVICE_NAME];
+ let args = vec!["/C".to_string(), commands.join(" ")];
+ duct::cmd(cmd_path, args)
+ .dir(sysdir)
+ .stdin_null()
+ .stdout_null()
+ .stderr_null()
+ .start()
+ .map(|_| ())
+ .map_err(|e| e.display_chain_with_msg("Failed to start helper process"))
+ }
+
+ /// Returns the absolute path of the system directory.
+ /// Always includes a terminating backslash.
+ unsafe fn get_system_directory() -> Result<String, String> {
+ // Returned count is including null terminator.
+ let chars_required = GetSystemDirectoryW(ptr::null_mut(), 0);
+ if chars_required != 0 {
+ let mut buffer: Vec::<WCHAR> = Vec::with_capacity(chars_required as usize);
+ // Returned count is excluding null terminator.
+ let chars_written = GetSystemDirectoryW(buffer.as_mut_ptr(), chars_required);
+ if chars_written == (chars_required - 1) {
+ buffer.set_len(chars_written as usize);
+ let mut path = String::from_utf16(&buffer)
+ .map_err(|e| e.display_chain_with_msg("Failed to convert system directory path string"))?;
+ if !path.ends_with("\\") {
+ path.push('\\');
+ }
+ return Ok(path);
+ }
+ }
+ Err("Failed to resolve system directory".into())
+ }
+}