summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2021-06-11 16:24:40 +0200
committerDavid Lönnhager <david.l@mullvad.net>2021-06-16 11:38:12 +0200
commitf3a274cbc5e424ad428e6763a58304427cc7be90 (patch)
tree6d7a3e08def842a22d760178cef25c492b57b662
parentf2a20eba3fccf1121ba7b8d5af78c34a8ed80687 (diff)
downloadmullvadvpn-f3a274cbc5e424ad428e6763a58304427cc7be90.tar.xz
mullvadvpn-f3a274cbc5e424ad428e6763a58304427cc7be90.zip
Improve OpenVPN event handling
-rw-r--r--talpid-core/src/tunnel/mod.rs67
-rw-r--r--talpid-core/src/tunnel/openvpn/mod.rs371
-rw-r--r--talpid-core/src/tunnel/wireguard/mod.rs26
-rw-r--r--talpid-core/src/tunnel_state_machine/connecting_state.rs8
-rw-r--r--talpid-openvpn-plugin/proto/openvpn_plugin.proto9
-rw-r--r--talpid-openvpn-plugin/src/lib.rs36
-rw-r--r--talpid-openvpn-plugin/src/processing.rs19
7 files changed, 318 insertions, 218 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs
index 890fc07b67..3ee6bf1886 100644
--- a/talpid-core/src/tunnel/mod.rs
+++ b/talpid-core/src/tunnel/mod.rs
@@ -1,7 +1,5 @@
use self::tun_provider::TunProvider;
use crate::{logging, routing::RouteManager};
-#[cfg(not(target_os = "android"))]
-use std::collections::HashMap;
use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
@@ -94,54 +92,6 @@ pub struct TunnelMetadata {
pub ipv6_gateway: Option<Ipv6Addr>,
}
-#[cfg(not(target_os = "android"))]
-impl TunnelEvent {
- /// Converts an `openvpn_plugin::EventType` to a `TunnelEvent`.
- /// Returns `None` if there is no corresponding `TunnelEvent`.
- fn from_openvpn_event(
- event: openvpn_plugin::EventType,
- env: &HashMap<String, String>,
- ) -> Option<TunnelEvent> {
- match event {
- openvpn_plugin::EventType::AuthFailed => {
- let reason = env.get("auth_failed_reason").cloned();
- Some(TunnelEvent::AuthFailed(reason))
- }
- openvpn_plugin::EventType::RouteUp => {
- let interface = env
- .get("dev")
- .expect("No \"dev\" in tunnel up event")
- .to_owned();
- let mut ips = vec![env
- .get("ifconfig_local")
- .expect("No \"ifconfig_local\" in tunnel up event")
- .parse()
- .expect("Tunnel IP not in valid format")];
- if let Some(ipv6_address) = env.get("ifconfig_ipv6_local") {
- ips.push(ipv6_address.parse().expect("Tunnel IP not in valid format"));
- }
- let ipv4_gateway = env
- .get("route_vpn_gateway")
- .expect("No \"route_vpn_gateway\" in tunnel up event")
- .parse()
- .expect("Tunnel gateway IP not in valid format");
- let ipv6_gateway = env.get("route_ipv6_gateway_1").map(|v6_str| {
- v6_str
- .parse()
- .expect("V6 Tunnel gateway IP not in valid format")
- });
- Some(TunnelEvent::Up(TunnelMetadata {
- interface,
- ips,
- ipv4_gateway,
- ipv6_gateway,
- }))
- }
- openvpn_plugin::EventType::RoutePredown => Some(TunnelEvent::Down),
- _ => None,
- }
- }
-}
/// Abstraction for monitoring a generic VPN tunnel.
pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
@@ -162,7 +112,11 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: Fn(TunnelEvent) + Send + Clone + Sync + 'static,
+ L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Clone
+ + Sync
+ + 'static,
{
Self::ensure_ipv6_can_be_used_if_enabled(&tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(&tunnel_parameters, log_dir)?;
@@ -215,7 +169,11 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: Fn(TunnelEvent) + Send + Sync + Clone + 'static,
+ L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + Clone
+ + 'static,
{
let config = wireguard::config::Config::from_parameters(&params)?;
let monitor = wireguard::WireguardMonitor::start(
@@ -240,7 +198,10 @@ impl TunnelMonitor {
route_manager: &mut RouteManager,
) -> Result<Self>
where
- L: Fn(TunnelEvent) + Send + Sync + 'static,
+ L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
{
let monitor =
openvpn::OpenVpnMonitor::start(on_event, config, log, resource_dir, route_manager)?;
diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs
index b6817bc591..e3709b0183 100644
--- a/talpid-core/src/tunnel/openvpn/mod.rs
+++ b/talpid-core/src/tunnel/openvpn/mod.rs
@@ -10,12 +10,17 @@ use crate::{
proxy::{self, ProxyMonitor, ProxyResourceData},
routing,
};
-#[cfg(target_os = "linux")]
-use ipnetwork::IpNetwork;
#[cfg(windows)]
use lazy_static::lazy_static;
+#[cfg(target_os = "linux")]
+use std::collections::{HashMap, HashSet};
+#[cfg(windows)]
+use std::{
+ ffi::{OsStr, OsString},
+ os::windows::ffi::OsStrExt,
+ time::Instant,
+};
use std::{
- collections::HashMap,
fs,
io::{self, Write},
path::{Path, PathBuf},
@@ -27,14 +32,6 @@ use std::{
thread,
time::Duration,
};
-#[cfg(target_os = "linux")]
-use std::{collections::HashSet, net::IpAddr};
-#[cfg(windows)]
-use std::{
- ffi::{OsStr, OsString},
- os::windows::ffi::OsStrExt,
- time::Instant,
-};
use talpid_types::{net::openvpn, ErrorExt};
use tokio::task;
#[cfg(target_os = "linux")]
@@ -218,11 +215,6 @@ pub enum Error {
#[error(display = "Failure in Windows syscall")]
WinnetError(#[error(source)] crate::winnet::Error),
- /// Error routes from the provided map
- #[cfg(target_os = "linux")]
- #[error(display = "Failed to parse OpenVPN-provided routes")]
- ParseRouteError(#[error(source)] RouteParseError),
-
/// The map is missing 'dev'
#[cfg(target_os = "linux")]
#[error(display = "Failed to obtain tunnel interface name")]
@@ -342,70 +334,22 @@ impl OpenVpnMonitor<OpenVpnCommand> {
#[cfg(not(target_os = "linux"))] _route_manager: &mut routing::RouteManager,
) -> Result<Self>
where
- L: Fn(TunnelEvent) + Send + Sync + 'static,
+ L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
{
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
.map_err(Error::CredentialsWriteError)?;
-
let proxy_auth_file =
Self::create_proxy_auth_file(&params.proxy).map_err(Error::CredentialsWriteError)?;
-
let user_pass_file_path = user_pass_file.to_path_buf();
-
let proxy_auth_file_path = match proxy_auth_file {
Some(ref file) => Some(file.to_path_buf()),
_ => None,
};
- #[cfg(target_os = "linux")]
- let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
-
- #[cfg(target_os = "linux")]
- let ipv6_enabled = params.generic_options.enable_ipv6;
-
- let on_openvpn_event = move |event, env: HashMap<String, String>| {
- #[cfg(target_os = "linux")]
- if event == openvpn_plugin::EventType::Up {
- tokio::task::block_in_place(|| {
- let routes = extract_routes(&env)
- .unwrap()
- .into_iter()
- .filter(|route| route.prefix.is_ipv4() || ipv6_enabled)
- .collect();
- let route_manager_handle = route_manager_handle.clone();
- if let Err(error) = route_manager_handle.add_routes(routes) {
- log::error!("{}", error.display_chain());
- panic!("Failed to add routes");
- }
-
- if let Err(error) = route_manager_handle.create_routing_rules(ipv6_enabled) {
- log::error!("{}", error.display_chain());
- panic!("Failed to add routes");
- }
- });
- return;
- }
- if event == openvpn_plugin::EventType::RouteUp {
- // The user-pass file has been read. Try to delete it early.
- let _ = fs::remove_file(&user_pass_file_path);
-
- // The proxy auth file has been read. Try to delete it early.
- if let Some(ref file_path) = &proxy_auth_file_path {
- let _ = fs::remove_file(file_path);
- }
-
- #[cfg(windows)]
- tokio::task::block_in_place(|| {
- wait_for_ready_device(env.get("dev").expect("missing tunnel alias")).unwrap();
- });
- }
- match TunnelEvent::from_openvpn_event(event, &env) {
- Some(tunnel_event) => on_event(tunnel_event),
- None => log::debug!("Ignoring OpenVpnEvent {:?}", event),
- }
- };
-
let log_dir: Option<PathBuf> = if let Some(ref log_path) = log_path {
Some(log_path.parent().expect("log_path has no parent").into())
} else {
@@ -514,9 +458,27 @@ impl OpenVpnMonitor<OpenVpnCommand> {
let plugin_path = Self::get_plugin_path(resource_dir)?;
+ #[cfg(target_os = "linux")]
+ let ipv6_enabled = params.generic_options.enable_ipv6;
+ #[cfg(target_os = "linux")]
+ let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
+
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
+
Self::new_internal(
cmd,
- on_openvpn_event,
+ event_server_abort_tx.clone(),
+ event_server_abort_rx,
+ event_server::OpenvpnEventProxyImpl {
+ on_event,
+ user_pass_file_path: user_pass_file_path.clone(),
+ proxy_auth_file_path: proxy_auth_file_path.clone(),
+ abort_server_tx: event_server_abort_tx,
+ #[cfg(target_os = "linux")]
+ route_manager_handle,
+ #[cfg(target_os = "linux")]
+ ipv6_enabled,
+ },
plugin_path,
log_path,
user_pass_file,
@@ -533,34 +495,6 @@ impl OpenVpnMonitor<OpenVpnCommand> {
}
#[cfg(target_os = "linux")]
-#[derive(Debug)]
-struct OpenVpnRoute {
- network: IpNetwork,
- gateway: IpAddr,
-}
-
-#[cfg(target_os = "linux")]
-#[derive(err_derive::Error, Debug)]
-#[error(no_from)]
-#[allow(missing_docs)]
-pub enum RouteParseError {
- #[error(display = "The route contains no network")]
- MissingNetwork,
- #[error(display = "The route contains no gateway")]
- MissingGateway,
- #[error(display = "Failed to parse route network address")]
- ParseNetworkAddress(#[error(source)] std::net::AddrParseError),
- #[error(display = "Failed to parse route network")]
- ParseNetwork(#[error(source)] ipnetwork::IpNetworkError),
- #[error(display = "Failed to parse route mask address")]
- ParseMaskAddress(#[error(source)] std::net::AddrParseError),
- #[error(display = "Failed to convert route mask to prefix")]
- ParseMask(#[error(source)] ipnetwork::IpNetworkError),
- #[error(display = "Failed to parse route gateway address")]
- ParseGatewayAddress(#[error(source)] std::net::AddrParseError),
-}
-
-#[cfg(target_os = "linux")]
fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute>> {
let tun_interface = env.get("dev").ok_or(Error::MissingTunnelInterface)?;
let tun_node = routing::Node::device(tun_interface.to_string());
@@ -574,6 +508,8 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute
impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
fn new_internal<L>(
mut cmd: C,
+ event_server_abort_tx: triggered::Trigger,
+ event_server_abort_rx: triggered::Listener,
on_event: L,
plugin_path: PathBuf,
log_path: Option<PathBuf>,
@@ -583,7 +519,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
#[cfg(windows)] wintun: Box<dyn WintunContext>,
) -> Result<OpenVpnMonitor<C>>
where
- L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static,
+ L: event_server::OpenvpnEventProxy + Send + Sync + 'static,
{
let uuid = uuid::Uuid::new_v4().to_string();
let ipc_path = if cfg!(windows) {
@@ -592,12 +528,9 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
format!("/tmp/talpid-openvpn-{}", uuid)
};
- let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
-
let mut runtime = tokio::runtime::Builder::new()
.threaded_scheduler()
.core_threads(1)
- .max_threads(1)
.enable_all()
.build()
.map_err(Error::RuntimeError)?;
@@ -1015,11 +948,11 @@ mod event_server {
use futures::stream::TryStreamExt;
use parity_tokio_ipc::Endpoint as IpcEndpoint;
use std::{
- collections::HashMap,
- convert::TryFrom,
pin::Pin,
task::{Context, Poll},
};
+ #[cfg(any(target_os = "linux", windows))]
+ use talpid_types::ErrorExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tonic::{
self,
@@ -1030,9 +963,9 @@ mod event_server {
mod proto {
tonic::include_proto!("talpid_openvpn_plugin");
}
- use proto::{
+ pub use proto::{
openvpn_event_proxy_server::{OpenvpnEventProxy, OpenvpnEventProxyServer},
- EventType,
+ EventDetails,
};
#[derive(err_derive::Error, Debug)]
@@ -1047,31 +980,167 @@ mod event_server {
}
/// Implements a gRPC service used to process events sent to by OpenVPN.
- #[derive(Debug)]
- pub struct OpenvpnEventProxyImpl<L> {
- on_event: L,
+ pub struct OpenvpnEventProxyImpl<
+ L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
+ > {
+ pub on_event: L,
+ pub user_pass_file_path: super::PathBuf,
+ pub proxy_auth_file_path: Option<super::PathBuf>,
+ pub abort_server_tx: triggered::Trigger,
+ #[cfg(target_os = "linux")]
+ pub route_manager_handle: super::routing::RouteManagerHandle,
+ #[cfg(target_os = "linux")]
+ pub ipv6_enabled: bool,
}
- #[tonic::async_trait]
- impl<L> OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
- where
- L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static,
+ impl<
+ L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
+ > OpenvpnEventProxyImpl<L>
{
- async fn event(
+ async fn route_up_inner(
&self,
- request: Request<EventType>,
+ request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
- log::trace!("OpenVPN event {:?}", request);
+ let env = request.into_inner().env;
- let request = request.into_inner();
+ let _ = tokio::fs::remove_file(&self.user_pass_file_path).await;
+ if let Some(ref file_path) = &self.proxy_auth_file_path {
+ let _ = tokio::fs::remove_file(file_path).await;
+ }
- let event_type =
- openvpn_plugin::EventType::try_from(request.event).map_err(|event: i32| {
- tonic::Status::invalid_argument(format!("Unknown event type: {}", event))
+ #[cfg(target_os = "linux")]
+ {
+ let route_handle = self.route_manager_handle.clone();
+ let ipv6_enabled = self.ipv6_enabled;
+
+ let routes = super::extract_routes(&env)
+ .map_err(|err| {
+ log::error!("{}", err.display_chain_with_msg("Failed to obtain routes"));
+ tonic::Status::failed_precondition("Failed to obtain routes")
+ })?
+ .into_iter()
+ .filter(|route| route.prefix.is_ipv4() || ipv6_enabled)
+ .collect();
+
+ tokio::task::spawn_blocking(move || {
+ if let Err(error) = route_handle.add_routes(routes) {
+ log::error!("{}", error.display_chain());
+ return Err(tonic::Status::failed_precondition("Failed to add routes"));
+ }
+ if let Err(error) = route_handle.create_routing_rules(ipv6_enabled) {
+ log::error!("{}", error.display_chain());
+ return Err(tonic::Status::failed_precondition("Failed to add routes"));
+ }
+ Ok(())
+ })
+ .await
+ .map_err(|_| tonic::Status::internal("task failed to complete"))??;
+ }
+
+ let tunnel_alias = env
+ .get("dev")
+ .ok_or(tonic::Status::invalid_argument("missing tunnel alias"))?
+ .to_string();
+
+ #[cfg(windows)]
+ {
+ let tunnel_device = tunnel_alias.clone();
+ tokio::task::spawn_blocking(move || super::wait_for_ready_device(&tunnel_device))
+ .await
+ .map_err(|_| tonic::Status::internal("task failed to complete"))?
+ .map_err(|error| {
+ log::error!(
+ "{}",
+ error.display_chain_with_msg("wait_for_ready_device failed")
+ );
+ tonic::Status::unavailable("wait_for_ready_device failed")
+ })?;
+ }
+
+ let mut ips = vec![env
+ .get("ifconfig_local")
+ .ok_or(tonic::Status::invalid_argument(
+ "missing \"ifconfig_local\" in up event",
+ ))?
+ .parse()
+ .map_err(|_| tonic::Status::invalid_argument("Invalid tunnel IPv4 address"))?];
+ if let Some(ipv6_address) = env.get("ifconfig_ipv6_local") {
+ ips.push(
+ ipv6_address.parse().map_err(|_| {
+ tonic::Status::invalid_argument("Invalid tunnel IPv6 address")
+ })?,
+ );
+ }
+ let ipv4_gateway = env
+ .get("route_vpn_gateway")
+ .ok_or(tonic::Status::invalid_argument(
+ "No \"route_vpn_gateway\" in tunnel up event",
+ ))?
+ .parse()
+ .map_err(|_| {
+ tonic::Status::invalid_argument("Invalid tunnel gateway IPv4 address")
})?;
+ let ipv6_gateway = if let Some(ipv6_address) = env.get("route_ipv6_gateway_1") {
+ Some(ipv6_address.parse().map_err(|_| {
+ tonic::Status::invalid_argument("Invalid tunnel gateway IPv6 address")
+ })?)
+ } else {
+ None
+ };
+
+ (self.on_event)(super::TunnelEvent::Up(crate::tunnel::TunnelMetadata {
+ interface: tunnel_alias,
+ ips,
+ ipv4_gateway,
+ ipv6_gateway,
+ }))
+ .await;
+
+ Ok(Response::new(()))
+ }
+ }
+
+ #[tonic::async_trait]
+ impl<
+ L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
+ > OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
+ {
+ async fn auth_failed(
+ &self,
+ request: Request<EventDetails>,
+ ) -> std::result::Result<Response<()>, tonic::Status> {
+ let env = request.into_inner().env;
+ (self.on_event)(super::TunnelEvent::AuthFailed(
+ env.get("auth_failed_reason").cloned(),
+ ))
+ .await;
+ Ok(Response::new(()))
+ }
- (self.on_event)(event_type, request.env);
+ async fn route_up(
+ &self,
+ request: Request<EventDetails>,
+ ) -> std::result::Result<Response<()>, tonic::Status> {
+ self.route_up_inner(request).await.map_err(|error| {
+ self.abort_server_tx.trigger();
+ error
+ })
+ }
+ async fn route_predown(
+ &self,
+ _request: Request<EventDetails>,
+ ) -> std::result::Result<Response<()>, tonic::Status> {
+ (self.on_event)(super::TunnelEvent::Down).await;
Ok(Response::new(()))
}
}
@@ -1079,20 +1148,18 @@ mod event_server {
pub async fn start<L>(
ipc_path: String,
server_start_tx: std::sync::mpsc::Sender<()>,
- on_event: L,
+ event_proxy: L,
abort_rx: triggered::Listener,
) -> std::result::Result<(), Error>
where
- L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static,
+ L: OpenvpnEventProxy + Sync + Send + 'static,
{
let endpoint = IpcEndpoint::new(ipc_path);
let incoming = endpoint.incoming().map_err(Error::StartServer)?;
let _ = server_start_tx.send(());
- let server = OpenvpnEventProxyImpl { on_event };
-
Server::builder()
- .add_service(OpenvpnEventProxyServer::new(server))
+ .add_service(OpenvpnEventProxyServer::new(event_proxy))
.serve_with_incoming_shutdown(incoming.map_ok(StreamBox), abort_rx)
.await
.map_err(Error::TonicError)
@@ -1269,6 +1336,30 @@ mod tests {
}
}
+ struct TestOpenvpnEventProxy {}
+
+ #[async_trait::async_trait]
+ impl event_server::OpenvpnEventProxy for TestOpenvpnEventProxy {
+ async fn auth_failed(
+ &self,
+ _request: tonic::Request<event_server::EventDetails>,
+ ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
+ Ok(tonic::Response::new(()))
+ }
+ async fn route_up(
+ &self,
+ _request: tonic::Request<event_server::EventDetails>,
+ ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
+ Ok(tonic::Response::new(()))
+ }
+ async fn route_predown(
+ &self,
+ _request: tonic::Request<event_server::EventDetails>,
+ ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
+ Ok(tonic::Response::new(()))
+ }
+ }
+
#[derive(Debug, Default, Clone)]
struct TestOpenVpnBuilder {
pub plugin: Arc<Mutex<Option<PathBuf>>>,
@@ -1319,9 +1410,12 @@ mod tests {
#[test]
fn sets_plugin() {
let builder = TestOpenVpnBuilder::default();
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"./my_test_plugin".into(),
None,
TempFile::new(),
@@ -1339,9 +1433,12 @@ mod tests {
#[test]
fn sets_log() {
let builder = TestOpenVpnBuilder::default();
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let _ = OpenVpnMonitor::new_internal(
builder.clone(),
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"".into(),
Some(PathBuf::from("./my_test_log_file")),
TempFile::new(),
@@ -1360,9 +1457,12 @@ mod tests {
fn exit_successfully() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(0));
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let testee = OpenVpnMonitor::new_internal(
builder,
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
@@ -1379,9 +1479,12 @@ mod tests {
fn exit_error() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let testee = OpenVpnMonitor::new_internal(
builder,
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
@@ -1398,9 +1501,12 @@ mod tests {
fn wait_closed() {
let mut builder = TestOpenVpnBuilder::default();
builder.process_handle = Some(TestProcessHandle(1));
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let testee = OpenVpnMonitor::new_internal(
builder,
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
@@ -1417,9 +1523,12 @@ mod tests {
#[test]
fn failed_process_start() {
let builder = TestOpenVpnBuilder::default();
+ let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger();
let result = OpenVpnMonitor::new_internal(
builder,
- |_, _| {},
+ event_server_abort_tx,
+ event_server_abort_rx,
+ TestOpenvpnEventProxy {},
"".into(),
None,
TempFile::new(),
diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs
index 3ffaa9995c..4d7fa478b1 100644
--- a/talpid-core/src/tunnel/wireguard/mod.rs
+++ b/talpid-core/src/tunnel/wireguard/mod.rs
@@ -74,10 +74,16 @@ pub enum Error {
/// Spawns and monitors a wireguard tunnel
pub struct WireguardMonitor {
+ runtime: tokio::runtime::Handle,
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
- event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>,
+ event_callback: Box<
+ dyn (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + 'static,
+ >,
close_msg_sender: mpsc::Sender<CloseMsg>,
close_msg_receiver: mpsc::Receiver<CloseMsg>,
#[cfg(target_os = "windows")]
@@ -149,7 +155,13 @@ impl Drop for TcpProxy {
impl WireguardMonitor {
/// Starts a WireGuard tunnel with the given config
- pub fn start<F: Fn(TunnelEvent) + Send + Sync + Clone + 'static>(
+ pub fn start<
+ F: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>)
+ + Send
+ + Sync
+ + Clone
+ + 'static,
+ >(
runtime: tokio::runtime::Handle,
mut config: Config,
log_path: Option<&Path>,
@@ -176,7 +188,7 @@ impl WireguardMonitor {
let iface_luid = tunnel.get_interface_luid();
let metadata = Self::tunnel_metadata(&iface_name, &config);
- (on_event)(TunnelEvent::InterfaceUp(metadata.clone()));
+ runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone())));
#[cfg(target_os = "windows")]
route_manager
@@ -188,6 +200,7 @@ impl WireguardMonitor {
#[cfg(target_os = "windows")]
let (stop_setup_tx, stop_setup_rx) = futures::channel::oneshot::channel();
let monitor = WireguardMonitor {
+ runtime: runtime.clone(),
tunnel: Arc::new(Mutex::new(Some(tunnel))),
event_callback,
close_msg_sender,
@@ -209,8 +222,6 @@ impl WireguardMonitor {
.map_err(Error::ConnectivityMonitorError)?;
let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
- #[cfg(windows)]
- let runtime = route_manager.runtime_handle();
std::thread::spawn(move || {
#[cfg(windows)]
@@ -265,7 +276,7 @@ impl WireguardMonitor {
match connectivity_monitor.establish_connectivity() {
Ok(true) => {
- (on_event)(TunnelEvent::Up(metadata));
+ runtime.block_on((on_event)(TunnelEvent::Up(metadata)));
if let Err(error) = connectivity_monitor.run() {
log::error!(
@@ -368,7 +379,8 @@ impl WireguardMonitor {
self.stop_tunnel();
- (self.event_callback)(TunnelEvent::Down);
+ self.runtime
+ .block_on((self.event_callback)(TunnelEvent::Down));
wait_result
}
diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs
index 8993354916..8486ae0a3e 100644
--- a/talpid-core/src/tunnel_state_machine/connecting_state.rs
+++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs
@@ -99,9 +99,11 @@ impl ConnectingState {
retry_attempt: u32,
) -> crate::tunnel::Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded();
- let on_tunnel_event = move |event| {
- let _ = event_tx.unbounded_send(event);
- };
+ let on_tunnel_event =
+ move |event| -> Box<dyn std::future::Future<Output = ()> + Unpin + Send> {
+ let _ = event_tx.unbounded_send(event);
+ Box::new(futures::future::ready(()))
+ };
let monitor = TunnelMonitor::start(
runtime,
diff --git a/talpid-openvpn-plugin/proto/openvpn_plugin.proto b/talpid-openvpn-plugin/proto/openvpn_plugin.proto
index 07bd9d6f1a..156caa8881 100644
--- a/talpid-openvpn-plugin/proto/openvpn_plugin.proto
+++ b/talpid-openvpn-plugin/proto/openvpn_plugin.proto
@@ -5,10 +5,11 @@ package talpid_openvpn_plugin;
import "google/protobuf/empty.proto";
service OpenvpnEventProxy {
- rpc Event(EventType) returns (google.protobuf.Empty) {}
+ rpc AuthFailed(EventDetails) returns (google.protobuf.Empty) {}
+ rpc RouteUp(EventDetails) returns (google.protobuf.Empty) {}
+ rpc RoutePredown(EventDetails) returns (google.protobuf.Empty) {}
}
-message EventType {
- int32 event = 1;
- map<string, string> env = 2;
+message EventDetails {
+ map<string, string> env = 1;
}
diff --git a/talpid-openvpn-plugin/src/lib.rs b/talpid-openvpn-plugin/src/lib.rs
index e0ddc0100f..09316d05df 100644
--- a/talpid-openvpn-plugin/src/lib.rs
+++ b/talpid-openvpn-plugin/src/lib.rs
@@ -28,6 +28,9 @@ pub enum Error {
#[error(display = "Unable to parse arguments from OpenVPN")]
ParseArgsFailed(#[error(source)] std::str::Utf8Error),
+
+ #[error(display = "Unhandled event type: {:?}", _0)]
+ UnhandledEvent(openvpn_plugin::EventType),
}
@@ -35,8 +38,6 @@ pub enum Error {
/// events.
pub static INTERESTING_EVENTS: &'static [EventType] = &[
EventType::AuthFailed,
- #[cfg(target_os = "linux")]
- EventType::Up,
EventType::RouteUp,
EventType::RoutePredown,
];
@@ -45,7 +46,7 @@ openvpn_plugin!(
crate::openvpn_open,
crate::openvpn_close,
crate::openvpn_event,
- crate::Mutex<EventProcessor>
+ crate::Mutex<Option<EventProcessor>>
);
pub struct Arguments {
@@ -55,7 +56,7 @@ pub struct Arguments {
fn openvpn_open(
args: Vec<CString>,
_env: HashMap<CString, CString>,
-) -> Result<(Vec<EventType>, Mutex<EventProcessor>), Error> {
+) -> Result<(Vec<EventType>, Mutex<Option<EventProcessor>>), Error> {
env_logger::init();
log::debug!("Initializing plugin");
@@ -66,7 +67,7 @@ fn openvpn_open(
);
let processor = EventProcessor::new(arguments)?;
- Ok((INTERESTING_EVENTS.to_vec(), Mutex::new(processor)))
+ Ok((INTERESTING_EVENTS.to_vec(), Mutex::new(Some(processor))))
}
fn parse_args(args: &[CString]) -> Result<Arguments, Error> {
@@ -81,7 +82,7 @@ fn parse_args(args: &[CString]) -> Result<Arguments, Error> {
}
-fn openvpn_close(_handle: Mutex<EventProcessor>) {
+fn openvpn_close(_handle: Mutex<Option<EventProcessor>>) {
log::info!("Unloading plugin");
}
@@ -89,21 +90,26 @@ fn openvpn_event(
event: EventType,
_args: Vec<CString>,
env: HashMap<CString, CString>,
- handle: &mut Mutex<EventProcessor>,
+ handle: &mut Mutex<Option<EventProcessor>>,
) -> Result<EventResult, Error> {
log::debug!("Received event: {:?}", event);
let parsed_env = openvpn_plugin::ffi::parse::env_utf8(&env).map_err(Error::ParseEnvFailed)?;
- let result = handle
+ let mut ctx = handle
.lock()
- .expect("failed to obtain mutex for EventProcessor")
- .process_event(event, parsed_env);
- match result {
- Ok(()) => Ok(EventResult::Success),
- Err(e) => {
- log::error!("{}", e.display_chain());
- Ok(EventResult::Failure)
+ .expect("failed to obtain mutex for EventProcessor");
+ if let Some(processor) = ctx.as_mut() {
+ match processor.process_event(event, parsed_env) {
+ Ok(()) => Ok(EventResult::Success),
+ Err(e) => {
+ log::error!("{}", e.display_chain());
+ *ctx = None;
+ Ok(EventResult::Failure)
+ }
}
+ } else {
+ log::error!("Client has been closed");
+ Ok(EventResult::Failure)
}
}
diff --git a/talpid-openvpn-plugin/src/processing.rs b/talpid-openvpn-plugin/src/processing.rs
index 291ad41092..c266c2d1b5 100644
--- a/talpid-openvpn-plugin/src/processing.rs
+++ b/talpid-openvpn-plugin/src/processing.rs
@@ -62,11 +62,20 @@ impl EventProcessor {
) -> Result<(), Error> {
log::debug!("Processing \"{:?}\" event", event);
- let future = self.ipc_client.event(proto::EventType {
- event: event as i16 as i32,
- env,
- });
- let response = self.runtime.block_on(future);
+ let details = proto::EventDetails { env };
+
+ let response = match event {
+ openvpn_plugin::EventType::AuthFailed => {
+ self.runtime.block_on(self.ipc_client.auth_failed(details))
+ }
+ openvpn_plugin::EventType::RouteUp => {
+ self.runtime.block_on(self.ipc_client.route_up(details))
+ }
+ openvpn_plugin::EventType::RoutePredown => self
+ .runtime
+ .block_on(self.ipc_client.route_predown(details)),
+ other => return Err(Error::UnhandledEvent(other)),
+ };
response.map(|_| ()).map_err(Error::SendEvent)
}
}