diff options
| -rw-r--r-- | Cargo.lock | 22 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | windows-service/Cargo.toml | 15 | ||||
| -rw-r--r-- | windows-service/examples/simple_service.rs | 451 | ||||
| -rw-r--r-- | windows-service/src/lib.rs | 18 | ||||
| -rw-r--r-- | windows-service/src/service.rs | 390 | ||||
| -rw-r--r-- | windows-service/src/service_control_handler.rs | 144 | ||||
| -rw-r--r-- | windows-service/src/service_dispatcher.rs | 90 | ||||
| -rw-r--r-- | windows-service/src/service_manager.rs | 196 | ||||
| -rw-r--r-- | windows-service/src/shell_escape.rs | 120 |
10 files changed, 1447 insertions, 0 deletions
diff --git a/Cargo.lock b/Cargo.lock index bf3c9f6102..b096d0c5b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1189,6 +1189,15 @@ dependencies = [ ] [[package]] +name = "simplelog" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "chrono 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "slab" version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1582,6 +1591,18 @@ dependencies = [ ] [[package]] +name = "windows-service" +version = "0.1.0" +dependencies = [ + "bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "error-chain 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", + "simplelog 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "widestring 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "ws" version = "0.7.5" source = "git+https://github.com/tomusdrw/ws-rs#368ce39e2aa8700d568ca29dbacaecdf1bf749d1" @@ -1739,6 +1760,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum shell-escape 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "dd5cc96481d54583947bfe88bf30c23d53f883c6cd0145368b69989d97b84ef8" "checksum shell32-sys 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "9ee04b46101f57121c9da2b151988283b6beb79b34f5bb29a58ee48cb695122c" "checksum simple-signal 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c1eb01a0c2d12db9e52684e73038eac812494e5937571ae2631f5cf53dc56687" +"checksum simplelog 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ce595117de34b75e057b41e99079e43e9fcc4e5ec9c7ba5f2fea55321f0c624e" "checksum slab 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d807fd58c4181bbabed77cb3b891ba9748241a552bcc5be698faaebefc54f46e" "checksum slab 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b4fcaed89ab08ef143da37bc52adbcc04d4a69014f4c1208d6b51f0c47bc23" "checksum slab 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fdeff4cd9ecff59ec7e3744cbca73dfe5ac35c2aedb2cfba8a1c715a18912e9d" diff --git a/Cargo.toml b/Cargo.toml index 1bed44a013..68e56d35aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "mullvad-daemon", + "windows-service", "mullvad-cli", "mullvad-types", "mullvad-rpc", diff --git a/windows-service/Cargo.toml b/windows-service/Cargo.toml new file mode 100644 index 0000000000..0e73261ce8 --- /dev/null +++ b/windows-service/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "windows-service" +version = "0.1.0" +authors = ["Mullvad VPN <admin@mullvad.net>", "Andrej Mihajlov <and@mullvad.net>"] + +[target.'cfg(windows)'.dependencies] +bitflags = "1.0.1" +error-chain = "0.11" +winapi = { version = "0.3", features = ["std", "winsvc", "winerror"] } +widestring = "0.3.0" + +[dev-dependencies] +error-chain = "0.11" +log = "0.4" +simplelog = { version = "0.5", default-features = false } diff --git a/windows-service/examples/simple_service.rs b/windows-service/examples/simple_service.rs new file mode 100644 index 0000000000..d175e69564 --- /dev/null +++ b/windows-service/examples/simple_service.rs @@ -0,0 +1,451 @@ +// Simple service example. +// +// All commands mentioned below shall be executed in Command Prompt with Administrator privileges. +// +// Service self-installation: `simple_service.exe --install-service` +// Service self-removal: `simple_service.exe --remove-service` +// +// Start the service: `net start simpleservice` +// Pause the service: `net pause simpleservice` +// Resume the service: `net continue simpleservice` +// Stop the service: `net stop simpleservice` +// +// Simple service outputs all logs in C:\Windows\Temp\simple-service.log. +// If you have GNU tools installed, you can follow the log using: +// `tail -F C:\Windows\Temp\simple-service.log` +// + +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate log; +extern crate simplelog; + +#[cfg(windows)] +#[macro_use] +extern crate windows_service; + +#[cfg(not(windows))] +fn main() { + panic!("This program is only intended to run on Windows."); +} + +#[cfg(windows)] +fn main() { + simple_service::run(); +} + +#[cfg(windows)] +mod simple_service { + use std::ffi::OsString; + use std::fs::OpenOptions; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::mpsc; + use std::time::Duration; + use std::{env, io, thread, time}; + + use log::LevelFilter; + use simplelog::{CombinedLogger, Config, WriteLogger}; + + use windows_service::service::{ServiceAccess, ServiceControl, ServiceControlAccept, + ServiceErrorControl, ServiceExitCode, ServiceInfo, + ServiceStartType, ServiceState, ServiceStatus, ServiceType}; + use windows_service::service_control_handler::{self, ServiceControlHandlerResult, + ServiceStatusHandle}; + use windows_service::service_dispatcher; + use windows_service::service_manager::{ServiceManager, ServiceManagerAccess}; + use windows_service::ChainedError; + + static SERVICE_NAME: &'static str = "SimpleService"; + static SERVICE_DISPLAY_NAME: &'static str = "Simple Service"; + + error_chain! { + errors { + InstallService { + description("Failed to install the service") + } + RemoveService { + description("Failed to remove the service") + } + OpenLogFile(path: PathBuf) { + description("Unable to open log file for writing") + display("Unable to open log file for writing: {}", path.to_string_lossy()) + } + InitLogger { + description("Cannot initialize logger") + } + } + foreign_links { + SetLoggerError(::log::SetLoggerError); + } + } + + static CHECKPOINT_COUNTER: AtomicUsize = AtomicUsize::new(0); + + pub fn update_service_status( + status_handle: &ServiceStatusHandle, + next_state: ServiceState, + exit_code: ServiceExitCode, + wait_hint: Duration, + ) -> io::Result<()> { + // Automatically bump the checkpoint when updating the pending events to tell the system + // that the service is making a progress in transition from pending to final state. + // `wait_hint` should reflect the estimated time for transition to complete. + let checkpoint = match next_state { + ServiceState::ContinuePending + | ServiceState::PausePending + | ServiceState::StartPending + | ServiceState::StopPending => CHECKPOINT_COUNTER.fetch_add(1, Ordering::SeqCst) + 1, + _ => 0, + }; + let service_status = ServiceStatus { + service_type: ServiceType::OwnProcess, + current_state: next_state, + controls_accepted: accepted_controls_by_state(next_state), + exit_code: exit_code, + checkpoint: checkpoint as u32, + wait_hint: wait_hint, + }; + info!( + "Update service status: {:?}, checkpoint: {}, wait_hint: {:?}", + service_status.current_state, service_status.checkpoint, service_status.wait_hint + ); + status_handle.set_service_status(service_status) + } + + pub fn run() { + if let Some(command) = env::args().nth(1) { + match command.as_ref() { + "--install-service" => { + if let Err(e) = install_service() { + println!("{}", e.display_chain()); + } else { + println!("Installed the service."); + } + } + "--remove-service" => { + if let Err(e) = remove_service() { + println!("{}", e.display_chain()); + } else { + println!("Removed the service."); + } + } + "--run-service" => { + // Setup file logger since there is no stdout when running as a service. + if let Err(err) = init_logger() { + panic!("Unable to initialize logger: {}", err.display_chain()); + } + + // Start the service dispatcher. + // This will block current thread until the service stopped. + let result = service_dispatcher::start_dispatcher(SERVICE_NAME, service_main); + + match result { + Err(ref e) => { + error!("Failed to start service dispatcher: {}", e.display_chain()); + } + Ok(_) => { + info!("Service dispatcher exited."); + } + }; + } + _ => println!("Unsupported command: {}", command), + } + } else { + println!("Usage:"); + println!("--install-service to install the service"); + println!("--remove-service to uninstall the service"); + println!("--run-service to run the service"); + } + } + + define_windows_service!(service_main, handle_service_main); + + pub fn handle_service_main(arguments: Vec<OsString>) { + // Create a shutdown channel to release this thread when stopping the service + let (event_tx, event_rx) = mpsc::channel(); + + info!("Received arguments: {:?}", arguments); + + // Register service event handler + let event_handler = move |control_event| -> ServiceControlHandlerResult { + match control_event { + // Notifies a service to report its current status information to the service + // control manager. Always return NO_ERROR even if not implemented. + ServiceControl::Interrogate => ServiceControlHandlerResult::NoError, + + // Handle primary control events + ServiceControl::Pause + | ServiceControl::Continue + | ServiceControl::Stop + | ServiceControl::Shutdown => { + event_tx.send(control_event).unwrap(); + ServiceControlHandlerResult::NoError + } + + _ => ServiceControlHandlerResult::NotImplemented, + } + }; + + let result = service_control_handler::register_control_handler(SERVICE_NAME, event_handler); + match result { + Ok(status_handle) => { + run_service(status_handle, event_rx); + } + Err(ref e) => { + error!("Cannot register a service control handler: {}", e); + } + }; + + info!("Quit service main."); + } + + #[derive(Debug, Copy, Clone)] + enum DaemonEvent { + Continue, + Pause, + Stop, + } + + fn start_event_monitor( + service_status_handle: ServiceStatusHandle, + event_rx: mpsc::Receiver<ServiceControl>, + daemon_tx: mpsc::Sender<DaemonEvent>, + ) -> thread::JoinHandle<()> { + thread::spawn(move || { + loop { + match event_rx.recv().unwrap() { + ServiceControl::Pause => { + info!("Pausing the service."); + + update_service_status( + &service_status_handle, + ServiceState::PausePending, + ServiceExitCode::Win32(0), + Duration::from_secs(2), + ).unwrap(); + + daemon_tx.send(DaemonEvent::Pause).unwrap(); + } + + ServiceControl::Continue => { + info!("Continuing the service."); + + update_service_status( + &service_status_handle, + ServiceState::ContinuePending, + ServiceExitCode::Win32(0), + Duration::from_secs(2), + ).unwrap(); + + daemon_tx.send(DaemonEvent::Continue).unwrap(); + } + + ServiceControl::Stop => { + info!("Stopping the service."); + + update_service_status( + &service_status_handle, + ServiceState::StopPending, + ServiceExitCode::Win32(0), + Duration::from_secs(2), + ).unwrap(); + + daemon_tx.send(DaemonEvent::Stop).unwrap(); + break; // break the loop + } + + ServiceControl::Shutdown => { + info!("Exiting due to shutdown."); + + update_service_status( + &service_status_handle, + ServiceState::StopPending, + ServiceExitCode::Win32(0), + Duration::from_secs(1), + ).unwrap(); + + daemon_tx.send(DaemonEvent::Stop).unwrap(); + break; // break the loop + } + + _ => (), + }; + } + }) + } + + fn start_worker( + service_status_handle: ServiceStatusHandle, + daemon_rx: mpsc::Receiver<DaemonEvent>, + ) -> thread::JoinHandle<()> { + thread::spawn(move || { + let mut is_running = true; + let mut is_paused = false; + + // Tell Windows that the service is running now + update_service_status( + &service_status_handle, + ServiceState::Running, + ServiceExitCode::Win32(0), + Duration::default(), + ).unwrap(); + + while is_running { + // Do some work + if !is_paused { + info!("Working..."); + } + + // Poll events + match daemon_rx.recv_timeout(Duration::from_secs(1)) { + Ok(DaemonEvent::Pause) => { + is_paused = true; + + update_service_status( + &service_status_handle, + ServiceState::Paused, + ServiceExitCode::Win32(0), + Duration::default(), + ).unwrap(); + } + Ok(DaemonEvent::Continue) => { + is_paused = false; + + update_service_status( + &service_status_handle, + ServiceState::Running, + ServiceExitCode::Win32(0), + Duration::default(), + ).unwrap(); + } + Ok(DaemonEvent::Stop) | Err(mpsc::RecvTimeoutError::Disconnected) => { + is_running = false; + + update_service_status( + &service_status_handle, + ServiceState::Stopped, + ServiceExitCode::Win32(0), + Duration::default(), + ).unwrap(); + } + Err(mpsc::RecvTimeoutError::Timeout) => (), + }; + } + }) + } + + fn run_service(status_handle: ServiceStatusHandle, event_rx: mpsc::Receiver<ServiceControl>) { + let (daemon_tx, daemon_rx) = mpsc::channel(); + + // Tell Windows that the service is starting up + update_service_status( + &status_handle, + ServiceState::StartPending, + ServiceExitCode::Win32(0), + Duration::from_secs(5), + ).unwrap(); + + let event_monitor_handle = start_event_monitor(status_handle, event_rx, daemon_tx); + let worker_thread_handle = start_worker(status_handle, daemon_rx); + + // Block current thread until other threads complete execution + event_monitor_handle.join().unwrap(); + worker_thread_handle.join().unwrap(); + } + + /// Returns the list of accepted service events at each stage of the service lifecycle. + fn accepted_controls_by_state(state: ServiceState) -> ServiceControlAccept { + match state { + ServiceState::StartPending + | ServiceState::PausePending + | ServiceState::ContinuePending => ServiceControlAccept::empty(), + ServiceState::Running => { + ServiceControlAccept::STOP | ServiceControlAccept::PAUSE_CONTINUE + | ServiceControlAccept::SHUTDOWN + } + ServiceState::Paused => { + ServiceControlAccept::STOP | ServiceControlAccept::PAUSE_CONTINUE + | ServiceControlAccept::SHUTDOWN + } + ServiceState::StopPending | ServiceState::Stopped => ServiceControlAccept::empty(), + } + } + + fn install_service() -> Result<()> { + let manager_access = ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE; + let service_manager = ServiceManager::local_computer(None::<&str>, manager_access) + .chain_err(|| ErrorKind::InstallService)?; + let service_info = get_service_info(); + service_manager + .create_service(service_info, ServiceAccess::empty()) + .map(|_| ()) + .chain_err(|| ErrorKind::InstallService) + } + + fn remove_service() -> Result<()> { + let manager_access = ServiceManagerAccess::CONNECT; + let service_manager = ServiceManager::local_computer(None::<&str>, manager_access) + .chain_err(|| ErrorKind::RemoveService)?; + + let service_access = + ServiceAccess::QUERY_STATUS | ServiceAccess::STOP | ServiceAccess::DELETE; + let service = service_manager + .open_service(SERVICE_NAME, service_access) + .chain_err(|| ErrorKind::RemoveService)?; + + loop { + let service_status = service + .query_status() + .chain_err(|| ErrorKind::RemoveService)?; + + match service_status.current_state { + ServiceState::StopPending => (), + ServiceState::Stopped => { + println!("Removing the service..."); + service.delete().chain_err(|| ErrorKind::RemoveService)?; + return Ok(()); // explicit return + } + _ => { + println!("Stopping the service..."); + service.stop().chain_err(|| ErrorKind::RemoveService)?; + } + } + + thread::sleep(time::Duration::from_secs(1)) + } + } + + fn get_service_info() -> ServiceInfo { + ServiceInfo { + name: OsString::from(SERVICE_NAME), + display_name: OsString::from(SERVICE_DISPLAY_NAME), + service_type: ServiceType::OwnProcess, + start_type: ServiceStartType::OnDemand, + error_control: ServiceErrorControl::Normal, + executable_path: env::current_exe().unwrap(), + launch_arguments: vec![OsString::from("--run-service")], + account_name: None, // run as System + account_password: None, + } + } + + fn init_logger() -> Result<()> { + let windows_directory = env::var_os("WINDIR").unwrap(); + let log_file_path = PathBuf::from(windows_directory) + .join("Temp") + .join("simple-service.log"); + + let log_file = OpenOptions::new() + .create(true) + .append(true) + .open(log_file_path.as_path()) + .chain_err(|| ErrorKind::OpenLogFile(log_file_path))?; + + let file_logger = WriteLogger::new(LevelFilter::Trace, Config::default(), log_file); + + CombinedLogger::init(vec![file_logger]).chain_err(|| ErrorKind::InitLogger) + } + +} diff --git a/windows-service/src/lib.rs b/windows-service/src/lib.rs new file mode 100644 index 0000000000..bae17eb41d --- /dev/null +++ b/windows-service/src/lib.rs @@ -0,0 +1,18 @@ +#![cfg(windows)] + +#[macro_use] +extern crate bitflags; +#[macro_use] +extern crate error_chain; +extern crate widestring; +extern crate winapi; + +pub use error_chain::ChainedError; + +pub mod service; +pub mod service_control_handler; +pub mod service_manager; +#[macro_use] +pub mod service_dispatcher; + +mod shell_escape; diff --git a/windows-service/src/service.rs b/windows-service/src/service.rs new file mode 100644 index 0000000000..f609221b8c --- /dev/null +++ b/windows-service/src/service.rs @@ -0,0 +1,390 @@ +use std::ffi::OsString; +use std::path::PathBuf; +use std::time::Duration; +use std::{io, mem}; + +use winapi::shared::winerror::ERROR_SERVICE_SPECIFIC_ERROR; +use winapi::um::{winnt, winsvc}; + +mod errors { + error_chain! { + errors { + InvalidServiceType(raw_value: u32) { + description("Invalid service type value") + display("Invalid service type value: {}", raw_value) + } + InvalidServiceState(raw_value: u32) { + description("Invalid service state") + display("Invalid service state value: {}", raw_value) + } + InvalidServiceControl(raw_value: u32) { + description("Invalid service control") + display("Invalid service control value: {}", raw_value) + } + } + foreign_links { + System(::std::io::Error); + } + } +} +pub use self::errors::*; + +/// Enum describing types of windows services +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ServiceType { + /// Service that runs in its own process. + OwnProcess = winnt::SERVICE_WIN32_OWN_PROCESS, +} + +impl ServiceType { + pub fn from_raw(raw_value: u32) -> Result<Self> { + let service_type = match raw_value { + x if x == ServiceType::OwnProcess.to_raw() => ServiceType::OwnProcess, + _ => Err(ErrorKind::InvalidServiceType(raw_value))?, + }; + Ok(service_type) + } + + pub fn to_raw(&self) -> u32 { + *self as u32 + } +} + +/// Flags describing the access permissions when working with services +bitflags! { + pub struct ServiceAccess: u32 { + /// Can query the service status + const QUERY_STATUS = winsvc::SERVICE_QUERY_STATUS; + + /// Can start the service + const START = winsvc::SERVICE_START; + + // Can stop the service + const STOP = winsvc::SERVICE_STOP; + + /// Can pause or continue the service execution + const PAUSE_CONTINUE = winsvc::SERVICE_PAUSE_CONTINUE; + + /// Can ask the service to report its status + const INTERROGATE = winsvc::SERVICE_INTERROGATE; + + /// Can delete the service + const DELETE = winnt::DELETE; + } +} + +/// Enum describing the start options for windows services +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ServiceStartType { + /// Autostart on system startup + AutoStart = winnt::SERVICE_AUTO_START, + /// Service is enabled, can be started manually + OnDemand = winnt::SERVICE_DEMAND_START, + /// Disabled service + Disabled = winnt::SERVICE_DISABLED, +} + +impl ServiceStartType { + pub fn to_raw(&self) -> u32 { + *self as u32 + } +} + +/// Error handling strategy for service failures. +/// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms682450(v=vs.85).aspx +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ServiceErrorControl { + Critical = winnt::SERVICE_ERROR_CRITICAL, + Ignore = winnt::SERVICE_ERROR_IGNORE, + Normal = winnt::SERVICE_ERROR_NORMAL, + Severe = winnt::SERVICE_ERROR_SEVERE, +} + +impl ServiceErrorControl { + pub fn to_raw(&self) -> u32 { + *self as u32 + } +} + +/// A struct that describes the service +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ServiceInfo { + /// Service name + pub name: OsString, + + /// Friendly service name + pub display_name: OsString, + + pub service_type: ServiceType, + pub start_type: ServiceStartType, + pub error_control: ServiceErrorControl, + + /// Path to the service binary. + pub executable_path: PathBuf, + + /// Launch arguments passed to `main` when system starts the service. + /// This is not the same as arguments passed to `service_main`. + pub launch_arguments: Vec<OsString>, + + /// Account to use for running the service. + /// for example: NT Authority\System. + /// use `None` to run as LocalSystem. + pub account_name: Option<OsString>, + + /// Account password. + /// For system accounts this should normally be `None`. + pub account_password: Option<OsString>, +} + +/// Enum describing the service control operations +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ServiceControl { + Continue = winsvc::SERVICE_CONTROL_CONTINUE, + Interrogate = winsvc::SERVICE_CONTROL_INTERROGATE, + NetBindAdd = winsvc::SERVICE_CONTROL_NETBINDADD, + NetBindDisable = winsvc::SERVICE_CONTROL_NETBINDDISABLE, + NetBindEnable = winsvc::SERVICE_CONTROL_NETBINDENABLE, + NetBindRemove = winsvc::SERVICE_CONTROL_NETBINDREMOVE, + ParamChange = winsvc::SERVICE_CONTROL_PARAMCHANGE, + Pause = winsvc::SERVICE_CONTROL_PAUSE, + Preshutdown = winsvc::SERVICE_CONTROL_PRESHUTDOWN, + Shutdown = winsvc::SERVICE_CONTROL_SHUTDOWN, + Stop = winsvc::SERVICE_CONTROL_STOP, +} + +impl ServiceControl { + pub fn from_raw(raw_value: u32) -> Result<Self> { + let service_control = match raw_value { + x if x == ServiceControl::Continue.to_raw() => ServiceControl::Continue, + x if x == ServiceControl::Interrogate.to_raw() => ServiceControl::Interrogate, + x if x == ServiceControl::NetBindAdd.to_raw() => ServiceControl::NetBindAdd, + x if x == ServiceControl::NetBindDisable.to_raw() => ServiceControl::NetBindDisable, + x if x == ServiceControl::NetBindEnable.to_raw() => ServiceControl::NetBindEnable, + x if x == ServiceControl::NetBindRemove.to_raw() => ServiceControl::NetBindRemove, + x if x == ServiceControl::ParamChange.to_raw() => ServiceControl::ParamChange, + x if x == ServiceControl::Pause.to_raw() => ServiceControl::Pause, + x if x == ServiceControl::Preshutdown.to_raw() => ServiceControl::Preshutdown, + x if x == ServiceControl::Shutdown.to_raw() => ServiceControl::Shutdown, + x if x == ServiceControl::Stop.to_raw() => ServiceControl::Stop, + other => Err(ErrorKind::InvalidServiceControl(other))?, + }; + Ok(service_control) + } + + pub fn to_raw(&self) -> u32 { + *self as u32 + } +} + +/// Service state returned as a part of ServiceStatus +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum ServiceState { + Stopped = winsvc::SERVICE_STOPPED, + StartPending = winsvc::SERVICE_START_PENDING, + StopPending = winsvc::SERVICE_STOP_PENDING, + Running = winsvc::SERVICE_RUNNING, + ContinuePending = winsvc::SERVICE_CONTINUE_PENDING, + PausePending = winsvc::SERVICE_PAUSE_PENDING, + Paused = winsvc::SERVICE_PAUSED, +} + +impl ServiceState { + fn from_raw(raw_state: u32) -> Result<Self> { + let service_state = match raw_state { + x if x == ServiceState::Stopped.to_raw() => ServiceState::Stopped, + x if x == ServiceState::StartPending.to_raw() => ServiceState::StartPending, + x if x == ServiceState::StopPending.to_raw() => ServiceState::StopPending, + x if x == ServiceState::Running.to_raw() => ServiceState::Running, + x if x == ServiceState::ContinuePending.to_raw() => ServiceState::ContinuePending, + x if x == ServiceState::PausePending.to_raw() => ServiceState::PausePending, + x if x == ServiceState::Paused.to_raw() => ServiceState::Paused, + other => Err(ErrorKind::InvalidServiceState(other))?, + }; + Ok(service_state) + } + + fn to_raw(&self) -> u32 { + *self as u32 + } +} + +/// Service exit code abstraction. +/// +/// This struct provides a logic around the relationship between `win32_exit_code` and +/// `service_specific_exit_code`. +/// +/// The service can either return a win32 error code or a custom error +/// code. In that case `win32_exit_code` has to be set to `ERROR_SERVICE_SPECIFIC_ERROR` and +/// the `service_specific_exit_code` assigned with custom error code. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ServiceExitCode { + Win32(u32), + ServiceSpecific(u32), +} + +impl ServiceExitCode { + fn copy_to(&self, raw_service_status: &mut winsvc::SERVICE_STATUS) { + match *self { + ServiceExitCode::Win32(win32_error_code) => { + raw_service_status.dwWin32ExitCode = win32_error_code; + raw_service_status.dwServiceSpecificExitCode = 0; + } + ServiceExitCode::ServiceSpecific(service_error_code) => { + raw_service_status.dwWin32ExitCode = ERROR_SERVICE_SPECIFIC_ERROR; + raw_service_status.dwServiceSpecificExitCode = service_error_code; + } + } + } +} + +impl<'a> From<&'a winsvc::SERVICE_STATUS> for ServiceExitCode { + fn from(service_status: &'a winsvc::SERVICE_STATUS) -> Self { + if service_status.dwWin32ExitCode == ERROR_SERVICE_SPECIFIC_ERROR { + ServiceExitCode::ServiceSpecific(service_status.dwServiceSpecificExitCode) + } else { + ServiceExitCode::Win32(service_status.dwWin32ExitCode) + } + } +} + +/// Flags describing accepted types of service control requests +bitflags! { + pub struct ServiceControlAccept: u32 { + /// The service is a network component that can accept changes in its binding without being + /// stopped and restarted. This allows service to receive `ServiceControl::Netbind*` + /// family of events. + const NETBIND_CHANGE = winsvc::SERVICE_ACCEPT_NETBINDCHANGE; + + /// The service can reread its startup parameters without being stopped and restarted. + const PARAM_CHANGE = winsvc::SERVICE_ACCEPT_PARAMCHANGE; + + /// The service can be paused and continued. + const PAUSE_CONTINUE = winsvc::SERVICE_ACCEPT_PAUSE_CONTINUE; + + /// The service can perform preshutdown tasks. + /// Mutually exclusive with shutdown. + const PRESHUTDOWN = winsvc::SERVICE_ACCEPT_PRESHUTDOWN; + + /// The service is notified when system shutdown occurs. + /// Mutually exclusive with preshutdown. + const SHUTDOWN = winsvc::SERVICE_ACCEPT_SHUTDOWN; + + /// The service can be stopped. + const STOP = winsvc::SERVICE_ACCEPT_STOP; + } +} + +/// Service status. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ServiceStatus { + /// Type of service + pub service_type: ServiceType, + + /// Current state of the service + pub current_state: ServiceState, + + /// Control commands that service accepts. + pub controls_accepted: ServiceControlAccept, + + /// Service exit code + pub exit_code: ServiceExitCode, + + /// Service initialization progress value that should be increased during a lengthy start, + /// stop, pause or continue eration. For example the service should increment the value as + /// it completes each step of initialization. + /// This value must be zero if the service does not have any pending start, stop, pause or + /// continue operations. + pub checkpoint: u32, + + /// Estimated time for pending operation. + /// This basically works as a timeout until the service manager assumes that the service hung. + /// This could be either circumvented by updating the `current_state` or incrementing a + /// `checkpoint` value. + pub wait_hint: Duration, +} + +impl ServiceStatus { + pub(super) fn to_raw(&self) -> winsvc::SERVICE_STATUS { + let mut raw_status = unsafe { mem::zeroed::<winsvc::SERVICE_STATUS>() }; + raw_status.dwServiceType = self.service_type.to_raw(); + raw_status.dwCurrentState = self.current_state.to_raw(); + raw_status.dwControlsAccepted = self.controls_accepted.bits(); + + self.exit_code.copy_to(&mut raw_status); + + raw_status.dwCheckPoint = self.checkpoint; + + // we lose precision here but dwWaitHint should never be too big. + raw_status.dwWaitHint = (self.wait_hint.as_secs() * 1000) as u32; + + raw_status + } + + fn from_raw(raw_status: winsvc::SERVICE_STATUS) -> Result<Self> { + Ok(ServiceStatus { + service_type: ServiceType::from_raw(raw_status.dwServiceType)?, + current_state: ServiceState::from_raw(raw_status.dwCurrentState)?, + controls_accepted: ServiceControlAccept::from_bits_truncate( + raw_status.dwControlsAccepted, + ), + exit_code: ServiceExitCode::from(&raw_status), + checkpoint: raw_status.dwCheckPoint, + wait_hint: Duration::from_millis(raw_status.dwWaitHint as u64), + }) + } +} + + +pub struct Service(winsvc::SC_HANDLE); + +impl Service { + /// Internal constructor + pub(super) unsafe fn from_handle(handle: winsvc::SC_HANDLE) -> Self { + Service(handle) + } + + pub fn stop(&self) -> Result<ServiceStatus> { + self.send_control_command(ServiceControl::Stop) + } + + pub fn query_status(&self) -> Result<ServiceStatus> { + let mut raw_status = unsafe { mem::zeroed::<winsvc::SERVICE_STATUS>() }; + let success = unsafe { winsvc::QueryServiceStatus(self.0, &mut raw_status) }; + if success == 1 { + ServiceStatus::from_raw(raw_status) + } else { + Err(io::Error::last_os_error().into()) + } + } + + pub fn delete(self) -> io::Result<()> { + let success = unsafe { winsvc::DeleteService(self.0) }; + if success == 1 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } + } + + fn send_control_command(&self, command: ServiceControl) -> Result<ServiceStatus> { + let mut raw_status = unsafe { mem::zeroed::<winsvc::SERVICE_STATUS>() }; + let success = unsafe { winsvc::ControlService(self.0, command.to_raw(), &mut raw_status) }; + + if success == 1 { + ServiceStatus::from_raw(raw_status).map_err(|err| err.into()) + } else { + Err(io::Error::last_os_error().into()) + } + } +} + +impl Drop for Service { + fn drop(&mut self) { + unsafe { winsvc::CloseServiceHandle(self.0) }; + } +} diff --git a/windows-service/src/service_control_handler.rs b/windows-service/src/service_control_handler.rs new file mode 100644 index 0000000000..b722e94f52 --- /dev/null +++ b/windows-service/src/service_control_handler.rs @@ -0,0 +1,144 @@ +use std::ffi::OsStr; +use std::io; +use widestring::WideCString; +use winapi::shared::winerror::{ERROR_CALL_NOT_IMPLEMENTED, NO_ERROR}; +use winapi::um::winsvc; + +use service::{ServiceControl, ServiceStatus}; + +mod errors { + error_chain! { + errors { + InvalidServiceName { + description("Invalid service name") + } + } + foreign_links { + System(::std::io::Error); + } + } +} +pub use self::errors::*; + +/// Struct that holds unique token for updating the status of the corresponding service. +#[derive(Debug, Clone, Copy)] +pub struct ServiceStatusHandle(winsvc::SERVICE_STATUS_HANDLE); + +impl ServiceStatusHandle { + fn from_handle(handle: winsvc::SERVICE_STATUS_HANDLE) -> Self { + ServiceStatusHandle(handle) + } + + /// Report the new service status to the system + pub fn set_service_status(&self, service_status: ServiceStatus) -> io::Result<()> { + let mut raw_service_status = service_status.to_raw(); + let result = unsafe { winsvc::SetServiceStatus(self.0, &mut raw_service_status) }; + if result == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } +} + +// Underlying SERVICE_STATUS_HANDLE is thread safe. +// See remarks section for more info: +// https://msdn.microsoft.com/en-us/library/windows/desktop/ms686241(v=vs.85).aspx +unsafe impl Send for ServiceStatusHandle {} + +/// Abstraction over the return value of service control handler. +/// The meaning of each of variants in this enum depends on the type of received event. +/// See the "Return value" section of corresponding MSDN article for more info: +/// https://msdn.microsoft.com/en-us/library/windows/desktop/ms683241(v=vs.85).aspx +#[derive(Debug)] +pub enum ServiceControlHandlerResult { + /// Either used to aknowledge the call or grant the permission in advanced events. + NoError, + /// The received event is not implemented. + NotImplemented, + /// This variant is used to deny permission and return the reason error code in advanced + /// events. + Other(u32), +} + +impl ServiceControlHandlerResult { + pub fn to_raw(&self) -> u32 { + match *self { + ServiceControlHandlerResult::NoError => NO_ERROR, + ServiceControlHandlerResult::NotImplemented => ERROR_CALL_NOT_IMPLEMENTED, + ServiceControlHandlerResult::Other(code) => code, + } + } +} + +/// Register a closure for receiving service events. +/// Returns `ServiceStatusHandle` that can be used to report the service status back to the system. +pub fn register_control_handler< + S: AsRef<OsStr>, + F: Fn(ServiceControl) -> ServiceControlHandlerResult + 'static, +>( + service_name: S, + event_handler: F, +) -> Result<ServiceStatusHandle> { + // Move closure data on heap. + // The Box<HandlerFn> is a trait object and is stored on stack at this point. + let heap_event_handler = Box::new(event_handler) as Box<HandlerFn>; + + // Box again to move trait object to heap. + let boxed_event_handler: Box<Box<HandlerFn>> = Box::new(heap_event_handler); + + // Important: leak the Box<Box<HandlerFn>> which will be released in `service_control_handler`. + let context = Box::into_raw(boxed_event_handler) as *mut ::std::os::raw::c_void; + + let service_name = + WideCString::from_str(service_name).chain_err(|| ErrorKind::InvalidServiceName)?; + let status_handle = unsafe { + winsvc::RegisterServiceCtrlHandlerExW( + service_name.as_ptr(), + Some(service_control_handler), + context, + ) + }; + + if status_handle.is_null() { + Err(io::Error::last_os_error().into()) + } else { + Ok(ServiceStatusHandle::from_handle(status_handle)) + } +} + +/// Alias for control event handler closure. +type HandlerFn = Fn(ServiceControl) -> ServiceControlHandlerResult; + +/// Static service control handler +#[allow(dead_code)] +extern "system" fn service_control_handler( + control: u32, + _event_type: u32, + _event_data: *mut ::std::os::raw::c_void, + context: *mut ::std::os::raw::c_void, +) -> u32 { + // Important: cast context to &mut Box<HandlerFn> without taking ownership. + let handler_fn = unsafe { &mut *(context as *mut Box<HandlerFn>) }; + + match ServiceControl::from_raw(control) { + Ok(service_control) => { + let return_code = ((handler_fn)(service_control)).to_raw(); + + // Important: release context upon Stop, Shutdown or Preshutdown at the end of the + // service lifecycle. + match service_control { + ServiceControl::Stop | ServiceControl::Shutdown | ServiceControl::Preshutdown => { + let _owned_boxed_handler: Box<Box<HandlerFn>> = + unsafe { Box::from_raw(context as *mut Box<HandlerFn>) }; + } + _ => (), + }; + + return_code + } + + // Report all unknown control commands as unimplemented + Err(_) => ServiceControlHandlerResult::NotImplemented.to_raw(), + } +} diff --git a/windows-service/src/service_dispatcher.rs b/windows-service/src/service_dispatcher.rs new file mode 100644 index 0000000000..fb9d5df8f3 --- /dev/null +++ b/windows-service/src/service_dispatcher.rs @@ -0,0 +1,90 @@ +use std::ffi::{OsStr, OsString}; +use std::{io, ptr}; +use widestring::{WideCStr, WideCString}; +use winapi::um::winsvc; + +mod errors { + error_chain! { + errors { + InvalidServiceName { + description("Invalid service name") + } + } + foreign_links { + System(::std::io::Error); + } + } +} +pub use self::errors::*; + +/// Macro to generate a "service_main" function for Windows service. +/// +/// The `service_main` function parses service arguments provided by the system +/// and passes them with a call to `$service_main_handler`. +/// +/// `$function_name` - name of the "service_main" callback. +/// `$service_main_handler` - function with a signature `fn(Vec<OsString>)` that's called from +/// generated `$function_name`. Accepts parsed service arguments as `Vec<OsString>`. Its +/// responsibility is to create a `ServiceControlHandler`, start processing control events and +/// report the service status to the system. +/// +#[macro_export] +macro_rules! define_windows_service { + ($function_name:ident, $service_main_handler:ident) => { + /// Static callback used by the system to bootstrap the service. + /// Do not call it directly. + extern "system" fn $function_name(argc: u32, argv: *mut *mut u16) { + let arguments = unsafe { $crate::service_dispatcher::parse_raw_arguments(argc, argv) }; + + $service_main_handler(arguments); + } + }; +} + +/// Start service control dispatcher. +/// +/// Once started the service control dispatcher blocks the current thread execution +/// until the service is stopped. +/// +/// Upon successful initialization, system calls the `service_main` in +/// background thread which parses service arguments received from the system and +/// passes them to higher level `$service_main_handler` handler. +/// +/// On failure: immediately returns an error, no threads are spawned. +/// +pub fn start_dispatcher<T: AsRef<OsStr>>( + service_name: T, + service_main: extern "system" fn(u32, *mut *mut u16), +) -> Result<()> { + let service_name = + WideCString::from_str(service_name).chain_err(|| ErrorKind::InvalidServiceName)?; + let service_table: &[winsvc::SERVICE_TABLE_ENTRYW] = &[ + winsvc::SERVICE_TABLE_ENTRYW { + lpServiceName: service_name.as_ptr(), + lpServiceProc: Some(service_main), + }, + // the last item has to be { null, null } + winsvc::SERVICE_TABLE_ENTRYW { + lpServiceName: ptr::null(), + lpServiceProc: None, + }, + ]; + + let result = unsafe { winsvc::StartServiceCtrlDispatcherW(service_table.as_ptr()) }; + if result == 0 { + Err(io::Error::last_os_error().into()) + } else { + Ok(()) + } +} + +/// Parse raw arguments received from `service_main` into Vec. +pub unsafe fn parse_raw_arguments(argc: u32, argv: *mut *mut u16) -> Vec<OsString> { + (0..argc) + .into_iter() + .map(|i| { + let array_element_ptr: *mut *mut u16 = argv.offset(i as isize); + WideCStr::from_ptr_str(*array_element_ptr).to_os_string() + }) + .collect() +} diff --git a/windows-service/src/service_manager.rs b/windows-service/src/service_manager.rs new file mode 100644 index 0000000000..b4dceb9d3e --- /dev/null +++ b/windows-service/src/service_manager.rs @@ -0,0 +1,196 @@ +use std::borrow::Cow; +use std::ffi::OsStr; +use std::{io, ptr}; + +use widestring::{NulError, WideCString, WideString}; +use winapi::um::winsvc; + +use service::{Service, ServiceAccess, ServiceInfo}; +use shell_escape; + +mod errors { + error_chain! { + errors { + InvalidAccountName { + description("Invalid account name") + } + InvalidAccountPassword { + description("Invalid account password") + } + InvalidDisplayName { + description("Invalid display name") + } + InvalidDatabaseName { + description("Invalid database name") + } + InvalidExecutablePath { + description("Invalid executable path") + } + InvalidLaunchArgument { + description("Invalid launch argument") + } + InvalidMachineName { + description("Invalid machine name") + } + InvalidServiceName { + description("Invalid service name") + } + } + foreign_links { + System(::std::io::Error); + } + } +} +pub use self::errors::*; + +/// Flags describing access permissions for ServiceManager +bitflags! { + pub struct ServiceManagerAccess: u32 { + /// Can connect to service control manager + const CONNECT = winsvc::SC_MANAGER_CONNECT; + + /// Can create services + const CREATE_SERVICE = winsvc::SC_MANAGER_CREATE_SERVICE; + + /// Can enumerate services + const ENUMERATE_SERVICE = winsvc::SC_MANAGER_ENUMERATE_SERVICE; + } +} + +/// Service control manager +pub struct ServiceManager(winsvc::SC_HANDLE); + +impl ServiceManager { + /// Private initializer + /// Passing None for machine connects to local machine + /// Passing None for database connects to active database + fn new<M: AsRef<OsStr>, D: AsRef<OsStr>>( + machine: Option<M>, + database: Option<D>, + request_access: ServiceManagerAccess, + ) -> Result<Self> { + let machine_name = to_wide(machine).chain_err(|| ErrorKind::InvalidMachineName)?; + let database_name = to_wide(database).chain_err(|| ErrorKind::InvalidDatabaseName)?; + let handle = unsafe { + winsvc::OpenSCManagerW( + machine_name.map_or(ptr::null(), |s| s.as_ptr()), + database_name.map_or(ptr::null(), |s| s.as_ptr()), + request_access.bits(), + ) + }; + + if handle.is_null() { + Err(io::Error::last_os_error().into()) + } else { + Ok(ServiceManager(handle)) + } + } + + /// Passing None for database connects to active database + pub fn local_computer<D: AsRef<OsStr>>( + database: Option<D>, + request_access: ServiceManagerAccess, + ) -> Result<Self> { + ServiceManager::new(None::<&OsStr>, database, request_access) + } + + /// Passing None for database connects to active database + pub fn remote_computer<M: AsRef<OsStr>, D: AsRef<OsStr>>( + machine: M, + database: Option<D>, + request_access: ServiceManagerAccess, + ) -> Result<Self> { + ServiceManager::new(Some(machine), database, request_access) + } + + pub fn create_service( + &self, + service_info: ServiceInfo, + service_access: ServiceAccess, + ) -> Result<Service> { + let service_name = + WideCString::from_str(service_info.name).chain_err(|| ErrorKind::InvalidServiceName)?; + let display_name = WideCString::from_str(service_info.display_name) + .chain_err(|| ErrorKind::InvalidDisplayName)?; + let account_name = + to_wide(service_info.account_name).chain_err(|| ErrorKind::InvalidAccountName)?; + let account_password = + to_wide(service_info.account_password).chain_err(|| ErrorKind::InvalidAccountPassword)?; + + // escape executable path and arguments and combine them into single command + let executable_path = escape_wide(service_info.executable_path) + .chain_err(|| ErrorKind::InvalidExecutablePath)?; + + let mut launch_command_buffer = WideString::new(); + launch_command_buffer.push(executable_path); + + for launch_argument in service_info.launch_arguments.iter() { + let wide = escape_wide(launch_argument).chain_err(|| ErrorKind::InvalidLaunchArgument)?; + + launch_command_buffer.push_str(" "); + launch_command_buffer.push(wide); + } + + let launch_command = WideCString::from_wide_str(launch_command_buffer).unwrap(); + + let service_handle = unsafe { + winsvc::CreateServiceW( + self.0, + service_name.as_ptr(), + display_name.as_ptr(), + service_access.bits(), + service_info.service_type.to_raw(), + service_info.start_type.to_raw(), + service_info.error_control.to_raw(), + launch_command.as_ptr(), + ptr::null(), // load ordering group + ptr::null_mut(), // tag id within the load ordering group + ptr::null(), // service dependencies + account_name.map_or(ptr::null(), |s| s.as_ptr()), + account_password.map_or(ptr::null(), |s| s.as_ptr()), + ) + }; + + if service_handle.is_null() { + Err(io::Error::last_os_error().into()) + } else { + Ok(unsafe { Service::from_handle(service_handle) }) + } + } + + pub fn open_service<T: AsRef<OsStr>>( + &self, + name: T, + request_access: ServiceAccess, + ) -> Result<Service> { + let service_name = WideCString::from_str(name).chain_err(|| ErrorKind::InvalidServiceName)?; + let service_handle = + unsafe { winsvc::OpenServiceW(self.0, service_name.as_ptr(), request_access.bits()) }; + + if service_handle.is_null() { + Err(io::Error::last_os_error().into()) + } else { + Ok(unsafe { Service::from_handle(service_handle) }) + } + } +} + +impl Drop for ServiceManager { + fn drop(&mut self) { + unsafe { winsvc::CloseServiceHandle(self.0) }; + } +} + +fn to_wide<T: AsRef<OsStr>>(s: Option<T>) -> ::std::result::Result<Option<WideCString>, NulError> { + if let Some(s) = s { + Ok(Some(WideCString::from_str(s)?)) + } else { + Ok(None) + } +} + +fn escape_wide<T: AsRef<OsStr>>(s: T) -> ::std::result::Result<WideString, NulError> { + let escaped = shell_escape::escape(Cow::Borrowed(s.as_ref())); + let wide = WideCString::from_str(escaped)?; + Ok(wide.to_wide_string()) +} diff --git a/windows-service/src/shell_escape.rs b/windows-service/src/shell_escape.rs new file mode 100644 index 0000000000..12014a3b15 --- /dev/null +++ b/windows-service/src/shell_escape.rs @@ -0,0 +1,120 @@ +use std::borrow::Cow; +use std::ffi::{OsStr, OsString}; +use std::iter::repeat; +use std::os::windows::ffi::{OsStrExt, OsStringExt}; + +/// Common UTF-16 code points. +mod utf16 { + pub const DOUBLEQUOTE: u16 = '"' as u16; + pub const BACKSLASH: u16 = '\\' as u16; + pub const SPACE: u16 = ' ' as u16; + pub const LINEFEED: u16 = '\n' as u16; + pub const HTAB: u16 = '\t' as u16; + pub const VTAB: u16 = 0x000B; // '\v' +} + +/// Loselessly escape shell arguments on Windows. +/// +/// Inspired by https://blogs.msdn.microsoft.com/twistylittlepassagesallalike/2011/04/23/everyone-quotes-command-line-arguments-the-wrong-way/. +/// Heavily based on https://github.com/sfackler/shell-escape +pub fn escape(s: Cow<OsStr>) -> Cow<OsStr> { + static ESCAPE_CHARS: &'static [u16] = &[ + utf16::DOUBLEQUOTE, + utf16::SPACE, + utf16::LINEFEED, + utf16::HTAB, + utf16::VTAB, + ]; + let needs_escape = s.is_empty() || s.encode_wide().any(|ref c| ESCAPE_CHARS.contains(c)); + if !needs_escape { + return s; + } + + let mut escaped_wide_string: Vec<u16> = Vec::with_capacity(s.len() + 2); + escaped_wide_string.push(utf16::DOUBLEQUOTE); + + let mut chars = s.encode_wide().peekable(); + loop { + let mut num_slashes = 0; + while let Some(&utf16::BACKSLASH) = chars.peek() { + chars.next(); + num_slashes += 1; + } + + match chars.next() { + Some(utf16::DOUBLEQUOTE) => { + escaped_wide_string.extend(repeat(utf16::BACKSLASH).take(num_slashes * 2 + 1)); + escaped_wide_string.push(utf16::DOUBLEQUOTE); + } + Some(c) => { + escaped_wide_string.extend(repeat(utf16::BACKSLASH).take(num_slashes)); + escaped_wide_string.push(c); + } + None => { + escaped_wide_string.extend(repeat(utf16::BACKSLASH).take(num_slashes * 2)); + break; + } + } + } + + escaped_wide_string.push(utf16::DOUBLEQUOTE); + + Cow::Owned(OsString::from_wide(&escaped_wide_string)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_no_escape() { + assert_eq!( + escape(Cow::Borrowed(OsStr::new("--aaa=bbb-ccc"))), + OsStr::new("--aaa=bbb-ccc") + ); + } + + #[test] + fn test_escape_empty_argument() { + assert_eq!(escape(Cow::Borrowed(OsStr::new(""))), OsStr::new(r#""""#)); + } + + #[test] + fn test_escape_argument_with_spaces() { + assert_eq!( + escape(Cow::Borrowed(OsStr::new("linker=gcc -L/foo -Wl,bar"))), + OsStr::new(r#""linker=gcc -L/foo -Wl,bar""#) + ); + } + + #[test] + fn test_escape_nested_quotes() { + assert_eq!( + escape(Cow::Borrowed(OsStr::new(r#"--features="default""#))), + OsStr::new(r#""--features=\"default\"""#) + ); + } + + + #[test] + fn test_escape_multiple_backslashes_and_nested_quotes() { + assert_eq!( + escape(Cow::Borrowed(OsStr::new(r#"hello \\\"quote\\\""#))), + OsStr::new(r#""hello \\\\\\\"quote\\\\\\\"""#) + ); + } + + // Input: + // child.exe "\some\directory with\spaces\" argument2 + // + // Parsed as: + // 0: [child.exe] + // 1: [\some\directory with\spaces" argument2] + #[test] + fn test_escape_trailing_backslash() { + assert_eq!( + escape(Cow::Borrowed(OsStr::new(r#"\some\directory with\spaces\"#))), + OsStr::new(r#""\some\directory with\spaces\\""#) + ); + } +} |
