summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/Cargo.toml2
-rw-r--r--talpid-core/src/split_tunnel/windows/driver.rs15
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs21
-rw-r--r--talpid-core/src/split_tunnel/windows/service.rs249
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs1
5 files changed, 277 insertions, 11 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 6216a23bb2..2ba2675e8f 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -90,6 +90,7 @@ version = "0.36.1"
features = [
"Win32_Foundation",
"Win32_Globalization",
+ "Win32_Security",
"Win32_System_Com",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_Ioctl",
@@ -97,6 +98,7 @@ features = [
"Win32_System_LibraryLoader",
"Win32_System_ProcessStatus",
"Win32_System_Registry",
+ "Win32_System_Services",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_System_WindowsProgramming",
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs
index f73e170cce..a3eb338c8c 100644
--- a/talpid-core/src/split_tunnel/windows/driver.rs
+++ b/talpid-core/src/split_tunnel/windows/driver.rs
@@ -214,8 +214,14 @@ pub enum DeviceHandleError {
impl DeviceHandle {
pub fn new() -> Result<Self, DeviceHandleError> {
- // Connect to the driver
+ let device = Self::new_handle_only()?;
+ device.reinitialize()?;
+ Ok(device)
+ }
+
+ pub(super) fn new_handle_only() -> Result<Self, DeviceHandleError> {
log::trace!("Connecting to the driver");
+
let handle = OpenOptions::new()
.read(true)
.write(true)
@@ -228,10 +234,7 @@ impl DeviceHandle {
Some(ERROR_ACCESS_DENIED) => DeviceHandleError::ConnectionDenied,
_ => DeviceHandleError::ConnectionError(e),
})?;
-
- let device = Self { handle };
- device.reinitialize()?;
- Ok(device)
+ Ok(Self { handle })
}
pub fn reinitialize(&self) -> Result<(), DeviceHandleError> {
@@ -385,7 +388,7 @@ impl DeviceHandle {
Ok(())
}
- fn reset(&self) -> io::Result<()> {
+ pub(super) fn reset(&self) -> io::Result<()> {
device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?;
Ok(())
}
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
index 47c7d4ec8f..175957621f 100644
--- a/talpid-core/src/split_tunnel/windows/mod.rs
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -1,5 +1,6 @@
mod driver;
mod path_monitor;
+mod service;
mod volume_monitor;
mod windows;
@@ -39,6 +40,10 @@ const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123);
#[derive(err_derive::Error, Debug)]
#[error(no_from)]
pub enum Error {
+ /// Failed to install or start driver service
+ #[error(display = "Failed to start driver service")]
+ ServiceError(#[error(source)] service::Error),
+
/// Failed to initialize the driver
#[error(display = "Failed to initialize driver")]
InitializationError(#[error(source)] driver::DeviceHandleError),
@@ -173,6 +178,7 @@ impl SplitTunnel {
/// Initialize the split tunnel device.
pub fn new(
runtime: tokio::runtime::Handle,
+ resource_dir: PathBuf,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
volume_update_rx: mpsc::UnboundedReceiver<()>,
power_mgmt_rx: PowerManagementListener,
@@ -180,7 +186,7 @@ impl SplitTunnel {
let excluded_processes = Arc::new(RwLock::new(HashMap::new()));
let (request_tx, handle) =
- Self::spawn_request_thread(volume_update_rx, excluded_processes.clone())?;
+ Self::spawn_request_thread(resource_dir, volume_update_rx, excluded_processes.clone())?;
let (event_thread, quit_event) =
Self::spawn_event_listener(handle, excluded_processes.clone())?;
@@ -400,6 +406,7 @@ impl SplitTunnel {
}
fn spawn_request_thread(
+ resource_dir: PathBuf,
volume_update_rx: mpsc::UnboundedReceiver<()>,
excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
) -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
@@ -421,10 +428,14 @@ impl SplitTunnel {
);
std::thread::spawn(move || {
- let result = driver::DeviceHandle::new()
- .map(Arc::new)
- .map_err(Error::InitializationError);
- let handle = match result {
+ let init_fn = || {
+ service::install_driver_if_required(&resource_dir).map_err(Error::ServiceError)?;
+ driver::DeviceHandle::new()
+ .map(Arc::new)
+ .map_err(Error::InitializationError)
+ };
+
+ let handle = match init_fn() {
Ok(handle) => {
let _ = init_tx.send(Ok(handle.clone()));
handle
diff --git a/talpid-core/src/split_tunnel/windows/service.rs b/talpid-core/src/split_tunnel/windows/service.rs
new file mode 100644
index 0000000000..0618a1cfbf
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/service.rs
@@ -0,0 +1,249 @@
+use std::{io, os::windows::prelude::OsStrExt, path::Path, ptr, time::Duration, ffi::OsString};
+use widestring::{WideCStr, WideCString};
+use windows::{
+ core::PCWSTR,
+ Win32::{
+ Foundation::{
+ GetLastError, ERROR_INSUFFICIENT_BUFFER, ERROR_SERVICE_DOES_NOT_EXIST, HANDLE, ERROR_SERVICE_ALREADY_RUNNING,
+ },
+ Security::SC_HANDLE,
+ System::{
+ Services::{
+ CloseServiceHandle, CreateServiceW, OpenSCManagerW, OpenServiceW,
+ QueryServiceConfigW, QUERY_SERVICE_CONFIGW, SC_MANAGER_ALL_ACCESS,
+ SERVICE_ALL_ACCESS, SERVICE_DEMAND_START, SERVICE_ERROR_NORMAL,
+ SERVICE_KERNEL_DRIVER, DeleteService, StartServiceW, QueryServiceStatus,
+ SERVICE_STATUS, SERVICE_RUNNING, ControlService, SERVICE_CONTROL_STOP,
+ },
+ SystemServices::GENERIC_READ,
+ },
+ },
+};
+use talpid_types::ErrorExt;
+
+const SPLIT_TUNNEL_SERVICE: &[u8] =
+ b"m\0u\0l\0l\0v\0a\0d\0-\0s\0p\0l\0i\0t\0-\0t\0u\0n\0n\0e\0l\0\0\0";
+const SERVICE_DISPLAY_NAME: &[u8] =
+ b"M\0u\0l\0l\0v\0a\0d\0 \0S\0p\0l\0i\0t\0 \0T\0u\0n\0n\0e\0l\0 \0S\0e\0r\0v\0i\0c\0e\0\0\0";
+const DRIVER_FILENAME: &str = "mullvad-split-tunnel.sys";
+
+const START_TIMEOUT: Duration = Duration::from_secs(8);
+
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to open service control manager
+ #[error(display = "Failed to connect to service control manager")]
+ OpenServiceControlManager(#[error(source)] windows::core::Error),
+
+ /// Failed to create a service handle
+ #[error(display = "Failed to open service")]
+ OpenServiceHandle(#[error(source)] windows::core::Error),
+
+ /// Failed to start split tunnel service
+ #[error(display = "Failed to start split tunnel device driver service")]
+ StartService(#[error(source)] windows::core::Error),
+
+ /// Failed to check service status
+ #[error(display = "Failed to query service status")]
+ QueryServiceStatus(#[error(source)] windows::core::Error),
+
+ /// Failed to open service config
+ #[error(display = "Failed to retrieve service config")]
+ QueryServiceConfig(#[error(source)] windows::core::Error),
+
+ /// Failed to install ST service
+ #[error(display = "Failed to install split tunnel driver")]
+ InstallService(#[error(source)] windows::core::Error),
+
+ /// Failed to start ST service
+ #[error(display = "Timed out waiting on service to start")]
+ StartTimeout,
+
+ /// Failed to connect to existing driver
+ #[error(display = "Failed to connect to old service")]
+ ConnectOldService(#[error(source)] super::driver::DeviceHandleError),
+
+ /// Failed to reset existing driver
+ #[error(display = "Failed to reset old service state")]
+ ResetOldDriver(#[error(source)] io::Error),
+}
+
+struct ScopedServiceHandle(SC_HANDLE);
+
+impl Drop for ScopedServiceHandle {
+ fn drop(&mut self) {
+ unsafe { CloseServiceHandle(self.0) };
+ }
+}
+
+pub fn install_driver_if_required(resource_dir: &Path) -> Result<(), Error> {
+ let scm =
+ ScopedServiceHandle(unsafe { OpenSCManagerW(PCWSTR::default(), PCWSTR::default(), SC_MANAGER_ALL_ACCESS) }
+ .map_err(Error::OpenServiceControlManager)?);
+
+ let expected_syspath = resource_dir.join(DRIVER_FILENAME);
+
+ let service = unsafe {
+ OpenServiceW(
+ scm.0,
+ PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16),
+ SERVICE_ALL_ACCESS,
+ )
+ .map(ScopedServiceHandle)
+ };
+
+ let service = match service {
+ Ok(service) => service,
+ Err(error) => {
+ return if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() {
+ // TODO: could be marked for deletion
+ unsafe { install_driver(scm.0, &expected_syspath) }
+ } else {
+ Err(Error::OpenServiceHandle(windows::core::Error::from(error)))
+ };
+ }
+ };
+
+ let binpath = unsafe { get_driver_binpath(service.0) }?;
+
+ // Replace existing driver if its path is unexpected
+
+ if expected_syspath != Path::new(&binpath) {
+ log::debug!("The correct ST driver is already installed");
+ return unsafe { start_and_wait_for_service(service.0) };
+ }
+
+ log::debug!("Replacing ST driver with unexpected path");
+
+ unsafe { remove_device(service.0) }?;
+ drop(service);
+
+ unsafe { install_driver(scm.0, &expected_syspath) }
+}
+
+unsafe fn remove_device(service: SC_HANDLE) -> Result<(), Error> {
+ if let Err(error) = (|| -> Result<(), Error> {
+ let old_handle = super::driver::DeviceHandle::new_handle_only()
+ .map_err(Error::ConnectOldService)?;
+ old_handle.reset().map_err(Error::ResetOldDriver)?;
+ Ok(())
+ })() {
+ log::warn!("{}", error.display_chain_with_msg("Failed to reset existing ST service"));
+ }
+
+ let mut service_status = SERVICE_STATUS::default();
+ ControlService(service, SERVICE_CONTROL_STOP, &mut service_status);
+
+ // TODO: wait?
+
+ DeleteService(service);
+
+ // TODO: handle error
+
+ Ok(())
+}
+
+unsafe fn install_driver(scm: SC_HANDLE, syspath: &Path) -> Result<(), Error> {
+ log::debug!("Installing split tunnel driver");
+
+ let binary_path: Vec<u16> = syspath
+ .as_os_str()
+ .encode_wide()
+ .chain(std::iter::once(0u16))
+ .collect();
+
+ let service = CreateServiceW(
+ scm,
+ PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16),
+ PCWSTR(SERVICE_DISPLAY_NAME as *const _ as *const u16),
+ SERVICE_ALL_ACCESS,
+ SERVICE_KERNEL_DRIVER,
+ SERVICE_DEMAND_START,
+ SERVICE_ERROR_NORMAL,
+ PCWSTR(binary_path.as_ptr()),
+ PCWSTR(ptr::null()),
+ ptr::null_mut(),
+ PCWSTR(ptr::null()),
+ PCWSTR(ptr::null()),
+ PCWSTR(ptr::null()),
+ )
+ .map_err(Error::InstallService)?;
+
+ log::debug!("Created split tunnel service");
+
+ let service = ScopedServiceHandle(service);
+ start_and_wait_for_service(service.0)
+}
+
+unsafe fn start_and_wait_for_service(service: SC_HANDLE) -> Result<(), Error> {
+ if !StartServiceW(service, &[]).as_bool() {
+ let last_error = GetLastError();
+
+ if last_error == ERROR_SERVICE_ALREADY_RUNNING {
+ return Ok(());
+ }
+
+ return Err(Error::StartService(windows::core::Error::from(last_error)));
+ }
+
+ log::debug!("Starting split tunnel service");
+
+ let mut service_status = SERVICE_STATUS::default();
+ let initial_time = std::time::Instant::now();
+ loop {
+ if !QueryServiceStatus(service, &mut service_status).as_bool() {
+ return Err(Error::QueryServiceStatus(windows::core::Error::from(
+ GetLastError(),
+ )));
+ }
+
+ if service_status.dwCurrentState == SERVICE_RUNNING {
+ break;
+ }
+
+ if initial_time.elapsed() >= START_TIMEOUT {
+ return Err(Error::StartTimeout);
+ }
+
+ std::thread::sleep(std::time::Duration::from_secs(1));
+ }
+
+ Ok(())
+}
+
+unsafe fn get_driver_binpath(service: SC_HANDLE) -> Result<OsString, Error> {
+ let mut config_buf = vec![];
+ let config;
+
+ let mut bytes_needed = 0u32;
+
+ let result = QueryServiceConfigW(service, ptr::null_mut(), 0, &mut bytes_needed);
+ if !result.as_bool() {
+ let last_error = GetLastError();
+ if last_error != ERROR_INSUFFICIENT_BUFFER {
+ return Err(Error::QueryServiceConfig(windows::core::Error::from(
+ last_error,
+ )));
+ }
+ }
+
+ config_buf.resize(usize::try_from(bytes_needed).unwrap(), 0u8);
+
+ let result = QueryServiceConfigW(
+ service,
+ config_buf.as_mut_ptr() as _,
+ u32::try_from(config_buf.len()).unwrap(),
+ &mut bytes_needed,
+ );
+
+ if !result.as_bool() {
+ return Err(Error::QueryServiceConfig(windows::core::Error::from(
+ GetLastError(),
+ )));
+ }
+
+ config = &*(config_buf.as_ptr() as *const QUERY_SERVICE_CONFIGW);
+
+ Ok(WideCStr::from_ptr_str(config.lpBinaryPathName.0).to_os_string())
+}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 061798b1e2..c1b52278f0 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -261,6 +261,7 @@ impl TunnelStateMachine {
#[cfg(windows)]
let split_tunnel = split_tunnel::SplitTunnel::new(
runtime.clone(),
+ args.resource_dir.clone(),
args.command_tx.clone(),
volume_update_rx,
power_mgmt_rx.clone(),