diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-05-29 09:45:17 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-07-02 09:54:19 +0200 |
| commit | e5baa0e08816d535a031b3d8575701b8d43fb0c2 (patch) | |
| tree | c4bf2ec1956977676bc25c2630bd38789f43dade | |
| parent | 207ab239223686ff72c43a8a5d615565ab81b5ab (diff) | |
| download | mullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.tar.xz mullvadvpn-e5baa0e08816d535a031b3d8575701b8d43fb0c2.zip | |
Support Windows split tunneling driver
23 files changed, 1715 insertions, 14 deletions
diff --git a/Cargo.lock b/Cargo.lock index 7cb2c21c7f..ffe5a42940 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1355,6 +1355,7 @@ dependencies = [ "err-derive 0.3.0", "futures", "lazy_static", + "log", "mullvad-paths", "mullvad-types", "nix 0.19.1", diff --git a/mullvad-cli/src/cmds/mod.rs b/mullvad-cli/src/cmds/mod.rs index bd7ca97372..2ceb3bfdcf 100644 --- a/mullvad-cli/src/cmds/mod.rs +++ b/mullvad-cli/src/cmds/mod.rs @@ -37,9 +37,9 @@ pub use self::relay::Relay; mod reset; pub use self::reset::Reset; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", windows))] mod split_tunnel; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", windows))] pub use self::split_tunnel::SplitTunnel; mod status; @@ -66,7 +66,7 @@ pub fn get_commands() -> HashMap<&'static str, Box<dyn Command>> { Box::new(Lan), Box::new(Relay), Box::new(Reset), - #[cfg(target_os = "linux")] + #[cfg(any(target_os = "linux", windows))] Box::new(SplitTunnel), Box::new(Status), Box::new(Tunnel), diff --git a/mullvad-cli/src/cmds/split_tunnel/mod.rs b/mullvad-cli/src/cmds/split_tunnel/mod.rs index c7c366d6ea..c9e87f5d7c 100644 --- a/mullvad-cli/src/cmds/split_tunnel/mod.rs +++ b/mullvad-cli/src/cmds/split_tunnel/mod.rs @@ -2,5 +2,9 @@ #[path = "linux.rs"] mod imp; -#[cfg(target_os = "linux")] +#[cfg(windows)] +#[path = "windows.rs"] +mod imp; + +#[cfg(any(target_os = "linux", windows))] pub use imp::*; diff --git a/mullvad-cli/src/cmds/split_tunnel/windows.rs b/mullvad-cli/src/cmds/split_tunnel/windows.rs new file mode 100644 index 0000000000..5f6fa8878e --- /dev/null +++ b/mullvad-cli/src/cmds/split_tunnel/windows.rs @@ -0,0 +1,112 @@ +use crate::{new_rpc_client, Command, Result}; +use clap::value_t_or_exit; + +pub struct SplitTunnel; + +#[mullvad_management_interface::async_trait] +impl Command for SplitTunnel { + fn name(&self) -> &'static str { + "split-tunnel" + } + + fn clap_subcommand(&self) -> clap::App<'static, 'static> { + clap::SubCommand::with_name(self.name()) + .about("Set options for applications to exclude from the tunnel") + .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(create_app_subcommand()) + .subcommand( + clap::SubCommand::with_name("set") + .about("Enable or disable split tunnel") + .arg( + clap::Arg::with_name("policy") + .required(true) + .possible_values(&["on", "off"]), + ), + ) + .subcommand(clap::SubCommand::with_name("get").about("Display the split tunnel status")) + } + + async fn run(&self, matches: &clap::ArgMatches<'_>) -> Result<()> { + match matches.subcommand() { + ("app", Some(matches)) => Self::handle_app_subcommand(matches).await, + ("get", _) => self.get().await, + ("set", Some(matches)) => { + let enabled = value_t_or_exit!(matches.value_of("policy"), String); + self.set(enabled == "on").await + } + _ => { + unreachable!("unhandled command"); + } + } + } +} + +fn create_app_subcommand() -> clap::App<'static, 'static> { + clap::SubCommand::with_name("app") + .about("Manage applications to exclude from the tunnel") + .setting(clap::AppSettings::SubcommandRequiredElseHelp) + .subcommand(clap::SubCommand::with_name("list")) + .subcommand( + clap::SubCommand::with_name("add").arg(clap::Arg::with_name("path").required(true)), + ) + .subcommand( + clap::SubCommand::with_name("remove").arg(clap::Arg::with_name("path").required(true)), + ) + .subcommand(clap::SubCommand::with_name("clear")) +} + +impl SplitTunnel { + async fn handle_app_subcommand(matches: &clap::ArgMatches<'_>) -> Result<()> { + match matches.subcommand() { + ("list", Some(_)) => { + let mut paths = new_rpc_client() + .await? + .get_split_tunnel_apps(()) + .await? + .into_inner(); + + println!("Excluded applications:"); + while let Some(path) = paths.message().await? { + println!(" {}", path); + } + + Ok(()) + } + ("add", Some(matches)) => { + let path = value_t_or_exit!(matches.value_of("path"), String); + new_rpc_client().await?.add_split_tunnel_app(path).await?; + Ok(()) + } + ("remove", Some(matches)) => { + let path = value_t_or_exit!(matches.value_of("path"), String); + new_rpc_client() + .await? + .remove_split_tunnel_app(path) + .await?; + Ok(()) + } + ("clear", Some(_)) => { + new_rpc_client().await?.clear_split_tunnel_apps(()).await?; + Ok(()) + } + _ => unreachable!("unhandled subcommand"), + } + } + + async fn set(&self, enabled: bool) -> Result<()> { + let mut rpc = new_rpc_client().await?; + rpc.set_split_tunnel_state(enabled).await?; + println!("Changed split tunnel setting"); + Ok(()) + } + + async fn get(&self) -> Result<()> { + let mut rpc = new_rpc_client().await?; + let enabled = rpc.get_settings(()).await?.into_inner().split_tunnel; + println!( + "Split tunnel status: {}", + if enabled { "on" } else { "off" } + ); + Ok(()) + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 25afafe250..013b3453db 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -43,6 +43,8 @@ use mullvad_types::{ use settings::SettingsPersister; #[cfg(target_os = "android")] use std::os::unix::io::RawFd; +#[cfg(target_os = "windows")] +use std::{collections::HashSet, ffi::OsString}; use std::{ marker::PhantomData, mem, @@ -52,7 +54,7 @@ use std::{ sync::{mpsc as sync_mpsc, Arc, Weak}, time::Duration, }; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", windows))] use talpid_core::split_tunnel; use talpid_core::{ mpsc::Sender, @@ -113,6 +115,10 @@ pub enum Error { #[error(display = "The account has too many wireguard keys")] TooManyKeys, + #[cfg(windows)] + #[error(display = "Split tunneling error")] + SplitTunnelError(#[error(source)] split_tunnel::Error), + #[error(display = "No wireguard private key available")] NoKeyAvailable, @@ -259,6 +265,21 @@ pub enum DaemonCommand { /// Clear list of processes excluded from the tunnel #[cfg(target_os = "linux")] ClearSplitTunnelProcesses(ResponseTx<(), split_tunnel::Error>), + /// Request list of apps to exclude from the tunnel + #[cfg(windows)] + GetSplitTunnelApps(oneshot::Sender<HashSet<PathBuf>>), + /// Exclude traffic of an application from the tunnel + #[cfg(windows)] + AddSplitTunnelApp(ResponseTx<(), Error>, PathBuf), + /// Remove application from list of apps to exclude from the tunnel + #[cfg(windows)] + RemoveSplitTunnelApp(ResponseTx<(), Error>, PathBuf), + /// Clear list of apps to exclude from the tunnel + #[cfg(windows)] + ClearSplitTunnelApps(ResponseTx<(), Error>), + /// Disable split tunnel + #[cfg(windows)] + SetSplitTunnelState(ResponseTx<(), Error>, bool), /// Makes the daemon exit the main loop and quit. Shutdown, /// Saves the target tunnel state and enters a blocking state. The state is restored @@ -635,6 +656,16 @@ where rpc_runtime.address_cache.peek_address(), TransportProtocol::Tcp, ); + #[cfg(windows)] + let exclude_apps = if settings.split_tunnel { + settings + .split_tunnel_apps + .iter() + .map(|s| OsString::from(s)) + .collect() + } else { + vec![] + }; let tunnel_command_tx = tunnel_state_machine::spawn( settings.allow_lan, @@ -650,6 +681,8 @@ where initial_target_state != TargetState::Secured, #[cfg(target_os = "android")] android_context, + #[cfg(windows)] + exclude_apps, ) .await .map_err(Error::TunnelError)?; @@ -1182,6 +1215,16 @@ where RemoveSplitTunnelProcess(tx, pid) => self.on_remove_split_tunnel_process(tx, pid), #[cfg(target_os = "linux")] ClearSplitTunnelProcesses(tx) => self.on_clear_split_tunnel_processes(tx), + #[cfg(windows)] + GetSplitTunnelApps(tx) => self.on_get_split_tunnel_apps(tx), + #[cfg(windows)] + AddSplitTunnelApp(tx, path) => self.on_add_split_tunnel_app(tx, path).await, + #[cfg(windows)] + RemoveSplitTunnelApp(tx, path) => self.on_remove_split_tunnel_app(tx, path).await, + #[cfg(windows)] + ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await, + #[cfg(windows)] + SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await, Shutdown => self.trigger_shutdown_event(), PrepareRestart => self.on_prepare_restart(), #[cfg(target_os = "android")] @@ -1716,6 +1759,168 @@ where Self::oneshot_send(tx, result, "clear_split_tunnel_processes response"); } + #[cfg(windows)] + fn on_get_split_tunnel_apps(&mut self, tx: oneshot::Sender<HashSet<PathBuf>>) { + Self::oneshot_send( + tx, + self.settings.to_settings().split_tunnel_apps, + "get_split_tunnel_apps response", + ); + } + + /// Update the split app paths in both the settings and tunnel + #[cfg(windows)] + async fn set_split_tunnel_paths( + &mut self, + tx: ResponseTx<(), Error>, + response_msg: &'static str, + settings: Settings, + new_list: HashSet<PathBuf>, + ) { + if new_list == settings.split_tunnel_apps { + Self::oneshot_send(tx, Ok(()), response_msg); + return; + } + + if settings.split_tunnel { + let (result_tx, result_rx) = oneshot::channel(); + self.send_tunnel_command(TunnelCommand::SetExcludedApps( + result_tx, + new_list.iter().map(|s| OsString::from(s)).collect(), + )); + match result_rx.await { + Ok(Ok(_)) => (), + Ok(Err(error)) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set excluded apps list") + ); + Self::oneshot_send(tx, Err(Error::SplitTunnelError(error)), response_msg); + return; + } + Err(_) => { + log::error!("The tunnel failed to return a result"); + return; + } + } + } + + let save_result = self + .settings + .set_split_tunnel_apps(new_list) + .await + .map_err(Error::SettingsError); + match save_result { + Ok(true) => { + Self::oneshot_send(tx, Ok(()), response_msg); + self.event_listener + .notify_settings(self.settings.to_settings()); + } + Err(error) => { + Self::oneshot_send(tx, Err(error), response_msg); + } + Ok(false) => { + // unreachable!("new_list != settings.split_tunnel_apps") + error!("BUG: new_list != settings.split_tunnel_apps"); + } + } + } + + #[cfg(windows)] + async fn on_add_split_tunnel_app(&mut self, tx: ResponseTx<(), Error>, path: PathBuf) { + let settings = self.settings.to_settings(); + + let mut new_list = settings.split_tunnel_apps.clone(); + new_list.insert(path); + + self.set_split_tunnel_paths(tx, "add_split_tunnel_app response", settings, new_list) + .await; + } + + #[cfg(windows)] + async fn on_remove_split_tunnel_app(&mut self, tx: ResponseTx<(), Error>, path: PathBuf) { + let settings = self.settings.to_settings(); + + let mut new_list = settings.split_tunnel_apps.clone(); + new_list.remove(&path); + + self.set_split_tunnel_paths(tx, "remove_split_tunnel_app response", settings, new_list) + .await; + } + + #[cfg(windows)] + async fn on_clear_split_tunnel_apps(&mut self, tx: ResponseTx<(), Error>) { + let settings = self.settings.to_settings(); + let new_list = HashSet::new(); + self.set_split_tunnel_paths(tx, "clear_split_tunnel_apps response", settings, new_list) + .await; + } + + #[cfg(windows)] + async fn on_set_split_tunnel_state(&mut self, tx: ResponseTx<(), Error>, enabled: bool) { + let settings = self.settings.to_settings(); + + if enabled != settings.split_tunnel { + let new_list = if enabled { + settings.split_tunnel_apps.clone() + } else { + HashSet::new() + }; + if !settings.split_tunnel_apps.is_empty() { + let (result_tx, result_rx) = oneshot::channel(); + self.send_tunnel_command(TunnelCommand::SetExcludedApps( + result_tx, + new_list.iter().map(|app| OsString::from(app)).collect(), + )); + match result_rx.await { + Ok(Ok(_)) => (), + Ok(Err(error)) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set excluded apps list") + ); + Self::oneshot_send( + tx, + Err(Error::SplitTunnelError(error)), + "set_split_tunnel_state response", + ); + return; + } + Err(_) => { + log::error!("The tunnel failed to return a result"); + return; + } + } + } + + let save_result = self + .settings + .set_split_tunnel_state(enabled) + .await + .map_err(Error::SettingsError); + match save_result { + Ok(true) => { + Self::oneshot_send(tx, Ok(()), "set_split_tunnel_state response"); + self.event_listener + .notify_settings(self.settings.to_settings()); + } + Err(error) => { + error!( + "{}", + error.display_chain_with_msg("Unable to save settings") + ); + Self::oneshot_send(tx, Err(error), "set_split_tunnel_state response"); + } + Ok(false) => { + // unreachable!("enabled != settings.split_tunnel"), + error!("BUG: enabled != settings.split_tunnel"); + } + } + } else { + Self::oneshot_send(tx, Ok(()), "set_split_tunnel_state response"); + } + } + async fn on_update_relay_settings( &mut self, tx: ResponseTx<(), settings::Error>, diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 2b8bf31a10..cf5903d714 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -18,6 +18,8 @@ use mullvad_types::{ wireguard::{RotationInterval, RotationIntervalError}, }; use parking_lot::RwLock; +#[cfg(windows)] +use std::path::PathBuf; use std::{ cmp, convert::{TryFrom, TryInto}, @@ -52,6 +54,7 @@ impl ManagementService for ManagementServiceImpl { type GetRelayLocationsStream = tokio::sync::mpsc::Receiver<Result<types::RelayListCountry, Status>>; type GetSplitTunnelProcessesStream = tokio::sync::mpsc::UnboundedReceiver<Result<i32, Status>>; + type GetSplitTunnelAppsStream = tokio::sync::mpsc::UnboundedReceiver<Result<String, Status>>; type EventsListenStream = EventsListenerReceiver; // Control and get the tunnel state @@ -641,6 +644,98 @@ impl ManagementService for ManagementServiceImpl { Ok(Response::new(())) } } + + async fn get_split_tunnel_apps( + &self, + _: Request<()>, + ) -> ServiceResult<Self::GetSplitTunnelAppsStream> { + #[cfg(windows)] + { + log::debug!("get_split_tunnel_apps"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::GetSplitTunnelApps(tx))?; + let paths = rx.await.map_err(|_| Status::internal("internal error"))?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + for path in paths { + let _ = tx.send(path.into_os_string().into_string().map_err(|os_path| { + Status::internal(format!("failed to convert OS string: {:?}", os_path)) + })); + } + }); + + Ok(Response::new(rx)) + } + #[cfg(not(windows))] + { + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + Ok(Response::new(rx)) + } + } + + #[cfg(windows)] + async fn add_split_tunnel_app(&self, request: Request<String>) -> ServiceResult<()> { + log::debug!("add_split_tunnel_app"); + let path = PathBuf::from(request.into_inner()); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::AddSplitTunnelApp(tx, path))?; + self.wait_for_result(rx) + .await? + .map_err(map_daemon_error) + .map(Response::new) + } + #[cfg(not(windows))] + async fn add_split_tunnel_app(&self, _: Request<String>) -> ServiceResult<()> { + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn remove_split_tunnel_app(&self, request: Request<String>) -> ServiceResult<()> { + log::debug!("remove_split_tunnel_app"); + let path = PathBuf::from(request.into_inner()); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::RemoveSplitTunnelApp(tx, path))?; + self.wait_for_result(rx) + .await? + .map_err(map_daemon_error) + .map(Response::new) + } + #[cfg(not(windows))] + async fn remove_split_tunnel_app(&self, _: Request<String>) -> ServiceResult<()> { + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn clear_split_tunnel_apps(&self, _: Request<()>) -> ServiceResult<()> { + log::debug!("clear_split_tunnel_apps"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::ClearSplitTunnelApps(tx))?; + self.wait_for_result(rx) + .await? + .map_err(map_daemon_error) + .map(Response::new) + } + #[cfg(not(windows))] + async fn clear_split_tunnel_apps(&self, _: Request<()>) -> ServiceResult<()> { + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn set_split_tunnel_state(&self, request: Request<bool>) -> ServiceResult<()> { + log::debug!("set_split_tunnel_state"); + let enabled = request.into_inner(); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::SetSplitTunnelState(tx, enabled))?; + self.wait_for_result(rx) + .await? + .map_err(map_daemon_error) + .map(Response::new) + } + #[cfg(not(windows))] + async fn set_split_tunnel_state(&self, _: Request<bool>) -> ServiceResult<()> { + Ok(Response::new(())) + } } impl ManagementServiceImpl { diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs index 0b75d0005c..1499ddab53 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings.rs @@ -6,6 +6,8 @@ use mullvad_types::{ settings::{DnsOptions, Settings}, wireguard::{RotationInterval, WireguardData}, }; +#[cfg(target_os = "windows")] +use std::collections::HashSet; use std::{ ops::Deref, path::{Path, PathBuf}, @@ -312,6 +314,21 @@ impl SettingsPersister { self.update(should_save).await } + #[cfg(windows)] + pub async fn set_split_tunnel_apps(&mut self, paths: HashSet<PathBuf>) -> Result<bool, Error> { + let should_save = paths != self.settings.split_tunnel_apps; + if should_save { + self.settings.split_tunnel_apps = paths; + } + self.update(should_save).await + } + + #[cfg(windows)] + pub async fn set_split_tunnel_state(&mut self, enabled: bool) -> Result<bool, Error> { + let should_save = Self::update_field(&mut self.settings.split_tunnel, enabled); + self.update(should_save).await + } + fn update_field<T: Eq>(field: &mut T, new_value: T) -> bool { if *field != new_value { *field = new_value; diff --git a/mullvad-management-interface/Cargo.toml b/mullvad-management-interface/Cargo.toml index 6b4cce66ac..23a5d907cd 100644 --- a/mullvad-management-interface/Cargo.toml +++ b/mullvad-management-interface/Cargo.toml @@ -20,6 +20,7 @@ parity-tokio-ipc = "0.8" futures = "0.3" tokio = { version = "0.2", features = [ "rt-util" ] } triggered = "0.1.1" +log = "0.4" [target.'cfg(unix)'.dependencies] nix = "0.19" diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 636148c7d6..e5ee4cb23f 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -58,11 +58,18 @@ service ManagementService { rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {} rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {} - // Split tunneling + // Split tunneling (Linux) rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {} rpc AddSplitTunnelProcess(google.protobuf.Int32Value) returns (google.protobuf.Empty) {} rpc RemoveSplitTunnelProcess(google.protobuf.Int32Value) returns (google.protobuf.Empty) {} rpc ClearSplitTunnelProcesses(google.protobuf.Empty) returns (google.protobuf.Empty) {} + + // Split tunneling (Windows) + rpc GetSplitTunnelApps(google.protobuf.Empty) returns (stream google.protobuf.StringValue) {} + rpc AddSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {} + rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {} + rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {} } message RelaySettingsUpdate { @@ -262,6 +269,8 @@ message Settings { bool auto_connect = 7; TunnelOptions tunnel_options = 8; bool show_beta_releases = 9; + bool split_tunnel = 10; + repeated string split_tunnel_apps = 11; } message RelaySettings { diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs index 7d219e0b4a..bebdb14d77 100644 --- a/mullvad-management-interface/src/types.rs +++ b/mullvad-management-interface/src/types.rs @@ -359,6 +359,20 @@ impl From<mullvad_types::relay_constraints::LocationConstraint> for RelayLocatio impl From<&mullvad_types::settings::Settings> for Settings { fn from(settings: &mullvad_types::settings::Settings) -> Self { + #[cfg(windows)] + let split_tunnel_apps = { + let mut converted_list = vec![]; + for path in settings.split_tunnel_apps.clone().iter() { + match path.as_path().as_os_str().to_str() { + Some(path) => converted_list.push(path.to_string()), + None => { + log::error!("failed to convert OS string: {:?}", path); + } + } + } + converted_list + }; + Self { account_token: settings.get_account_token().unwrap_or_default(), relay_settings: Some(RelaySettings::from(settings.get_relay_settings())), @@ -369,6 +383,14 @@ impl From<&mullvad_types::settings::Settings> for Settings { auto_connect: settings.auto_connect, tunnel_options: Some(TunnelOptions::from(&settings.tunnel_options)), show_beta_releases: settings.show_beta_releases, + #[cfg(windows)] + split_tunnel: settings.split_tunnel, + #[cfg(windows)] + split_tunnel_apps, + #[cfg(not(windows))] + split_tunnel: false, + #[cfg(not(windows))] + split_tunnel_apps: Vec::new(), } } } diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs index faddb9ad0c..bfcffe878a 100644 --- a/mullvad-types/src/settings/mod.rs +++ b/mullvad-types/src/settings/mod.rs @@ -11,6 +11,8 @@ use log::{debug, info}; use serde::{Deserialize, Serialize}; use serde_json; use std::net::IpAddr; +#[cfg(target_os = "windows")] +use std::{collections::HashSet, path::PathBuf}; use talpid_types::net::{self, openvpn, GenericTunnelOptions}; mod migrations; @@ -58,6 +60,12 @@ pub struct Settings { pub tunnel_options: TunnelOptions, /// Whether to notify users of beta updates. pub show_beta_releases: bool, + /// Whether to enable split tunneling for [`Settings::split_tunnel_apps`]. + #[cfg(windows)] + pub split_tunnel: bool, + /// List of applications to exclude from the tunnel. + #[cfg(windows)] + pub split_tunnel_apps: HashSet<PathBuf>, /// Specifies settings schema version #[cfg_attr(target_os = "android", jnix(skip))] settings_version: migrations::SettingsVersion, @@ -79,6 +87,10 @@ impl Default for Settings { auto_connect: false, tunnel_options: TunnelOptions::default(), show_beta_releases: false, + #[cfg(windows)] + split_tunnel: false, + #[cfg(windows)] + split_tunnel_apps: HashSet::new(), settings_version: migrations::CURRENT_SETTINGS_VERSION, } } diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 9d372a7c00..0635b5bcf6 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -78,7 +78,7 @@ byteorder = "1" internet-checksum = "0.2" widestring = "0.4" winreg = { version = "0.7", features = ["transactions"] } -winapi = { version = "0.3.6", features = ["combaseapi", "handleapi", "ifdef", "libloaderapi", "netioapi", "stringapiset", "synchapi", "winbase", "winerror", "winuser"] } +winapi = { version = "0.3.6", features = ["combaseapi", "handleapi", "ifdef", "libloaderapi", "netioapi", "psapi", "stringapiset", "synchapi", "winbase", "winioctl", "winuser"] } socket2 = "0.3" talpid-platform-metadata = { path = "../talpid-platform-metadata" } diff --git a/talpid-core/src/split_tunnel/mod.rs b/talpid-core/src/split_tunnel/mod.rs index c7c366d6ea..3c3f6af294 100644 --- a/talpid-core/src/split_tunnel/mod.rs +++ b/talpid-core/src/split_tunnel/mod.rs @@ -4,3 +4,10 @@ mod imp; #[cfg(target_os = "linux")] pub use imp::*; + +#[cfg(windows)] +#[path = "windows/mod.rs"] +mod imp; + +#[cfg(windows)] +pub use imp::*; diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs new file mode 100644 index 0000000000..26495a5877 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -0,0 +1,514 @@ +use super::windows::{ + get_final_path_name, get_process_creation_time, get_process_device_path, open_process, + ProcessAccess, ProcessSnapshot, +}; +use std::{ + cell::RefCell, + collections::HashMap, + ffi::{OsStr, OsString}, + fs::{self, OpenOptions}, + io, + mem::{self, size_of}, + net::{Ipv4Addr, Ipv6Addr}, + os::windows::{ + ffi::OsStrExt, + fs::OpenOptionsExt, + io::{AsRawHandle, RawHandle}, + }, + ptr, +}; +use winapi::{ + shared::{in6addr::IN6_ADDR, inaddr::IN_ADDR}, + um::{ + ioapiset::DeviceIoControl, + tlhelp32::TH32CS_SNAPPROCESS, + winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER}, + }, +}; + +const DRIVER_SYMBOLIC_NAME: &str = "\\\\.\\MULLVADSPLITTUNNEL"; +const ST_DEVICE_TYPE: u32 = 0x8000; + +const fn ctl_code(device_type: u32, function: u32, method: u32, access: u32) -> u32 { + device_type << 16 | access << 14 | function << 2 | method +} + +#[repr(u32)] +#[allow(dead_code)] +enum DriverIoctlCode { + Initialize = ctl_code(ST_DEVICE_TYPE, 1, METHOD_NEITHER, FILE_ANY_ACCESS), + DequeEvent = ctl_code(ST_DEVICE_TYPE, 2, METHOD_BUFFERED, FILE_ANY_ACCESS), + RegisterProcesses = ctl_code(ST_DEVICE_TYPE, 3, METHOD_BUFFERED, FILE_ANY_ACCESS), + RegisterIpAddresses = ctl_code(ST_DEVICE_TYPE, 4, METHOD_BUFFERED, FILE_ANY_ACCESS), + GetIpAddresses = ctl_code(ST_DEVICE_TYPE, 5, METHOD_BUFFERED, FILE_ANY_ACCESS), + SetConfiguration = ctl_code(ST_DEVICE_TYPE, 6, METHOD_BUFFERED, FILE_ANY_ACCESS), + GetConfiguration = ctl_code(ST_DEVICE_TYPE, 7, METHOD_BUFFERED, FILE_ANY_ACCESS), + ClearConfiguration = ctl_code(ST_DEVICE_TYPE, 8, METHOD_NEITHER, FILE_ANY_ACCESS), + GetState = ctl_code(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS), + QueryProcess = ctl_code(ST_DEVICE_TYPE, 10, METHOD_BUFFERED, FILE_ANY_ACCESS), +} + +#[derive(Debug, PartialEq)] +#[repr(u32)] +#[allow(dead_code)] +pub enum DriverState { + // Default state after being loaded. + None = 0, + // DriverEntry has completed successfully. + // Basically only driver and device objects are created at this point. + Started = 1, + // All subsystems are initialized. + Initialized = 2, + // User mode has registered all processes in the system. + Ready = 3, + // IP addresses are registered. + // A valid configuration is registered. + Engaged = 4, + // Driver is unloading. + Terminating = 5, +} + +pub struct DeviceHandle { + handle: fs::File, +} + +impl DeviceHandle { + pub fn new() -> io::Result<Self> { + // Connect to the driver + log::trace!("Connecting to the driver"); + let handle = OpenOptions::new() + .read(true) + .write(true) + .share_mode(0) + .custom_flags(0) + .attributes(0) + .open(DRIVER_SYMBOLIC_NAME)?; + + let device = Self { handle }; + + // Initialize the driver + let state = device.get_driver_state()?; + if state == DriverState::Started { + log::trace!("Initializing driver"); + device.initialize()?; + } + + // Initialize process tree + let state = device.get_driver_state()?; + if state == DriverState::Initialized { + log::trace!("Registering processes"); + device.register_processes()?; + } + + Ok(device) + } + + fn initialize(&self) -> io::Result<()> { + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::Initialize as u32, + None, + 0, + )?; + Ok(()) + } + + fn register_processes(&self) -> io::Result<()> { + let process_tree_buffer = serialize_process_tree(build_process_tree()?)?; + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::RegisterProcesses as u32, + Some(&process_tree_buffer), + 0, + )?; + Ok(()) + } + + pub fn register_ips( + &self, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> io::Result<()> { + let mut addresses: SplitTunnelAddresses = unsafe { mem::zeroed() }; + + unsafe { + let tunnel_ipv4 = tunnel_ipv4.octets(); + ptr::copy_nonoverlapping( + &tunnel_ipv4[0] as *const u8, + &mut addresses.tunnel_ipv4 as *mut _ as *mut u8, + tunnel_ipv4.len(), + ); + + if let Some(tunnel_ipv6) = tunnel_ipv6 { + let tunnel_ipv6 = tunnel_ipv6.octets(); + ptr::copy_nonoverlapping( + &tunnel_ipv6[0] as *const u8, + &mut addresses.tunnel_ipv6 as *mut _ as *mut u8, + tunnel_ipv6.len(), + ); + } + + let internet_ipv4 = internet_ipv4.octets(); + ptr::copy_nonoverlapping( + &internet_ipv4[0] as *const u8, + &mut addresses.internet_ipv4 as *mut _ as *mut u8, + internet_ipv4.len(), + ); + + if let Some(internet_ipv6) = internet_ipv6 { + let internet_ipv6 = internet_ipv6.octets(); + ptr::copy_nonoverlapping( + &internet_ipv6[0] as *const u8, + &mut addresses.internet_ipv6 as *mut _ as *mut u8, + internet_ipv6.len(), + ); + } + } + + let buffer = &addresses as *const _ as *const u8; + let buffer = + unsafe { std::slice::from_raw_parts(buffer, size_of::<SplitTunnelAddresses>()) }; + + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::RegisterIpAddresses as u32, + Some(buffer), + 0, + )?; + + Ok(()) + } + + pub fn get_driver_state(&self) -> io::Result<DriverState> { + let buffer = device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::GetState as u32, + None, + size_of::<u64>() as u32, + )? + .unwrap(); + + Ok(unsafe { deserialize_buffer(&buffer) }) + } + + pub fn set_config<T: AsRef<OsStr>>(&self, apps: &[T]) -> io::Result<()> { + let mut device_paths = Vec::with_capacity(apps.len()); + for app in apps.as_ref() { + device_paths.push(get_final_path_name(app)?); + } + + log::debug!("Excluded device paths:"); + for path in &device_paths { + log::debug!(" {:?}", path); + } + + let config = make_process_config(&device_paths); + + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::SetConfiguration as u32, + Some(&config), + 0, + )?; + + Ok(()) + } + + pub fn clear_config(&self) -> io::Result<()> { + device_io_control( + self.handle.as_raw_handle(), + DriverIoctlCode::ClearConfiguration as u32, + None, + 0, + )?; + + Ok(()) + } +} + +#[repr(C)] +struct SplitTunnelAddresses { + tunnel_ipv4: IN_ADDR, + internet_ipv4: IN_ADDR, + tunnel_ipv6: IN6_ADDR, + internet_ipv6: IN6_ADDR, +} + +#[repr(C)] +struct ConfigurationHeader { + // Number of entries immediately following the header. + num_entries: usize, + // Total byte length: header + entries + string buffer. + total_length: usize, +} + +#[repr(C)] +struct ConfigurationEntry { + // Offset into buffer region that follows all entries. + // The image name uses the physical path. + name_offset: usize, + // Byte length for non-null terminated wide char string. + name_length: u16, +} + +/// Create a buffer containing a `ConfigurationHeader` and number of `ConfigurationEntry`s, +/// followed by the same number of paths to those entries. +fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> { + let apps: Vec<Vec<u16>> = apps + .iter() + .map(|app| app.as_ref().encode_wide().collect()) + .collect(); + + let total_string_size: usize = apps.iter().map(|app| size_of::<u16>() * app.len()).sum(); + + let total_buffer_size = size_of::<ConfigurationHeader>() + + size_of::<ConfigurationEntry>() * apps.len() + + total_string_size; + + let mut buffer = Vec::<u8>::new(); + buffer.resize(total_buffer_size, 0); + + let (header, tail) = buffer.split_at_mut(size_of::<ConfigurationHeader>()); + + // Serialize configuration header + let header_struct = ConfigurationHeader { + num_entries: apps.len(), + total_length: total_buffer_size, + }; + header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + + // Serialize configuration entries and strings + let (entries, string_data) = tail.split_at_mut(apps.len() * size_of::<ConfigurationEntry>()); + let mut string_offset = 0; + + for (i, app) in apps.iter().enumerate() { + write_string_to_buffer(string_data, string_offset, &app); + + let app_bytelen = size_of::<u16>() * app.len(); + let entry = ConfigurationEntry { + name_offset: string_offset, + name_length: app_bytelen as u16, + }; + let entry_offset = size_of::<ConfigurationEntry>() * i; + entries[entry_offset..entry_offset + size_of::<ConfigurationEntry>()] + .copy_from_slice(unsafe { as_u8_slice(&entry) }); + + string_offset += app_bytelen; + } + + buffer +} + +#[derive(Debug)] +struct ProcessInfo { + pid: u32, + parent_pid: u32, + creation_time: u64, + device_path: Vec<u16>, +} + +/// List process identifiers, their parents, and their device paths. +fn build_process_tree() -> io::Result<Vec<ProcessInfo>> { + let mut process_info = HashMap::new(); + + let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?; + for entry in snap.entries() { + let entry = entry?; + + let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) { + Ok(handle) => Ok(handle), + Err(error) => { + // Skip process objects that cannot be opened + match error.kind() { + // System process + io::ErrorKind::PermissionDenied => continue, + // System idle or csrss process + io::ErrorKind::InvalidInput => continue, + _ => Err(error), + } + } + }?; + + // TODO: Skip objects whose paths or timestamps cannot be obtained? + + process_info.insert( + entry.pid, + RefCell::new(ProcessInfo { + pid: entry.pid, + parent_pid: entry.parent_pid, + creation_time: get_process_creation_time(process.get_raw()).unwrap_or(0), + device_path: get_process_device_path(process.get_raw()) + .unwrap_or(OsString::from("")) + .encode_wide() + .collect(), + }), + ); + } + + // Handle pid recycling + // If the "parent" is younger than the process itself, it is not our parent. + for info in process_info.values() { + let mut info = info.borrow_mut(); + let parent_pid = info.parent_pid; + if parent_pid == 0 { + continue; + } + if let Some(parent_info) = process_info.get(&parent_pid) { + if parent_info.borrow_mut().creation_time > info.creation_time { + info.parent_pid = 0; + } + } + } + + Ok(process_info + .into_iter() + .map(|(_, info)| info.into_inner()) + .collect()) +} + +#[repr(C)] +struct ProcessRegistryHeader { + // Number of entries immediately following the header. + num_entries: usize, + // Total byte length: header + entries + string buffer. + total_length: usize, +} + +#[repr(C)] +struct ProcessRegistryEntry { + pid: RawHandle, + parent_pid: RawHandle, + // Image name offset (following the last entry). + image_name_offset: usize, + // Image name length. + image_name_size: u16, +} + +fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Error> { + // Construct a buffer: + // ProcessRegistryHeader + // ProcessRegistryEntry.. + // Image names.. + + let total_string_size: usize = processes + .iter() + .map(|info| size_of::<u16>() * info.device_path.len()) + .sum(); + let total_buffer_size = size_of::<ProcessRegistryHeader>() + + size_of::<ProcessRegistryEntry>() * processes.len() + + total_string_size; + + let mut buffer = Vec::<u8>::new(); + buffer.resize(total_buffer_size, 0); + + let (header, tail) = buffer.split_at_mut(size_of::<ProcessRegistryHeader>()); + let header_struct = ProcessRegistryHeader { + num_entries: processes.len(), + total_length: total_buffer_size, + }; + header.copy_from_slice(unsafe { as_u8_slice(&header_struct) }); + + let (entries, string_data) = + tail.split_at_mut(size_of::<ProcessRegistryEntry>() * processes.len()); + + let mut string_offset = 0; + + for (i, entry) in processes.into_iter().enumerate() { + let mut out_entry = ProcessRegistryEntry { + pid: entry.pid as usize as RawHandle, + parent_pid: entry.parent_pid as usize as RawHandle, + image_name_size: 0, + image_name_offset: 0, + }; + + if !entry.device_path.is_empty() { + write_string_to_buffer(string_data, string_offset, &entry.device_path); + + out_entry.image_name_size = (entry.device_path.len() * size_of::<u16>()) as u16; + out_entry.image_name_offset = string_offset; + + string_offset += size_of::<u16>() * entry.device_path.len(); + } + + let entry_offset = size_of::<ProcessRegistryEntry>() * i; + entries[entry_offset..entry_offset + size_of::<ProcessRegistryEntry>()] + .copy_from_slice(unsafe { as_u8_slice(&out_entry) }); + } + + Ok(buffer) +} + +/// Send an IOCTL code to the given device handle. +/// `input` specifies an optional buffer to send. +/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0. +pub fn device_io_control( + device: RawHandle, + ioctl_code: u32, + input: Option<&[u8]>, + output_size: u32, +) -> Result<Option<Vec<u8>>, io::Error> { + let input_ptr = match input { + Some(input) => input as *const _ as *mut _, + None => ptr::null_mut(), + }; + let input_len = input.map(|input| input.len()).unwrap_or(0); + + let mut out_buffer = if output_size > 0 { + Some(Vec::with_capacity(output_size as usize)) + } else { + None + }; + + let out_ptr = match out_buffer { + Some(ref mut out_buffer) => out_buffer.as_mut_ptr() as *mut _, + None => ptr::null_mut(), + }; + + let mut returned_bytes = 0u32; + + let result = unsafe { + DeviceIoControl( + device as *mut _, + ioctl_code, + input_ptr, + input_len as u32, + out_ptr, + output_size, + &mut returned_bytes as *mut _, + ptr::null_mut(), // TODO + ) + }; + + if let Some(ref mut out_buffer) = out_buffer { + unsafe { out_buffer.set_len(returned_bytes as usize) }; + } + + if result != 0 { + Ok(out_buffer) + } else { + Err(io::Error::last_os_error()) + } +} + +/// Creates a new instance of an arbitrary type from a byte buffer. +pub unsafe fn deserialize_buffer<T: Sized>(buffer: &Vec<u8>) -> T { + let mut instance: T = mem::zeroed(); + ptr::copy_nonoverlapping(buffer.as_ptr() as *const T, &mut instance as *mut _, 1); + instance +} + +fn write_string_to_buffer(buffer: &mut [u8], byte_offset: usize, string: &[u16]) { + for (i, byte) in string + .iter() + .flat_map(|word| std::array::IntoIter::new(word.to_ne_bytes())) + .enumerate() + { + buffer[byte_offset + i] = byte; + } +} + +unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] { + std::slice::from_raw_parts(object as *const _ as *const _, size_of::<T>()) +} diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs new file mode 100644 index 0000000000..c6b8fae332 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -0,0 +1,69 @@ +mod driver; +mod windows; + +use std::{ + ffi::OsStr, + io, + net::{Ipv4Addr, Ipv6Addr}, +}; +use talpid_types::ErrorExt; + +/// Errors that may occur in [`SplitTunnel`]. +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Failed to identify or initialize the driver + #[error(display = "Failed to find or initialize driver")] + InitializationFailed(#[error(source)] io::Error), + + /// Failed to set paths to excluded applications + #[error(display = "Failed to set list of excluded applications")] + SetConfiguration(#[error(source)] io::Error), + + /// Failed to register interface IP addresses + #[error(display = "Failed to register IP addresses for exclusions")] + RegisterIps(#[error(source)] io::Error), +} + +/// Manages applications whose traffic to exclude from the tunnel. +pub struct SplitTunnel(driver::DeviceHandle); + +impl SplitTunnel { + /// Initialize the driver. + pub fn new() -> Result<Self, Error> { + Ok(SplitTunnel( + driver::DeviceHandle::new().map_err(Error::InitializationFailed)?, + )) + } + + /// Set a list of applications to exclude from the tunnel. + pub fn set_paths<T: AsRef<OsStr>>(&self, paths: &[T]) -> Result<(), Error> { + if paths.len() > 0 { + self.0.set_config(paths).map_err(Error::SetConfiguration) + } else { + self.0.clear_config().map_err(Error::SetConfiguration) + } + } + + /// Configures IP addresses used for socket rebinding. + pub fn register_ips( + &self, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> Result<(), Error> { + self.0 + .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6) + .map_err(Error::RegisterIps) + } +} + +impl Drop for SplitTunnel { + fn drop(&mut self) { + let paths: [&OsStr; 0] = []; + if let Err(error) = self.set_paths(&paths) { + log::error!("{}", error.display_chain()); + } + } +} diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs new file mode 100644 index 0000000000..be8631d53c --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -0,0 +1,259 @@ +// TODO: The snapshot code could be combined with the mostly-identical code in +// the windows_exception_logging module. + +use std::{ + ffi::{OsStr, OsString}, + fs::OpenOptions, + io, mem, + os::windows::{ + ffi::OsStringExt, + io::{AsRawHandle, RawHandle}, + }, + ptr, +}; +use winapi::{ + shared::{ + minwindef::{DWORD, FALSE, FILETIME, TRUE}, + ntdef::ULARGE_INTEGER, + winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES}, + }, + um::{ + fileapi::GetFinalPathNameByHandleW, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + processthreadsapi::{GetProcessTimes, OpenProcess}, + psapi::K32GetProcessImageFileNameW, + tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W}, + winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION}, + }, +}; + +/// Return path with the volume device path. +const VOLUME_NAME_NT: u32 = 0x02; + +pub struct ProcessSnapshot { + handle: HANDLE, +} + +impl ProcessSnapshot { + pub fn new(flags: DWORD, process_id: DWORD) -> io::Result<ProcessSnapshot> { + let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; + + if snap == INVALID_HANDLE_VALUE { + Err(io::Error::last_os_error()) + } else { + Ok(ProcessSnapshot { handle: snap }) + } + } + + pub fn handle(&self) -> HANDLE { + self.handle + } + + pub fn entries(&self) -> ProcessSnapshotEntries<'_> { + let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() }; + entry.dwSize = mem::size_of::<PROCESSENTRY32W>() as u32; + + ProcessSnapshotEntries { + snapshot: self, + iter_started: false, + temp_entry: entry, + } + } +} + +impl Drop for ProcessSnapshot { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +pub struct ProcessEntry { + pub pid: u32, + pub parent_pid: u32, +} + +pub struct ProcessSnapshotEntries<'a> { + snapshot: &'a ProcessSnapshot, + iter_started: bool, + temp_entry: PROCESSENTRY32W, +} + +impl Iterator for ProcessSnapshotEntries<'_> { + type Item = io::Result<ProcessEntry>; + + fn next(&mut self) -> Option<io::Result<ProcessEntry>> { + if self.iter_started { + if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + let last_error = io::Error::last_os_error(); + + return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { + None + } else { + Some(Err(last_error)) + }; + } + } else { + if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE { + return Some(Err(io::Error::last_os_error())); + } + self.iter_started = true; + } + + Some(Ok(ProcessEntry { + pid: self.temp_entry.th32ProcessID, + parent_pid: self.temp_entry.th32ParentProcessID, + })) + } +} + +pub fn get_final_path_name<T: AsRef<OsStr>>(path: T) -> Result<OsString, io::Error> { + // TODO: verify that all flags, including security flags, are ok + // TODO: verify that the file is a PE executable? + // TODO: verify that the executable is on a physical drive? + let file = OpenOptions::new().read(true).open(path.as_ref())?; + get_final_path_name_by_handle(file.as_raw_handle()) +} + +pub fn get_final_path_name_by_handle(raw_handle: RawHandle) -> Result<OsString, io::Error> { + let buffer_size = unsafe { + GetFinalPathNameByHandleW(raw_handle as *mut _, ptr::null_mut(), 0u32, VOLUME_NAME_NT) + } as usize; + + if buffer_size == 0 { + return Err(io::Error::last_os_error()); + } + + let mut buffer = Vec::new(); + buffer.reserve_exact(buffer_size); + + let status = unsafe { + GetFinalPathNameByHandleW( + raw_handle as *mut _, + buffer.as_mut_ptr(), + buffer_size as u32, + VOLUME_NAME_NT, + ) + } as usize; + + if status == 0 { + return Err(io::Error::last_os_error()); + } + + unsafe { buffer.set_len(buffer_size - 1) }; + + // TODO: can this be done by stealing 'buffer' instead of copying it? + Ok(OsStringExt::from_wide(&buffer)) +} + +/// Object that frees its handle when dropped. +pub struct WinHandle(RawHandle); + +impl WinHandle { + pub fn get_raw(&self) -> RawHandle { + self.0 + } +} + +impl Drop for WinHandle { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + +#[repr(u32)] +pub enum ProcessAccess { + QueryLimitedInformation = PROCESS_QUERY_LIMITED_INFORMATION, + // TODO: could be extended +} + +/// Open an existing process object. +pub fn open_process( + access: ProcessAccess, + inherit_handle: bool, + pid: u32, +) -> Result<WinHandle, io::Error> { + let handle = unsafe { + OpenProcess( + access as u32, + if inherit_handle { TRUE } else { FALSE }, + pid, + ) + }; + + if handle == ptr::null_mut() { + return Err(io::Error::last_os_error()); + } + Ok(WinHandle(handle)) +} + +/// Returns the age of a running process. +pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> { + // TODO: FileTimeToSystemTime -> chrono::NaiveDateTime + let mut creation_time: FILETIME = unsafe { mem::zeroed() }; + let mut dummy: FILETIME = unsafe { mem::zeroed() }; + if unsafe { + GetProcessTimes( + handle, + &mut creation_time as *mut _, + &mut dummy as *mut _, + &mut dummy as *mut _, + &mut dummy as *mut _, + ) + } == 0 + { + return Err(io::Error::last_os_error()); + } + + let mut uli_time: ULARGE_INTEGER = unsafe { mem::zeroed() }; + unsafe { + uli_time.s_mut().LowPart = creation_time.dwLowDateTime; + uli_time.s_mut().HighPart = creation_time.dwHighDateTime; + } + + Ok(*unsafe { uli_time.QuadPart() }) +} + +/// Returns the device path for a running process. +pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> { + let mut initial_capacity = 512; + loop { + let result = get_process_device_path_inner(handle, initial_capacity); + match result { + Ok(path) => return Ok(path), + Err(error) => { + if ERROR_INSUFFICIENT_BUFFER == error.raw_os_error().unwrap() as u32 { + // Try again with a larger buffer capacity. + initial_capacity *= 2; + continue; + } + return Err(error); + } + } + } +} + +fn get_process_device_path_inner( + handle: RawHandle, + buffer_capacity: usize, +) -> Result<OsString, io::Error> { + let mut buffer = Vec::<u16>::new(); + buffer.reserve_exact(buffer_capacity); + + let written = unsafe { + K32GetProcessImageFileNameW( + handle, + buffer.as_mut_ptr() as *mut _, + buffer.capacity() as u32, + ) + }; + if written == 0 { + return Err(io::Error::last_os_error()); + } + + // `written` does not include a null terminator + unsafe { buffer.set_len(written as usize) }; + + Ok(OsStringExt::from_wide(&buffer)) +} diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 72af2b6f38..d7bf24b49e 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -7,9 +7,20 @@ use crate::{ firewall::FirewallPolicy, tunnel::{CloseHandle, TunnelEvent, TunnelMetadata}, }; +#[cfg(windows)] +use crate::{ + split_tunnel::{self, SplitTunnel}, + winnet::{self, get_best_default_route, interface_luid_to_ip, WinNetAddrFamily}, +}; use cfg_if::cfg_if; use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::net::IpAddr; +#[cfg(windows)] +use std::{ + ffi::OsStr, + net::{Ipv4Addr, Ipv6Addr}, + sync::{Arc, Mutex}, +}; use talpid_types::{ net::TunnelParameters, tunnel::{ErrorStateCause, FirewallPolicyError}, @@ -116,6 +127,137 @@ impl ConnectedState { } } + #[cfg(target_os = "windows")] + pub unsafe extern "system" fn split_tunnel_default_route_change_handler( + event_type: winnet::WinNetDefaultRouteChangeEventType, + address_family: WinNetAddrFamily, + default_route: winnet::WinNetDefaultRoute, + ctx: *mut libc::c_void, + ) { + // Update the "internet interface" IP when best default route changes + let ctx = &mut *(ctx as *mut SplitTunnelDefaultRouteChangeHandlerContext); + + let result = match event_type { + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => { + let ip = interface_luid_to_ip(address_family.clone(), default_route.interface_luid); + + // TODO: Should we block here? + let ip = match ip { + Ok(Some(ip)) => ip, + Ok(None) => { + log::error!("Failed to obtain new default route address: none found",); + // Early return + return; + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to obtain new default route address" + ) + ); + // Early return + return; + } + }; + + match address_family { + WinNetAddrFamily::IPV4 => { + let ip = Ipv4Addr::from(ip); + ctx.internet_ipv4 = ip; + } + WinNetAddrFamily::IPV6 => { + let ip = Ipv6Addr::from(ip); + ctx.internet_ipv6 = Some(ip); + } + } + + ctx.register_ips() + } + // no default route + winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => { + match address_family { + WinNetAddrFamily::IPV4 => { + ctx.internet_ipv4 = Ipv4Addr::new(0, 0, 0, 0); + } + WinNetAddrFamily::IPV6 => { + ctx.internet_ipv6 = None; + } + } + ctx.register_ips() + } + }; + + if let Err(error) = result { + // TODO: Should we block here? + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to register new addresses in split tunnel driver" + ) + ); + } + } + + #[cfg(windows)] + fn update_split_tunnel_addresses( + &self, + shared_values: &mut SharedTunnelStateValues, + ) -> Result<(), BoxedError> { + // Identify tunnel IP addresses + // TODO: Multiple IP addresses? + let mut tunnel_ipv4 = None; + let mut tunnel_ipv6 = None; + + for ip in &self.metadata.ips { + match ip { + IpAddr::V4(address) => tunnel_ipv4 = Some(address.clone()), + IpAddr::V6(address) => tunnel_ipv6 = Some(address.clone()), + } + } + + // Identify IP address that gives us Internet access + let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4) + .map_err(BoxedError::new)? + .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV4, route.interface_luid)) + .transpose() + .map_err(BoxedError::new)? + .flatten(); + let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6) + .map_err(BoxedError::new)? + .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV6, route.interface_luid)) + .transpose() + .map_err(BoxedError::new)? + .flatten(); + + let tunnel_ipv4 = tunnel_ipv4.unwrap_or(Ipv4Addr::new(0, 0, 0, 0)); + let internet_ipv4 = Ipv4Addr::from(internet_ipv4.unwrap_or_default()); + let internet_ipv6 = internet_ipv6.map(|addr| Ipv6Addr::from(addr)); + + let context = SplitTunnelDefaultRouteChangeHandlerContext::new( + shared_values.split_tunnel.clone(), + tunnel_ipv4, + tunnel_ipv6, + internet_ipv4, + internet_ipv6, + ); + + shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex") + .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6) + .map_err(BoxedError::new)?; + + #[cfg(target_os = "windows")] + shared_values.route_manager.add_default_route_callback( + Some(Self::split_tunnel_default_route_change_handler), + context, + ); + + Ok(()) + } + fn set_dns(&self, shared_values: &mut SharedTunnelStateValues) -> Result<(), BoxedError> { let dns_ips = self.get_dns_servers(shared_values); shared_values @@ -150,6 +292,18 @@ impl ConnectedState { } } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn disconnect( self, shared_values: &mut SharedTunnelStateValues, @@ -158,6 +312,24 @@ impl ConnectedState { Self::reset_dns(shared_values); Self::reset_routes(shared_values); + #[cfg(windows)] + if let Err(error) = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex") + .register_ips( + Ipv4Addr::new(0, 0, 0, 0), + None, + Ipv4Addr::new(0, 0, 0, 0), + None, + ) + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to unregister IP addresses") + ); + } + EventConsequence::NewState(DisconnectingState::enter( shared_values, (self.close_handle, self.tunnel_close_event, after_disconnect), @@ -257,6 +429,11 @@ impl ConnectedState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } @@ -326,6 +503,19 @@ impl TunnelState for ConnectedState { ), ) } else { + #[cfg(windows)] + if let Err(error) = connected_state.update_split_tunnel_addresses(shared_values) { + log::error!("{}", error.display_chain()); + return DisconnectingState::enter( + shared_values, + ( + connected_state.close_handle, + connected_state.tunnel_close_event, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ), + ); + } + ( TunnelStateWrapper::from(connected_state), TunnelStateTransition::Connected(tunnel_endpoint), @@ -360,3 +550,44 @@ impl TunnelState for ConnectedState { } } } + +#[cfg(target_os = "windows")] +struct SplitTunnelDefaultRouteChangeHandlerContext { + split_tunnel: Arc<Mutex<SplitTunnel>>, + pub tunnel_ipv4: Ipv4Addr, + pub tunnel_ipv6: Option<Ipv6Addr>, + pub internet_ipv4: Ipv4Addr, + pub internet_ipv6: Option<Ipv6Addr>, +} + +#[cfg(target_os = "windows")] +impl SplitTunnelDefaultRouteChangeHandlerContext { + pub fn new( + split_tunnel: Arc<Mutex<SplitTunnel>>, + tunnel_ipv4: Ipv4Addr, + tunnel_ipv6: Option<Ipv6Addr>, + internet_ipv4: Ipv4Addr, + internet_ipv6: Option<Ipv6Addr>, + ) -> Self { + SplitTunnelDefaultRouteChangeHandlerContext { + split_tunnel, + tunnel_ipv4, + tunnel_ipv6, + internet_ipv4, + internet_ipv6, + } + } + + pub fn register_ips(&self) -> Result<(), split_tunnel::Error> { + let split_tunnel = self + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.register_ips( + self.tunnel_ipv4, + self.tunnel_ipv6, + self.internet_ipv4, + self.internet_ipv6, + ) + } +} diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index b0c87acdb4..34e9eeb2be 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -3,6 +3,8 @@ use super::{ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::{ firewall::FirewallPolicy, routing::RouteManager, @@ -17,6 +19,8 @@ use futures::{ FutureExt, StreamExt, }; use log::{debug, error, info, trace, warn}; +#[cfg(windows)] +use std::ffi::OsStr; use std::{ path::{Path, PathBuf}, thread, @@ -89,6 +93,18 @@ impl ConnectingState { }) } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn start_tunnel( runtime: tokio::runtime::Handle, parameters: TunnelParameters, @@ -314,6 +330,11 @@ impl ConnectingState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 8d2c9bc0fa..cfa7794af7 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -3,7 +3,11 @@ use super::{ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; +#[cfg(windows)] +use crate::split_tunnel; use futures::StreamExt; +#[cfg(windows)] +use std::ffi::OsStr; use talpid_types::ErrorExt; /// No tunnel is running. @@ -36,6 +40,18 @@ impl DisconnectedState { log::error!("{}", error_chain); } } + + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } } impl TunnelState for DisconnectedState { @@ -115,6 +131,11 @@ impl TunnelState for DisconnectedState { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } Some(_) => SameState(self.into()), None => Finished, } diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 7d308d5971..e488896042 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -3,8 +3,12 @@ use super::{ EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::tunnel::CloseHandle; use futures::{future::FusedFuture, StreamExt}; +#[cfg(windows)] +use std::ffi::OsStr; use std::thread; use talpid_types::{ tunnel::{ActionAfterDisconnect, ErrorStateCause}, @@ -59,6 +63,11 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Nothing } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Nothing + } }, AfterDisconnect::Block(reason) => match command { Some(TunnelCommand::AllowLan(allow_lan)) => { @@ -96,6 +105,11 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Block(reason) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Block(reason) + } None => AfterDisconnect::Block(reason), }, AfterDisconnect::Reconnect(retry_attempt) => match command { @@ -134,12 +148,29 @@ impl DisconnectingState { shared_values.bypass_socket(fd, done_tx); AfterDisconnect::Reconnect(retry_attempt) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + AfterDisconnect::Reconnect(retry_attempt) + } }, }; EventConsequence::SameState(self.into()) } + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } + fn after_disconnect( self, block_reason: Option<ErrorStateCause>, diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 5e647c8201..6d772a62b8 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -3,7 +3,11 @@ use super::{ TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, }; use crate::firewall::FirewallPolicy; +#[cfg(windows)] +use crate::split_tunnel; use futures::StreamExt; +#[cfg(windows)] +use std::ffi::OsStr; use talpid_types::{ tunnel::{self as talpid_tunnel, ErrorStateCause, FirewallPolicyError}, ErrorExt, @@ -61,6 +65,18 @@ impl ErrorState { } } } + + #[cfg(windows)] + fn apply_split_tunnel_config<T: AsRef<OsStr>>( + shared_values: &SharedTunnelStateValues, + paths: &[T], + ) -> Result<(), split_tunnel::Error> { + let split_tunnel = shared_values + .split_tunnel + .lock() + .expect("Thread unexpectedly panicked while holding the mutex"); + split_tunnel.set_paths(paths) + } } impl TunnelState for ErrorState { @@ -151,12 +167,17 @@ impl TunnelState for ErrorState { Some(TunnelCommand::Block(reason)) => { NewState(ErrorState::enter(shared_values, reason)) } - #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); SameState(self.into()) } + #[cfg(windows)] + Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { + // TODO: Do nothing here? + let _ = result_tx.send(Self::apply_split_tunnel_config(shared_values, &paths)); + SameState(self.into()) + } } } } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 38496865a4..969c2116a1 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -11,6 +11,8 @@ use self::{ disconnecting_state::{AfterDisconnect, DisconnectingState}, error_state::ErrorState, }; +#[cfg(windows)] +use crate::split_tunnel; use crate::{ dns::DnsMonitor, firewall::{Firewall, FirewallArguments}, @@ -19,6 +21,11 @@ use crate::{ routing::RouteManager, tunnel::{tun_provider::TunProvider, TunnelEvent}, }; +#[cfg(windows)] +use std::ffi::OsString; +#[cfg(windows)] +use std::sync::Mutex; + use futures::{ channel::{mpsc, oneshot}, stream, StreamExt, @@ -47,9 +54,9 @@ pub enum Error { OfflineMonitorError(#[error(source)] crate::offline::Error), /// Unable to set up split tunneling - #[cfg(target_os = "linux")] + #[cfg(target_os = "windows")] #[error(display = "Failed to initialize split tunneling")] - InitSplitTunneling(#[error(source)] crate::split_tunnel::Error), + InitSplitTunneling(#[error(source)] split_tunnel::Error), /// Failed to initialize the system firewall integration. #[error(display = "Failed to initialize the system firewall integration")] @@ -86,6 +93,7 @@ pub async fn spawn( shutdown_tx: oneshot::Sender<()>, reset_firewall: bool, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(windows)] exclude_paths: Vec<OsString>, ) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> { let (command_tx, command_rx) = mpsc::unbounded(); let command_tx = Arc::new(command_tx); @@ -122,6 +130,8 @@ pub async fn spawn( reset_firewall, #[cfg(target_os = "android")] android_context, + #[cfg(windows)] + exclude_paths, )); let state_machine = match state_machine { Ok(state_machine) => { @@ -169,6 +179,12 @@ pub enum TunnelCommand { /// Bypass a socket, allowing traffic to flow through outside the tunnel. #[cfg(target_os = "android")] BypassSocket(RawFd, oneshot::Sender<()>), + /// Set applications that are allowed to send and receive traffic outside of the tunnel. + #[cfg(windows)] + SetExcludedApps( + oneshot::Sender<Result<(), split_tunnel::Error>>, + Vec<OsString>, + ), } type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>; @@ -207,6 +223,7 @@ impl TunnelStateMachine { commands: mpsc::UnboundedReceiver<TunnelCommand>, reset_firewall: bool, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(windows)] exclude_paths: Vec<OsString>, ) -> Result<Self, Error> { let args = FirewallArguments { initialize_blocked: block_when_disconnected || !reset_firewall, @@ -239,6 +256,14 @@ impl TunnelStateMachine { .await .map_err(Error::OfflineMonitorError)?; let is_offline = offline_monitor.is_offline().await; + + #[cfg(windows)] + let split_tunnel = split_tunnel::SplitTunnel::new().map_err(Error::InitSplitTunneling)?; + #[cfg(windows)] + split_tunnel + .set_paths(&exclude_paths) + .map_err(Error::InitSplitTunneling)?; + let mut shared_values = SharedTunnelStateValues { runtime, firewall, @@ -256,6 +281,8 @@ impl TunnelStateMachine { resource_dir, #[cfg(target_os = "linux")] connectivity_check_was_enabled: None, + #[cfg(windows)] + split_tunnel: Arc::new(Mutex::new(split_tunnel)), }; let (initial_state, _) = DisconnectedState::enter(&mut shared_values, reset_firewall); @@ -337,6 +364,9 @@ struct SharedTunnelStateValues { /// NetworkManager's connecitivity check state. #[cfg(target_os = "linux")] connectivity_check_was_enabled: Option<bool>, + /// Management of excluded apps. + #[cfg(windows)] + split_tunnel: Arc<Mutex<split_tunnel::SplitTunnel>>, } impl SharedTunnelStateValues { diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index 79008f8cbc..8b3d503462 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -85,6 +85,7 @@ pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, E } } +#[derive(Debug, Clone)] #[allow(dead_code)] #[repr(u32)] pub enum WinNetAddrFamily { @@ -121,15 +122,33 @@ pub struct WinNetDefaultRoute { pub gateway: WinNetIp, } -impl From<WinNetIp> for IpAddr { - fn from(addr: WinNetIp) -> IpAddr { +impl From<WinNetIp> for Ipv4Addr { + fn from(addr: WinNetIp) -> Ipv4Addr { match addr.addr_family { WinNetAddrFamily::IPV4 => { let mut bytes: [u8; 4] = Default::default(); bytes.clone_from_slice(&addr.ip_bytes[..4]); - IpAddr::V4(Ipv4Addr::from(bytes)) + Ipv4Addr::from(bytes) } - WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)), + WinNetAddrFamily::IPV6 => panic!("address family mismatch"), + } + } +} + +impl From<WinNetIp> for Ipv6Addr { + fn from(addr: WinNetIp) -> Ipv6Addr { + match addr.addr_family { + WinNetAddrFamily::IPV4 => panic!("address family mismatch"), + WinNetAddrFamily::IPV6 => Ipv6Addr::from(addr.ip_bytes), + } + } +} + +impl From<WinNetIp> for IpAddr { + fn from(addr: WinNetIp) -> IpAddr { + match addr.addr_family { + WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::from(addr)), + WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr)), } } } |
