diff options
| -rw-r--r-- | talpid-core/src/routing/mod.rs | 3 | ||||
| -rw-r--r-- | talpid-core/src/routing/unix.rs | 117 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 135 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/openvpn.rs | 17 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 4 |
5 files changed, 192 insertions, 84 deletions
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs index a4f7dd1f76..d75073df8a 100644 --- a/talpid-core/src/routing/mod.rs +++ b/talpid-core/src/routing/mod.rs @@ -17,8 +17,7 @@ use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; pub use imp::{Error, RouteManager}; -#[cfg(target_os = "linux")] -pub use imp::RouteManagerCommand; +pub use imp::RouteManagerHandle; /// A netowrk route with a specific network node, destinaiton and an optional metric. #[derive(Debug, Hash, Eq, PartialEq, Clone)] diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index c766fbfbe8..147a33a1a5 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -8,7 +8,6 @@ use futures::channel::{ oneshot, }; use std::{collections::HashSet, io}; -use talpid_types::ErrorExt; #[cfg(target_os = "linux")] use std::net::IpAddr; @@ -30,9 +29,9 @@ pub use imp::Error as PlatformError; /// Errors that can be encountered whilst initializing RouteManager #[derive(err_derive::Error, Debug)] pub enum Error { - /// Routing manager thread panicked before starting routing manager - #[error(display = "Routing manager thread panicked before starting routing manager")] - RoutingManagerThreadPanic, + /// Route manager thread may have panicked + #[error(display = "The channel sender was dropped")] + ManagerChannelDown, /// Platform specific error occured #[error(display = "Internal route manager error")] PlatformError(#[error(source)] imp::Error), @@ -47,6 +46,43 @@ pub enum Error { RouteManagerDown, } +/// Handle to a route manager. +#[derive(Clone)] +pub struct RouteManagerHandle { + runtime: tokio::runtime::Handle, + tx: UnboundedSender<RouteManagerCommand>, +} + +impl RouteManagerHandle { + /// Applies the given routes while the route manager is running. + pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + self.runtime + .block_on(response_rx) + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) + } + + /// Set the link to be ignored by the exclusions routing table. + #[cfg(target_os = "linux")] + pub fn set_tunnel_link(&self, interface: &str) -> Result<(), Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::SetTunnelLink( + interface.to_string(), + response_tx, + )) + .map_err(|_| Error::RouteManagerDown)?; + Ok(self + .runtime + .block_on(response_rx) + .map_err(|_| Error::ManagerChannelDown)?) + } +} + /// Commands for the underlying route manager object. #[derive(Debug)] pub enum RouteManagerCommand { @@ -82,23 +118,20 @@ pub enum RouteManagerCommand { /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { manage_tx: Option<UnboundedSender<RouteManagerCommand>>, - runtime: tokio::runtime::Runtime, + runtime: tokio::runtime::Handle, } impl RouteManager { /// Constructs a RouteManager and applies the required routes. /// Takes a set of network destinations and network nodes as an argument, and applies said /// routes. - pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> { + pub fn new( + runtime: tokio::runtime::Handle, + required_routes: HashSet<RequiredRoute>, + ) -> Result<Self, Error> { let (manage_tx, manage_rx) = mpsc::unbounded(); - let mut runtime = tokio::runtime::Builder::new() - .threaded_scheduler() - .core_threads(1) - .max_threads(1) - .enable_all() - .build()?; let manager = runtime.block_on(imp::RouteManagerImpl::new(required_routes))?; - runtime.handle().spawn(manager.run(manage_rx)); + runtime.spawn(manager.run(manage_rx)); Ok(Self { runtime, @@ -120,7 +153,7 @@ impl RouteManager { } if self.runtime.block_on(wait_rx).is_err() { - log::error!("RouteManager paniced while shutting down"); + log::error!("{}", Error::ManagerChannelDown); } } } @@ -136,16 +169,10 @@ impl RouteManager { return Err(Error::RouteManagerDown); } - match self.runtime.block_on(result_rx) { - Ok(result) => result.map_err(Error::PlatformError), - Err(error) => { - log::trace!( - "{}", - error.display_chain_with_msg("oneshot channel is closed") - ); - Ok(()) - } - } + self.runtime + .block_on(result_rx) + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) } else { Err(Error::RouteManagerDown) } @@ -176,13 +203,10 @@ impl RouteManager { return Err(Error::RouteManagerDown); } - match self.runtime.block_on(result_rx) { - Ok(result) => result.map_err(Error::PlatformError), - Err(error) => { - log::trace!("{}", error.display_chain_with_msg("channel is closed")); - Ok(()) - } - } + self.runtime + .block_on(result_rx) + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) } else { Err(Error::RouteManagerDown) } @@ -218,23 +242,21 @@ impl RouteManager { { return Err(Error::RouteManagerDown); } - match self.runtime.block_on(result_rx) { - Ok(()) => Ok(()), - Err(error) => { - log::trace!("{}", error.display_chain_with_msg("channel is closed")); - Ok(()) - } - } + self.runtime + .block_on(result_rx) + .map_err(|_| Error::ManagerChannelDown) } else { Err(Error::RouteManagerDown) } } /// Retrieve a sender directly to the command channel. - #[cfg(target_os = "linux")] - pub fn channel(&self) -> Result<UnboundedSender<RouteManagerCommand>, Error> { + pub fn handle(&self) -> Result<RouteManagerHandle, Error> { if let Some(tx) = &self.manage_tx { - Ok(tx.clone()) + Ok(RouteManagerHandle { + runtime: self.runtime.clone(), + tx: tx.clone(), + }) } else { Err(Error::RouteManagerDown) } @@ -243,7 +265,7 @@ impl RouteManager { /// Exposes runtime handle #[cfg(target_os = "linux")] pub fn runtime_handle(&self) -> tokio::runtime::Handle { - self.runtime.handle().clone() + self.runtime.clone() } /// Route DNS requests through the tunnel interface. @@ -266,13 +288,10 @@ impl RouteManager { return Err(Error::RouteManagerDown); } - match self.runtime.block_on(result_rx) { - Ok(result) => result.map_err(Error::PlatformError), - Err(error) => { - log::trace!("{}", error.display_chain_with_msg("channel is closed")); - Ok(()) - } - } + self.runtime + .block_on(result_rx) + .map_err(|_| Error::ManagerChannelDown)? + .map_err(Error::PlatformError) } else { Err(Error::RouteManagerDown) } diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index 412384574b..47a45672dc 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -1,10 +1,20 @@ use super::NetNode; use crate::{routing::RequiredRoute, winnet}; +use futures::{ + channel::{ + mpsc::{self, UnboundedReceiver, UnboundedSender}, + oneshot, + }, + StreamExt, +}; use std::collections::HashSet; /// Windows routing errors. #[derive(err_derive::Error, Debug)] pub enum Error { + /// The sender was dropped unexpectedly -- possible panic + #[error(display = "The channel sender was dropped")] + ManagerChannelDown, /// Failure to initialize route manager #[error(display = "Failed to start route manager")] FailedToStartManager, @@ -14,6 +24,9 @@ pub enum Error { /// Failure to clear routes #[error(display = "Failed to clear applied routes")] ClearRoutesFailed, + /// Attempt to use route manager that has been dropped + #[error(display = "Cannot send message to route manager since it is down")] + RouteManagerDown, } pub type Result<T> = std::result::Result<T, Error>; @@ -21,24 +34,103 @@ pub type Result<T> = std::result::Result<T, Error>; /// Manages routes by calling into WinNet pub struct RouteManager { callback_handles: Vec<winnet::WinNetCallbackHandle>, - is_stopped: bool, + runtime: tokio::runtime::Handle, + manage_tx: Option<UnboundedSender<RouteManagerCommand>>, +} + +/// Handle to a route manager. +#[derive(Clone)] +pub struct RouteManagerHandle { + runtime: tokio::runtime::Handle, + tx: UnboundedSender<RouteManagerCommand>, +} + +impl RouteManagerHandle { + /// Applies the given routes while the route manager is running. + pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + self.runtime + .block_on(response_rx) + .map_err(|_| Error::ManagerChannelDown)? + } +} + +#[derive(Debug)] +pub enum RouteManagerCommand { + AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), + Shutdown, } impl RouteManager { /// Creates a new route manager that will apply the provided routes and ensure they exist until /// it's stopped. - pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { + pub fn new( + runtime: tokio::runtime::Handle, + required_routes: HashSet<RequiredRoute>, + ) -> Result<Self> { if !winnet::activate_routing_manager() { return Err(Error::FailedToStartManager); } + let (manage_tx, manage_rx) = mpsc::unbounded(); let manager = Self { callback_handles: vec![], - is_stopped: false, + runtime: runtime.clone(), + manage_tx: Some(manage_tx), }; + runtime.spawn(RouteManager::listen(manage_rx)); manager.add_routes(required_routes)?; + Ok(manager) } + /// Retrieve a sender directly to the command channel. + pub fn handle(&self) -> Result<RouteManagerHandle> { + if let Some(tx) = &self.manage_tx { + Ok(RouteManagerHandle { + runtime: self.runtime.clone(), + tx: tx.clone(), + }) + } else { + Err(Error::RouteManagerDown) + } + } + + async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) { + while let Some(command) = manage_rx.next().await { + match command { + RouteManagerCommand::AddRoutes(routes, tx) => { + let routes: Vec<_> = routes + .iter() + .map(|route| { + let destination = winnet::WinNetIpNetwork::from(route.prefix); + match &route.node { + NetNode::DefaultNode => { + winnet::WinNetRoute::through_default_node(destination) + } + NetNode::RealNode(node) => winnet::WinNetRoute::new( + winnet::WinNetNode::from(node), + destination, + ), + } + }) + .collect(); + + if winnet::routing_manager_add_routes(&routes) { + let _ = tx.send(Ok(())); + } else { + let _ = tx.send(Err(Error::AddRoutesFailed)); + } + } + RouteManagerCommand::Shutdown => { + break; + } + } + } + } + /// Sets a callback that is called whenever the default route changes. #[cfg(target_os = "windows")] pub fn add_default_route_callback<T: 'static>( @@ -46,7 +138,7 @@ impl RouteManager { callback: Option<winnet::DefaultRouteChangedCallback>, context: T, ) { - if self.is_stopped { + if self.manage_tx.is_none() { return; } @@ -71,32 +163,31 @@ impl RouteManager { /// Stops the routing manager and invalidates the route manager - no new default route callbacks /// can be added pub fn stop(&mut self) { - if !self.is_stopped { + if let Some(tx) = self.manage_tx.take() { + if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() { + log::error!("RouteManager channel already down or thread panicked"); + } + self.callback_handles.clear(); winnet::deactivate_routing_manager(); - self.is_stopped = true; } } /// Applies the given routes until [`RouteManager::stop`] is called. pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - let routes: Vec<_> = routes - .iter() - .map(|route| { - let destination = winnet::WinNetIpNetwork::from(route.prefix); - match &route.node { - NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), - NetNode::RealNode(node) => { - winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) - } - } - }) - .collect(); - - if winnet::routing_manager_add_routes(&routes) { - Ok(()) + if let Some(tx) = &self.manage_tx { + let (result_tx, result_rx) = oneshot::channel(); + if tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .is_err() + { + return Err(Error::RouteManagerDown); + } + self.runtime + .block_on(result_rx) + .map_err(|_| Error::ManagerChannelDown)? } else { - Err(Error::AddRoutesFailed) + Err(Error::RouteManagerDown) } } diff --git a/talpid-core/src/tunnel/openvpn.rs b/talpid-core/src/tunnel/openvpn.rs index 8730eadeeb..f83cabca3b 100644 --- a/talpid-core/src/tunnel/openvpn.rs +++ b/talpid-core/src/tunnel/openvpn.rs @@ -8,8 +8,6 @@ use crate::{ proxy::{self, ProxyMonitor, ProxyResourceData}, routing, }; -#[cfg(target_os = "linux")] -use futures::channel::oneshot; use std::{ collections::HashMap, fs, @@ -172,18 +170,17 @@ impl OpenVpnMonitor<OpenVpnCommand> { }; #[cfg(target_os = "linux")] - let route_manager_tx = route_manager.channel().map_err(Error::SetupRoutingError)?; + let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?; let on_openvpn_event = move |event, env: HashMap<String, String>| { #[cfg(target_os = "linux")] if event == openvpn_plugin::EventType::Up { - let (tx, rx) = oneshot::channel(); - let interface = env.get("dev").unwrap().to_owned(); - route_manager_tx - .unbounded_send(routing::RouteManagerCommand::SetTunnelLink(interface, tx)) - .unwrap(); - tokio::task::block_in_place(move || { - futures::executor::block_on(rx).unwrap(); + let interface = env.get("dev").unwrap(); + tokio::task::block_in_place(|| { + route_manager_handle + .clone() + .set_tunnel_link(interface) + .unwrap(); }); return; } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index b98eb820d1..9b33e8ae48 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -109,6 +109,7 @@ pub async fn spawn( let (startup_result_tx, startup_result_rx) = sync_mpsc::channel(); std::thread::spawn(move || { let state_machine = TunnelStateMachine::new( + runtime.clone(), allow_lan, block_when_disconnected, is_offline, @@ -189,6 +190,7 @@ struct TunnelStateMachine { impl TunnelStateMachine { fn new( + runtime: tokio::runtime::Handle, allow_lan: bool, block_when_disconnected: bool, is_offline: bool, @@ -209,7 +211,7 @@ impl TunnelStateMachine { let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?; let route_manager = - RouteManager::new(HashSet::new()).map_err(Error::InitRouteManagerError)?; + RouteManager::new(runtime, HashSet::new()).map_err(Error::InitRouteManagerError)?; let mut shared_values = SharedTunnelStateValues { firewall, dns_monitor, |
