diff options
| -rw-r--r-- | talpid-core/Cargo.toml | 2 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/driver.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/mod.rs | 21 | ||||
| -rw-r--r-- | talpid-core/src/split_tunnel/windows/service.rs | 249 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 1 |
5 files changed, 277 insertions, 11 deletions
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 6216a23bb2..2ba2675e8f 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -90,6 +90,7 @@ version = "0.36.1" features = [ "Win32_Foundation", "Win32_Globalization", + "Win32_Security", "Win32_System_Com", "Win32_System_Diagnostics_ToolHelp", "Win32_System_Ioctl", @@ -97,6 +98,7 @@ features = [ "Win32_System_LibraryLoader", "Win32_System_ProcessStatus", "Win32_System_Registry", + "Win32_System_Services", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming", diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index f73e170cce..a3eb338c8c 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -214,8 +214,14 @@ pub enum DeviceHandleError { impl DeviceHandle { pub fn new() -> Result<Self, DeviceHandleError> { - // Connect to the driver + let device = Self::new_handle_only()?; + device.reinitialize()?; + Ok(device) + } + + pub(super) fn new_handle_only() -> Result<Self, DeviceHandleError> { log::trace!("Connecting to the driver"); + let handle = OpenOptions::new() .read(true) .write(true) @@ -228,10 +234,7 @@ impl DeviceHandle { Some(ERROR_ACCESS_DENIED) => DeviceHandleError::ConnectionDenied, _ => DeviceHandleError::ConnectionError(e), })?; - - let device = Self { handle }; - device.reinitialize()?; - Ok(device) + Ok(Self { handle }) } pub fn reinitialize(&self) -> Result<(), DeviceHandleError> { @@ -385,7 +388,7 @@ impl DeviceHandle { Ok(()) } - fn reset(&self) -> io::Result<()> { + pub(super) fn reset(&self) -> io::Result<()> { device_io_control(self, DriverIoctlCode::Reset as u32, None, 0)?; Ok(()) } diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 47c7d4ec8f..175957621f 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -1,5 +1,6 @@ mod driver; mod path_monitor; +mod service; mod volume_monitor; mod windows; @@ -39,6 +40,10 @@ const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); #[derive(err_derive::Error, Debug)] #[error(no_from)] pub enum Error { + /// Failed to install or start driver service + #[error(display = "Failed to start driver service")] + ServiceError(#[error(source)] service::Error), + /// Failed to initialize the driver #[error(display = "Failed to initialize driver")] InitializationError(#[error(source)] driver::DeviceHandleError), @@ -173,6 +178,7 @@ impl SplitTunnel { /// Initialize the split tunnel device. pub fn new( runtime: tokio::runtime::Handle, + resource_dir: PathBuf, daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>, volume_update_rx: mpsc::UnboundedReceiver<()>, power_mgmt_rx: PowerManagementListener, @@ -180,7 +186,7 @@ impl SplitTunnel { let excluded_processes = Arc::new(RwLock::new(HashMap::new())); let (request_tx, handle) = - Self::spawn_request_thread(volume_update_rx, excluded_processes.clone())?; + Self::spawn_request_thread(resource_dir, volume_update_rx, excluded_processes.clone())?; let (event_thread, quit_event) = Self::spawn_event_listener(handle, excluded_processes.clone())?; @@ -400,6 +406,7 @@ impl SplitTunnel { } fn spawn_request_thread( + resource_dir: PathBuf, volume_update_rx: mpsc::UnboundedReceiver<()>, excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>, ) -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> { @@ -421,10 +428,14 @@ impl SplitTunnel { ); std::thread::spawn(move || { - let result = driver::DeviceHandle::new() - .map(Arc::new) - .map_err(Error::InitializationError); - let handle = match result { + let init_fn = || { + service::install_driver_if_required(&resource_dir).map_err(Error::ServiceError)?; + driver::DeviceHandle::new() + .map(Arc::new) + .map_err(Error::InitializationError) + }; + + let handle = match init_fn() { Ok(handle) => { let _ = init_tx.send(Ok(handle.clone())); handle diff --git a/talpid-core/src/split_tunnel/windows/service.rs b/talpid-core/src/split_tunnel/windows/service.rs new file mode 100644 index 0000000000..0618a1cfbf --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/service.rs @@ -0,0 +1,249 @@ +use std::{io, os::windows::prelude::OsStrExt, path::Path, ptr, time::Duration, ffi::OsString}; +use widestring::{WideCStr, WideCString}; +use windows::{ + core::PCWSTR, + Win32::{ + Foundation::{ + GetLastError, ERROR_INSUFFICIENT_BUFFER, ERROR_SERVICE_DOES_NOT_EXIST, HANDLE, ERROR_SERVICE_ALREADY_RUNNING, + }, + Security::SC_HANDLE, + System::{ + Services::{ + CloseServiceHandle, CreateServiceW, OpenSCManagerW, OpenServiceW, + QueryServiceConfigW, QUERY_SERVICE_CONFIGW, SC_MANAGER_ALL_ACCESS, + SERVICE_ALL_ACCESS, SERVICE_DEMAND_START, SERVICE_ERROR_NORMAL, + SERVICE_KERNEL_DRIVER, DeleteService, StartServiceW, QueryServiceStatus, + SERVICE_STATUS, SERVICE_RUNNING, ControlService, SERVICE_CONTROL_STOP, + }, + SystemServices::GENERIC_READ, + }, + }, +}; +use talpid_types::ErrorExt; + +const SPLIT_TUNNEL_SERVICE: &[u8] = + b"m\0u\0l\0l\0v\0a\0d\0-\0s\0p\0l\0i\0t\0-\0t\0u\0n\0n\0e\0l\0\0\0"; +const SERVICE_DISPLAY_NAME: &[u8] = + b"M\0u\0l\0l\0v\0a\0d\0 \0S\0p\0l\0i\0t\0 \0T\0u\0n\0n\0e\0l\0 \0S\0e\0r\0v\0i\0c\0e\0\0\0"; +const DRIVER_FILENAME: &str = "mullvad-split-tunnel.sys"; + +const START_TIMEOUT: Duration = Duration::from_secs(8); + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to open service control manager + #[error(display = "Failed to connect to service control manager")] + OpenServiceControlManager(#[error(source)] windows::core::Error), + + /// Failed to create a service handle + #[error(display = "Failed to open service")] + OpenServiceHandle(#[error(source)] windows::core::Error), + + /// Failed to start split tunnel service + #[error(display = "Failed to start split tunnel device driver service")] + StartService(#[error(source)] windows::core::Error), + + /// Failed to check service status + #[error(display = "Failed to query service status")] + QueryServiceStatus(#[error(source)] windows::core::Error), + + /// Failed to open service config + #[error(display = "Failed to retrieve service config")] + QueryServiceConfig(#[error(source)] windows::core::Error), + + /// Failed to install ST service + #[error(display = "Failed to install split tunnel driver")] + InstallService(#[error(source)] windows::core::Error), + + /// Failed to start ST service + #[error(display = "Timed out waiting on service to start")] + StartTimeout, + + /// Failed to connect to existing driver + #[error(display = "Failed to connect to old service")] + ConnectOldService(#[error(source)] super::driver::DeviceHandleError), + + /// Failed to reset existing driver + #[error(display = "Failed to reset old service state")] + ResetOldDriver(#[error(source)] io::Error), +} + +struct ScopedServiceHandle(SC_HANDLE); + +impl Drop for ScopedServiceHandle { + fn drop(&mut self) { + unsafe { CloseServiceHandle(self.0) }; + } +} + +pub fn install_driver_if_required(resource_dir: &Path) -> Result<(), Error> { + let scm = + ScopedServiceHandle(unsafe { OpenSCManagerW(PCWSTR::default(), PCWSTR::default(), SC_MANAGER_ALL_ACCESS) } + .map_err(Error::OpenServiceControlManager)?); + + let expected_syspath = resource_dir.join(DRIVER_FILENAME); + + let service = unsafe { + OpenServiceW( + scm.0, + PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16), + SERVICE_ALL_ACCESS, + ) + .map(ScopedServiceHandle) + }; + + let service = match service { + Ok(service) => service, + Err(error) => { + return if error.code() == ERROR_SERVICE_DOES_NOT_EXIST.to_hresult() { + // TODO: could be marked for deletion + unsafe { install_driver(scm.0, &expected_syspath) } + } else { + Err(Error::OpenServiceHandle(windows::core::Error::from(error))) + }; + } + }; + + let binpath = unsafe { get_driver_binpath(service.0) }?; + + // Replace existing driver if its path is unexpected + + if expected_syspath != Path::new(&binpath) { + log::debug!("The correct ST driver is already installed"); + return unsafe { start_and_wait_for_service(service.0) }; + } + + log::debug!("Replacing ST driver with unexpected path"); + + unsafe { remove_device(service.0) }?; + drop(service); + + unsafe { install_driver(scm.0, &expected_syspath) } +} + +unsafe fn remove_device(service: SC_HANDLE) -> Result<(), Error> { + if let Err(error) = (|| -> Result<(), Error> { + let old_handle = super::driver::DeviceHandle::new_handle_only() + .map_err(Error::ConnectOldService)?; + old_handle.reset().map_err(Error::ResetOldDriver)?; + Ok(()) + })() { + log::warn!("{}", error.display_chain_with_msg("Failed to reset existing ST service")); + } + + let mut service_status = SERVICE_STATUS::default(); + ControlService(service, SERVICE_CONTROL_STOP, &mut service_status); + + // TODO: wait? + + DeleteService(service); + + // TODO: handle error + + Ok(()) +} + +unsafe fn install_driver(scm: SC_HANDLE, syspath: &Path) -> Result<(), Error> { + log::debug!("Installing split tunnel driver"); + + let binary_path: Vec<u16> = syspath + .as_os_str() + .encode_wide() + .chain(std::iter::once(0u16)) + .collect(); + + let service = CreateServiceW( + scm, + PCWSTR(SPLIT_TUNNEL_SERVICE as *const _ as *const u16), + PCWSTR(SERVICE_DISPLAY_NAME as *const _ as *const u16), + SERVICE_ALL_ACCESS, + SERVICE_KERNEL_DRIVER, + SERVICE_DEMAND_START, + SERVICE_ERROR_NORMAL, + PCWSTR(binary_path.as_ptr()), + PCWSTR(ptr::null()), + ptr::null_mut(), + PCWSTR(ptr::null()), + PCWSTR(ptr::null()), + PCWSTR(ptr::null()), + ) + .map_err(Error::InstallService)?; + + log::debug!("Created split tunnel service"); + + let service = ScopedServiceHandle(service); + start_and_wait_for_service(service.0) +} + +unsafe fn start_and_wait_for_service(service: SC_HANDLE) -> Result<(), Error> { + if !StartServiceW(service, &[]).as_bool() { + let last_error = GetLastError(); + + if last_error == ERROR_SERVICE_ALREADY_RUNNING { + return Ok(()); + } + + return Err(Error::StartService(windows::core::Error::from(last_error))); + } + + log::debug!("Starting split tunnel service"); + + let mut service_status = SERVICE_STATUS::default(); + let initial_time = std::time::Instant::now(); + loop { + if !QueryServiceStatus(service, &mut service_status).as_bool() { + return Err(Error::QueryServiceStatus(windows::core::Error::from( + GetLastError(), + ))); + } + + if service_status.dwCurrentState == SERVICE_RUNNING { + break; + } + + if initial_time.elapsed() >= START_TIMEOUT { + return Err(Error::StartTimeout); + } + + std::thread::sleep(std::time::Duration::from_secs(1)); + } + + Ok(()) +} + +unsafe fn get_driver_binpath(service: SC_HANDLE) -> Result<OsString, Error> { + let mut config_buf = vec![]; + let config; + + let mut bytes_needed = 0u32; + + let result = QueryServiceConfigW(service, ptr::null_mut(), 0, &mut bytes_needed); + if !result.as_bool() { + let last_error = GetLastError(); + if last_error != ERROR_INSUFFICIENT_BUFFER { + return Err(Error::QueryServiceConfig(windows::core::Error::from( + last_error, + ))); + } + } + + config_buf.resize(usize::try_from(bytes_needed).unwrap(), 0u8); + + let result = QueryServiceConfigW( + service, + config_buf.as_mut_ptr() as _, + u32::try_from(config_buf.len()).unwrap(), + &mut bytes_needed, + ); + + if !result.as_bool() { + return Err(Error::QueryServiceConfig(windows::core::Error::from( + GetLastError(), + ))); + } + + config = &*(config_buf.as_ptr() as *const QUERY_SERVICE_CONFIGW); + + Ok(WideCStr::from_ptr_str(config.lpBinaryPathName.0).to_os_string()) +} diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 061798b1e2..c1b52278f0 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -261,6 +261,7 @@ impl TunnelStateMachine { #[cfg(windows)] let split_tunnel = split_tunnel::SplitTunnel::new( runtime.clone(), + args.resource_dir.clone(), args.command_tx.clone(), volume_update_rx, power_mgmt_rx.clone(), |
