summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2022-07-22 11:49:14 +0200
committerDavid Lönnhager <david.l@mullvad.net>2022-08-25 15:34:30 +0200
commitb301a5d64d863f28947a4c133fe152c70cabe088 (patch)
tree8e2cfefe16acfb061c63a912018feb4635d9a267
parent4406785dec6c4235cab411eba44b46a8734ae0d6 (diff)
downloadmullvadvpn-b301a5d64d863f28947a4c133fe152c70cabe088.tar.xz
mullvadvpn-b301a5d64d863f28947a4c133fe152c70cabe088.zip
Rewrite ST service code to use windows-service crate
-rw-r--r--Cargo.lock1
-rw-r--r--talpid-core/Cargo.toml1
-rw-r--r--talpid-core/src/split_tunnel/windows/service.rs290
3 files changed, 105 insertions, 187 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 3be1fd40c4..ae54ff53cc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3120,6 +3120,7 @@ dependencies = [
"which",
"widestring 0.5.1",
"winapi",
+ "windows-service",
"windows-sys",
"winreg",
"zeroize",
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 2ba2675e8f..d36460cf41 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -84,6 +84,7 @@ winreg = { version = "0.7", features = ["transactions"] }
winapi = { version = "0.3.6", features = ["ws2def"] }
talpid-platform-metadata = { path = "../talpid-platform-metadata" }
memoffset = "0.6"
+windows-service = "0.5.0"
[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.36.1"
diff --git a/talpid-core/src/split_tunnel/windows/service.rs b/talpid-core/src/split_tunnel/windows/service.rs
index ba67425489..9aaf787f2f 100644
--- a/talpid-core/src/split_tunnel/windows/service.rs
+++ b/talpid-core/src/split_tunnel/windows/service.rs
@@ -1,30 +1,20 @@
-use std::{io, os::windows::prelude::OsStrExt, path::Path, ptr, time::Duration, ffi::OsString};
-use widestring::{WideCStr, WideCString};
-use windows::{
- core::PCWSTR,
- Win32::{
- Foundation::{
- GetLastError, ERROR_INSUFFICIENT_BUFFER, ERROR_SERVICE_DOES_NOT_EXIST, HANDLE, ERROR_SERVICE_ALREADY_RUNNING,
- },
- Security::SC_HANDLE,
- System::{
- Services::{
- CloseServiceHandle, CreateServiceW, OpenSCManagerW, OpenServiceW,
- QueryServiceConfigW, QUERY_SERVICE_CONFIGW, SC_MANAGER_ALL_ACCESS,
- SERVICE_ALL_ACCESS, SERVICE_DEMAND_START, SERVICE_ERROR_NORMAL,
- SERVICE_KERNEL_DRIVER, DeleteService, StartServiceW, QueryServiceStatus,
- SERVICE_STATUS, SERVICE_RUNNING, SERVICE_STOPPED, ControlService, SERVICE_CONTROL_STOP, SERVICE_STATUS_CURRENT_STATE,
- },
- SystemServices::GENERIC_READ,
- },
+use std::{
+ ffi::{OsStr, OsString},
+ io,
+ path::{Path, PathBuf},
+ time::Duration,
+};
+use windows::Win32::Foundation::{ERROR_SERVICE_ALREADY_RUNNING, ERROR_SERVICE_DOES_NOT_EXIST};
+use windows_service::{
+ service::{
+ Service, ServiceAccess, ServiceErrorControl, ServiceInfo, ServiceStartType, ServiceState,
+ ServiceType,
},
+ service_manager::{ServiceManager, ServiceManagerAccess},
};
-use talpid_types::ErrorExt;
-const SPLIT_TUNNEL_SERVICE: &[u8] =
- b"m\0u\0l\0l\0v\0a\0d\0-\0s\0p\0l\0i\0t\0-\0t\0u\0n\0n\0e\0l\0\0\0";
-const SERVICE_DISPLAY_NAME: &[u8] =
- b"M\0u\0l\0l\0v\0a\0d\0 \0S\0p\0l\0i\0t\0 \0T\0u\0n\0n\0e\0l\0 \0S\0e\0r\0v\0i\0c\0e\0\0\0";
+const SPLIT_TUNNEL_SERVICE: &str = "mullvad-split-tunnel";
+const SPLIT_TUNNEL_DISPLAY_NAME: &str = "Mullvad Split Tunnel Service";
const DRIVER_FILENAME: &str = "mullvad-split-tunnel.sys";
const WAIT_STATUS_TIMEOUT: Duration = Duration::from_secs(8);
@@ -34,211 +24,168 @@ const WAIT_STATUS_TIMEOUT: Duration = Duration::from_secs(8);
pub enum Error {
/// Failed to open service control manager
#[error(display = "Failed to connect to service control manager")]
- OpenServiceControlManager(#[error(source)] windows::core::Error),
+ OpenServiceControlManager(#[error(source)] windows_service::Error),
/// Failed to create a service handle
#[error(display = "Failed to open service")]
- OpenServiceHandle(#[error(source)] windows::core::Error),
+ OpenServiceHandle(#[error(source)] windows_service::Error),
/// Failed to start split tunnel service
#[error(display = "Failed to start split tunnel device driver service")]
- StartService(#[error(source)] windows::core::Error),
+ StartService(#[error(source)] windows_service::Error),
/// Failed to check service status
#[error(display = "Failed to query service status")]
- QueryServiceStatus(#[error(source)] windows::core::Error),
+ QueryServiceStatus(#[error(source)] windows_service::Error),
/// Failed to open service config
#[error(display = "Failed to retrieve service config")]
- QueryServiceConfig(#[error(source)] windows::core::Error),
+ QueryServiceConfig(#[error(source)] windows_service::Error),
/// Failed to install ST service
#[error(display = "Failed to install split tunnel driver")]
- InstallService(#[error(source)] windows::core::Error),
+ InstallService(#[error(source)] windows_service::Error),
/// Failed to start ST service
#[error(display = "Timed out waiting on service to start")]
StartTimeout,
/// Failed to connect to existing driver
- #[error(display = "Failed to connect to old service")]
- ConnectOldService(#[error(source)] super::driver::DeviceHandleError),
+ #[error(display = "Failed to open service handle")]
+ OpenHandle(#[error(source)] super::driver::DeviceHandleError),
/// Failed to reset existing driver
- #[error(display = "Failed to reset old service state")]
- ResetOldDriver(#[error(source)] io::Error),
-}
-
-struct ScopedServiceHandle(SC_HANDLE);
-
-impl Drop for ScopedServiceHandle {
- fn drop(&mut self) {
- unsafe { CloseServiceHandle(self.0) };
- }
+ #[error(display = "Failed to reset driver state")]
+ ResetDriver(#[error(source)] io::Error),
}
pub fn install_driver_if_required(resource_dir: &Path) -> Result<(), Error> {
- let scm =
- ScopedServiceHandle(unsafe { OpenSCManagerW(PCWSTR::default(), PCWSTR::default(), SC_MANAGER_ALL_ACCESS) }
- .map_err(Error::OpenServiceControlManager)?);
+ let scm = ServiceManager::local_computer(
+ None::<OsString>,
+ ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE,
+ )
+ .map_err(Error::OpenServiceControlManager)?;
let expected_syspath = resource_dir.join(DRIVER_FILENAME);
- let service = unsafe {
- OpenServiceW(
- scm.0,
- PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16),
- SERVICE_ALL_ACCESS,
- )
- .map(ScopedServiceHandle)
- };
-
- let service = match service {
+ let service = match scm.open_service(SPLIT_TUNNEL_SERVICE, ServiceAccess::all()) {
Ok(service) => service,
Err(error) => {
- return if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() {
- // TODO: could be marked for deletion
- unsafe { install_driver(scm.0, &expected_syspath) }
- } else {
- Err(Error::OpenServiceHandle(windows::core::Error::from(error)))
+ return match error {
+ windows_service::Error::Winapi(io_error)
+ if io_error.raw_os_error() == Some(ERROR_SERVICE_DOES_NOT_EXIST.0 as i32) =>
+ {
+ // TODO: could be marked for deletion
+ install_driver(&scm, &expected_syspath)
+ }
+ error => Err(Error::OpenServiceHandle(error)),
};
}
};
- let binpath = unsafe { get_driver_binpath(service.0) }?;
-
- // Replace existing driver if its path is unexpected
-
- if expected_syspath != Path::new(&binpath) {
- log::debug!("The correct ST driver is already installed");
- return unsafe { start_and_wait_for_service(service.0) };
+ if expected_syspath != get_driver_binpath(&service)? {
+ log::debug!("ST driver is already installed");
+ return start_and_wait_for_service(&service);
}
- log::debug!("Replacing ST driver with unexpected path");
-
- unsafe { remove_device(service.0) }?;
- drop(service);
+ log::debug!("Replacing ST driver due to unexpected path");
- unsafe { install_driver(scm.0, &expected_syspath) }
+ remove_device(service)?;
+ install_driver(&scm, &expected_syspath)
}
pub fn stop_driver_service() -> Result<(), Error> {
- let scm =
- ScopedServiceHandle(unsafe { OpenSCManagerW(PCWSTR::default(), PCWSTR::default(), SC_MANAGER_ALL_ACCESS) }
- .map_err(Error::OpenServiceControlManager)?);
- let service = unsafe {
- OpenServiceW(
- scm.0,
- PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16),
- SERVICE_ALL_ACCESS,
- )
- .map(ScopedServiceHandle)
- };
+ let scm = ServiceManager::local_computer(None::<OsString>, ServiceManagerAccess::CONNECT)
+ .map_err(Error::OpenServiceControlManager)?;
- let service = match service {
+ let service = match scm.open_service(SPLIT_TUNNEL_SERVICE, ServiceAccess::all()) {
Ok(service) => service,
Err(error) => {
- return if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() {
- return Ok(());
- } else {
- Err(Error::OpenServiceHandle(windows::core::Error::from(error)))
+ return match error {
+ windows_service::Error::Winapi(io_error)
+ if io_error.raw_os_error() == Some(ERROR_SERVICE_DOES_NOT_EXIST.0 as i32) =>
+ {
+ Ok(())
+ }
+ error => Err(Error::OpenServiceHandle(error)),
};
}
};
- unsafe { stop_service(service.0) }
+ stop_service(&service)
}
-unsafe fn stop_service(service: SC_HANDLE) -> Result<(), Error> {
- let mut service_status = SERVICE_STATUS::default();
- ControlService(service, SERVICE_CONTROL_STOP, &mut service_status);
- wait_for_status(service, SERVICE_STOPPED)
+fn stop_service(service: &Service) -> Result<(), Error> {
+ let _ = service.stop();
+ wait_for_status(service, ServiceState::Stopped)
}
-unsafe fn remove_device(service: SC_HANDLE) -> Result<(), Error> {
- reset_driver(service)?;
- stop_service(service)?;
- DeleteService(service);
+fn remove_device(service: Service) -> Result<(), Error> {
+ reset_driver(&service)?;
+ stop_service(&service)?;
+ let _ = service.delete();
Ok(())
}
-unsafe fn reset_driver(service: SC_HANDLE) -> Result<(), Error> {
- let mut service_status = SERVICE_STATUS::default();
+fn reset_driver(service: &Service) -> Result<(), Error> {
+ let status = service.query_status().map_err(Error::QueryServiceStatus)?;
- if !QueryServiceStatus(service, &mut service_status).as_bool() {
- return Err(Error::QueryServiceStatus(windows::core::Error::from(
- GetLastError(),
- )));
- }
-
- if service_status.dwCurrentState == SERVICE_RUNNING {
- let old_handle = super::driver::DeviceHandle::new_handle_only()
- .map_err(Error::ConnectOldService)?;
- old_handle.reset().map_err(Error::ResetOldDriver)?;
+ if status.current_state == ServiceState::Running {
+ let old_handle =
+ super::driver::DeviceHandle::new_handle_only().map_err(Error::OpenHandle)?;
+ old_handle.reset().map_err(Error::ResetDriver)?;
}
Ok(())
}
-unsafe fn install_driver(scm: SC_HANDLE, syspath: &Path) -> Result<(), Error> {
+fn install_driver(scm: &ServiceManager, syspath: &Path) -> Result<(), Error> {
log::debug!("Installing split tunnel driver");
- let binary_path: Vec<u16> = syspath
- .as_os_str()
- .encode_wide()
- .chain(std::iter::once(0u16))
- .collect();
-
- let service = CreateServiceW(
- scm,
- PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16),
- PCWSTR(SERVICE_DISPLAY_NAME as *const _ as *const u16),
- SERVICE_ALL_ACCESS,
- SERVICE_KERNEL_DRIVER,
- SERVICE_DEMAND_START,
- SERVICE_ERROR_NORMAL,
- PCWSTR(binary_path.as_ptr()),
- PCWSTR(ptr::null()),
- ptr::null_mut(),
- PCWSTR(ptr::null()),
- PCWSTR(ptr::null()),
- PCWSTR(ptr::null()),
- )
- .map_err(Error::InstallService)?;
+ let service_info = ServiceInfo {
+ name: SPLIT_TUNNEL_SERVICE.into(),
+ display_name: SPLIT_TUNNEL_DISPLAY_NAME.into(),
+ service_type: ServiceType::KERNEL_DRIVER,
+ start_type: ServiceStartType::OnDemand,
+ error_control: ServiceErrorControl::Normal,
+ executable_path: syspath.to_path_buf(),
+ launch_arguments: vec![],
+ dependencies: vec![],
+ account_name: None,
+ account_password: None,
+ };
- log::debug!("Created split tunnel service");
+ let service = scm
+ .create_service(
+ &service_info,
+ ServiceAccess::START | ServiceAccess::QUERY_STATUS,
+ )
+ .map_err(Error::InstallService)?;
- let service = ScopedServiceHandle(service);
- start_and_wait_for_service(service.0)
+ start_and_wait_for_service(&service)
}
-unsafe fn start_and_wait_for_service(service: SC_HANDLE) -> Result<(), Error> {
- if !StartServiceW(service, &[]).as_bool() {
- let last_error = GetLastError();
+fn start_and_wait_for_service(service: &Service) -> Result<(), Error> {
+ log::debug!("Starting split tunnel service");
- if last_error == ERROR_SERVICE_ALREADY_RUNNING {
- return Ok(());
+ if let Err(error) = service.start::<&OsStr>(&[]) {
+ if let windows_service::Error::Winapi(error) = &error {
+ if error.raw_os_error() == Some(ERROR_SERVICE_ALREADY_RUNNING.0 as i32) {
+ return Ok(());
+ }
}
-
- return Err(Error::StartService(windows::core::Error::from(last_error)));
+ return Err(Error::StartService(error));
}
- log::debug!("Starting split tunnel service");
-
- wait_for_status(service, SERVICE_RUNNING)
+ wait_for_status(service, ServiceState::Running)
}
-unsafe fn wait_for_status(service: SC_HANDLE, target_state: SERVICE_STATUS_CURRENT_STATE) -> Result<(), Error> {
- let mut service_status = SERVICE_STATUS::default();
+fn wait_for_status(service: &Service, target_state: ServiceState) -> Result<(), Error> {
let initial_time = std::time::Instant::now();
loop {
- if !QueryServiceStatus(service, &mut service_status).as_bool() {
- return Err(Error::QueryServiceStatus(windows::core::Error::from(
- GetLastError(),
- )));
- }
+ let status = service.query_status().map_err(Error::QueryServiceStatus)?;
- if service_status.dwCurrentState == target_state {
+ if status.current_state == target_state {
break;
}
@@ -252,38 +199,7 @@ unsafe fn wait_for_status(service: SC_HANDLE, target_state: SERVICE_STATUS_CURRE
Ok(())
}
-unsafe fn get_driver_binpath(service: SC_HANDLE) -> Result<OsString, Error> {
- let mut config_buf = vec![];
- let config;
-
- let mut bytes_needed = 0u32;
-
- let result = QueryServiceConfigW(service, ptr::null_mut(), 0, &mut bytes_needed);
- if !result.as_bool() {
- let last_error = GetLastError();
- if last_error != ERROR_INSUFFICIENT_BUFFER {
- return Err(Error::QueryServiceConfig(windows::core::Error::from(
- last_error,
- )));
- }
- }
-
- config_buf.resize(usize::try_from(bytes_needed).unwrap(), 0u8);
-
- let result = QueryServiceConfigW(
- service,
- config_buf.as_mut_ptr() as _,
- u32::try_from(config_buf.len()).unwrap(),
- &mut bytes_needed,
- );
-
- if !result.as_bool() {
- return Err(Error::QueryServiceConfig(windows::core::Error::from(
- GetLastError(),
- )));
- }
-
- config = &*(config_buf.as_ptr() as *const QUERY_SERVICE_CONFIGW);
-
- Ok(WideCStr::from_ptr_str(config.lpBinaryPathName.0).to_os_string())
+fn get_driver_binpath(service: &Service) -> Result<PathBuf, Error> {
+ let config = service.query_config().map_err(Error::QueryServiceConfig)?;
+ Ok(config.executable_path)
}