summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-07-02 16:11:57 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-07-02 16:11:57 +0200
commitf77cde89c6b8d3d580ee773628f89211a24852c5 (patch)
tree93f33f7be297caa2d2e4b0c0ba032e04d82abe57
parent207ab239223686ff72c43a8a5d615565ab81b5ab (diff)
parent4b0d04e6534ebfecc4a905237468774af934bd45 (diff)
downloadmullvadvpn-f77cde89c6b8d3d580ee773628f89211a24852c5.tar.xz
mullvadvpn-f77cde89c6b8d3d580ee773628f89211a24852c5.zip
Merge branch 'split-tunnel-win-update'
-rw-r--r--Cargo.lock11
-rw-r--r--gui/src/main/daemon-rpc.ts2
-rw-r--r--mullvad-cli/src/cmds/mod.rs6
-rw-r--r--mullvad-cli/src/cmds/split_tunnel/mod.rs6
-rw-r--r--mullvad-cli/src/cmds/split_tunnel/windows.rs121
-rw-r--r--mullvad-cli/src/format.rs2
-rw-r--r--mullvad-daemon/src/lib.rs194
-rw-r--r--mullvad-daemon/src/management_interface.rs84
-rw-r--r--mullvad-daemon/src/settings.rs18
-rw-r--r--mullvad-management-interface/Cargo.toml1
-rw-r--r--mullvad-management-interface/proto/management_interface.proto15
-rw-r--r--mullvad-management-interface/src/types.rs25
-rw-r--r--mullvad-types/src/settings/mod.rs16
-rw-r--r--talpid-core/Cargo.toml3
-rw-r--r--talpid-core/src/routing/windows.rs30
-rw-r--r--talpid-core/src/split_tunnel/mod.rs7
-rw-r--r--talpid-core/src/split_tunnel/windows/driver.rs888
-rw-r--r--talpid-core/src/split_tunnel/windows/mod.rs574
-rw-r--r--talpid-core/src/split_tunnel/windows/windows.rs286
-rw-r--r--talpid-core/src/tunnel/mod.rs8
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs73
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs24
-rw-r--r--talpid-core/src/tunnel_state_machine/connected_state.rs20
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs56
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnected_state.rs34
-rw-r--r--talpid-core/src/tunnel_state_machine/disconnecting_state.rs15
-rw-r--r--talpid-core/src/tunnel_state_machine/error_state.rs16
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs43
-rw-r--r--talpid-core/src/winnet.rs37
-rw-r--r--talpid-openvpn-plugin/proto/openvpn_plugin.proto1
-rw-r--r--talpid-openvpn-plugin/src/lib.rs1
-rw-r--r--talpid-openvpn-plugin/src/processing.rs1
-rw-r--r--talpid-types/src/tunnel.rs5
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitdhcp.cpp4
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitdhcpserver.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitdns.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitlan.cpp4
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitlanservice.cpp4
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitloopback.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitndp.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp2
-rw-r--r--windows/winfw/src/winfw/rules/dns/permitnontunnel.cpp4
-rw-r--r--windows/winfw/src/winfw/rules/dns/permittunnel.cpp4
-rw-r--r--windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp2
-rw-r--r--windows/winnet/src/winnet/winnet.cpp2
46 files changed, 2557 insertions, 102 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 7cb2c21c7f..a7a7ff27c4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1147,6 +1147,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
[[package]]
+name = "memoffset"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
name = "miniz_oxide"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1355,6 +1364,7 @@ dependencies = [
"err-derive 0.3.0",
"futures",
"lazy_static",
+ "log",
"mullvad-paths",
"mullvad-types",
"nix 0.19.1",
@@ -2539,6 +2549,7 @@ dependencies = [
"lazy_static",
"libc",
"log",
+ "memoffset",
"mnl",
"netlink-packet-core",
"netlink-packet-route",
diff --git a/gui/src/main/daemon-rpc.ts b/gui/src/main/daemon-rpc.ts
index 8ed75b8885..c740393e2a 100644
--- a/gui/src/main/daemon-rpc.ts
+++ b/gui/src/main/daemon-rpc.ts
@@ -801,6 +801,8 @@ function convertFromTunnelStateErrorCause(
};
return { reason: 'tunnel_parameter_error', details: parameterErrorMap[state.parameterError] };
}
+ case grpcTypes.ErrorState.Cause.SPLIT_TUNNEL_ERROR:
+ return { reason: 'start_tunnel_error' };
case grpcTypes.ErrorState.Cause.VPN_PERMISSION_DENIED:
// VPN_PERMISSION_DENIED is only ever created on Android
throw invalidErrorStateCause;
diff --git a/mullvad-cli/src/cmds/mod.rs b/mullvad-cli/src/cmds/mod.rs
index bd7ca97372..2ceb3bfdcf 100644
--- a/mullvad-cli/src/cmds/mod.rs
+++ b/mullvad-cli/src/cmds/mod.rs
@@ -37,9 +37,9 @@ pub use self::relay::Relay;
mod reset;
pub use self::reset::Reset;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", windows))]
mod split_tunnel;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", windows))]
pub use self::split_tunnel::SplitTunnel;
mod status;
@@ -66,7 +66,7 @@ pub fn get_commands() -> HashMap<&'static str, Box<dyn Command>> {
Box::new(Lan),
Box::new(Relay),
Box::new(Reset),
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", windows))]
Box::new(SplitTunnel),
Box::new(Status),
Box::new(Tunnel),
diff --git a/mullvad-cli/src/cmds/split_tunnel/mod.rs b/mullvad-cli/src/cmds/split_tunnel/mod.rs
index c7c366d6ea..c9e87f5d7c 100644
--- a/mullvad-cli/src/cmds/split_tunnel/mod.rs
+++ b/mullvad-cli/src/cmds/split_tunnel/mod.rs
@@ -2,5 +2,9 @@
#[path = "linux.rs"]
mod imp;
-#[cfg(target_os = "linux")]
+#[cfg(windows)]
+#[path = "windows.rs"]
+mod imp;
+
+#[cfg(any(target_os = "linux", windows))]
pub use imp::*;
diff --git a/mullvad-cli/src/cmds/split_tunnel/windows.rs b/mullvad-cli/src/cmds/split_tunnel/windows.rs
new file mode 100644
index 0000000000..402186a77c
--- /dev/null
+++ b/mullvad-cli/src/cmds/split_tunnel/windows.rs
@@ -0,0 +1,121 @@
+use crate::{new_rpc_client, Command, Result};
+use clap::value_t_or_exit;
+
+pub struct SplitTunnel;
+
+#[mullvad_management_interface::async_trait]
+impl Command for SplitTunnel {
+ fn name(&self) -> &'static str {
+ "split-tunnel"
+ }
+
+ fn clap_subcommand(&self) -> clap::App<'static, 'static> {
+ clap::SubCommand::with_name(self.name())
+ .about("Set options for applications to exclude from the tunnel")
+ .setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(create_app_subcommand())
+ .subcommand(
+ clap::SubCommand::with_name("set")
+ .about("Enable or disable split tunnel")
+ .arg(
+ clap::Arg::with_name("policy")
+ .required(true)
+ .possible_values(&["on", "off"]),
+ ),
+ )
+ .subcommand(clap::SubCommand::with_name("get").about("Display the split tunnel status"))
+ }
+
+ async fn run(&self, matches: &clap::ArgMatches<'_>) -> Result<()> {
+ match matches.subcommand() {
+ ("app", Some(matches)) => Self::handle_app_subcommand(matches).await,
+ ("get", _) => self.get().await,
+ ("set", Some(matches)) => {
+ let enabled = value_t_or_exit!(matches.value_of("policy"), String);
+ self.set(enabled == "on").await
+ }
+ _ => {
+ unreachable!("unhandled command");
+ }
+ }
+ }
+}
+
+fn create_app_subcommand() -> clap::App<'static, 'static> {
+ clap::SubCommand::with_name("app")
+ .about("Manage applications to exclude from the tunnel")
+ .setting(clap::AppSettings::SubcommandRequiredElseHelp)
+ .subcommand(clap::SubCommand::with_name("list"))
+ .subcommand(
+ clap::SubCommand::with_name("add").arg(clap::Arg::with_name("path").required(true)),
+ )
+ .subcommand(
+ clap::SubCommand::with_name("remove").arg(clap::Arg::with_name("path").required(true)),
+ )
+ .subcommand(clap::SubCommand::with_name("clear"))
+}
+
+impl SplitTunnel {
+ async fn handle_app_subcommand(matches: &clap::ArgMatches<'_>) -> Result<()> {
+ match matches.subcommand() {
+ ("list", Some(_)) => {
+ let paths = new_rpc_client()
+ .await?
+ .get_settings(())
+ .await?
+ .into_inner()
+ .split_tunnel
+ .unwrap()
+ .apps;
+
+ println!("Excluded applications:");
+ for path in &paths {
+ println!(" {}", path);
+ }
+
+ Ok(())
+ }
+ ("add", Some(matches)) => {
+ let path = value_t_or_exit!(matches.value_of("path"), String);
+ new_rpc_client().await?.add_split_tunnel_app(path).await?;
+ Ok(())
+ }
+ ("remove", Some(matches)) => {
+ let path = value_t_or_exit!(matches.value_of("path"), String);
+ new_rpc_client()
+ .await?
+ .remove_split_tunnel_app(path)
+ .await?;
+ Ok(())
+ }
+ ("clear", Some(_)) => {
+ new_rpc_client().await?.clear_split_tunnel_apps(()).await?;
+ Ok(())
+ }
+ _ => unreachable!("unhandled subcommand"),
+ }
+ }
+
+ async fn set(&self, enabled: bool) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ rpc.set_split_tunnel_state(enabled).await?;
+ println!("Changed split tunnel setting");
+ Ok(())
+ }
+
+ async fn get(&self) -> Result<()> {
+ let mut rpc = new_rpc_client().await?;
+ let enabled = rpc
+ .get_settings(())
+ .await?
+ .into_inner()
+ .split_tunnel
+ .unwrap()
+ .enable_exclusions;
+ println!(
+ "Split tunnel status: {}",
+ if enabled { "on" } else { "off" }
+ );
+ Ok(())
+ }
+}
diff --git a/mullvad-cli/src/format.rs b/mullvad-cli/src/format.rs
index 52014076c3..b056ffff53 100644
--- a/mullvad-cli/src/format.rs
+++ b/mullvad-cli/src/format.rs
@@ -160,6 +160,8 @@ fn error_state_to_string(error_state: &ErrorState) -> String {
IsOffline => "This device is offline, no tunnels can be established",
#[cfg(target_os = "android")]
VpnPermissionDenied => "The Android VPN permission was denied when creating the tunnel",
+ #[cfg(target_os = "windows")]
+ SplitTunnelError => "The split tunneling module reported an error",
#[cfg(not(target_os = "android"))]
_ => unreachable!("unknown error cause"),
};
diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs
index 25afafe250..698db84b53 100644
--- a/mullvad-daemon/src/lib.rs
+++ b/mullvad-daemon/src/lib.rs
@@ -43,6 +43,8 @@ use mullvad_types::{
use settings::SettingsPersister;
#[cfg(target_os = "android")]
use std::os::unix::io::RawFd;
+#[cfg(target_os = "windows")]
+use std::{collections::HashSet, ffi::OsString};
use std::{
marker::PhantomData,
mem,
@@ -52,7 +54,7 @@ use std::{
sync::{mpsc as sync_mpsc, Arc, Weak},
time::Duration,
};
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", windows))]
use talpid_core::split_tunnel;
use talpid_core::{
mpsc::Sender,
@@ -113,6 +115,10 @@ pub enum Error {
#[error(display = "The account has too many wireguard keys")]
TooManyKeys,
+ #[cfg(windows)]
+ #[error(display = "Split tunneling error")]
+ SplitTunnelError(#[error(source)] split_tunnel::Error),
+
#[error(display = "No wireguard private key available")]
NoKeyAvailable,
@@ -259,6 +265,18 @@ pub enum DaemonCommand {
/// Clear list of processes excluded from the tunnel
#[cfg(target_os = "linux")]
ClearSplitTunnelProcesses(ResponseTx<(), split_tunnel::Error>),
+ /// Exclude traffic of an application from the tunnel
+ #[cfg(windows)]
+ AddSplitTunnelApp(ResponseTx<(), Error>, PathBuf),
+ /// Remove application from list of apps to exclude from the tunnel
+ #[cfg(windows)]
+ RemoveSplitTunnelApp(ResponseTx<(), Error>, PathBuf),
+ /// Clear list of apps to exclude from the tunnel
+ #[cfg(windows)]
+ ClearSplitTunnelApps(ResponseTx<(), Error>),
+ /// Disable split tunnel
+ #[cfg(windows)]
+ SetSplitTunnelState(ResponseTx<(), Error>, bool),
/// Makes the daemon exit the main loop and quit.
Shutdown,
/// Saves the target tunnel state and enters a blocking state. The state is restored
@@ -635,6 +653,17 @@ where
rpc_runtime.address_cache.peek_address(),
TransportProtocol::Tcp,
);
+ #[cfg(windows)]
+ let exclude_apps = if settings.split_tunnel.enable_exclusions {
+ settings
+ .split_tunnel
+ .apps
+ .iter()
+ .map(|s| OsString::from(s))
+ .collect()
+ } else {
+ vec![]
+ };
let tunnel_command_tx = tunnel_state_machine::spawn(
settings.allow_lan,
@@ -650,6 +679,8 @@ where
initial_target_state != TargetState::Secured,
#[cfg(target_os = "android")]
android_context,
+ #[cfg(windows)]
+ exclude_apps,
)
.await
.map_err(Error::TunnelError)?;
@@ -1182,6 +1213,14 @@ where
RemoveSplitTunnelProcess(tx, pid) => self.on_remove_split_tunnel_process(tx, pid),
#[cfg(target_os = "linux")]
ClearSplitTunnelProcesses(tx) => self.on_clear_split_tunnel_processes(tx),
+ #[cfg(windows)]
+ AddSplitTunnelApp(tx, path) => self.on_add_split_tunnel_app(tx, path).await,
+ #[cfg(windows)]
+ RemoveSplitTunnelApp(tx, path) => self.on_remove_split_tunnel_app(tx, path).await,
+ #[cfg(windows)]
+ ClearSplitTunnelApps(tx) => self.on_clear_split_tunnel_apps(tx).await,
+ #[cfg(windows)]
+ SetSplitTunnelState(tx, enabled) => self.on_set_split_tunnel_state(tx, enabled).await,
Shutdown => self.trigger_shutdown_event(),
PrepareRestart => self.on_prepare_restart(),
#[cfg(target_os = "android")]
@@ -1716,6 +1755,159 @@ where
Self::oneshot_send(tx, result, "clear_split_tunnel_processes response");
}
+ /// Update the split app paths in both the settings and tunnel
+ #[cfg(windows)]
+ async fn set_split_tunnel_paths(
+ &mut self,
+ tx: ResponseTx<(), Error>,
+ response_msg: &'static str,
+ settings: Settings,
+ new_list: HashSet<PathBuf>,
+ ) {
+ if new_list == settings.split_tunnel.apps {
+ Self::oneshot_send(tx, Ok(()), response_msg);
+ return;
+ }
+
+ if settings.split_tunnel.enable_exclusions {
+ let (result_tx, result_rx) = oneshot::channel();
+ self.send_tunnel_command(TunnelCommand::SetExcludedApps(
+ result_tx,
+ new_list.iter().map(|s| OsString::from(s)).collect(),
+ ));
+ match result_rx.await {
+ Ok(Ok(_)) => (),
+ Ok(Err(error)) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set excluded apps list")
+ );
+ Self::oneshot_send(tx, Err(Error::SplitTunnelError(error)), response_msg);
+ return;
+ }
+ Err(_) => {
+ log::error!("The tunnel failed to return a result");
+ return;
+ }
+ }
+ }
+
+ let save_result = self
+ .settings
+ .set_split_tunnel_apps(new_list)
+ .await
+ .map_err(Error::SettingsError);
+ match save_result {
+ Ok(true) => {
+ Self::oneshot_send(tx, Ok(()), response_msg);
+ self.event_listener
+ .notify_settings(self.settings.to_settings());
+ }
+ Err(error) => {
+ Self::oneshot_send(tx, Err(error), response_msg);
+ }
+ Ok(false) => {
+ // unreachable!("new_list != settings.split_tunnel_apps")
+ error!("BUG: new_list != settings.split_tunnel_apps");
+ }
+ }
+ }
+
+ #[cfg(windows)]
+ async fn on_add_split_tunnel_app(&mut self, tx: ResponseTx<(), Error>, path: PathBuf) {
+ let settings = self.settings.to_settings();
+
+ let mut new_list = settings.split_tunnel.apps.clone();
+ new_list.insert(path);
+
+ self.set_split_tunnel_paths(tx, "add_split_tunnel_app response", settings, new_list)
+ .await;
+ }
+
+ #[cfg(windows)]
+ async fn on_remove_split_tunnel_app(&mut self, tx: ResponseTx<(), Error>, path: PathBuf) {
+ let settings = self.settings.to_settings();
+
+ let mut new_list = settings.split_tunnel.apps.clone();
+ new_list.remove(&path);
+
+ self.set_split_tunnel_paths(tx, "remove_split_tunnel_app response", settings, new_list)
+ .await;
+ }
+
+ #[cfg(windows)]
+ async fn on_clear_split_tunnel_apps(&mut self, tx: ResponseTx<(), Error>) {
+ let settings = self.settings.to_settings();
+ let new_list = HashSet::new();
+ self.set_split_tunnel_paths(tx, "clear_split_tunnel_apps response", settings, new_list)
+ .await;
+ }
+
+ #[cfg(windows)]
+ async fn on_set_split_tunnel_state(&mut self, tx: ResponseTx<(), Error>, enabled: bool) {
+ let settings = self.settings.to_settings();
+
+ if enabled != settings.split_tunnel.enable_exclusions {
+ let new_list = if enabled {
+ settings.split_tunnel.apps.clone()
+ } else {
+ HashSet::new()
+ };
+ if !settings.split_tunnel.apps.is_empty() {
+ let (result_tx, result_rx) = oneshot::channel();
+ self.send_tunnel_command(TunnelCommand::SetExcludedApps(
+ result_tx,
+ new_list.iter().map(|app| OsString::from(app)).collect(),
+ ));
+ match result_rx.await {
+ Ok(Ok(_)) => (),
+ Ok(Err(error)) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to set excluded apps list")
+ );
+ Self::oneshot_send(
+ tx,
+ Err(Error::SplitTunnelError(error)),
+ "set_split_tunnel_state response",
+ );
+ return;
+ }
+ Err(_) => {
+ log::error!("The tunnel failed to return a result");
+ return;
+ }
+ }
+ }
+
+ let save_result = self
+ .settings
+ .set_split_tunnel_state(enabled)
+ .await
+ .map_err(Error::SettingsError);
+ match save_result {
+ Ok(true) => {
+ Self::oneshot_send(tx, Ok(()), "set_split_tunnel_state response");
+ self.event_listener
+ .notify_settings(self.settings.to_settings());
+ }
+ Err(error) => {
+ error!(
+ "{}",
+ error.display_chain_with_msg("Unable to save settings")
+ );
+ Self::oneshot_send(tx, Err(error), "set_split_tunnel_state response");
+ }
+ Ok(false) => {
+ // unreachable!("enabled != settings.split_tunnel"),
+ error!("BUG: enabled != settings.split_tunnel");
+ }
+ }
+ } else {
+ Self::oneshot_send(tx, Ok(()), "set_split_tunnel_state response");
+ }
+ }
+
async fn on_update_relay_settings(
&mut self,
tx: ResponseTx<(), settings::Error>,
diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs
index 2b8bf31a10..ed69f84838 100644
--- a/mullvad-daemon/src/management_interface.rs
+++ b/mullvad-daemon/src/management_interface.rs
@@ -18,6 +18,8 @@ use mullvad_types::{
wireguard::{RotationInterval, RotationIntervalError},
};
use parking_lot::RwLock;
+#[cfg(windows)]
+use std::path::PathBuf;
use std::{
cmp,
convert::{TryFrom, TryInto},
@@ -641,6 +643,69 @@ impl ManagementService for ManagementServiceImpl {
Ok(Response::new(()))
}
}
+
+ #[cfg(windows)]
+ async fn add_split_tunnel_app(&self, request: Request<String>) -> ServiceResult<()> {
+ log::debug!("add_split_tunnel_app");
+ let path = PathBuf::from(request.into_inner());
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::AddSplitTunnelApp(tx, path))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+ #[cfg(not(windows))]
+ async fn add_split_tunnel_app(&self, _: Request<String>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
+
+ #[cfg(windows)]
+ async fn remove_split_tunnel_app(&self, request: Request<String>) -> ServiceResult<()> {
+ log::debug!("remove_split_tunnel_app");
+ let path = PathBuf::from(request.into_inner());
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::RemoveSplitTunnelApp(tx, path))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+ #[cfg(not(windows))]
+ async fn remove_split_tunnel_app(&self, _: Request<String>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
+
+ #[cfg(windows)]
+ async fn clear_split_tunnel_apps(&self, _: Request<()>) -> ServiceResult<()> {
+ log::debug!("clear_split_tunnel_apps");
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::ClearSplitTunnelApps(tx))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+ #[cfg(not(windows))]
+ async fn clear_split_tunnel_apps(&self, _: Request<()>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
+
+ #[cfg(windows)]
+ async fn set_split_tunnel_state(&self, request: Request<bool>) -> ServiceResult<()> {
+ log::debug!("set_split_tunnel_state");
+ let enabled = request.into_inner();
+ let (tx, rx) = oneshot::channel();
+ self.send_command_to_daemon(DaemonCommand::SetSplitTunnelState(tx, enabled))?;
+ self.wait_for_result(rx)
+ .await?
+ .map_err(map_daemon_error)
+ .map(Response::new)
+ }
+ #[cfg(not(windows))]
+ async fn set_split_tunnel_state(&self, _: Request<bool>) -> ServiceResult<()> {
+ Ok(Response::new(()))
+ }
}
impl ManagementServiceImpl {
@@ -807,6 +872,8 @@ fn map_daemon_error(error: crate::Error) -> Status {
match error {
DaemonError::RestError(error) => map_rest_error(error),
DaemonError::SettingsError(error) => map_settings_error(error),
+ #[cfg(windows)]
+ DaemonError::SplitTunnelError(error) => map_split_tunnel_error(error),
DaemonError::AccountHistory(error) => map_account_history_error(error),
DaemonError::NoAccountToken | DaemonError::NoAccountTokenHistory => {
Status::unauthenticated(error.to_string())
@@ -815,6 +882,23 @@ fn map_daemon_error(error: crate::Error) -> Status {
}
}
+#[cfg(windows)]
+/// Converts [`talpid_core::split_tunnel::Error`] into a tonic status.
+fn map_split_tunnel_error(error: talpid_core::split_tunnel::Error) -> Status {
+ use talpid_core::split_tunnel::Error;
+
+ match &error {
+ Error::RegisterIps(io_error) | Error::SetConfiguration(io_error) => {
+ if io_error.kind() == std::io::ErrorKind::NotFound {
+ Status::not_found(format!("{}: {}", error, io_error))
+ } else {
+ Status::unknown(error.to_string())
+ }
+ }
+ _ => Status::unknown(error.to_string()),
+ }
+}
+
/// Converts a REST API voucher error into a tonic status.
fn map_rest_voucher_error(error: RestError) -> Status {
match error {
diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings.rs
index 0b75d0005c..8940a1a0a8 100644
--- a/mullvad-daemon/src/settings.rs
+++ b/mullvad-daemon/src/settings.rs
@@ -6,6 +6,8 @@ use mullvad_types::{
settings::{DnsOptions, Settings},
wireguard::{RotationInterval, WireguardData},
};
+#[cfg(target_os = "windows")]
+use std::collections::HashSet;
use std::{
ops::Deref,
path::{Path, PathBuf},
@@ -312,6 +314,22 @@ impl SettingsPersister {
self.update(should_save).await
}
+ #[cfg(windows)]
+ pub async fn set_split_tunnel_apps(&mut self, paths: HashSet<PathBuf>) -> Result<bool, Error> {
+ let should_save = paths != self.settings.split_tunnel.apps;
+ if should_save {
+ self.settings.split_tunnel.apps = paths;
+ }
+ self.update(should_save).await
+ }
+
+ #[cfg(windows)]
+ pub async fn set_split_tunnel_state(&mut self, enabled: bool) -> Result<bool, Error> {
+ let should_save =
+ Self::update_field(&mut self.settings.split_tunnel.enable_exclusions, enabled);
+ self.update(should_save).await
+ }
+
fn update_field<T: Eq>(field: &mut T, new_value: T) -> bool {
if *field != new_value {
*field = new_value;
diff --git a/mullvad-management-interface/Cargo.toml b/mullvad-management-interface/Cargo.toml
index 6b4cce66ac..23a5d907cd 100644
--- a/mullvad-management-interface/Cargo.toml
+++ b/mullvad-management-interface/Cargo.toml
@@ -20,6 +20,7 @@ parity-tokio-ipc = "0.8"
futures = "0.3"
tokio = { version = "0.2", features = [ "rt-util" ] }
triggered = "0.1.1"
+log = "0.4"
[target.'cfg(unix)'.dependencies]
nix = "0.19"
diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto
index 636148c7d6..33c4294db7 100644
--- a/mullvad-management-interface/proto/management_interface.proto
+++ b/mullvad-management-interface/proto/management_interface.proto
@@ -58,11 +58,17 @@ service ManagementService {
rpc GetWireguardKey(google.protobuf.Empty) returns (PublicKey) {}
rpc VerifyWireguardKey(google.protobuf.Empty) returns (google.protobuf.BoolValue) {}
- // Split tunneling
+ // Split tunneling (Linux)
rpc GetSplitTunnelProcesses(google.protobuf.Empty) returns (stream google.protobuf.Int32Value) {}
rpc AddSplitTunnelProcess(google.protobuf.Int32Value) returns (google.protobuf.Empty) {}
rpc RemoveSplitTunnelProcess(google.protobuf.Int32Value) returns (google.protobuf.Empty) {}
rpc ClearSplitTunnelProcesses(google.protobuf.Empty) returns (google.protobuf.Empty) {}
+
+ // Split tunneling (Windows)
+ rpc AddSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc RemoveSplitTunnelApp(google.protobuf.StringValue) returns (google.protobuf.Empty) {}
+ rpc ClearSplitTunnelApps(google.protobuf.Empty) returns (google.protobuf.Empty) {}
+ rpc SetSplitTunnelState(google.protobuf.BoolValue) returns (google.protobuf.Empty) {}
}
message RelaySettingsUpdate {
@@ -101,6 +107,7 @@ message ErrorState {
TUNNEL_PARAMETER_ERROR = 5;
IS_OFFLINE = 6;
VPN_PERMISSION_DENIED = 7;
+ SPLIT_TUNNEL_ERROR = 8;
}
enum GenerationError {
@@ -262,6 +269,12 @@ message Settings {
bool auto_connect = 7;
TunnelOptions tunnel_options = 8;
bool show_beta_releases = 9;
+ SplitTunnelSettings split_tunnel = 10;
+}
+
+message SplitTunnelSettings {
+ bool enable_exclusions = 1;
+ repeated string apps = 2;
}
message RelaySettings {
diff --git a/mullvad-management-interface/src/types.rs b/mullvad-management-interface/src/types.rs
index 7d219e0b4a..11547887a4 100644
--- a/mullvad-management-interface/src/types.rs
+++ b/mullvad-management-interface/src/types.rs
@@ -143,6 +143,10 @@ impl From<mullvad_types::states::TunnelState> for TunnelState {
talpid_tunnel::ErrorStateCause::VpnPermissionDenied => {
i32::from(Cause::VpnPermissionDenied)
}
+ #[cfg(target_os = "windows")]
+ talpid_tunnel::ErrorStateCause::SplitTunnelError => {
+ i32::from(Cause::SplitTunnelError)
+ }
},
blocking_error: error_state.block_failure().map(map_firewall_error),
auth_fail_reason: if let talpid_tunnel::ErrorStateCause::AuthFailed(
@@ -359,6 +363,26 @@ impl From<mullvad_types::relay_constraints::LocationConstraint> for RelayLocatio
impl From<&mullvad_types::settings::Settings> for Settings {
fn from(settings: &mullvad_types::settings::Settings) -> Self {
+ #[cfg(windows)]
+ let split_tunnel = {
+ let mut converted_list = vec![];
+ for path in settings.split_tunnel.apps.clone().iter() {
+ match path.as_path().as_os_str().to_str() {
+ Some(path) => converted_list.push(path.to_string()),
+ None => {
+ log::error!("failed to convert OS string: {:?}", path);
+ }
+ }
+ }
+
+ Some(SplitTunnelSettings {
+ enable_exclusions: settings.split_tunnel.enable_exclusions,
+ apps: converted_list,
+ })
+ };
+ #[cfg(not(windows))]
+ let split_tunnel = None;
+
Self {
account_token: settings.get_account_token().unwrap_or_default(),
relay_settings: Some(RelaySettings::from(settings.get_relay_settings())),
@@ -369,6 +393,7 @@ impl From<&mullvad_types::settings::Settings> for Settings {
auto_connect: settings.auto_connect,
tunnel_options: Some(TunnelOptions::from(&settings.tunnel_options)),
show_beta_releases: settings.show_beta_releases,
+ split_tunnel,
}
}
}
diff --git a/mullvad-types/src/settings/mod.rs b/mullvad-types/src/settings/mod.rs
index faddb9ad0c..a15cb8fa61 100644
--- a/mullvad-types/src/settings/mod.rs
+++ b/mullvad-types/src/settings/mod.rs
@@ -11,6 +11,8 @@ use log::{debug, info};
use serde::{Deserialize, Serialize};
use serde_json;
use std::net::IpAddr;
+#[cfg(target_os = "windows")]
+use std::{collections::HashSet, path::PathBuf};
use talpid_types::net::{self, openvpn, GenericTunnelOptions};
mod migrations;
@@ -58,11 +60,23 @@ pub struct Settings {
pub tunnel_options: TunnelOptions,
/// Whether to notify users of beta updates.
pub show_beta_releases: bool,
+ /// Split tunneling settings
+ #[cfg(windows)]
+ pub split_tunnel: SplitTunnelSettings,
/// Specifies settings schema version
#[cfg_attr(target_os = "android", jnix(skip))]
settings_version: migrations::SettingsVersion,
}
+#[cfg(windows)]
+#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
+pub struct SplitTunnelSettings {
+ /// Toggles split tunneling on or off
+ pub enable_exclusions: bool,
+ /// List of applications to exclude from the tunnel.
+ pub apps: HashSet<PathBuf>,
+}
+
impl Default for Settings {
fn default() -> Self {
Settings {
@@ -79,6 +93,8 @@ impl Default for Settings {
auto_connect: false,
tunnel_options: TunnelOptions::default(),
show_beta_releases: false,
+ #[cfg(windows)]
+ split_tunnel: SplitTunnelSettings::default(),
settings_version: migrations::CURRENT_SETTINGS_VERSION,
}
}
diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml
index 9d372a7c00..1c60b5b66e 100644
--- a/talpid-core/Cargo.toml
+++ b/talpid-core/Cargo.toml
@@ -78,9 +78,10 @@ byteorder = "1"
internet-checksum = "0.2"
widestring = "0.4"
winreg = { version = "0.7", features = ["transactions"] }
-winapi = { version = "0.3.6", features = ["combaseapi", "handleapi", "ifdef", "libloaderapi", "netioapi", "stringapiset", "synchapi", "winbase", "winerror", "winuser"] }
+winapi = { version = "0.3.6", features = ["combaseapi", "handleapi", "ifdef", "libloaderapi", "netioapi", "psapi", "stringapiset", "synchapi", "winbase", "winioctl", "winuser"] }
socket2 = "0.3"
talpid-platform-metadata = { path = "../talpid-platform-metadata" }
+memoffset = "0.6"
[build-dependencies]
tonic-build = { version = "0.3", default-features = false, features = ["transport", "prost"] }
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index cba05cb936..be812ff99a 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -24,6 +24,9 @@ pub enum Error {
/// Failure to clear routes
#[error(display = "Failed to clear applied routes")]
ClearRoutesFailed,
+ /// WinNet returned an error while adding default route callback
+ #[error(display = "Failed to set callback for default route")]
+ FailedToAddDefaultRouteCallback,
/// Attempt to use route manager that has been dropped
#[error(display = "Cannot send message to route manager since it is down")]
RouteManagerDown,
@@ -33,7 +36,6 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Manages routes by calling into WinNet
pub struct RouteManager {
- callback_handles: Vec<winnet::WinNetCallbackHandle>,
runtime: tokio::runtime::Handle,
manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
}
@@ -73,7 +75,6 @@ impl RouteManager {
}
let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = Self {
- callback_handles: vec![],
runtime: runtime.clone(),
manage_tx: Some(manage_tx),
};
@@ -129,32 +130,16 @@ impl RouteManager {
}
/// Sets a callback that is called whenever the default route changes.
- #[cfg(target_os = "windows")]
pub fn add_default_route_callback<T: 'static>(
&mut self,
callback: Option<winnet::DefaultRouteChangedCallback>,
context: T,
- ) {
+ ) -> Result<winnet::WinNetCallbackHandle> {
if self.manage_tx.is_none() {
- return;
+ return Err(Error::RouteManagerDown);
}
-
- match winnet::add_default_route_change_callback(callback, context) {
- Err(_e) => {
- // not sure if this should panic
- log::error!("Failed to add callback!");
- }
- Ok(handle) => {
- self.callback_handles.push(handle);
- }
- }
- }
-
- /// Removes all routes previously applied in [`RouteManager::new`] or
- /// [`RouteManager::add_routes`].
- pub fn clear_default_route_callbacks(&mut self) {
- // `WinNetCallbackHandle::drop` removes these callbacks.
- self.callback_handles.clear();
+ winnet::add_default_route_change_callback(callback, context)
+ .map_err(|_| Error::FailedToAddDefaultRouteCallback)
}
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
@@ -165,7 +150,6 @@ impl RouteManager {
log::error!("RouteManager channel already down or thread panicked");
}
- self.callback_handles.clear();
winnet::deactivate_routing_manager();
}
}
diff --git a/talpid-core/src/split_tunnel/mod.rs b/talpid-core/src/split_tunnel/mod.rs
index c7c366d6ea..3c3f6af294 100644
--- a/talpid-core/src/split_tunnel/mod.rs
+++ b/talpid-core/src/split_tunnel/mod.rs
@@ -4,3 +4,10 @@ mod imp;
#[cfg(target_os = "linux")]
pub use imp::*;
+
+#[cfg(windows)]
+#[path = "windows/mod.rs"]
+mod imp;
+
+#[cfg(windows)]
+pub use imp::*;
diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs
new file mode 100644
index 0000000000..a162e35ffe
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/driver.rs
@@ -0,0 +1,888 @@
+use super::windows::{
+ get_device_path, get_process_creation_time, get_process_device_path, open_process,
+ ProcessAccess, ProcessSnapshot,
+};
+use memoffset::offset_of;
+use std::{
+ cell::RefCell,
+ collections::HashMap,
+ ffi::{OsStr, OsString},
+ fs::{self, OpenOptions},
+ io,
+ mem::{self, size_of},
+ net::{Ipv4Addr, Ipv6Addr},
+ os::windows::{
+ ffi::{OsStrExt, OsStringExt},
+ fs::OpenOptionsExt,
+ io::{AsRawHandle, RawHandle},
+ },
+ ptr,
+ time::Duration,
+};
+use winapi::{
+ shared::{
+ in6addr::IN6_ADDR,
+ inaddr::IN_ADDR,
+ minwindef::{FALSE, TRUE},
+ ntdef::NTSTATUS,
+ winerror::{ERROR_INVALID_PARAMETER, ERROR_IO_PENDING},
+ },
+ um::{
+ handleapi::CloseHandle,
+ ioapiset::{DeviceIoControl, GetOverlappedResult},
+ minwinbase::OVERLAPPED,
+ synchapi::{CreateEventW, WaitForSingleObject},
+ tlhelp32::TH32CS_SNAPPROCESS,
+ winbase::{FILE_FLAG_OVERLAPPED, INFINITE, WAIT_ABANDONED, WAIT_FAILED, WAIT_OBJECT_0},
+ winioctl::{FILE_ANY_ACCESS, METHOD_BUFFERED, METHOD_NEITHER},
+ },
+};
+
+const DRIVER_SYMBOLIC_NAME: &str = "\\\\.\\MULLVADSPLITTUNNEL";
+const ST_DEVICE_TYPE: u32 = 0x8000;
+
+const DRIVER_IO_TIMEOUT: Duration = Duration::from_secs(3);
+
+const fn ctl_code(device_type: u32, function: u32, method: u32, access: u32) -> u32 {
+ device_type << 16 | access << 14 | function << 2 | method
+}
+
+#[repr(u32)]
+#[allow(dead_code)]
+pub enum DriverIoctlCode {
+ Initialize = ctl_code(ST_DEVICE_TYPE, 1, METHOD_NEITHER, FILE_ANY_ACCESS),
+ DequeEvent = ctl_code(ST_DEVICE_TYPE, 2, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ RegisterProcesses = ctl_code(ST_DEVICE_TYPE, 3, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ RegisterIpAddresses = ctl_code(ST_DEVICE_TYPE, 4, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ GetIpAddresses = ctl_code(ST_DEVICE_TYPE, 5, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ SetConfiguration = ctl_code(ST_DEVICE_TYPE, 6, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ GetConfiguration = ctl_code(ST_DEVICE_TYPE, 7, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ ClearConfiguration = ctl_code(ST_DEVICE_TYPE, 8, METHOD_NEITHER, FILE_ANY_ACCESS),
+ GetState = ctl_code(ST_DEVICE_TYPE, 9, METHOD_BUFFERED, FILE_ANY_ACCESS),
+ QueryProcess = ctl_code(ST_DEVICE_TYPE, 10, METHOD_BUFFERED, FILE_ANY_ACCESS),
+}
+
+#[derive(Debug, PartialEq)]
+#[repr(u32)]
+#[allow(dead_code)]
+pub enum DriverState {
+ // Default state after being loaded.
+ None = 0,
+ // DriverEntry has completed successfully.
+ // Basically only driver and device objects are created at this point.
+ Started = 1,
+ // All subsystems are initialized.
+ Initialized = 2,
+ // User mode has registered all processes in the system.
+ Ready = 3,
+ // IP addresses are registered.
+ // A valid configuration is registered.
+ Engaged = 4,
+ // Driver is unloading.
+ Terminating = 5,
+}
+
+#[repr(u32)]
+#[derive(Clone, Copy)]
+#[allow(dead_code)]
+pub enum EventId {
+ StartSplittingProcess = 0,
+ StopSplittingProcess,
+
+ // ErrorFlag = 0x80000000,
+ ErrorStartSplittingProcess = 0x80000001,
+ ErrorStopSplittingProcess,
+
+ ErrorMessage,
+
+ Unknown,
+}
+
+pub enum EventBody {
+ SplittingEvent {
+ process_id: usize,
+ reason: SplittingChangeReason,
+ image: OsString,
+ },
+ SplittingError {
+ process_id: usize,
+ image: OsString,
+ },
+ ErrorMessage {
+ status: NTSTATUS,
+ message: OsString,
+ },
+}
+
+#[repr(u32)]
+#[derive(Debug)]
+#[allow(dead_code)]
+pub enum SplittingChangeReason {
+ ByInheritance = 0,
+ ByConfig = 1,
+}
+
+pub struct DeviceHandle {
+ handle: fs::File,
+}
+
+unsafe impl Sync for DeviceHandle {}
+unsafe impl Send for DeviceHandle {}
+
+impl DeviceHandle {
+ pub fn new() -> io::Result<Self> {
+ // Connect to the driver
+ log::trace!("Connecting to the driver");
+ let handle = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .share_mode(0)
+ .custom_flags(FILE_FLAG_OVERLAPPED)
+ .attributes(0)
+ .open(DRIVER_SYMBOLIC_NAME)?;
+
+ let device = Self { handle };
+
+ // Initialize the driver
+ let state = device.get_driver_state()?;
+ if state == DriverState::Started {
+ log::trace!("Initializing driver");
+ device.initialize()?;
+ }
+
+ // Initialize process tree
+ let state = device.get_driver_state()?;
+ if state == DriverState::Initialized {
+ log::trace!("Registering processes");
+ device.register_processes()?;
+ }
+
+ Ok(device)
+ }
+
+ fn initialize(&self) -> io::Result<()> {
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::Initialize as u32,
+ None,
+ 0,
+ Some(DRIVER_IO_TIMEOUT),
+ )?;
+ Ok(())
+ }
+
+ fn register_processes(&self) -> io::Result<()> {
+ let process_tree_buffer = serialize_process_tree(build_process_tree()?)?;
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::RegisterProcesses as u32,
+ Some(&process_tree_buffer),
+ 0,
+ Some(DRIVER_IO_TIMEOUT),
+ )?;
+ Ok(())
+ }
+
+ pub fn register_ips(
+ &self,
+ tunnel_ipv4: Option<Ipv4Addr>,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Option<Ipv4Addr>,
+ internet_ipv6: Option<Ipv6Addr>,
+ ) -> io::Result<()> {
+ log::debug!("Register IPs: tunnel IPv4: {:?}, tunnel IPv6 {:?}, internet IPv4: {:?}, internet IPv6: {:?}", tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6);
+ let mut addresses: SplitTunnelAddresses = unsafe { mem::zeroed() };
+
+ unsafe {
+ if let Some(tunnel_ipv4) = tunnel_ipv4 {
+ let tunnel_ipv4 = tunnel_ipv4.octets();
+ ptr::copy_nonoverlapping(
+ &tunnel_ipv4[0] as *const u8,
+ &mut addresses.tunnel_ipv4 as *mut _ as *mut u8,
+ tunnel_ipv4.len(),
+ );
+ }
+
+ if let Some(tunnel_ipv6) = tunnel_ipv6 {
+ let tunnel_ipv6 = tunnel_ipv6.octets();
+ ptr::copy_nonoverlapping(
+ &tunnel_ipv6[0] as *const u8,
+ &mut addresses.tunnel_ipv6 as *mut _ as *mut u8,
+ tunnel_ipv6.len(),
+ );
+ }
+
+ if let Some(internet_ipv4) = internet_ipv4 {
+ let internet_ipv4 = internet_ipv4.octets();
+ ptr::copy_nonoverlapping(
+ &internet_ipv4[0] as *const u8,
+ &mut addresses.internet_ipv4 as *mut _ as *mut u8,
+ internet_ipv4.len(),
+ );
+ }
+
+ if let Some(internet_ipv6) = internet_ipv6 {
+ let internet_ipv6 = internet_ipv6.octets();
+ ptr::copy_nonoverlapping(
+ &internet_ipv6[0] as *const u8,
+ &mut addresses.internet_ipv6 as *mut _ as *mut u8,
+ internet_ipv6.len(),
+ );
+ }
+ }
+
+ let buffer = &addresses as *const _ as *const u8;
+ let buffer =
+ unsafe { std::slice::from_raw_parts(buffer, size_of::<SplitTunnelAddresses>()) };
+
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::RegisterIpAddresses as u32,
+ Some(buffer),
+ 0,
+ Some(DRIVER_IO_TIMEOUT),
+ )?;
+
+ Ok(())
+ }
+
+ pub fn get_driver_state(&self) -> io::Result<DriverState> {
+ let buffer = device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::GetState as u32,
+ None,
+ size_of::<u64>() as u32,
+ Some(DRIVER_IO_TIMEOUT),
+ )?
+ .unwrap();
+
+ Ok(unsafe { deserialize_buffer(&buffer) })
+ }
+
+ pub fn set_config<T: AsRef<OsStr>>(&self, apps: &[T]) -> io::Result<()> {
+ let mut device_paths = Vec::with_capacity(apps.len());
+ for app in apps.as_ref() {
+ device_paths.push(get_device_path(app.as_ref())?);
+ }
+
+ log::debug!("Excluded device paths:");
+ for path in &device_paths {
+ log::debug!(" {:?}", path);
+ }
+
+ let config = make_process_config(&device_paths);
+
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::SetConfiguration as u32,
+ Some(&config),
+ 0,
+ Some(DRIVER_IO_TIMEOUT),
+ )?;
+
+ Ok(())
+ }
+
+ pub fn clear_config(&self) -> io::Result<()> {
+ device_io_control(
+ self.handle.as_raw_handle(),
+ DriverIoctlCode::ClearConfiguration as u32,
+ None,
+ 0,
+ Some(DRIVER_IO_TIMEOUT),
+ )?;
+
+ Ok(())
+ }
+}
+
+impl AsRawHandle for DeviceHandle {
+ fn as_raw_handle(&self) -> RawHandle {
+ self.handle.as_raw_handle()
+ }
+}
+
+#[repr(C)]
+struct SplitTunnelAddresses {
+ tunnel_ipv4: IN_ADDR,
+ internet_ipv4: IN_ADDR,
+ tunnel_ipv6: IN6_ADDR,
+ internet_ipv6: IN6_ADDR,
+}
+
+#[repr(C)]
+struct ConfigurationHeader {
+ // Number of entries immediately following the header.
+ num_entries: usize,
+ // Total byte length: header + entries + string buffer.
+ total_length: usize,
+}
+
+#[repr(C)]
+struct ConfigurationEntry {
+ // Offset into buffer region that follows all entries.
+ // The image name uses the physical path.
+ name_offset: usize,
+ // Byte length for non-null terminated wide char string.
+ name_length: u16,
+}
+
+/// Create a buffer containing a `ConfigurationHeader` and number of `ConfigurationEntry`s,
+/// followed by the same number of paths to those entries.
+fn make_process_config<T: AsRef<OsStr>>(apps: &[T]) -> Vec<u8> {
+ let apps: Vec<Vec<u16>> = apps
+ .iter()
+ .map(|app| app.as_ref().encode_wide().collect())
+ .collect();
+
+ let total_string_size: usize = apps.iter().map(|app| size_of::<u16>() * app.len()).sum();
+
+ let total_buffer_size = size_of::<ConfigurationHeader>()
+ + size_of::<ConfigurationEntry>() * apps.len()
+ + total_string_size;
+
+ let mut buffer = Vec::<u8>::new();
+ buffer.resize(total_buffer_size, 0);
+
+ let (header, tail) = buffer.split_at_mut(size_of::<ConfigurationHeader>());
+
+ // Serialize configuration header
+ let header_struct = ConfigurationHeader {
+ num_entries: apps.len(),
+ total_length: total_buffer_size,
+ };
+ header.copy_from_slice(unsafe { as_u8_slice(&header_struct) });
+
+ // Serialize configuration entries and strings
+ let (entries, string_data) = tail.split_at_mut(apps.len() * size_of::<ConfigurationEntry>());
+ let mut string_offset = 0;
+
+ for (i, app) in apps.iter().enumerate() {
+ write_string_to_buffer(string_data, string_offset, &app);
+
+ let app_bytelen = size_of::<u16>() * app.len();
+ let entry = ConfigurationEntry {
+ name_offset: string_offset,
+ name_length: app_bytelen as u16,
+ };
+ let entry_offset = size_of::<ConfigurationEntry>() * i;
+ entries[entry_offset..entry_offset + size_of::<ConfigurationEntry>()]
+ .copy_from_slice(unsafe { as_u8_slice(&entry) });
+
+ string_offset += app_bytelen;
+ }
+
+ buffer
+}
+
+#[derive(Debug)]
+struct ProcessInfo {
+ pid: u32,
+ parent_pid: u32,
+ creation_time: u64,
+ device_path: Vec<u16>,
+}
+
+/// List process identifiers, their parents, and their device paths.
+fn build_process_tree() -> io::Result<Vec<ProcessInfo>> {
+ let mut process_info = HashMap::new();
+
+ let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?;
+ for entry in snap.entries() {
+ let entry = entry?;
+
+ let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) {
+ Ok(handle) => Ok(handle),
+ Err(error) => {
+ // Skip process objects that cannot be opened
+ match error.kind() {
+ // System process
+ io::ErrorKind::PermissionDenied => continue,
+ // System idle or csrss process
+ io::ErrorKind::InvalidInput => continue,
+ io::ErrorKind::Other => {
+ // Old rust lib maps INVALID_PARAMETER to "Other"
+ if error.raw_os_error() == Some(ERROR_INVALID_PARAMETER as i32) {
+ continue;
+ }
+ Err(error)
+ }
+ _ => Err(error),
+ }
+ }
+ }?;
+
+ // TODO: Skip objects whose paths or timestamps cannot be obtained?
+
+ process_info.insert(
+ entry.pid,
+ RefCell::new(ProcessInfo {
+ pid: entry.pid,
+ parent_pid: entry.parent_pid,
+ creation_time: get_process_creation_time(process.get_raw()).unwrap_or(0),
+ device_path: get_process_device_path(process.get_raw())
+ .unwrap_or(OsString::from(""))
+ .encode_wide()
+ .collect(),
+ }),
+ );
+ }
+
+ // Handle pid recycling
+ // If the "parent" is younger than the process itself, it is not our parent.
+ for info in process_info.values() {
+ let mut info = info.borrow_mut();
+ let parent_pid = info.parent_pid;
+ if parent_pid == 0 {
+ continue;
+ }
+ if let Some(parent_info) = process_info.get(&parent_pid) {
+ if parent_info.borrow_mut().creation_time > info.creation_time {
+ info.parent_pid = 0;
+ }
+ }
+ }
+
+ Ok(process_info
+ .into_iter()
+ .map(|(_, info)| info.into_inner())
+ .collect())
+}
+
+#[repr(C)]
+struct ProcessRegistryHeader {
+ // Number of entries immediately following the header.
+ num_entries: usize,
+ // Total byte length: header + entries + string buffer.
+ total_length: usize,
+}
+
+#[repr(C)]
+struct ProcessRegistryEntry {
+ pid: RawHandle,
+ parent_pid: RawHandle,
+ // Image name offset (following the last entry).
+ image_name_offset: usize,
+ // Image name length.
+ image_name_size: u16,
+}
+
+fn serialize_process_tree(processes: Vec<ProcessInfo>) -> Result<Vec<u8>, io::Error> {
+ // Construct a buffer:
+ // ProcessRegistryHeader
+ // ProcessRegistryEntry..
+ // Image names..
+
+ let total_string_size: usize = processes
+ .iter()
+ .map(|info| size_of::<u16>() * info.device_path.len())
+ .sum();
+ let total_buffer_size = size_of::<ProcessRegistryHeader>()
+ + size_of::<ProcessRegistryEntry>() * processes.len()
+ + total_string_size;
+
+ let mut buffer = Vec::<u8>::new();
+ buffer.resize(total_buffer_size, 0);
+
+ let (header, tail) = buffer.split_at_mut(size_of::<ProcessRegistryHeader>());
+ let header_struct = ProcessRegistryHeader {
+ num_entries: processes.len(),
+ total_length: total_buffer_size,
+ };
+ header.copy_from_slice(unsafe { as_u8_slice(&header_struct) });
+
+ let (entries, string_data) =
+ tail.split_at_mut(size_of::<ProcessRegistryEntry>() * processes.len());
+
+ let mut string_offset = 0;
+
+ for (i, entry) in processes.into_iter().enumerate() {
+ let mut out_entry = ProcessRegistryEntry {
+ pid: entry.pid as usize as RawHandle,
+ parent_pid: entry.parent_pid as usize as RawHandle,
+ image_name_size: 0,
+ image_name_offset: 0,
+ };
+
+ if !entry.device_path.is_empty() {
+ write_string_to_buffer(string_data, string_offset, &entry.device_path);
+
+ out_entry.image_name_size = (entry.device_path.len() * size_of::<u16>()) as u16;
+ out_entry.image_name_offset = string_offset;
+
+ string_offset += size_of::<u16>() * entry.device_path.len();
+ }
+
+ let entry_offset = size_of::<ProcessRegistryEntry>() * i;
+ entries[entry_offset..entry_offset + size_of::<ProcessRegistryEntry>()]
+ .copy_from_slice(unsafe { as_u8_slice(&out_entry) });
+ }
+
+ Ok(buffer)
+}
+
+#[repr(C)]
+struct EventHeader {
+ event_id: EventId,
+ event_size: usize,
+ event_data: [u8; 1],
+}
+
+#[repr(C)]
+struct SplittingEventHeader {
+ process_id: usize,
+ reason: SplittingChangeReason,
+ image_name_length: u16,
+ image_name_data: [u16; 1],
+}
+
+#[repr(C)]
+struct SplittingErrorEventHeader {
+ process_id: usize,
+ image_name_length: u16,
+ image_name_data: [u16; 1],
+}
+
+#[repr(C)]
+struct ErrorMessageEventHeader {
+ status: NTSTATUS,
+ error_message_length: u16,
+ error_message_data: [u16; 1],
+}
+
+pub fn parse_event_buffer(buffer: &Vec<u8>) -> Option<(EventId, EventBody)> {
+ let mut raw_event_id = 0u32;
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[0],
+ &mut raw_event_id as *mut _ as *mut u8,
+ mem::size_of::<u32>(),
+ )
+ };
+ if raw_event_id >= EventId::Unknown as u32 {
+ return None;
+ }
+
+ let mut event_header: EventHeader = unsafe { mem::zeroed() };
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[0],
+ &mut event_header as *mut _ as *mut u8,
+ offset_of!(EventHeader, event_data),
+ )
+ };
+
+ match event_header.event_id {
+ EventId::StartSplittingProcess | EventId::StopSplittingProcess => {
+ let mut event: SplittingEventHeader = unsafe { mem::zeroed() };
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[offset_of!(EventHeader, event_data)],
+ &mut event as *mut _ as *mut u8,
+ offset_of!(SplittingEventHeader, image_name_data),
+ )
+ };
+
+ let mut image_name = Vec::new();
+ image_name.resize(
+ event.image_name_length as usize / mem::size_of::<u16>(),
+ 0u16,
+ );
+
+ let string_byte_offset = offset_of!(EventHeader, event_data)
+ + offset_of!(SplittingEventHeader, image_name_data);
+
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[string_byte_offset] as *const _ as *const u16,
+ image_name.as_mut_ptr(),
+ image_name.len(),
+ )
+ };
+
+ Some((
+ event_header.event_id,
+ EventBody::SplittingEvent {
+ process_id: event.process_id,
+ reason: event.reason,
+ image: OsStringExt::from_wide(&image_name),
+ },
+ ))
+ }
+ EventId::ErrorStartSplittingProcess | EventId::ErrorStopSplittingProcess => {
+ let mut event: SplittingErrorEventHeader = unsafe { mem::zeroed() };
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[offset_of!(EventHeader, event_data)],
+ &mut event as *mut _ as *mut u8,
+ offset_of!(SplittingErrorEventHeader, image_name_data),
+ )
+ };
+
+ let mut image_name = Vec::new();
+ image_name.resize(
+ event.image_name_length as usize / mem::size_of::<u16>(),
+ 0u16,
+ );
+
+ let string_byte_offset = offset_of!(EventHeader, event_data)
+ + offset_of!(SplittingErrorEventHeader, image_name_data);
+
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[string_byte_offset] as *const _ as *const u16,
+ image_name.as_mut_ptr(),
+ image_name.len(),
+ )
+ };
+
+ Some((
+ event_header.event_id,
+ EventBody::SplittingError {
+ process_id: event.process_id,
+ image: OsStringExt::from_wide(&image_name),
+ },
+ ))
+ }
+ EventId::ErrorMessage => {
+ let mut event: ErrorMessageEventHeader = unsafe { mem::zeroed() };
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[offset_of!(EventHeader, event_data)],
+ &mut event as *mut _ as *mut u8,
+ offset_of!(ErrorMessageEventHeader, error_message_data),
+ )
+ };
+
+ let mut error_message = Vec::new();
+ error_message.resize(
+ event.error_message_length as usize / mem::size_of::<u16>(),
+ 0u16,
+ );
+
+ let string_byte_offset = offset_of!(EventHeader, event_data)
+ + offset_of!(ErrorMessageEventHeader, error_message_data);
+
+ unsafe {
+ ptr::copy_nonoverlapping(
+ &buffer[string_byte_offset] as *const _ as *const u16,
+ error_message.as_mut_ptr(),
+ error_message.len(),
+ )
+ };
+
+ Some((
+ event_header.event_id,
+ EventBody::ErrorMessage {
+ status: event.status,
+ message: OsStringExt::from_wide(&error_message),
+ },
+ ))
+ }
+ EventId::Unknown => None,
+ }
+}
+
+/// Send an IOCTL code to the given device handle.
+/// `input` specifies an optional buffer to send.
+/// Upon success, a buffer of size `output_size` is returned, or None if `output_size` is 0.
+pub fn device_io_control(
+ device: RawHandle,
+ ioctl_code: u32,
+ input: Option<&[u8]>,
+ output_size: u32,
+ timeout: Option<Duration>,
+) -> Result<Option<Vec<u8>>, io::Error> {
+ struct HandleOwner {
+ handle: RawHandle,
+ }
+ impl Drop for HandleOwner {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.handle) };
+ }
+ }
+
+ let mut overlapped: OVERLAPPED = unsafe { mem::zeroed() };
+ overlapped.hEvent = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+
+ if overlapped.hEvent == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+
+ let _handle_owner = HandleOwner {
+ handle: overlapped.hEvent,
+ };
+
+ let mut out_buffer = if output_size > 0 {
+ Some(Vec::with_capacity(output_size as usize))
+ } else {
+ None
+ };
+
+ device_io_control_buffer(
+ device,
+ ioctl_code,
+ input,
+ out_buffer.as_mut(),
+ &overlapped,
+ timeout,
+ )
+ .map(|()| out_buffer)
+}
+
+/// Send an IOCTL code to the given device handle.
+/// `input` specifies an optional buffer to send.
+/// Upon success, `output` buffer will contain at most `output.capacity()` bytes of data.
+pub fn device_io_control_buffer(
+ device: RawHandle,
+ ioctl_code: u32,
+ input: Option<&[u8]>,
+ mut output: Option<&mut Vec<u8>>,
+ overlapped: &OVERLAPPED,
+ timeout: Option<Duration>,
+) -> Result<(), io::Error> {
+ let input_ptr = match input {
+ Some(input) => input as *const _ as *mut _,
+ None => ptr::null_mut(),
+ };
+ let input_len = input.map(|input| input.len()).unwrap_or(0);
+
+ let out_ptr = match output {
+ Some(ref mut output) => output.as_mut_ptr() as *mut _,
+ None => ptr::null_mut(),
+ };
+ let output_size = if let Some(ref output) = output {
+ output.capacity()
+ } else {
+ 0
+ };
+
+ let event = overlapped.hEvent;
+
+ let mut returned_bytes = 0u32;
+ let overlapped = overlapped as *const _ as *mut _;
+
+ let result = unsafe {
+ DeviceIoControl(
+ device as *mut _,
+ ioctl_code,
+ input_ptr,
+ input_len as u32,
+ out_ptr,
+ output_size as u32,
+ &mut returned_bytes,
+ overlapped,
+ )
+ };
+
+ if result != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "Expected pending operation",
+ ));
+ }
+
+ let last_error = io::Error::last_os_error();
+ if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) {
+ return Err(last_error);
+ }
+
+ let timeout = timeout
+ .map(|timeout| timeout.as_millis() as u32)
+ .unwrap_or(INFINITE);
+ let result = unsafe { WaitForSingleObject(event, timeout) };
+ match result {
+ WAIT_FAILED => return Err(io::Error::last_os_error()),
+ WAIT_ABANDONED => return Err(io::Error::new(io::ErrorKind::Other, "abandoned mutex")),
+ WAIT_OBJECT_0 => (),
+ error => return Err(io::Error::from_raw_os_error(error as i32)),
+ }
+
+ let result =
+ unsafe { GetOverlappedResult(device as *mut _, overlapped, &mut returned_bytes, FALSE) };
+
+ if result == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ if let Some(ref mut output) = output {
+ unsafe { output.set_len(returned_bytes as usize) };
+ }
+
+ Ok(())
+}
+
+/// Send an IOCTL code to the given device handle.
+/// `input` specifies an optional buffer to send.
+/// The result must be obtained using `GetOverlappedResult[Ex]`.
+pub unsafe fn device_io_control_buffer_async(
+ device: RawHandle,
+ ioctl_code: u32,
+ mut output: Option<&mut Vec<u8>>,
+ input: Option<&[u8]>,
+ overlapped: &OVERLAPPED,
+) -> Result<(), io::Error> {
+ let input_ptr = match input {
+ Some(input) => input as *const _ as *mut _,
+ None => ptr::null_mut(),
+ };
+ let input_len = input.map(|input| input.len()).unwrap_or(0);
+
+ let out_ptr = match output {
+ Some(ref mut output) => output.as_mut_ptr() as *mut _,
+ None => ptr::null_mut(),
+ };
+ let output_size = if let Some(ref output) = output {
+ output.capacity()
+ } else {
+ 0
+ };
+
+ let overlapped = overlapped as *const _ as *mut _;
+
+ let result = DeviceIoControl(
+ device as *mut _,
+ ioctl_code,
+ input_ptr,
+ input_len as u32,
+ out_ptr,
+ output_size as u32,
+ ptr::null_mut(),
+ overlapped,
+ );
+
+ if result != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "Expected pending operation",
+ ));
+ }
+
+ let last_error = io::Error::last_os_error();
+ if last_error.raw_os_error() != Some(ERROR_IO_PENDING as i32) {
+ return Err(last_error);
+ }
+
+ Ok(())
+}
+
+/// Creates a new instance of an arbitrary type from a byte buffer.
+pub unsafe fn deserialize_buffer<T: Sized>(buffer: &Vec<u8>) -> T {
+ let mut instance: T = mem::zeroed();
+ ptr::copy_nonoverlapping(buffer.as_ptr() as *const T, &mut instance as *mut _, 1);
+ instance
+}
+
+fn write_string_to_buffer(buffer: &mut [u8], byte_offset: usize, string: &[u16]) {
+ for (i, byte) in string
+ .iter()
+ .flat_map(|word| std::array::IntoIter::new(word.to_ne_bytes()))
+ .enumerate()
+ {
+ buffer[byte_offset + i] = byte;
+ }
+}
+
+unsafe fn as_u8_slice<T: Sized>(object: &T) -> &[u8] {
+ std::slice::from_raw_parts(object as *const _ as *const _, size_of::<T>())
+}
diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs
new file mode 100644
index 0000000000..2e538a014b
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/mod.rs
@@ -0,0 +1,574 @@
+mod driver;
+mod windows;
+
+use crate::{
+ tunnel::TunnelMetadata,
+ tunnel_state_machine::TunnelCommand,
+ winnet::{
+ self, get_best_default_route, interface_luid_to_ip, WinNetAddrFamily, WinNetCallbackHandle,
+ },
+};
+use futures::channel::mpsc;
+use lazy_static::lazy_static;
+use std::{
+ convert::TryFrom,
+ ffi::{OsStr, OsString},
+ io, mem,
+ net::{IpAddr, Ipv4Addr, Ipv6Addr},
+ os::windows::io::{AsRawHandle, RawHandle},
+ ptr,
+ sync::{mpsc as sync_mpsc, Arc, Weak},
+ time::Duration,
+};
+use talpid_types::{tunnel::ErrorStateCause, ErrorExt};
+use winapi::{
+ shared::minwindef::{FALSE, TRUE},
+ um::{
+ handleapi::CloseHandle,
+ ioapiset::GetOverlappedResult,
+ minwinbase::OVERLAPPED,
+ synchapi::{CreateEventW, SetEvent, WaitForMultipleObjects, WaitForSingleObject},
+ winbase::{INFINITE, WAIT_ABANDONED_0, WAIT_OBJECT_0},
+ },
+};
+
+const DRIVER_EVENT_BUFFER_SIZE: usize = 2048;
+
+lazy_static! {
+ static ref RESERVED_IP_V4: Ipv4Addr = "192.0.2.123".parse().unwrap();
+}
+
+/// Errors that may occur in [`SplitTunnel`].
+#[derive(err_derive::Error, Debug)]
+#[error(no_from)]
+pub enum Error {
+ /// Failed to identify or initialize the driver
+ #[error(display = "Failed to find or initialize driver")]
+ InitializationFailed(#[error(source)] io::Error),
+
+ /// Failed to set paths to excluded applications
+ #[error(display = "Failed to set list of excluded applications")]
+ SetConfiguration(#[error(source)] io::Error),
+
+ /// Failed to register interface IP addresses
+ #[error(display = "Failed to register IP addresses for exclusions")]
+ RegisterIps(#[error(source)] io::Error),
+
+ /// Failed to clear interface IP addresses
+ #[error(display = "Failed to clear registered IP addresses")]
+ ClearIps(#[error(source)] io::Error),
+
+ /// Failed to set up the driver event loop
+ #[error(display = "Failed to set up the driver event loop")]
+ EventThreadError(#[error(source)] io::Error),
+
+ /// Failed to obtain default route
+ #[error(display = "Failed to obtain the default route")]
+ ObtainDefaultRoute(#[error(source)] winnet::Error),
+
+ /// Failed to obtain an IP address given a network interface LUID
+ #[error(display = "Failed to obtain IP address for interface LUID")]
+ LuidToIp(#[error(source)] winnet::Error),
+
+ /// Failed to set up callback for monitoring default route changes
+ #[error(display = "Failed to register default route change callback")]
+ RegisterRouteChangeCallback,
+
+ /// Unexpected IP parsing error
+ #[error(display = "Failed to parse IP address")]
+ IpParseError,
+
+ /// The request handling thread is stuck
+ #[error(display = "The ST request thread is stuck")]
+ RequestThreadStuck,
+
+ /// The request handling thread is down
+ #[error(display = "The ST request thread is down")]
+ RequestThreadDown,
+}
+
+/// Manages applications whose traffic to exclude from the tunnel.
+pub struct SplitTunnel {
+ request_tx: RequestTx,
+ event_thread: Option<std::thread::JoinHandle<()>>,
+ quit_event: RawHandle,
+ _route_change_callback: Option<WinNetCallbackHandle>,
+ daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
+}
+
+enum Request {
+ SetPaths(Vec<OsString>),
+ RegisterIps(
+ Option<Ipv4Addr>,
+ Option<Ipv6Addr>,
+ Option<Ipv4Addr>,
+ Option<Ipv6Addr>,
+ ),
+}
+type RequestResponseTx = sync_mpsc::Sender<Result<(), Error>>;
+type RequestTx = sync_mpsc::SyncSender<(Request, RequestResponseTx)>;
+
+const REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
+
+struct EventThreadContext {
+ handle: Arc<driver::DeviceHandle>,
+ event_overlapped: OVERLAPPED,
+ quit_event: RawHandle,
+}
+unsafe impl Send for EventThreadContext {}
+
+impl SplitTunnel {
+ /// Initialize the driver.
+ pub fn new(daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>) -> Result<Self, Error> {
+ let (request_tx, handle) = Self::spawn_request_thread()?;
+
+ let mut event_overlapped: OVERLAPPED = unsafe { mem::zeroed() };
+ event_overlapped.hEvent =
+ unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+ if event_overlapped.hEvent == ptr::null_mut() {
+ return Err(Error::EventThreadError(io::Error::last_os_error()));
+ }
+
+ let quit_event = unsafe { CreateEventW(ptr::null_mut(), TRUE, FALSE, ptr::null()) };
+
+ let event_context = EventThreadContext {
+ handle: handle.clone(),
+ event_overlapped,
+ quit_event,
+ };
+
+ let event_thread = std::thread::spawn(move || {
+ use driver::{EventBody, EventId};
+
+ let mut data_buffer = Vec::with_capacity(DRIVER_EVENT_BUFFER_SIZE);
+ let mut returned_bytes = 0u32;
+
+ let event_objects = [
+ event_context.event_overlapped.hEvent,
+ event_context.quit_event,
+ ];
+
+ loop {
+ if unsafe { WaitForSingleObject(event_context.quit_event, 0) == WAIT_OBJECT_0 } {
+ // Quit event was signaled
+ break;
+ }
+
+ if let Err(error) = unsafe {
+ driver::device_io_control_buffer_async(
+ event_context.handle.as_raw_handle(),
+ driver::DriverIoctlCode::DequeEvent as u32,
+ Some(&mut data_buffer),
+ None,
+ &event_context.event_overlapped,
+ )
+ } {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("device_io_control failed")
+ );
+ continue;
+ }
+
+ let result = unsafe {
+ WaitForMultipleObjects(
+ event_objects.len() as u32,
+ &event_objects[0],
+ FALSE,
+ INFINITE,
+ )
+ };
+
+ let signaled_index = if result >= WAIT_OBJECT_0
+ && result < WAIT_OBJECT_0 + event_objects.len() as u32
+ {
+ result - WAIT_OBJECT_0
+ } else if result >= WAIT_ABANDONED_0
+ && result < WAIT_ABANDONED_0 + event_objects.len() as u32
+ {
+ result - WAIT_ABANDONED_0
+ } else {
+ let error = io::Error::last_os_error();
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("WaitForMultipleObjects failed")
+ );
+
+ continue;
+ };
+
+ if event_context.quit_event == event_objects[signaled_index as usize] {
+ // Quit event was signaled
+ break;
+ }
+
+ let result = unsafe {
+ GetOverlappedResult(
+ event_context.handle.as_raw_handle(),
+ &event_context.event_overlapped as *const _ as *mut _,
+ &mut returned_bytes,
+ TRUE,
+ )
+ };
+
+ if result == 0 {
+ let error = io::Error::last_os_error();
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("GetOverlappedResult failed")
+ );
+
+ continue;
+ }
+
+ unsafe { data_buffer.set_len(returned_bytes as usize) };
+
+ let event = driver::parse_event_buffer(&data_buffer);
+
+ let (event_id, event_body) = match event {
+ Some((event_id, event_body)) => (event_id, event_body),
+ None => continue,
+ };
+
+ let event_str = match &event_id {
+ EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => {
+ "Start splitting process"
+ }
+ EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => {
+ "Stop splitting process"
+ }
+ EventId::ErrorMessage => "ErrorMessage",
+ _ => "Unknown event ID",
+ };
+
+ match event_body {
+ EventBody::SplittingEvent {
+ process_id,
+ reason,
+ image,
+ } => {
+ log::trace!(
+ "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}",
+ event_str,
+ process_id,
+ reason,
+ image,
+ );
+ }
+ EventBody::SplittingError { process_id, image } => {
+ log::error!(
+ "FAILED: {}:\n\tpid: {}\n\timage: {:?}",
+ event_str,
+ process_id,
+ image,
+ );
+ }
+ EventBody::ErrorMessage { status, message } => {
+ log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy())
+ }
+ }
+ }
+
+ log::debug!("Stopping split tunnel event thread");
+
+ unsafe { CloseHandle(event_context.event_overlapped.hEvent) };
+ unsafe { CloseHandle(event_context.quit_event) };
+ });
+
+ Ok(SplitTunnel {
+ request_tx,
+ event_thread: Some(event_thread),
+ quit_event,
+ _route_change_callback: None,
+ daemon_tx,
+ })
+ }
+
+ fn spawn_request_thread() -> Result<(RequestTx, Arc<driver::DeviceHandle>), Error> {
+ let (tx, rx): (RequestTx, _) = sync_mpsc::sync_channel(3);
+ let (init_tx, init_rx) = sync_mpsc::channel();
+
+ std::thread::spawn(move || {
+ let result = driver::DeviceHandle::new()
+ .map(Arc::new)
+ .map_err(Error::InitializationFailed);
+ let handle = match result {
+ Ok(handle) => {
+ let _ = init_tx.send(Ok(handle.clone()));
+ handle
+ }
+ Err(error) => {
+ let _ = init_tx.send(Err(error));
+ return;
+ }
+ };
+
+ while let Ok((request, response_tx)) = rx.recv() {
+ let response = match request {
+ Request::SetPaths(paths) => {
+ if paths.len() > 0 {
+ handle.set_config(&paths).map_err(Error::SetConfiguration)
+ } else {
+ handle.clear_config().map_err(Error::SetConfiguration)
+ }
+ }
+ Request::RegisterIps(
+ mut tunnel_ipv4,
+ mut tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ ) => {
+ if internet_ipv4.is_none() && internet_ipv6.is_none() {
+ tunnel_ipv4 = None;
+ tunnel_ipv6 = None;
+ }
+ handle
+ .register_ips(tunnel_ipv4, tunnel_ipv6, internet_ipv4, internet_ipv6)
+ .map_err(Error::RegisterIps)
+ }
+ };
+ if response_tx.send(response).is_err() {
+ log::error!("A response could not be sent for a completed request");
+ }
+ }
+ });
+
+ let handle = init_rx
+ .recv_timeout(REQUEST_TIMEOUT)
+ .map_err(|_| Error::RequestThreadStuck)??;
+
+ Ok((tx, handle))
+ }
+
+ fn send_request(&self, request: Request) -> Result<(), Error> {
+ Self::send_request_inner(&self.request_tx, request)
+ }
+
+ fn send_request_inner(request_tx: &RequestTx, request: Request) -> Result<(), Error> {
+ let (response_tx, response_rx) = sync_mpsc::channel();
+
+ request_tx
+ .try_send((request, response_tx))
+ .map_err(|error| match error {
+ sync_mpsc::TrySendError::Disconnected(_) => Error::RequestThreadDown,
+ sync_mpsc::TrySendError::Full(_) => Error::RequestThreadStuck,
+ })?;
+
+ response_rx
+ .recv_timeout(REQUEST_TIMEOUT)
+ .map_err(|_| Error::RequestThreadStuck)?
+ }
+
+ /// Set a list of applications to exclude from the tunnel.
+ pub fn set_paths<T: AsRef<OsStr>>(&self, paths: &[T]) -> Result<(), Error> {
+ self.send_request(Request::SetPaths(
+ paths
+ .iter()
+ .map(|path| path.as_ref().to_os_string())
+ .collect(),
+ ))
+ }
+
+ /// Instructs the driver to redirect traffic from sockets bound to 0.0.0.0, ::, or the
+ /// tunnel addresses (if any) to the default route.
+ pub fn set_tunnel_addresses(&mut self, metadata: Option<&TunnelMetadata>) -> Result<(), Error> {
+ let mut tunnel_ipv4 = None;
+ let mut tunnel_ipv6 = None;
+
+ if let Some(metadata) = metadata {
+ for ip in &metadata.ips {
+ match ip {
+ IpAddr::V4(address) => tunnel_ipv4 = Some(address.clone()),
+ IpAddr::V6(address) => tunnel_ipv6 = Some(address.clone()),
+ }
+ }
+ }
+
+ // Identify IP address that gives us Internet access
+ let internet_ipv4 = get_best_default_route(WinNetAddrFamily::IPV4)
+ .map_err(Error::ObtainDefaultRoute)?
+ .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV4, route.interface_luid))
+ .transpose()
+ .map_err(Error::LuidToIp)?
+ .flatten();
+ let internet_ipv6 = get_best_default_route(WinNetAddrFamily::IPV6)
+ .map_err(Error::ObtainDefaultRoute)?
+ .map(|route| interface_luid_to_ip(WinNetAddrFamily::IPV6, route.interface_luid))
+ .transpose()
+ .map_err(Error::LuidToIp)?
+ .flatten();
+
+ let tunnel_ipv4 = Some(tunnel_ipv4.unwrap_or(*RESERVED_IP_V4));
+ let internet_ipv4 = internet_ipv4
+ .map(|addr| Ipv4Addr::try_from(addr).map_err(|_| Error::IpParseError))
+ .transpose()?;
+ let internet_ipv6 = internet_ipv6
+ .map(|addr| Ipv6Addr::try_from(addr).map_err(|_| Error::IpParseError))
+ .transpose()?;
+
+ let context = SplitTunnelDefaultRouteChangeHandlerContext::new(
+ self.request_tx.clone(),
+ self.daemon_tx.clone(),
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ );
+
+ self._route_change_callback = None;
+
+ self.send_request(Request::RegisterIps(
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ ))?;
+
+ self._route_change_callback = winnet::add_default_route_change_callback(
+ Some(split_tunnel_default_route_change_handler),
+ context,
+ )
+ .map(Some)
+ .map_err(|_| Error::RegisterRouteChangeCallback)?;
+
+ Ok(())
+ }
+
+ /// Instructs the driver to stop redirecting tunnel traffic and INADDR_ANY.
+ pub fn clear_tunnel_addresses(&mut self) -> Result<(), Error> {
+ self._route_change_callback = None;
+ self.send_request(Request::RegisterIps(None, None, None, None))
+ }
+}
+
+impl Drop for SplitTunnel {
+ fn drop(&mut self) {
+ if let Some(_event_thread) = self.event_thread.take() {
+ if unsafe { SetEvent(self.quit_event) } == 0 {
+ let error = io::Error::last_os_error();
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to close ST event thread")
+ );
+ }
+ // Not joining `event_thread`: It may be unresponsive.
+ }
+
+ let paths: [&OsStr; 0] = [];
+ if let Err(error) = self.set_paths(&paths) {
+ log::error!("{}", error.display_chain());
+ }
+ }
+}
+
+struct SplitTunnelDefaultRouteChangeHandlerContext {
+ request_tx: RequestTx,
+ pub daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ pub tunnel_ipv4: Option<Ipv4Addr>,
+ pub tunnel_ipv6: Option<Ipv6Addr>,
+ pub internet_ipv4: Option<Ipv4Addr>,
+ pub internet_ipv6: Option<Ipv6Addr>,
+}
+
+impl SplitTunnelDefaultRouteChangeHandlerContext {
+ pub fn new(
+ request_tx: RequestTx,
+ daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
+ tunnel_ipv4: Option<Ipv4Addr>,
+ tunnel_ipv6: Option<Ipv6Addr>,
+ internet_ipv4: Option<Ipv4Addr>,
+ internet_ipv6: Option<Ipv6Addr>,
+ ) -> Self {
+ SplitTunnelDefaultRouteChangeHandlerContext {
+ request_tx,
+ daemon_tx,
+ tunnel_ipv4,
+ tunnel_ipv6,
+ internet_ipv4,
+ internet_ipv6,
+ }
+ }
+
+ pub fn register_ips(&self) -> Result<(), Error> {
+ SplitTunnel::send_request_inner(
+ &self.request_tx,
+ Request::RegisterIps(
+ self.tunnel_ipv4,
+ self.tunnel_ipv6,
+ self.internet_ipv4,
+ self.internet_ipv6,
+ ),
+ )
+ }
+}
+
+unsafe extern "system" fn split_tunnel_default_route_change_handler(
+ event_type: winnet::WinNetDefaultRouteChangeEventType,
+ address_family: WinNetAddrFamily,
+ default_route: winnet::WinNetDefaultRoute,
+ ctx: *mut libc::c_void,
+) {
+ // Update the "internet interface" IP when best default route changes
+ let ctx = &mut *(ctx as *mut SplitTunnelDefaultRouteChangeHandlerContext);
+
+ let daemon_tx = ctx.daemon_tx.upgrade();
+ let maybe_send = move |content| {
+ if let Some(tx) = daemon_tx {
+ let _ = tx.unbounded_send(content);
+ }
+ };
+
+ let result = match event_type {
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteChanged => {
+ match interface_luid_to_ip(address_family.clone(), default_route.interface_luid) {
+ Ok(Some(ip)) => match IpAddr::from(ip) {
+ IpAddr::V4(addr) => ctx.internet_ipv4 = Some(addr),
+ IpAddr::V6(addr) => ctx.internet_ipv6 = Some(addr),
+ },
+ Ok(None) => {
+ log::warn!("Failed to obtain default route interface address");
+ match address_family {
+ WinNetAddrFamily::IPV4 => {
+ ctx.internet_ipv4 = None;
+ }
+ WinNetAddrFamily::IPV6 => {
+ ctx.internet_ipv6 = None;
+ }
+ }
+ }
+ Err(error) => {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to obtain default route interface address"
+ )
+ );
+ maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError));
+ return;
+ }
+ };
+
+ ctx.register_ips()
+ }
+ // no default route
+ winnet::WinNetDefaultRouteChangeEventType::DefaultRouteRemoved => {
+ match address_family {
+ WinNetAddrFamily::IPV4 => {
+ ctx.internet_ipv4 = None;
+ }
+ WinNetAddrFamily::IPV6 => {
+ ctx.internet_ipv6 = None;
+ }
+ }
+ ctx.register_ips()
+ }
+ };
+
+ if let Err(error) = result {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("Failed to register new addresses in split tunnel driver")
+ );
+ maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError));
+ }
+}
diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs
new file mode 100644
index 0000000000..b706a73203
--- /dev/null
+++ b/talpid-core/src/split_tunnel/windows/windows.rs
@@ -0,0 +1,286 @@
+// TODO: The snapshot code could be combined with the mostly-identical code in
+// the windows_exception_logging module.
+
+use std::{
+ ffi::{OsStr, OsString},
+ io, iter, mem,
+ os::windows::{
+ ffi::{OsStrExt, OsStringExt},
+ io::RawHandle,
+ },
+ path::Path,
+ ptr,
+};
+use winapi::{
+ shared::{
+ minwindef::{DWORD, FALSE, FILETIME, TRUE},
+ ntdef::ULARGE_INTEGER,
+ winerror::{ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES},
+ },
+ um::{
+ fileapi::QueryDosDeviceW,
+ handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ processthreadsapi::{GetProcessTimes, OpenProcess},
+ psapi::K32GetProcessImageFileNameW,
+ tlhelp32::{CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W},
+ winnt::{HANDLE, PROCESS_QUERY_LIMITED_INFORMATION},
+ },
+};
+
+pub struct ProcessSnapshot {
+ handle: HANDLE,
+}
+
+impl ProcessSnapshot {
+ pub fn new(flags: DWORD, process_id: DWORD) -> io::Result<ProcessSnapshot> {
+ let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) };
+
+ if snap == INVALID_HANDLE_VALUE {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(ProcessSnapshot { handle: snap })
+ }
+ }
+
+ pub fn handle(&self) -> HANDLE {
+ self.handle
+ }
+
+ pub fn entries(&self) -> ProcessSnapshotEntries<'_> {
+ let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() };
+ entry.dwSize = mem::size_of::<PROCESSENTRY32W>() as u32;
+
+ ProcessSnapshotEntries {
+ snapshot: self,
+ iter_started: false,
+ temp_entry: entry,
+ }
+ }
+}
+
+impl Drop for ProcessSnapshot {
+ fn drop(&mut self) {
+ unsafe {
+ CloseHandle(self.handle);
+ }
+ }
+}
+
+pub struct ProcessEntry {
+ pub pid: u32,
+ pub parent_pid: u32,
+}
+
+pub struct ProcessSnapshotEntries<'a> {
+ snapshot: &'a ProcessSnapshot,
+ iter_started: bool,
+ temp_entry: PROCESSENTRY32W,
+}
+
+impl Iterator for ProcessSnapshotEntries<'_> {
+ type Item = io::Result<ProcessEntry>;
+
+ fn next(&mut self) -> Option<io::Result<ProcessEntry>> {
+ if self.iter_started {
+ if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE {
+ let last_error = io::Error::last_os_error();
+
+ return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES {
+ None
+ } else {
+ Some(Err(last_error))
+ };
+ }
+ } else {
+ if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == FALSE {
+ return Some(Err(io::Error::last_os_error()));
+ }
+ self.iter_started = true;
+ }
+
+ Some(Ok(ProcessEntry {
+ pid: self.temp_entry.th32ProcessID,
+ parent_pid: self.temp_entry.th32ParentProcessID,
+ }))
+ }
+}
+
+/// Obtains a device path without resolving links or mount points.
+pub fn get_device_path<T: AsRef<Path>>(path: T) -> Result<OsString, io::Error> {
+ if !path.as_ref().is_absolute() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "path must be absolute",
+ ));
+ }
+
+ let drive_comp = path.as_ref().components().next();
+ let drive = match drive_comp {
+ Some(std::path::Component::Prefix(prefix)) => prefix.as_os_str(),
+ _ => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "invalid drive label",
+ ))
+ }
+ };
+
+ let mut new_path = query_dos_device(drive)?;
+ let suffix = path
+ .as_ref()
+ .strip_prefix(drive_comp.unwrap())
+ .expect("path missing own component");
+ new_path.push(suffix);
+
+ Ok(new_path)
+}
+
+/// Obtains the real device path for a label (such as C:).
+/// The underlying function may return multiple paths, but only the first is returned.
+fn query_dos_device<T: AsRef<OsStr>>(device_name: T) -> io::Result<OsString> {
+ let device_name_c: Vec<u16> = device_name
+ .as_ref()
+ .encode_wide()
+ .chain(iter::once(0u16))
+ .collect();
+ let mut new_prefix = vec![0u16; 64];
+
+ loop {
+ let prefix_len = unsafe {
+ QueryDosDeviceW(
+ device_name_c.as_ptr(),
+ new_prefix.as_mut_ptr(),
+ new_prefix.len() as u32,
+ ) as usize
+ };
+
+ if prefix_len == 0 {
+ let last_error = io::Error::last_os_error();
+ if last_error.raw_os_error() == Some(ERROR_INSUFFICIENT_BUFFER as i32) {
+ // resize buffer and try again
+ new_prefix.resize(2 * new_prefix.len(), 0);
+ continue;
+ }
+ break Err(last_error);
+ }
+
+ // We must scan for the first null terminator
+ // Because `new_prefix` may contain multiple strings.
+
+ let real_len = new_prefix.iter().position(|&c| c == 0u16).unwrap();
+ unsafe { new_prefix.set_len(real_len) };
+
+ break Ok(OsString::from_wide(&new_prefix));
+ }
+}
+
+/// Object that frees its handle when dropped.
+pub struct WinHandle(RawHandle);
+
+impl WinHandle {
+ pub fn get_raw(&self) -> RawHandle {
+ self.0
+ }
+}
+
+impl Drop for WinHandle {
+ fn drop(&mut self) {
+ unsafe { CloseHandle(self.0) };
+ }
+}
+
+#[repr(u32)]
+pub enum ProcessAccess {
+ QueryLimitedInformation = PROCESS_QUERY_LIMITED_INFORMATION,
+ // TODO: could be extended
+}
+
+/// Open an existing process object.
+pub fn open_process(
+ access: ProcessAccess,
+ inherit_handle: bool,
+ pid: u32,
+) -> Result<WinHandle, io::Error> {
+ let handle = unsafe {
+ OpenProcess(
+ access as u32,
+ if inherit_handle { TRUE } else { FALSE },
+ pid,
+ )
+ };
+
+ if handle == ptr::null_mut() {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(WinHandle(handle))
+}
+
+/// Returns the age of a running process.
+pub fn get_process_creation_time(handle: RawHandle) -> Result<u64, io::Error> {
+ // TODO: FileTimeToSystemTime -> chrono::NaiveDateTime
+ let mut creation_time: FILETIME = unsafe { mem::zeroed() };
+ let mut dummy: FILETIME = unsafe { mem::zeroed() };
+ if unsafe {
+ GetProcessTimes(
+ handle,
+ &mut creation_time as *mut _,
+ &mut dummy as *mut _,
+ &mut dummy as *mut _,
+ &mut dummy as *mut _,
+ )
+ } == 0
+ {
+ return Err(io::Error::last_os_error());
+ }
+
+ let mut uli_time: ULARGE_INTEGER = unsafe { mem::zeroed() };
+ unsafe {
+ uli_time.s_mut().LowPart = creation_time.dwLowDateTime;
+ uli_time.s_mut().HighPart = creation_time.dwHighDateTime;
+ }
+
+ Ok(*unsafe { uli_time.QuadPart() })
+}
+
+/// Returns the device path for a running process.
+pub fn get_process_device_path(handle: RawHandle) -> Result<OsString, io::Error> {
+ let mut initial_capacity = 512;
+ loop {
+ let result = get_process_device_path_inner(handle, initial_capacity);
+ match result {
+ Ok(path) => return Ok(path),
+ Err(error) => {
+ if ERROR_INSUFFICIENT_BUFFER == error.raw_os_error().unwrap() as u32 {
+ // Try again with a larger buffer capacity.
+ initial_capacity *= 2;
+ continue;
+ }
+ return Err(error);
+ }
+ }
+ }
+}
+
+fn get_process_device_path_inner(
+ handle: RawHandle,
+ buffer_capacity: usize,
+) -> Result<OsString, io::Error> {
+ let mut buffer = Vec::<u16>::new();
+ buffer.reserve_exact(buffer_capacity);
+
+ let written = unsafe {
+ K32GetProcessImageFileNameW(
+ handle,
+ buffer.as_mut_ptr() as *mut _,
+ buffer.capacity() as u32,
+ )
+ };
+ if written == 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ // `written` does not include a null terminator
+ unsafe { buffer.set_len(written as usize) };
+
+ Ok(OsStringExt::from_wide(&buffer))
+}
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 3ee6bf1886..570cb64983 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -71,7 +71,7 @@ pub enum Error {
pub enum TunnelEvent {
/// Sent when the tunnel fails to connect due to an authentication error.
AuthFailed(Option<String>),
- /// Sent when the tunnel interface has been created.
+ /// Sent when the tunnel interface has been created, before routes are set up.
InterfaceUp(TunnelMetadata),
/// Sent when the tunnel comes up and is ready for traffic.
Up(TunnelMetadata),
@@ -112,7 +112,7 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Clone
+ Sync
@@ -169,7 +169,7 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ Clone
@@ -198,7 +198,7 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index 87dbc55101..f25f2624da 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -334,7 +334,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(not(target_os = "linux"))] _route_manager: &mut routing::RouteManager,
) -> Result<Self>
where
- L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
@@ -945,9 +945,11 @@ impl ProcessHandle for OpenVpnProcHandle {
mod event_server {
+ use crate::tunnel::TunnelMetadata;
use futures::stream::TryStreamExt;
use parity_tokio_ipc::Endpoint as IpcEndpoint;
use std::{
+ collections::HashMap,
pin::Pin,
task::{Context, Poll},
};
@@ -981,7 +983,9 @@ mod event_server {
/// Implements a gRPC service used to process events sent to by OpenVPN.
pub struct OpenvpnEventProxyImpl<
- L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(
+ super::TunnelEvent,
+ ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
@@ -997,12 +1001,27 @@ mod event_server {
}
impl<
- L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(
+ super::TunnelEvent,
+ )
+ -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
> OpenvpnEventProxyImpl<L>
{
+ async fn up_inner(
+ &self,
+ request: Request<EventDetails>,
+ ) -> std::result::Result<Response<()>, tonic::Status> {
+ let env = request.into_inner().env;
+ (self.on_event)(super::TunnelEvent::InterfaceUp(Self::get_tunnel_metadata(
+ &env,
+ )?))
+ .await;
+ Ok(Response::new(()))
+ }
+
async fn route_up_inner(
&self,
request: Request<EventDetails>,
@@ -1038,14 +1057,11 @@ mod event_server {
}
}
- let tunnel_alias = env
- .get("dev")
- .ok_or(tonic::Status::invalid_argument("missing tunnel alias"))?
- .to_string();
+ let metadata = Self::get_tunnel_metadata(&env)?;
#[cfg(windows)]
{
- let tunnel_device = tunnel_alias.clone();
+ let tunnel_device = metadata.interface.clone();
tokio::task::spawn_blocking(move || super::wait_for_ready_device(&tunnel_device))
.await
.map_err(|_| tonic::Status::internal("task failed to complete"))?
@@ -1058,6 +1074,19 @@ mod event_server {
})?;
}
+ (self.on_event)(super::TunnelEvent::Up(metadata)).await;
+
+ Ok(Response::new(()))
+ }
+
+ fn get_tunnel_metadata(
+ env: &HashMap<String, String>,
+ ) -> std::result::Result<TunnelMetadata, tonic::Status> {
+ let tunnel_alias = env
+ .get("dev")
+ .ok_or(tonic::Status::invalid_argument("missing tunnel alias"))?
+ .to_string();
+
let mut ips = vec![env
.get("ifconfig_local")
.ok_or(tonic::Status::invalid_argument(
@@ -1089,21 +1118,21 @@ mod event_server {
None
};
- (self.on_event)(super::TunnelEvent::Up(crate::tunnel::TunnelMetadata {
+ Ok(TunnelMetadata {
interface: tunnel_alias,
ips,
ipv4_gateway,
ipv6_gateway,
- }))
- .await;
-
- Ok(Response::new(()))
+ })
}
}
#[tonic::async_trait]
impl<
- L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ L: (Fn(
+ super::TunnelEvent,
+ )
+ -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
@@ -1121,6 +1150,16 @@ mod event_server {
Ok(Response::new(()))
}
+ async fn up(
+ &self,
+ request: Request<EventDetails>,
+ ) -> std::result::Result<Response<()>, tonic::Status> {
+ self.up_inner(request).await.map_err(|error| {
+ self.abort_server_tx.trigger();
+ error
+ })
+ }
+
async fn route_up(
&self,
request: Request<EventDetails>,
@@ -1341,6 +1380,12 @@ mod tests {
) -> std::result::Result<tonic::Response<()>, tonic::Status> {
Ok(tonic::Response::new(()))
}
+ async fn up(
+ &self,
+ _request: tonic::Request<event_server::EventDetails>,
+ ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
+ Ok(tonic::Response::new(()))
+ }
async fn route_up(
&self,
_request: tonic::Request<event_server::EventDetails>,
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 899589dbc4..034cabd316 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -79,7 +79,7 @@ pub struct WireguardMonitor {
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
event_callback: Box<
- dyn (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ dyn (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
@@ -90,6 +90,8 @@ pub struct WireguardMonitor {
stop_setup_tx: Option<futures::channel::oneshot::Sender<()>>,
pinger_stop_sender: mpsc::Sender<()>,
_tcp_proxies: Vec<TcpProxy>,
+ #[cfg(target_os = "windows")]
+ _callback_handle: Option<crate::winnet::WinNetCallbackHandle>,
}
#[cfg(target_os = "linux")]
@@ -156,7 +158,7 @@ impl Drop for TcpProxy {
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
pub fn start<
- F: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ F: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ Clone
@@ -187,12 +189,14 @@ impl WireguardMonitor {
#[cfg(windows)]
let iface_luid = tunnel.get_interface_luid();
- let metadata = Self::tunnel_metadata(&iface_name, &config);
- runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone())));
-
#[cfg(target_os = "windows")]
- route_manager
- .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ());
+ let callback_handle = route_manager
+ .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ())
+ .ok();
+ #[cfg(target_os = "windows")]
+ if callback_handle.is_none() {
+ log::warn!("Failed to register default route callback");
+ }
let event_callback = Box::new(on_event.clone());
let (close_msg_sender, close_msg_receiver) = mpsc::channel();
@@ -209,6 +213,8 @@ impl WireguardMonitor {
stop_setup_tx: Some(stop_setup_tx),
pinger_stop_sender: pinger_tx,
_tcp_proxies: tcp_proxies,
+ #[cfg(target_os = "windows")]
+ _callback_handle: callback_handle,
};
let gateway = config.ipv4_gateway;
@@ -223,7 +229,11 @@ impl WireguardMonitor {
let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
+ let metadata = Self::tunnel_metadata(&iface_name, &config);
+
std::thread::spawn(move || {
+ runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone())));
+
#[cfg(windows)]
{
let iface_close_sender = close_sender.clone();
diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs
index 72af2b6f38..49365d379e 100644
--- a/talpid-core/src/tunnel_state_machine/connected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connected_state.rs
@@ -8,7 +8,11 @@ use crate::{
tunnel::{CloseHandle, TunnelEvent, TunnelMetadata},
};
use cfg_if::cfg_if;
-use futures::{channel::mpsc, stream::Fuse, StreamExt};
+use futures::{
+ channel::{mpsc, oneshot},
+ stream::Fuse,
+ StreamExt,
+};
use std::net::IpAddr;
use talpid_types::{
net::TunnelParameters,
@@ -21,7 +25,8 @@ use crate::tunnel::TunnelMonitor;
use super::connecting_state::TunnelCloseEvent;
-pub(crate) type TunnelEventsReceiver = Fuse<mpsc::UnboundedReceiver<TunnelEvent>>;
+pub(crate) type TunnelEventsReceiver =
+ Fuse<mpsc::UnboundedReceiver<(TunnelEvent, oneshot::Sender<()>)>>;
pub struct ConnectedStateBootstrap {
@@ -133,8 +138,6 @@ impl ConnectedState {
}
fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
- #[cfg(windows)]
- shared_values.route_manager.clear_default_route_callbacks();
if let Err(error) = shared_values.route_manager.clear_routes() {
log::error!("{}", error.display_chain_with_msg("Failed to clear routes"));
}
@@ -257,18 +260,23 @@ impl ConnectedState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ SameState(self.into())
+ }
}
}
fn handle_tunnel_events(
self,
- event: Option<TunnelEvent>,
+ event: Option<(TunnelEvent, oneshot::Sender<()>)>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
use self::EventConsequence::*;
match event {
- Some(TunnelEvent::Down) | None => {
+ Some((TunnelEvent::Down, _)) | None => {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
Some(_) => SameState(self.into()),
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index b0c87acdb4..f9a52c7881 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -100,9 +100,12 @@ impl ConnectingState {
) -> crate::tunnel::Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded();
let on_tunnel_event =
- move |event| -> Box<dyn std::future::Future<Output = ()> + Unpin + Send> {
- let _ = event_tx.unbounded_send(event);
- Box::new(futures::future::ready(()))
+ move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
+ let (tx, rx) = oneshot::channel();
+ let _ = event_tx.unbounded_send((event, tx));
+ Box::pin(async move {
+ let _ = rx.await;
+ })
};
let monitor = TunnelMonitor::start(
@@ -198,8 +201,6 @@ impl ConnectingState {
}
fn reset_routes(shared_values: &mut SharedTunnelStateValues) {
- #[cfg(windows)]
- shared_values.route_manager.clear_default_route_callbacks();
if let Err(error) = shared_values.route_manager.clear_routes() {
log::error!("{}", error.display_chain_with_msg("Failed to clear routes"));
}
@@ -314,23 +315,44 @@ impl ConnectingState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ SameState(self.into())
+ }
}
}
fn handle_tunnel_events(
mut self,
- event: Option<tunnel::TunnelEvent>,
+ event: Option<(tunnel::TunnelEvent, oneshot::Sender<()>)>,
shared_values: &mut SharedTunnelStateValues,
) -> EventConsequence {
use self::EventConsequence::*;
match event {
- Some(TunnelEvent::AuthFailed(reason)) => self.disconnect(
+ Some((TunnelEvent::AuthFailed(reason), _)) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)),
),
- Some(TunnelEvent::InterfaceUp(tunnel_metadata)) => {
- self.tunnel_metadata = Some(tunnel_metadata);
+ Some((TunnelEvent::InterfaceUp(metadata), _done_tx)) => {
+ #[cfg(windows)]
+ if let Err(error) = shared_values
+ .split_tunnel
+ .set_tunnel_addresses(Some(&metadata))
+ {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to register addresses with split tunnel driver"
+ )
+ );
+ return self.disconnect(
+ shared_values,
+ AfterDisconnect::Block(ErrorStateCause::SplitTunnelError),
+ );
+ }
+ self.tunnel_metadata = Some(metadata);
match Self::set_firewall_policy(
shared_values,
&self.tunnel_parameters,
@@ -343,11 +365,11 @@ impl ConnectingState {
),
}
}
- Some(TunnelEvent::Up(metadata)) => NewState(ConnectedState::enter(
+ Some((TunnelEvent::Up(metadata), _)) => NewState(ConnectedState::enter(
shared_values,
self.into_connected_state_bootstrap(metadata),
)),
- Some(TunnelEvent::Down) => SameState(self.into()),
+ Some((TunnelEvent::Down, _)) => SameState(self.into()),
None => {
// The channel was closed
debug!("The tunnel disconnected unexpectedly");
@@ -437,6 +459,18 @@ impl TunnelState for ConnectingState {
ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err))
}
Ok(tunnel_parameters) => {
+ #[cfg(windows)]
+ if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to reset addresses in split tunnel driver"
+ )
+ );
+
+ return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError);
+ }
+
if let Err(error) =
Self::set_firewall_policy(shared_values, &tunnel_parameters, &None)
{
diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
index 8d2c9bc0fa..9b5d51bd51 100644
--- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs
@@ -36,6 +36,31 @@ impl DisconnectedState {
log::error!("{}", error_chain);
}
}
+
+ #[cfg(windows)]
+ fn register_split_tunnel_addresses(
+ shared_values: &mut SharedTunnelStateValues,
+ should_reset_firewall: bool,
+ ) {
+ if should_reset_firewall && !shared_values.block_when_disconnected {
+ if let Err(error) = shared_values.split_tunnel.clear_tunnel_addresses() {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to unregister addresses with split tunnel driver"
+ )
+ );
+ }
+ } else {
+ if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
+ log::error!(
+ "{}",
+ error
+ .display_chain_with_msg("Failed to reset addresses in split tunnel driver")
+ );
+ }
+ }
+ }
}
impl TunnelState for DisconnectedState {
@@ -45,6 +70,8 @@ impl TunnelState for DisconnectedState {
shared_values: &mut SharedTunnelStateValues,
should_reset_firewall: Self::Bootstrap,
) -> (TunnelStateWrapper, TunnelStateTransition) {
+ #[cfg(windows)]
+ Self::register_split_tunnel_addresses(shared_values, should_reset_firewall);
Self::set_firewall_policy(shared_values, should_reset_firewall);
#[cfg(target_os = "linux")]
shared_values.reset_connectivity_check();
@@ -98,6 +125,8 @@ impl TunnelState for DisconnectedState {
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
shared_values.block_when_disconnected = block_when_disconnected;
+ #[cfg(windows)]
+ Self::register_split_tunnel_addresses(shared_values, true);
Self::set_firewall_policy(shared_values, true);
}
SameState(self.into())
@@ -115,6 +144,11 @@ impl TunnelState for DisconnectedState {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ SameState(self.into())
+ }
Some(_) => SameState(self.into()),
None => Finished,
}
diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
index 7d308d5971..71fbd6ae95 100644
--- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs
@@ -59,6 +59,11 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Nothing
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ AfterDisconnect::Nothing
+ }
},
AfterDisconnect::Block(reason) => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
@@ -96,6 +101,11 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Block(reason)
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ AfterDisconnect::Block(reason)
+ }
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
@@ -134,6 +144,11 @@ impl DisconnectingState {
shared_values.bypass_socket(fd, done_tx);
AfterDisconnect::Reconnect(retry_attempt)
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ AfterDisconnect::Reconnect(retry_attempt)
+ }
},
};
diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs
index 5e647c8201..2f1f461dfd 100644
--- a/talpid-core/src/tunnel_state_machine/error_state.rs
+++ b/talpid-core/src/tunnel_state_machine/error_state.rs
@@ -70,6 +70,16 @@ impl TunnelState for ErrorState {
shared_values: &mut SharedTunnelStateValues,
block_reason: Self::Bootstrap,
) -> (TunnelStateWrapper, TunnelStateTransition) {
+ #[cfg(windows)]
+ if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg(
+ "Failed to register addresses with split tunnel driver"
+ )
+ );
+ }
+
#[cfg(not(target_os = "android"))]
let block_failure = Self::set_firewall_policy(shared_values).err();
#[cfg(target_os = "android")]
@@ -151,12 +161,16 @@ impl TunnelState for ErrorState {
Some(TunnelCommand::Block(reason)) => {
NewState(ErrorState::enter(shared_values, reason))
}
-
#[cfg(target_os = "android")]
Some(TunnelCommand::BypassSocket(fd, done_tx)) => {
shared_values.bypass_socket(fd, done_tx);
SameState(self.into())
}
+ #[cfg(windows)]
+ Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
+ let _ = result_tx.send(shared_values.split_tunnel.set_paths(&paths));
+ SameState(self.into())
+ }
}
}
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index 38496865a4..97324e3484 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -11,6 +11,8 @@ use self::{
disconnecting_state::{AfterDisconnect, DisconnectingState},
error_state::ErrorState,
};
+#[cfg(windows)]
+use crate::split_tunnel;
use crate::{
dns::DnsMonitor,
firewall::{Firewall, FirewallArguments},
@@ -19,6 +21,9 @@ use crate::{
routing::RouteManager,
tunnel::{tun_provider::TunProvider, TunnelEvent},
};
+#[cfg(windows)]
+use std::ffi::OsString;
+
use futures::{
channel::{mpsc, oneshot},
stream, StreamExt,
@@ -47,9 +52,9 @@ pub enum Error {
OfflineMonitorError(#[error(source)] crate::offline::Error),
/// Unable to set up split tunneling
- #[cfg(target_os = "linux")]
+ #[cfg(target_os = "windows")]
#[error(display = "Failed to initialize split tunneling")]
- InitSplitTunneling(#[error(source)] crate::split_tunnel::Error),
+ InitSplitTunneling(#[error(source)] split_tunnel::Error),
/// Failed to initialize the system firewall integration.
#[error(display = "Failed to initialize the system firewall integration")]
@@ -86,6 +91,7 @@ pub async fn spawn(
shutdown_tx: oneshot::Sender<()>,
reset_firewall: bool,
#[cfg(target_os = "android")] android_context: AndroidContext,
+ #[cfg(windows)] exclude_paths: Vec<OsString>,
) -> Result<Arc<mpsc::UnboundedSender<TunnelCommand>>, Error> {
let (command_tx, command_rx) = mpsc::unbounded();
let command_tx = Arc::new(command_tx);
@@ -122,6 +128,8 @@ pub async fn spawn(
reset_firewall,
#[cfg(target_os = "android")]
android_context,
+ #[cfg(windows)]
+ exclude_paths,
));
let state_machine = match state_machine {
Ok(state_machine) => {
@@ -169,13 +177,19 @@ pub enum TunnelCommand {
/// Bypass a socket, allowing traffic to flow through outside the tunnel.
#[cfg(target_os = "android")]
BypassSocket(RawFd, oneshot::Sender<()>),
+ /// Set applications that are allowed to send and receive traffic outside of the tunnel.
+ #[cfg(windows)]
+ SetExcludedApps(
+ oneshot::Sender<Result<(), split_tunnel::Error>>,
+ Vec<OsString>,
+ ),
}
type TunnelCommandReceiver = stream::Fuse<mpsc::UnboundedReceiver<TunnelCommand>>;
enum EventResult {
Command(Option<TunnelCommand>),
- Event(Option<TunnelEvent>),
+ Event(Option<(TunnelEvent, oneshot::Sender<()>)>),
Close(Result<Option<ErrorStateCause>, oneshot::Canceled>),
}
@@ -204,9 +218,10 @@ impl TunnelStateMachine {
log_dir: Option<PathBuf>,
resource_dir: PathBuf,
cache_dir: impl AsRef<Path>,
- commands: mpsc::UnboundedReceiver<TunnelCommand>,
+ commands_rx: mpsc::UnboundedReceiver<TunnelCommand>,
reset_firewall: bool,
#[cfg(target_os = "android")] android_context: AndroidContext,
+ #[cfg(windows)] exclude_paths: Vec<OsString>,
) -> Result<Self, Error> {
let args = FirewallArguments {
initialize_blocked: block_when_disconnected || !reset_firewall,
@@ -228,7 +243,7 @@ impl TunnelStateMachine {
)
.map_err(Error::InitDnsMonitorError)?;
let mut offline_monitor = offline::spawn_monitor(
- command_tx,
+ command_tx.clone(),
#[cfg(target_os = "linux")]
route_manager
.handle()
@@ -239,7 +254,18 @@ impl TunnelStateMachine {
.await
.map_err(Error::OfflineMonitorError)?;
let is_offline = offline_monitor.is_offline().await;
+
+ #[cfg(windows)]
+ let split_tunnel =
+ split_tunnel::SplitTunnel::new(command_tx).map_err(Error::InitSplitTunneling)?;
+ #[cfg(windows)]
+ split_tunnel
+ .set_paths(&exclude_paths)
+ .map_err(Error::InitSplitTunneling)?;
+
let mut shared_values = SharedTunnelStateValues {
+ #[cfg(windows)]
+ split_tunnel,
runtime,
firewall,
dns_monitor,
@@ -262,7 +288,7 @@ impl TunnelStateMachine {
Ok(TunnelStateMachine {
current_state: Some(initial_state),
- commands: commands.fuse(),
+ commands: commands_rx.fuse(),
shared_values,
})
}
@@ -310,6 +336,11 @@ pub trait TunnelParametersGenerator: Send + 'static {
/// Values that are common to all tunnel states.
struct SharedTunnelStateValues {
+ /// Management of excluded apps.
+ /// This object should be dropped before deinitializing WinFw (dropping the `Firewall`
+ /// instance), since the driver may add filters to the same sublayer.
+ #[cfg(windows)]
+ split_tunnel: split_tunnel::SplitTunnel,
runtime: tokio::runtime::Handle,
firewall: Firewall,
dns_monitor: DnsMonitor,
diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs
index 79008f8cbc..4f71fa0040 100644
--- a/talpid-core/src/winnet.rs
+++ b/talpid-core/src/winnet.rs
@@ -4,6 +4,7 @@ use crate::{logging::windows::log_sink, routing::Node};
use ipnetwork::IpNetwork;
use libc::c_void;
use std::{
+ convert::TryFrom,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
ptr,
};
@@ -85,6 +86,7 @@ pub fn ensure_best_metric_for_interface(interface_alias: &str) -> Result<bool, E
}
}
+#[derive(Debug, Clone)]
#[allow(dead_code)]
#[repr(u32)]
pub enum WinNetAddrFamily {
@@ -121,15 +123,40 @@ pub struct WinNetDefaultRoute {
pub gateway: WinNetIp,
}
-impl From<WinNetIp> for IpAddr {
- fn from(addr: WinNetIp) -> IpAddr {
+#[derive(Debug)]
+pub struct WrongIpFamilyError;
+
+impl TryFrom<WinNetIp> for Ipv4Addr {
+ type Error = WrongIpFamilyError;
+
+ fn try_from(addr: WinNetIp) -> Result<Ipv4Addr, WrongIpFamilyError> {
match addr.addr_family {
WinNetAddrFamily::IPV4 => {
let mut bytes: [u8; 4] = Default::default();
bytes.clone_from_slice(&addr.ip_bytes[..4]);
- IpAddr::V4(Ipv4Addr::from(bytes))
+ Ok(Ipv4Addr::from(bytes))
}
- WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::from(addr.ip_bytes)),
+ WinNetAddrFamily::IPV6 => Err(WrongIpFamilyError),
+ }
+ }
+}
+
+impl TryFrom<WinNetIp> for Ipv6Addr {
+ type Error = WrongIpFamilyError;
+
+ fn try_from(addr: WinNetIp) -> Result<Ipv6Addr, WrongIpFamilyError> {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => Err(WrongIpFamilyError),
+ WinNetAddrFamily::IPV6 => Ok(Ipv6Addr::from(addr.ip_bytes)),
+ }
+ }
+}
+
+impl From<WinNetIp> for IpAddr {
+ fn from(addr: WinNetIp) -> IpAddr {
+ match addr.addr_family {
+ WinNetAddrFamily::IPV4 => IpAddr::V4(Ipv4Addr::try_from(addr).unwrap()),
+ WinNetAddrFamily::IPV6 => IpAddr::V6(Ipv6Addr::try_from(addr).unwrap()),
}
}
}
@@ -362,8 +389,6 @@ pub fn get_best_default_route(
}
}
-// TODO: Remove attribute once this is in use.
-#[allow(dead_code)]
pub fn interface_luid_to_ip(
family: WinNetAddrFamily,
luid: u64,
diff --git a/talpid-openvpn-plugin/proto/openvpn_plugin.proto b/talpid-openvpn-plugin/proto/openvpn_plugin.proto
index 156caa8881..7795803e35 100644
--- a/talpid-openvpn-plugin/proto/openvpn_plugin.proto
+++ b/talpid-openvpn-plugin/proto/openvpn_plugin.proto
@@ -6,6 +6,7 @@ import "google/protobuf/empty.proto";
service OpenvpnEventProxy {
rpc AuthFailed(EventDetails) returns (google.protobuf.Empty) {}
+ rpc Up(EventDetails) returns (google.protobuf.Empty) {}
rpc RouteUp(EventDetails) returns (google.protobuf.Empty) {}
rpc RoutePredown(EventDetails) returns (google.protobuf.Empty) {}
}
diff --git a/talpid-openvpn-plugin/src/lib.rs b/talpid-openvpn-plugin/src/lib.rs
index 09316d05df..85c7e6a921 100644
--- a/talpid-openvpn-plugin/src/lib.rs
+++ b/talpid-openvpn-plugin/src/lib.rs
@@ -38,6 +38,7 @@ pub enum Error {
/// events.
pub static INTERESTING_EVENTS: &'static [EventType] = &[
EventType::AuthFailed,
+ EventType::Up,
EventType::RouteUp,
EventType::RoutePredown,
];
diff --git a/talpid-openvpn-plugin/src/processing.rs b/talpid-openvpn-plugin/src/processing.rs
index c266c2d1b5..e6d2a77349 100644
--- a/talpid-openvpn-plugin/src/processing.rs
+++ b/talpid-openvpn-plugin/src/processing.rs
@@ -68,6 +68,7 @@ impl EventProcessor {
openvpn_plugin::EventType::AuthFailed => {
self.runtime.block_on(self.ipc_client.auth_failed(details))
}
+ openvpn_plugin::EventType::Up => self.runtime.block_on(self.ipc_client.up(details)),
openvpn_plugin::EventType::RouteUp => {
self.runtime.block_on(self.ipc_client.route_up(details))
}
diff --git a/talpid-types/src/tunnel.rs b/talpid-types/src/tunnel.rs
index b3a75999b4..3c300c30cc 100644
--- a/talpid-types/src/tunnel.rs
+++ b/talpid-types/src/tunnel.rs
@@ -104,6 +104,9 @@ pub enum ErrorStateCause {
/// The Android VPN permission was denied.
#[cfg(target_os = "android")]
VpnPermissionDenied,
+ /// Error reported by split tunnel module.
+ #[cfg(target_os = "windows")]
+ SplitTunnelError,
}
/// Errors that can occur when generating tunnel parameters.
@@ -194,6 +197,8 @@ impl fmt::Display for ErrorStateCause {
IsOffline => "This device is offline, no tunnels can be established",
#[cfg(target_os = "android")]
VpnPermissionDenied => "The Android VPN permission was denied when creating the tunnel",
+ #[cfg(target_os = "windows")]
+ SplitTunnelError => "The split tunneling module reported an error",
};
write!(f, "{}", description)
diff --git a/windows/winfw/src/winfw/rules/baseline/permitdhcp.cpp b/windows/winfw/src/winfw/rules/baseline/permitdhcp.cpp
index 6e5b2896b9..6987934f9d 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitdhcp.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitdhcp.cpp
@@ -42,7 +42,7 @@ bool PermitDhcp::applyIpv4(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
@@ -94,7 +94,7 @@ bool PermitDhcp::applyIpv6(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
diff --git a/windows/winfw/src/winfw/rules/baseline/permitdhcpserver.cpp b/windows/winfw/src/winfw/rules/baseline/permitdhcpserver.cpp
index a8fb2e3036..f3fa2a853d 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitdhcpserver.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitdhcpserver.cpp
@@ -46,7 +46,7 @@ bool PermitDhcpServer::applyIpv4(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
diff --git a/windows/winfw/src/winfw/rules/baseline/permitdns.cpp b/windows/winfw/src/winfw/rules/baseline/permitdns.cpp
index 8ce530edaa..26ca8828af 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitdns.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitdns.cpp
@@ -26,7 +26,7 @@ bool PermitDns::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
diff --git a/windows/winfw/src/winfw/rules/baseline/permitlan.cpp b/windows/winfw/src/winfw/rules/baseline/permitlan.cpp
index 2397c78cdd..7c4e7d8a8e 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitlan.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitlan.cpp
@@ -32,7 +32,7 @@ bool PermitLan::applyIpv4(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
@@ -84,7 +84,7 @@ bool PermitLan::applyIpv6(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
diff --git a/windows/winfw/src/winfw/rules/baseline/permitlanservice.cpp b/windows/winfw/src/winfw/rules/baseline/permitlanservice.cpp
index d729b4ad52..61aae2851c 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitlanservice.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitlanservice.cpp
@@ -32,7 +32,7 @@ bool PermitLanService::applyIpv4(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4);
@@ -60,7 +60,7 @@ bool PermitLanService::applyIpv6(IObjectInstaller &objectInstaller) const
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6);
diff --git a/windows/winfw/src/winfw/rules/baseline/permitloopback.cpp b/windows/winfw/src/winfw/rules/baseline/permitloopback.cpp
index 123bed4b42..fdab4b9c40 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitloopback.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitloopback.cpp
@@ -25,7 +25,7 @@ bool PermitLoopback::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
diff --git a/windows/winfw/src/winfw/rules/baseline/permitndp.cpp b/windows/winfw/src/winfw/rules/baseline/permitndp.cpp
index 135fbb9979..60c95ec8e9 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitndp.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitndp.cpp
@@ -32,7 +32,7 @@ bool PermitNdp::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
index c09f7b631c..e756d68464 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnel.cpp
@@ -30,7 +30,7 @@ bool PermitVpnTunnel::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
{
diff --git a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
index d24830db8f..00fbc8e76b 100644
--- a/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
+++ b/windows/winfw/src/winfw/rules/baseline/permitvpntunnelservice.cpp
@@ -30,7 +30,7 @@ bool PermitVpnTunnelService::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4)
.sublayer(MullvadGuids::SublayerBaseline())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4);
diff --git a/windows/winfw/src/winfw/rules/dns/permitnontunnel.cpp b/windows/winfw/src/winfw/rules/dns/permitnontunnel.cpp
index 729254d1f4..d9b6942243 100644
--- a/windows/winfw/src/winfw/rules/dns/permitnontunnel.cpp
+++ b/windows/winfw/src/winfw/rules/dns/permitnontunnel.cpp
@@ -38,7 +38,7 @@ bool PermitNonTunnel::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerDns())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
@@ -77,7 +77,7 @@ bool PermitNonTunnel::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
.sublayer(MullvadGuids::SublayerDns())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
diff --git a/windows/winfw/src/winfw/rules/dns/permittunnel.cpp b/windows/winfw/src/winfw/rules/dns/permittunnel.cpp
index cc1af84223..578038cbcf 100644
--- a/windows/winfw/src/winfw/rules/dns/permittunnel.cpp
+++ b/windows/winfw/src/winfw/rules/dns/permittunnel.cpp
@@ -38,7 +38,7 @@ bool PermitTunnel::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V4)
.sublayer(MullvadGuids::SublayerDns())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V4);
@@ -73,7 +73,7 @@ bool PermitTunnel::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(FWPM_LAYER_ALE_AUTH_CONNECT_V6)
.sublayer(MullvadGuids::SublayerDns())
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(FWPM_LAYER_ALE_AUTH_CONNECT_V6);
diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
index ee5ffcb0c4..a403230df9 100644
--- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
+++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp
@@ -90,7 +90,7 @@ bool PermitVpnRelay::apply(IObjectInstaller &objectInstaller)
.provider(MullvadGuids::Provider())
.layer(LayerFromIp(m_relay))
.sublayer(TranslateSublayer(m_sublayer))
- .weight(wfp::FilterBuilder::WeightClass::Max)
+ .weight(wfp::FilterBuilder::WeightClass::Medium)
.permit();
wfp::ConditionBuilder conditionBuilder(LayerFromIp(m_relay));
diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp
index e82345ebee..d7b34c1bc0 100644
--- a/windows/winnet/src/winnet/winnet.cpp
+++ b/windows/winnet/src/winnet/winnet.cpp
@@ -554,6 +554,8 @@ WinNet_DeactivateRouteManager(
{
delete g_RouteManager;
g_RouteManager = nullptr;
+
+ g_RouteManagerLogSink.reset();
}
catch (...)
{