diff options
| author | David Lönnhager <david.l@mullvad.net> | 2021-06-11 16:24:40 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2021-06-16 11:38:12 +0200 |
| commit | f3a274cbc5e424ad428e6763a58304427cc7be90 (patch) | |
| tree | 6d7a3e08def842a22d760178cef25c492b57b662 | |
| parent | f2a20eba3fccf1121ba7b8d5af78c34a8ed80687 (diff) | |
| download | mullvadvpn-f3a274cbc5e424ad428e6763a58304427cc7be90.tar.xz mullvadvpn-f3a274cbc5e424ad428e6763a58304427cc7be90.zip | |
Improve OpenVPN event handling
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 67 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn/mod.rs | 371 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 26 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 8 | ||||
| -rw-r--r-- | talpid-openvpn-plugin/proto/openvpn_plugin.proto | 9 | ||||
| -rw-r--r-- | talpid-openvpn-plugin/src/lib.rs | 36 | ||||
| -rw-r--r-- | talpid-openvpn-plugin/src/processing.rs | 19 |
7 files changed, 318 insertions, 218 deletions
diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 890fc07b67..3ee6bf1886 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,7 +1,5 @@ use self::tun_provider::TunProvider; use crate::{logging, routing::RouteManager}; -#[cfg(not(target_os = "android"))] -use std::collections::HashMap; use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -94,54 +92,6 @@ pub struct TunnelMetadata { pub ipv6_gateway: Option<Ipv6Addr>, } -#[cfg(not(target_os = "android"))] -impl TunnelEvent { - /// Converts an `openvpn_plugin::EventType` to a `TunnelEvent`. - /// Returns `None` if there is no corresponding `TunnelEvent`. - fn from_openvpn_event( - event: openvpn_plugin::EventType, - env: &HashMap<String, String>, - ) -> Option<TunnelEvent> { - match event { - openvpn_plugin::EventType::AuthFailed => { - let reason = env.get("auth_failed_reason").cloned(); - Some(TunnelEvent::AuthFailed(reason)) - } - openvpn_plugin::EventType::RouteUp => { - let interface = env - .get("dev") - .expect("No \"dev\" in tunnel up event") - .to_owned(); - let mut ips = vec![env - .get("ifconfig_local") - .expect("No \"ifconfig_local\" in tunnel up event") - .parse() - .expect("Tunnel IP not in valid format")]; - if let Some(ipv6_address) = env.get("ifconfig_ipv6_local") { - ips.push(ipv6_address.parse().expect("Tunnel IP not in valid format")); - } - let ipv4_gateway = env - .get("route_vpn_gateway") - .expect("No \"route_vpn_gateway\" in tunnel up event") - .parse() - .expect("Tunnel gateway IP not in valid format"); - let ipv6_gateway = env.get("route_ipv6_gateway_1").map(|v6_str| { - v6_str - .parse() - .expect("V6 Tunnel gateway IP not in valid format") - }); - Some(TunnelEvent::Up(TunnelMetadata { - interface, - ips, - ipv4_gateway, - ipv6_gateway, - })) - } - openvpn_plugin::EventType::RoutePredown => Some(TunnelEvent::Down), - _ => None, - } - } -} /// Abstraction for monitoring a generic VPN tunnel. pub struct TunnelMonitor { monitor: InternalTunnelMonitor, @@ -162,7 +112,11 @@ impl TunnelMonitor { route_manager: &mut RouteManager, ) -> Result<Self> where - L: Fn(TunnelEvent) + Send + Clone + Sync + 'static, + L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Clone + + Sync + + 'static, { Self::ensure_ipv6_can_be_used_if_enabled(&tunnel_parameters)?; let log_file = Self::prepare_tunnel_log_file(&tunnel_parameters, log_dir)?; @@ -215,7 +169,11 @@ impl TunnelMonitor { route_manager: &mut RouteManager, ) -> Result<Self> where - L: Fn(TunnelEvent) + Send + Sync + Clone + 'static, + L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + Clone + + 'static, { let config = wireguard::config::Config::from_parameters(¶ms)?; let monitor = wireguard::WireguardMonitor::start( @@ -240,7 +198,10 @@ impl TunnelMonitor { route_manager: &mut RouteManager, ) -> Result<Self> where - L: Fn(TunnelEvent) + Send + Sync + 'static, + L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, { let monitor = openvpn::OpenVpnMonitor::start(on_event, config, log, resource_dir, route_manager)?; diff --git a/talpid-core/src/tunnel/openvpn/mod.rs b/talpid-core/src/tunnel/openvpn/mod.rs index b6817bc591..e3709b0183 100644 --- a/talpid-core/src/tunnel/openvpn/mod.rs +++ b/talpid-core/src/tunnel/openvpn/mod.rs @@ -10,12 +10,17 @@ use crate::{ proxy::{self, ProxyMonitor, ProxyResourceData}, routing, }; -#[cfg(target_os = "linux")] -use ipnetwork::IpNetwork; #[cfg(windows)] use lazy_static::lazy_static; +#[cfg(target_os = "linux")] +use std::collections::{HashMap, HashSet}; +#[cfg(windows)] +use std::{ + ffi::{OsStr, OsString}, + os::windows::ffi::OsStrExt, + time::Instant, +}; use std::{ - collections::HashMap, fs, io::{self, Write}, path::{Path, PathBuf}, @@ -27,14 +32,6 @@ use std::{ thread, time::Duration, }; -#[cfg(target_os = "linux")] -use std::{collections::HashSet, net::IpAddr}; -#[cfg(windows)] -use std::{ - ffi::{OsStr, OsString}, - os::windows::ffi::OsStrExt, - time::Instant, -}; use talpid_types::{net::openvpn, ErrorExt}; use tokio::task; #[cfg(target_os = "linux")] @@ -218,11 +215,6 @@ pub enum Error { #[error(display = "Failure in Windows syscall")] WinnetError(#[error(source)] crate::winnet::Error), - /// Error routes from the provided map - #[cfg(target_os = "linux")] - #[error(display = "Failed to parse OpenVPN-provided routes")] - ParseRouteError(#[error(source)] RouteParseError), - /// The map is missing 'dev' #[cfg(target_os = "linux")] #[error(display = "Failed to obtain tunnel interface name")] @@ -342,70 +334,22 @@ impl OpenVpnMonitor<OpenVpnCommand> { #[cfg(not(target_os = "linux"))] _route_manager: &mut routing::RouteManager, ) -> Result<Self> where - L: Fn(TunnelEvent) + Send + Sync + 'static, + L: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, { let user_pass_file = Self::create_credentials_file(¶ms.config.username, ¶ms.config.password) .map_err(Error::CredentialsWriteError)?; - let proxy_auth_file = Self::create_proxy_auth_file(¶ms.proxy).map_err(Error::CredentialsWriteError)?; - let user_pass_file_path = user_pass_file.to_path_buf(); - let proxy_auth_file_path = match proxy_auth_file { Some(ref file) => Some(file.to_path_buf()), _ => None, }; - #[cfg(target_os = "linux")] - let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; - - #[cfg(target_os = "linux")] - let ipv6_enabled = params.generic_options.enable_ipv6; - - let on_openvpn_event = move |event, env: HashMap<String, String>| { - #[cfg(target_os = "linux")] - if event == openvpn_plugin::EventType::Up { - tokio::task::block_in_place(|| { - let routes = extract_routes(&env) - .unwrap() - .into_iter() - .filter(|route| route.prefix.is_ipv4() || ipv6_enabled) - .collect(); - let route_manager_handle = route_manager_handle.clone(); - if let Err(error) = route_manager_handle.add_routes(routes) { - log::error!("{}", error.display_chain()); - panic!("Failed to add routes"); - } - - if let Err(error) = route_manager_handle.create_routing_rules(ipv6_enabled) { - log::error!("{}", error.display_chain()); - panic!("Failed to add routes"); - } - }); - return; - } - if event == openvpn_plugin::EventType::RouteUp { - // The user-pass file has been read. Try to delete it early. - let _ = fs::remove_file(&user_pass_file_path); - - // The proxy auth file has been read. Try to delete it early. - if let Some(ref file_path) = &proxy_auth_file_path { - let _ = fs::remove_file(file_path); - } - - #[cfg(windows)] - tokio::task::block_in_place(|| { - wait_for_ready_device(env.get("dev").expect("missing tunnel alias")).unwrap(); - }); - } - match TunnelEvent::from_openvpn_event(event, &env) { - Some(tunnel_event) => on_event(tunnel_event), - None => log::debug!("Ignoring OpenVpnEvent {:?}", event), - } - }; - let log_dir: Option<PathBuf> = if let Some(ref log_path) = log_path { Some(log_path.parent().expect("log_path has no parent").into()) } else { @@ -514,9 +458,27 @@ impl OpenVpnMonitor<OpenVpnCommand> { let plugin_path = Self::get_plugin_path(resource_dir)?; + #[cfg(target_os = "linux")] + let ipv6_enabled = params.generic_options.enable_ipv6; + #[cfg(target_os = "linux")] + let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; + + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); + Self::new_internal( cmd, - on_openvpn_event, + event_server_abort_tx.clone(), + event_server_abort_rx, + event_server::OpenvpnEventProxyImpl { + on_event, + user_pass_file_path: user_pass_file_path.clone(), + proxy_auth_file_path: proxy_auth_file_path.clone(), + abort_server_tx: event_server_abort_tx, + #[cfg(target_os = "linux")] + route_manager_handle, + #[cfg(target_os = "linux")] + ipv6_enabled, + }, plugin_path, log_path, user_pass_file, @@ -533,34 +495,6 @@ impl OpenVpnMonitor<OpenVpnCommand> { } #[cfg(target_os = "linux")] -#[derive(Debug)] -struct OpenVpnRoute { - network: IpNetwork, - gateway: IpAddr, -} - -#[cfg(target_os = "linux")] -#[derive(err_derive::Error, Debug)] -#[error(no_from)] -#[allow(missing_docs)] -pub enum RouteParseError { - #[error(display = "The route contains no network")] - MissingNetwork, - #[error(display = "The route contains no gateway")] - MissingGateway, - #[error(display = "Failed to parse route network address")] - ParseNetworkAddress(#[error(source)] std::net::AddrParseError), - #[error(display = "Failed to parse route network")] - ParseNetwork(#[error(source)] ipnetwork::IpNetworkError), - #[error(display = "Failed to parse route mask address")] - ParseMaskAddress(#[error(source)] std::net::AddrParseError), - #[error(display = "Failed to convert route mask to prefix")] - ParseMask(#[error(source)] ipnetwork::IpNetworkError), - #[error(display = "Failed to parse route gateway address")] - ParseGatewayAddress(#[error(source)] std::net::AddrParseError), -} - -#[cfg(target_os = "linux")] fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute>> { let tun_interface = env.get("dev").ok_or(Error::MissingTunnelInterface)?; let tun_node = routing::Node::device(tun_interface.to_string()); @@ -574,6 +508,8 @@ fn extract_routes(env: &HashMap<String, String>) -> Result<HashSet<RequiredRoute impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { fn new_internal<L>( mut cmd: C, + event_server_abort_tx: triggered::Trigger, + event_server_abort_rx: triggered::Listener, on_event: L, plugin_path: PathBuf, log_path: Option<PathBuf>, @@ -583,7 +519,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { #[cfg(windows)] wintun: Box<dyn WintunContext>, ) -> Result<OpenVpnMonitor<C>> where - L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static, + L: event_server::OpenvpnEventProxy + Send + Sync + 'static, { let uuid = uuid::Uuid::new_v4().to_string(); let ipc_path = if cfg!(windows) { @@ -592,12 +528,9 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> { format!("/tmp/talpid-openvpn-{}", uuid) }; - let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); - let mut runtime = tokio::runtime::Builder::new() .threaded_scheduler() .core_threads(1) - .max_threads(1) .enable_all() .build() .map_err(Error::RuntimeError)?; @@ -1015,11 +948,11 @@ mod event_server { use futures::stream::TryStreamExt; use parity_tokio_ipc::Endpoint as IpcEndpoint; use std::{ - collections::HashMap, - convert::TryFrom, pin::Pin, task::{Context, Poll}, }; + #[cfg(any(target_os = "linux", windows))] + use talpid_types::ErrorExt; use tokio::io::{AsyncRead, AsyncWrite}; use tonic::{ self, @@ -1030,9 +963,9 @@ mod event_server { mod proto { tonic::include_proto!("talpid_openvpn_plugin"); } - use proto::{ + pub use proto::{ openvpn_event_proxy_server::{OpenvpnEventProxy, OpenvpnEventProxyServer}, - EventType, + EventDetails, }; #[derive(err_derive::Error, Debug)] @@ -1047,31 +980,167 @@ mod event_server { } /// Implements a gRPC service used to process events sent to by OpenVPN. - #[derive(Debug)] - pub struct OpenvpnEventProxyImpl<L> { - on_event: L, + pub struct OpenvpnEventProxyImpl< + L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, + > { + pub on_event: L, + pub user_pass_file_path: super::PathBuf, + pub proxy_auth_file_path: Option<super::PathBuf>, + pub abort_server_tx: triggered::Trigger, + #[cfg(target_os = "linux")] + pub route_manager_handle: super::routing::RouteManagerHandle, + #[cfg(target_os = "linux")] + pub ipv6_enabled: bool, } - #[tonic::async_trait] - impl<L> OpenvpnEventProxy for OpenvpnEventProxyImpl<L> - where - L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static, + impl< + L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, + > OpenvpnEventProxyImpl<L> { - async fn event( + async fn route_up_inner( &self, - request: Request<EventType>, + request: Request<EventDetails>, ) -> std::result::Result<Response<()>, tonic::Status> { - log::trace!("OpenVPN event {:?}", request); + let env = request.into_inner().env; - let request = request.into_inner(); + let _ = tokio::fs::remove_file(&self.user_pass_file_path).await; + if let Some(ref file_path) = &self.proxy_auth_file_path { + let _ = tokio::fs::remove_file(file_path).await; + } - let event_type = - openvpn_plugin::EventType::try_from(request.event).map_err(|event: i32| { - tonic::Status::invalid_argument(format!("Unknown event type: {}", event)) + #[cfg(target_os = "linux")] + { + let route_handle = self.route_manager_handle.clone(); + let ipv6_enabled = self.ipv6_enabled; + + let routes = super::extract_routes(&env) + .map_err(|err| { + log::error!("{}", err.display_chain_with_msg("Failed to obtain routes")); + tonic::Status::failed_precondition("Failed to obtain routes") + })? + .into_iter() + .filter(|route| route.prefix.is_ipv4() || ipv6_enabled) + .collect(); + + tokio::task::spawn_blocking(move || { + if let Err(error) = route_handle.add_routes(routes) { + log::error!("{}", error.display_chain()); + return Err(tonic::Status::failed_precondition("Failed to add routes")); + } + if let Err(error) = route_handle.create_routing_rules(ipv6_enabled) { + log::error!("{}", error.display_chain()); + return Err(tonic::Status::failed_precondition("Failed to add routes")); + } + Ok(()) + }) + .await + .map_err(|_| tonic::Status::internal("task failed to complete"))??; + } + + let tunnel_alias = env + .get("dev") + .ok_or(tonic::Status::invalid_argument("missing tunnel alias"))? + .to_string(); + + #[cfg(windows)] + { + let tunnel_device = tunnel_alias.clone(); + tokio::task::spawn_blocking(move || super::wait_for_ready_device(&tunnel_device)) + .await + .map_err(|_| tonic::Status::internal("task failed to complete"))? + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("wait_for_ready_device failed") + ); + tonic::Status::unavailable("wait_for_ready_device failed") + })?; + } + + let mut ips = vec![env + .get("ifconfig_local") + .ok_or(tonic::Status::invalid_argument( + "missing \"ifconfig_local\" in up event", + ))? + .parse() + .map_err(|_| tonic::Status::invalid_argument("Invalid tunnel IPv4 address"))?]; + if let Some(ipv6_address) = env.get("ifconfig_ipv6_local") { + ips.push( + ipv6_address.parse().map_err(|_| { + tonic::Status::invalid_argument("Invalid tunnel IPv6 address") + })?, + ); + } + let ipv4_gateway = env + .get("route_vpn_gateway") + .ok_or(tonic::Status::invalid_argument( + "No \"route_vpn_gateway\" in tunnel up event", + ))? + .parse() + .map_err(|_| { + tonic::Status::invalid_argument("Invalid tunnel gateway IPv4 address") })?; + let ipv6_gateway = if let Some(ipv6_address) = env.get("route_ipv6_gateway_1") { + Some(ipv6_address.parse().map_err(|_| { + tonic::Status::invalid_argument("Invalid tunnel gateway IPv6 address") + })?) + } else { + None + }; + + (self.on_event)(super::TunnelEvent::Up(crate::tunnel::TunnelMetadata { + interface: tunnel_alias, + ips, + ipv4_gateway, + ipv6_gateway, + })) + .await; + + Ok(Response::new(())) + } + } + + #[tonic::async_trait] + impl< + L: (Fn(super::TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, + > OpenvpnEventProxy for OpenvpnEventProxyImpl<L> + { + async fn auth_failed( + &self, + request: Request<EventDetails>, + ) -> std::result::Result<Response<()>, tonic::Status> { + let env = request.into_inner().env; + (self.on_event)(super::TunnelEvent::AuthFailed( + env.get("auth_failed_reason").cloned(), + )) + .await; + Ok(Response::new(())) + } - (self.on_event)(event_type, request.env); + async fn route_up( + &self, + request: Request<EventDetails>, + ) -> std::result::Result<Response<()>, tonic::Status> { + self.route_up_inner(request).await.map_err(|error| { + self.abort_server_tx.trigger(); + error + }) + } + async fn route_predown( + &self, + _request: Request<EventDetails>, + ) -> std::result::Result<Response<()>, tonic::Status> { + (self.on_event)(super::TunnelEvent::Down).await; Ok(Response::new(())) } } @@ -1079,20 +1148,18 @@ mod event_server { pub async fn start<L>( ipc_path: String, server_start_tx: std::sync::mpsc::Sender<()>, - on_event: L, + event_proxy: L, abort_rx: triggered::Listener, ) -> std::result::Result<(), Error> where - L: Fn(openvpn_plugin::EventType, HashMap<String, String>) + Send + Sync + 'static, + L: OpenvpnEventProxy + Sync + Send + 'static, { let endpoint = IpcEndpoint::new(ipc_path); let incoming = endpoint.incoming().map_err(Error::StartServer)?; let _ = server_start_tx.send(()); - let server = OpenvpnEventProxyImpl { on_event }; - Server::builder() - .add_service(OpenvpnEventProxyServer::new(server)) + .add_service(OpenvpnEventProxyServer::new(event_proxy)) .serve_with_incoming_shutdown(incoming.map_ok(StreamBox), abort_rx) .await .map_err(Error::TonicError) @@ -1269,6 +1336,30 @@ mod tests { } } + struct TestOpenvpnEventProxy {} + + #[async_trait::async_trait] + impl event_server::OpenvpnEventProxy for TestOpenvpnEventProxy { + async fn auth_failed( + &self, + _request: tonic::Request<event_server::EventDetails>, + ) -> std::result::Result<tonic::Response<()>, tonic::Status> { + Ok(tonic::Response::new(())) + } + async fn route_up( + &self, + _request: tonic::Request<event_server::EventDetails>, + ) -> std::result::Result<tonic::Response<()>, tonic::Status> { + Ok(tonic::Response::new(())) + } + async fn route_predown( + &self, + _request: tonic::Request<event_server::EventDetails>, + ) -> std::result::Result<tonic::Response<()>, tonic::Status> { + Ok(tonic::Response::new(())) + } + } + #[derive(Debug, Default, Clone)] struct TestOpenVpnBuilder { pub plugin: Arc<Mutex<Option<PathBuf>>>, @@ -1319,9 +1410,12 @@ mod tests { #[test] fn sets_plugin() { let builder = TestOpenVpnBuilder::default(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let _ = OpenVpnMonitor::new_internal( builder.clone(), - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "./my_test_plugin".into(), None, TempFile::new(), @@ -1339,9 +1433,12 @@ mod tests { #[test] fn sets_log() { let builder = TestOpenVpnBuilder::default(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let _ = OpenVpnMonitor::new_internal( builder.clone(), - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "".into(), Some(PathBuf::from("./my_test_log_file")), TempFile::new(), @@ -1360,9 +1457,12 @@ mod tests { fn exit_successfully() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(0)); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let testee = OpenVpnMonitor::new_internal( builder, - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "".into(), None, TempFile::new(), @@ -1379,9 +1479,12 @@ mod tests { fn exit_error() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let testee = OpenVpnMonitor::new_internal( builder, - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "".into(), None, TempFile::new(), @@ -1398,9 +1501,12 @@ mod tests { fn wait_closed() { let mut builder = TestOpenVpnBuilder::default(); builder.process_handle = Some(TestProcessHandle(1)); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let testee = OpenVpnMonitor::new_internal( builder, - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "".into(), None, TempFile::new(), @@ -1417,9 +1523,12 @@ mod tests { #[test] fn failed_process_start() { let builder = TestOpenVpnBuilder::default(); + let (event_server_abort_tx, event_server_abort_rx) = triggered::trigger(); let result = OpenVpnMonitor::new_internal( builder, - |_, _| {}, + event_server_abort_tx, + event_server_abort_rx, + TestOpenvpnEventProxy {}, "".into(), None, TempFile::new(), diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index 3ffaa9995c..4d7fa478b1 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -74,10 +74,16 @@ pub enum Error { /// Spawns and monitors a wireguard tunnel pub struct WireguardMonitor { + runtime: tokio::runtime::Handle, /// Tunnel implementation tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, /// Callback to signal tunnel events - event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>, + event_callback: Box< + dyn (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + 'static, + >, close_msg_sender: mpsc::Sender<CloseMsg>, close_msg_receiver: mpsc::Receiver<CloseMsg>, #[cfg(target_os = "windows")] @@ -149,7 +155,13 @@ impl Drop for TcpProxy { impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config - pub fn start<F: Fn(TunnelEvent) + Send + Sync + Clone + 'static>( + pub fn start< + F: (Fn(TunnelEvent) -> Box<dyn std::future::Future<Output = ()> + Unpin + Send>) + + Send + + Sync + + Clone + + 'static, + >( runtime: tokio::runtime::Handle, mut config: Config, log_path: Option<&Path>, @@ -176,7 +188,7 @@ impl WireguardMonitor { let iface_luid = tunnel.get_interface_luid(); let metadata = Self::tunnel_metadata(&iface_name, &config); - (on_event)(TunnelEvent::InterfaceUp(metadata.clone())); + runtime.block_on((on_event)(TunnelEvent::InterfaceUp(metadata.clone()))); #[cfg(target_os = "windows")] route_manager @@ -188,6 +200,7 @@ impl WireguardMonitor { #[cfg(target_os = "windows")] let (stop_setup_tx, stop_setup_rx) = futures::channel::oneshot::channel(); let monitor = WireguardMonitor { + runtime: runtime.clone(), tunnel: Arc::new(Mutex::new(Some(tunnel))), event_callback, close_msg_sender, @@ -209,8 +222,6 @@ impl WireguardMonitor { .map_err(Error::ConnectivityMonitorError)?; let route_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; - #[cfg(windows)] - let runtime = route_manager.runtime_handle(); std::thread::spawn(move || { #[cfg(windows)] @@ -265,7 +276,7 @@ impl WireguardMonitor { match connectivity_monitor.establish_connectivity() { Ok(true) => { - (on_event)(TunnelEvent::Up(metadata)); + runtime.block_on((on_event)(TunnelEvent::Up(metadata))); if let Err(error) = connectivity_monitor.run() { log::error!( @@ -368,7 +379,8 @@ impl WireguardMonitor { self.stop_tunnel(); - (self.event_callback)(TunnelEvent::Down); + self.runtime + .block_on((self.event_callback)(TunnelEvent::Down)); wait_result } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 8993354916..8486ae0a3e 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -99,9 +99,11 @@ impl ConnectingState { retry_attempt: u32, ) -> crate::tunnel::Result<Self> { let (event_tx, event_rx) = mpsc::unbounded(); - let on_tunnel_event = move |event| { - let _ = event_tx.unbounded_send(event); - }; + let on_tunnel_event = + move |event| -> Box<dyn std::future::Future<Output = ()> + Unpin + Send> { + let _ = event_tx.unbounded_send(event); + Box::new(futures::future::ready(())) + }; let monitor = TunnelMonitor::start( runtime, diff --git a/talpid-openvpn-plugin/proto/openvpn_plugin.proto b/talpid-openvpn-plugin/proto/openvpn_plugin.proto index 07bd9d6f1a..156caa8881 100644 --- a/talpid-openvpn-plugin/proto/openvpn_plugin.proto +++ b/talpid-openvpn-plugin/proto/openvpn_plugin.proto @@ -5,10 +5,11 @@ package talpid_openvpn_plugin; import "google/protobuf/empty.proto"; service OpenvpnEventProxy { - rpc Event(EventType) returns (google.protobuf.Empty) {} + rpc AuthFailed(EventDetails) returns (google.protobuf.Empty) {} + rpc RouteUp(EventDetails) returns (google.protobuf.Empty) {} + rpc RoutePredown(EventDetails) returns (google.protobuf.Empty) {} } -message EventType { - int32 event = 1; - map<string, string> env = 2; +message EventDetails { + map<string, string> env = 1; } diff --git a/talpid-openvpn-plugin/src/lib.rs b/talpid-openvpn-plugin/src/lib.rs index e0ddc0100f..09316d05df 100644 --- a/talpid-openvpn-plugin/src/lib.rs +++ b/talpid-openvpn-plugin/src/lib.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(display = "Unable to parse arguments from OpenVPN")] ParseArgsFailed(#[error(source)] std::str::Utf8Error), + + #[error(display = "Unhandled event type: {:?}", _0)] + UnhandledEvent(openvpn_plugin::EventType), } @@ -35,8 +38,6 @@ pub enum Error { /// events. pub static INTERESTING_EVENTS: &'static [EventType] = &[ EventType::AuthFailed, - #[cfg(target_os = "linux")] - EventType::Up, EventType::RouteUp, EventType::RoutePredown, ]; @@ -45,7 +46,7 @@ openvpn_plugin!( crate::openvpn_open, crate::openvpn_close, crate::openvpn_event, - crate::Mutex<EventProcessor> + crate::Mutex<Option<EventProcessor>> ); pub struct Arguments { @@ -55,7 +56,7 @@ pub struct Arguments { fn openvpn_open( args: Vec<CString>, _env: HashMap<CString, CString>, -) -> Result<(Vec<EventType>, Mutex<EventProcessor>), Error> { +) -> Result<(Vec<EventType>, Mutex<Option<EventProcessor>>), Error> { env_logger::init(); log::debug!("Initializing plugin"); @@ -66,7 +67,7 @@ fn openvpn_open( ); let processor = EventProcessor::new(arguments)?; - Ok((INTERESTING_EVENTS.to_vec(), Mutex::new(processor))) + Ok((INTERESTING_EVENTS.to_vec(), Mutex::new(Some(processor)))) } fn parse_args(args: &[CString]) -> Result<Arguments, Error> { @@ -81,7 +82,7 @@ fn parse_args(args: &[CString]) -> Result<Arguments, Error> { } -fn openvpn_close(_handle: Mutex<EventProcessor>) { +fn openvpn_close(_handle: Mutex<Option<EventProcessor>>) { log::info!("Unloading plugin"); } @@ -89,21 +90,26 @@ fn openvpn_event( event: EventType, _args: Vec<CString>, env: HashMap<CString, CString>, - handle: &mut Mutex<EventProcessor>, + handle: &mut Mutex<Option<EventProcessor>>, ) -> Result<EventResult, Error> { log::debug!("Received event: {:?}", event); let parsed_env = openvpn_plugin::ffi::parse::env_utf8(&env).map_err(Error::ParseEnvFailed)?; - let result = handle + let mut ctx = handle .lock() - .expect("failed to obtain mutex for EventProcessor") - .process_event(event, parsed_env); - match result { - Ok(()) => Ok(EventResult::Success), - Err(e) => { - log::error!("{}", e.display_chain()); - Ok(EventResult::Failure) + .expect("failed to obtain mutex for EventProcessor"); + if let Some(processor) = ctx.as_mut() { + match processor.process_event(event, parsed_env) { + Ok(()) => Ok(EventResult::Success), + Err(e) => { + log::error!("{}", e.display_chain()); + *ctx = None; + Ok(EventResult::Failure) + } } + } else { + log::error!("Client has been closed"); + Ok(EventResult::Failure) } } diff --git a/talpid-openvpn-plugin/src/processing.rs b/talpid-openvpn-plugin/src/processing.rs index 291ad41092..c266c2d1b5 100644 --- a/talpid-openvpn-plugin/src/processing.rs +++ b/talpid-openvpn-plugin/src/processing.rs @@ -62,11 +62,20 @@ impl EventProcessor { ) -> Result<(), Error> { log::debug!("Processing \"{:?}\" event", event); - let future = self.ipc_client.event(proto::EventType { - event: event as i16 as i32, - env, - }); - let response = self.runtime.block_on(future); + let details = proto::EventDetails { env }; + + let response = match event { + openvpn_plugin::EventType::AuthFailed => { + self.runtime.block_on(self.ipc_client.auth_failed(details)) + } + openvpn_plugin::EventType::RouteUp => { + self.runtime.block_on(self.ipc_client.route_up(details)) + } + openvpn_plugin::EventType::RoutePredown => self + .runtime + .block_on(self.ipc_client.route_predown(details)), + other => return Err(Error::UnhandledEvent(other)), + }; response.map(|_| ()).map_err(Error::SendEvent) } } |
