diff options
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 37 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/service.rs | 72 | ||||
| -rw-r--r-- | windows/driverlogic/src/service.cpp | 5 |
3 files changed, 94 insertions, 20 deletions
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 175957621f..ed5102ab45 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -48,6 +48,10 @@ pub enum Error { #[error(display = "Failed to initialize driver")] InitializationError(#[error(source)] driver::DeviceHandleError), + /// Failed to reset the driver + #[error(display = "Failed to reset driver")] + ResetError(#[error(source)] io::Error), + /// Failed to set paths to excluded applications #[error(display = "Failed to set list of excluded applications")] SetConfiguration(#[error(source)] io::Error), @@ -122,6 +126,7 @@ enum Request { SetPaths(Vec<OsString>), RegisterIps(InterfaceAddresses), Restart, + Stop, } type RequestResponseTx = sync_mpsc::Sender<Result<(), Error>>; type RequestTx = sync_mpsc::Sender<(Request, RequestResponseTx)>; @@ -529,6 +534,20 @@ impl SplitTunnel { Ok(()) })() } + Request::Stop => { + if let Err(error) = handle.reset().map_err(Error::ResetError) { + let _ = response_tx.send(Err(error)); + continue; + } + + monitored_paths.lock().unwrap().clear(); + excluded_processes.write().unwrap().clear(); + + let _ = response_tx.send(Ok(())); + + // Stop listening to commands + break; + } }; if response_tx.send(response).is_err() { log::error!("A response could not be sent for a completed request"); @@ -542,6 +561,16 @@ impl SplitTunnel { error.display_chain_with_msg("Failed to shut down path monitor") ); } + + drop(handle); + + log::debug!("Stopping ST service"); + if let Err(error) = service::stop_driver_service() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to stop ST service") + ); + } }); let handle = init_rx @@ -729,9 +758,11 @@ impl Drop for SplitTunnel { // Not joining `event_thread`: It may be unresponsive. } - let paths: [&OsStr; 0] = []; - if let Err(error) = self.set_paths_sync(&paths) { - log::error!("{}", error.display_chain()); + if let Err(error) = self.send_request(Request::Stop) { + log::error!( + "{}", + error.display_chain_with_msg("Failed to stop ST driver service") + ); } } } diff --git a/talpid-core/src/split_tunnel/windows/service.rs b/talpid-core/src/split_tunnel/windows/service.rs index 0618a1cfbf..ba67425489 100644 --- a/talpid-core/src/split_tunnel/windows/service.rs +++ b/talpid-core/src/split_tunnel/windows/service.rs @@ -13,7 +13,7 @@ use windows::{ QueryServiceConfigW, QUERY_SERVICE_CONFIGW, SC_MANAGER_ALL_ACCESS, SERVICE_ALL_ACCESS, SERVICE_DEMAND_START, SERVICE_ERROR_NORMAL, SERVICE_KERNEL_DRIVER, DeleteService, StartServiceW, QueryServiceStatus, - SERVICE_STATUS, SERVICE_RUNNING, ControlService, SERVICE_CONTROL_STOP, + SERVICE_STATUS, SERVICE_RUNNING, SERVICE_STOPPED, ControlService, SERVICE_CONTROL_STOP, SERVICE_STATUS_CURRENT_STATE, }, SystemServices::GENERIC_READ, }, @@ -27,7 +27,7 @@ const SERVICE_DISPLAY_NAME: &[u8] = b"M\0u\0l\0l\0v\0a\0d\0 \0S\0p\0l\0i\0t\0 \0T\0u\0n\0n\0e\0l\0 \0S\0e\0r\0v\0i\0c\0e\0\0\0"; const DRIVER_FILENAME: &str = "mullvad-split-tunnel.sys"; -const START_TIMEOUT: Duration = Duration::from_secs(8); +const WAIT_STATUS_TIMEOUT: Duration = Duration::from_secs(8); #[derive(err_derive::Error, Debug)] #[error(no_from)] @@ -122,24 +122,60 @@ pub fn install_driver_if_required(resource_dir: &Path) -> Result<(), Error> { unsafe { install_driver(scm.0, &expected_syspath) } } -unsafe fn remove_device(service: SC_HANDLE) -> Result<(), Error> { - if let Err(error) = (|| -> Result<(), Error> { - let old_handle = super::driver::DeviceHandle::new_handle_only() - .map_err(Error::ConnectOldService)?; - old_handle.reset().map_err(Error::ResetOldDriver)?; - Ok(()) - })() { - log::warn!("{}", error.display_chain_with_msg("Failed to reset existing ST service")); - } +pub fn stop_driver_service() -> Result<(), Error> { + let scm = + ScopedServiceHandle(unsafe { OpenSCManagerW(PCWSTR::default(), PCWSTR::default(), SC_MANAGER_ALL_ACCESS) } + .map_err(Error::OpenServiceControlManager)?); + let service = unsafe { + OpenServiceW( + scm.0, + PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16), + SERVICE_ALL_ACCESS, + ) + .map(ScopedServiceHandle) + }; + let service = match service { + Ok(service) => service, + Err(error) => { + return if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() { + return Ok(()); + } else { + Err(Error::OpenServiceHandle(windows::core::Error::from(error))) + }; + } + }; + + unsafe { stop_service(service.0) } +} + +unsafe fn stop_service(service: SC_HANDLE) -> Result<(), Error> { let mut service_status = SERVICE_STATUS::default(); ControlService(service, SERVICE_CONTROL_STOP, &mut service_status); + wait_for_status(service, SERVICE_STOPPED) +} - // TODO: wait? - +unsafe fn remove_device(service: SC_HANDLE) -> Result<(), Error> { + reset_driver(service)?; + stop_service(service)?; DeleteService(service); + Ok(()) +} + +unsafe fn reset_driver(service: SC_HANDLE) -> Result<(), Error> { + let mut service_status = SERVICE_STATUS::default(); - // TODO: handle error + if !QueryServiceStatus(service, &mut service_status).as_bool() { + return Err(Error::QueryServiceStatus(windows::core::Error::from( + GetLastError(), + ))); + } + + if service_status.dwCurrentState == SERVICE_RUNNING { + let old_handle = super::driver::DeviceHandle::new_handle_only() + .map_err(Error::ConnectOldService)?; + old_handle.reset().map_err(Error::ResetOldDriver)?; + } Ok(()) } @@ -189,6 +225,10 @@ unsafe fn start_and_wait_for_service(service: SC_HANDLE) -> Result<(), Error> { log::debug!("Starting split tunnel service"); + wait_for_status(service, SERVICE_RUNNING) +} + +unsafe fn wait_for_status(service: SC_HANDLE, target_state: SERVICE_STATUS_CURRENT_STATE) -> Result<(), Error> { let mut service_status = SERVICE_STATUS::default(); let initial_time = std::time::Instant::now(); loop { @@ -198,11 +238,11 @@ unsafe fn start_and_wait_for_service(service: SC_HANDLE) -> Result<(), Error> { ))); } - if service_status.dwCurrentState == SERVICE_RUNNING { + if service_status.dwCurrentState == target_state { break; } - if initial_time.elapsed() >= START_TIMEOUT { + if initial_time.elapsed() >= WAIT_STATUS_TIMEOUT { return Err(Error::StartTimeout); } diff --git a/windows/driverlogic/src/service.cpp b/windows/driverlogic/src/service.cpp index f1daa1abfb..cdaaf77ea9 100644 --- a/windows/driverlogic/src/service.cpp +++ b/windows/driverlogic/src/service.cpp @@ -124,7 +124,10 @@ bool ServiceIsRunning(const std::wstring &serviceName) return false; } - CloseServiceHandle(service); + dtor += [service]() + { + CloseServiceHandle(service); + }; return GetServiceProcessStatus(service).dwCurrentState == SERVICE_RUNNING; } |
