diff options
| author | Emīls Piņķis <emils@mullvad.net> | 2019-05-14 18:35:28 +0100 |
|---|---|---|
| committer | Emīls Piņķis <emils@mullvad.net> | 2019-05-21 11:14:14 +0100 |
| commit | 34a205895c72725becc65079fe9f5176d89dba1a (patch) | |
| tree | 1a67990783789f7cd54c6e86101b79793c93efd5 | |
| parent | bd7beaa1911e8c843eb4f80ca8fcd71d1e1eebce (diff) | |
| download | mullvadvpn-34a205895c72725becc65079fe9f5176d89dba1a.tar.xz mullvadvpn-34a205895c72725becc65079fe9f5176d89dba1a.zip | |
Add new windows connectivity checks to daemon
| -rw-r--r-- | talpid-core/src/offline/windows.rs | 154 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 55 |
2 files changed, 165 insertions, 44 deletions
diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs index e40d8fe44f..81a9344d3f 100644 --- a/talpid-core/src/offline/windows.rs +++ b/talpid-core/src/offline/windows.rs @@ -6,17 +6,20 @@ //! GNU General Public License as published by the Free Software Foundation, either version 3 of //! the License, or (at your option) any later version. -use crate::tunnel_state_machine::TunnelCommand; +use crate::{tunnel_state_machine::TunnelCommand, winnet}; use futures::sync::mpsc::UnboundedSender; -use log::debug; +use parking_lot::Mutex; use std::{ ffi::c_void, io, mem::zeroed, os::windows::io::{IntoRawHandle, RawHandle}, - ptr, thread, + ptr, + sync::Arc, + thread, time::Duration, }; +use talpid_types::ErrorExt; use winapi::{ shared::{ basetsd::LONG_PTR, @@ -46,32 +49,61 @@ const REQUEST_THREAD_SHUTDOWN: UINT = WM_USER + 1; pub enum Error { #[error(display = "Unable to create listener thread")] ThreadCreationError(#[error(cause)] io::Error), + #[error(display = "Failed to start connectivity monitor")] + ConnectivityMonitorError, } pub struct BroadcastListener { thread_handle: RawHandle, thread_id: DWORD, + _system_state: Arc<Mutex<SystemState>>, } unsafe impl Send for BroadcastListener {} impl BroadcastListener { - pub fn start<F>(client_callback: F) -> Result<Self, Error> - where - F: Fn(UINT, WPARAM, LPARAM) + 'static + Send, - { + pub fn start(sender: UnboundedSender<TunnelCommand>) -> Result<Self, Error> { + let mut system_state = Arc::new(Mutex::new(SystemState { + network_connectivity: false, + suspended: false, + daemon_channel: sender, + })); + + let power_broadcast_state_ref = system_state.clone(); + + let power_broadcast_callback = move |message: UINT, wparam: WPARAM, _lparam: LPARAM| { + let state = power_broadcast_state_ref.clone(); + if message == WM_POWERBROADCAST { + if wparam == PBT_APMSUSPEND { + log::debug!("Machine is preparing to enter sleep mode"); + apply_system_state_change(state, StateChange::Suspended(true)); + } else if wparam == PBT_APMRESUMEAUTOMATIC { + log::debug!("Machine is returning from sleep mode"); + thread::spawn(move || { + // TAP will be unavailable for approximately 2 seconds on a healthy machine. + thread::sleep(Duration::from_secs(5)); + log::debug!("TAP is presumed to have been re-initialized"); + apply_system_state_change(state, StateChange::Suspended(false)); + }); + } + } + }; + let join_handle = thread::Builder::new() .spawn(move || unsafe { - Self::message_pump(client_callback); + Self::message_pump(power_broadcast_callback); }) .map_err(Error::ThreadCreationError)?; let real_handle = join_handle.into_raw_handle(); + unsafe { Self::setup_network_connectivity_listener(&mut system_state)? }; + Ok(BroadcastListener { thread_handle: real_handle, thread_id: unsafe { GetThreadId(real_handle) }, + _system_state: system_state, }) } @@ -156,6 +188,33 @@ impl BroadcastListener { DefWindowProcW(window, message, wparam, lparam) } + + /// The caller must make sure the `system_state` reference is valid + /// until after `WinNet_DeactivateConnectivityMonitor` has been called. + unsafe fn setup_network_connectivity_listener( + system_state: &Mutex<SystemState>, + ) -> Result<(), Error> { + let callback_context = system_state as *const _ as *mut libc::c_void; + let mut state = system_state.lock(); + let mut current_connectivity = true; + if !winnet::WinNet_ActivateConnectivityMonitor( + Some(Self::connectivity_callback), + callback_context, + &mut current_connectivity as *mut _, + Some(winnet::error_sink), + ptr::null_mut(), + ) { + return Err(Error::ConnectivityMonitorError); + } + state.network_connectivity = current_connectivity; + Ok(()) + } + + unsafe extern "system" fn connectivity_callback(connectivity: bool, context: *mut c_void) { + let state_lock: &mut Mutex<SystemState> = &mut *(context as *mut _); + let mut state = state_lock.lock(); + state.apply_change(StateChange::NetworkConnectivity(connectivity)); + } } impl Drop for BroadcastListener { @@ -164,35 +223,74 @@ impl Drop for BroadcastListener { PostThreadMessageW(self.thread_id, REQUEST_THREAD_SHUTDOWN, 0, 0); WaitForSingleObject(self.thread_handle, INFINITE); CloseHandle(self.thread_handle); + if !winnet::WinNet_DeactivateConnectivityMonitor() { + log::error!("Failed to deactivate connectivity monitor"); + } } } } +#[derive(Debug)] +enum StateChange { + NetworkConnectivity(bool), + Suspended(bool), +} + +struct SystemState { + network_connectivity: bool, + suspended: bool, + daemon_channel: UnboundedSender<TunnelCommand>, +} + +impl SystemState { + fn apply_change(&mut self, change: StateChange) { + let old_state = self.is_offline_currently(); + match change { + StateChange::NetworkConnectivity(connectivity) => { + self.network_connectivity = connectivity; + } + + StateChange::Suspended(suspended) => { + self.suspended = suspended; + } + }; + + let new_state = self.is_offline_currently(); + if old_state != new_state { + if let Err(e) = self + .daemon_channel + .unbounded_send(TunnelCommand::IsOffline(new_state)) + { + log::error!("Failed to send new offline state to daemon: {}", e); + } + } + } + + fn is_offline_currently(&self) -> bool { + !self.network_connectivity || self.suspended + } +} + pub type MonitorHandle = BroadcastListener; pub fn spawn_monitor(sender: UnboundedSender<TunnelCommand>) -> Result<MonitorHandle, Error> { - let listener = - BroadcastListener::start(move |message: UINT, wparam: WPARAM, _lparam: LPARAM| { - if message == WM_POWERBROADCAST { - if wparam == PBT_APMSUSPEND { - debug!("Machine is preparing to enter sleep mode"); - let _ = sender.unbounded_send(TunnelCommand::IsOffline(true)); - } else if wparam == PBT_APMRESUMEAUTOMATIC { - debug!("Machine is returning from sleep mode"); - let cloned_sender = sender.clone(); - thread::spawn(move || { - // TAP will be unavailable for approximately 2 seconds on a healthy machine. - thread::sleep(Duration::from_secs(5)); - debug!("TAP is presumed to have been re-initialized"); - let _ = cloned_sender.unbounded_send(TunnelCommand::IsOffline(false)); - }); - } - } - })?; + BroadcastListener::start(sender) +} - Ok(listener) +fn apply_system_state_change(state: Arc<Mutex<SystemState>>, change: StateChange) { + let mut state = state.lock(); + state.apply_change(change); } pub fn is_offline() -> bool { - false + match winnet::is_offline() { + Ok(state) => state, + Err(e) => { + log::error!( + "{}", + e.display_chain_with_msg("Failed to get current connectivity") + ); + false + } + } } diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index d659534e1f..eb8694bf27 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -1,5 +1,7 @@ -pub use self::api::ErrorSink; use self::api::*; +pub use self::api::{ + ErrorSink, WinNet_ActivateConnectivityMonitor, WinNet_DeactivateConnectivityMonitor, +}; use libc::{c_char, c_void, wchar_t}; use std::{ffi::OsString, ptr}; use widestring::WideCString; @@ -22,6 +24,10 @@ pub enum Error { /// Failed to determine alias of TAP adapter. #[error(display = "Failed to determine alias of TAP adapter")] GetTapAlias, + + /// Can't establish whether host is connected to a non-virtual network + #[error(display = "Network connectivity undecideable")] + ConnectivityUnkown, } /// Error callback used with `winnet.dll`. @@ -92,14 +98,7 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> { WinNet_GetTapInterfaceAlias(&mut alias_ptr as *mut _, Some(error_sink), ptr::null_mut()) }; - if status != 0 { - if status != 1 { - log::error!( - "Unexpected return code from WinNet_GetTapInterfaceAlias: {}", - status - ); - } - + if !status { return Err(Error::GetTapAlias); } @@ -109,6 +108,19 @@ pub fn get_tap_interface_alias() -> Result<OsString, Error> { Ok(alias.to_os_string()) } +/// Returns true if current host is not connected to any network +pub fn is_offline() -> Result<bool, Error> { + match unsafe { WinNet_CheckConnectivity(Some(error_sink), ptr::null_mut()) } { + // Not connected + 0 => Ok(true), + // Connected + 1 => Ok(false), + // 2 means that connectivity can't be determined, but any other return value is unexpected + // and as such, is considered to be an error. + _ => Err(Error::ConnectivityUnkown), + } +} + #[allow(non_snake_case)] mod api { use libc::{c_char, c_void, wchar_t}; @@ -116,6 +128,8 @@ mod api { /// Error callback type for use with `winnet.dll`. pub type ErrorSink = extern "system" fn(msg: *const c_char, ctx: *mut c_void); + pub type ConnectivityCallback = unsafe extern "system" fn(is_connected: bool, ctx: *mut c_void); + extern "system" { #[link_name = "WinNet_EnsureTopMetric"] pub fn WinNet_EnsureTopMetric( @@ -123,27 +137,36 @@ mod api { sink: Option<ErrorSink>, sink_context: *mut c_void, ) -> u32; - } - extern "system" { #[link_name = "WinNet_GetTapInterfaceIpv6Status"] pub fn WinNet_GetTapInterfaceIpv6Status( sink: Option<ErrorSink>, sink_context: *mut c_void, ) -> u32; - } - extern "system" { #[link_name = "WinNet_GetTapInterfaceAlias"] pub fn WinNet_GetTapInterfaceAlias( tunnel_interface_alias: *mut *mut wchar_t, sink: Option<ErrorSink>, sink_context: *mut c_void, - ) -> u32; - } + ) -> bool; - extern "system" { #[link_name = "WinNet_ReleaseString"] pub fn WinNet_ReleaseString(string: *mut wchar_t) -> u32; + + #[link_name = "WinNet_ActivateConnectivityMonitor"] + pub fn WinNet_ActivateConnectivityMonitor( + callback: Option<ConnectivityCallback>, + callbackContext: *mut libc::c_void, + currentConnectivity: *mut bool, + sink: Option<ErrorSink>, + sink_context: *mut c_void, + ) -> bool; + + #[link_name = "WinNet_DeactivateConnectivityMonitor"] + pub fn WinNet_DeactivateConnectivityMonitor() -> bool; + + #[link_name = "WinNet_CheckConnectivity"] + pub fn WinNet_CheckConnectivity(sink: Option<ErrorSink>, sink_context: *mut c_void) -> u32; } } |
