summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--talpid-core/src/routing/mod.rs3
-rw-r--r--talpid-core/src/routing/unix.rs117
-rw-r--r--talpid-core/src/routing/windows.rs135
-rw-r--r--talpid-core/src/tunnel/openvpn.rs17
-rw-r--r--talpid-core/src/tunnel_state_machine/mod.rs4
5 files changed, 192 insertions, 84 deletions
diff --git a/talpid-core/src/routing/mod.rs b/talpid-core/src/routing/mod.rs
index a4f7dd1f76..d75073df8a 100644
--- a/talpid-core/src/routing/mod.rs
+++ b/talpid-core/src/routing/mod.rs
@@ -17,8 +17,7 @@ use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN;
pub use imp::{Error, RouteManager};
-#[cfg(target_os = "linux")]
-pub use imp::RouteManagerCommand;
+pub use imp::RouteManagerHandle;
/// A netowrk route with a specific network node, destinaiton and an optional metric.
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs
index c766fbfbe8..147a33a1a5 100644
--- a/talpid-core/src/routing/unix.rs
+++ b/talpid-core/src/routing/unix.rs
@@ -8,7 +8,6 @@ use futures::channel::{
oneshot,
};
use std::{collections::HashSet, io};
-use talpid_types::ErrorExt;
#[cfg(target_os = "linux")]
use std::net::IpAddr;
@@ -30,9 +29,9 @@ pub use imp::Error as PlatformError;
/// Errors that can be encountered whilst initializing RouteManager
#[derive(err_derive::Error, Debug)]
pub enum Error {
- /// Routing manager thread panicked before starting routing manager
- #[error(display = "Routing manager thread panicked before starting routing manager")]
- RoutingManagerThreadPanic,
+ /// Route manager thread may have panicked
+ #[error(display = "The channel sender was dropped")]
+ ManagerChannelDown,
/// Platform specific error occured
#[error(display = "Internal route manager error")]
PlatformError(#[error(source)] imp::Error),
@@ -47,6 +46,43 @@ pub enum Error {
RouteManagerDown,
}
+/// Handle to a route manager.
+#[derive(Clone)]
+pub struct RouteManagerHandle {
+ runtime: tokio::runtime::Handle,
+ tx: UnboundedSender<RouteManagerCommand>,
+}
+
+impl RouteManagerHandle {
+ /// Applies the given routes while the route manager is running.
+ pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<(), Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ self.runtime
+ .block_on(response_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
+ .map_err(Error::PlatformError)
+ }
+
+ /// Set the link to be ignored by the exclusions routing table.
+ #[cfg(target_os = "linux")]
+ pub fn set_tunnel_link(&self, interface: &str) -> Result<(), Error> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::SetTunnelLink(
+ interface.to_string(),
+ response_tx,
+ ))
+ .map_err(|_| Error::RouteManagerDown)?;
+ Ok(self
+ .runtime
+ .block_on(response_rx)
+ .map_err(|_| Error::ManagerChannelDown)?)
+ }
+}
+
/// Commands for the underlying route manager object.
#[derive(Debug)]
pub enum RouteManagerCommand {
@@ -82,23 +118,20 @@ pub enum RouteManagerCommand {
/// the route will be adjusted dynamically when the default route changes.
pub struct RouteManager {
manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
- runtime: tokio::runtime::Runtime,
+ runtime: tokio::runtime::Handle,
}
impl RouteManager {
/// Constructs a RouteManager and applies the required routes.
/// Takes a set of network destinations and network nodes as an argument, and applies said
/// routes.
- pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> {
+ pub fn new(
+ runtime: tokio::runtime::Handle,
+ required_routes: HashSet<RequiredRoute>,
+ ) -> Result<Self, Error> {
let (manage_tx, manage_rx) = mpsc::unbounded();
- let mut runtime = tokio::runtime::Builder::new()
- .threaded_scheduler()
- .core_threads(1)
- .max_threads(1)
- .enable_all()
- .build()?;
let manager = runtime.block_on(imp::RouteManagerImpl::new(required_routes))?;
- runtime.handle().spawn(manager.run(manage_rx));
+ runtime.spawn(manager.run(manage_rx));
Ok(Self {
runtime,
@@ -120,7 +153,7 @@ impl RouteManager {
}
if self.runtime.block_on(wait_rx).is_err() {
- log::error!("RouteManager paniced while shutting down");
+ log::error!("{}", Error::ManagerChannelDown);
}
}
}
@@ -136,16 +169,10 @@ impl RouteManager {
return Err(Error::RouteManagerDown);
}
- match self.runtime.block_on(result_rx) {
- Ok(result) => result.map_err(Error::PlatformError),
- Err(error) => {
- log::trace!(
- "{}",
- error.display_chain_with_msg("oneshot channel is closed")
- );
- Ok(())
- }
- }
+ self.runtime
+ .block_on(result_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
+ .map_err(Error::PlatformError)
} else {
Err(Error::RouteManagerDown)
}
@@ -176,13 +203,10 @@ impl RouteManager {
return Err(Error::RouteManagerDown);
}
- match self.runtime.block_on(result_rx) {
- Ok(result) => result.map_err(Error::PlatformError),
- Err(error) => {
- log::trace!("{}", error.display_chain_with_msg("channel is closed"));
- Ok(())
- }
- }
+ self.runtime
+ .block_on(result_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
+ .map_err(Error::PlatformError)
} else {
Err(Error::RouteManagerDown)
}
@@ -218,23 +242,21 @@ impl RouteManager {
{
return Err(Error::RouteManagerDown);
}
- match self.runtime.block_on(result_rx) {
- Ok(()) => Ok(()),
- Err(error) => {
- log::trace!("{}", error.display_chain_with_msg("channel is closed"));
- Ok(())
- }
- }
+ self.runtime
+ .block_on(result_rx)
+ .map_err(|_| Error::ManagerChannelDown)
} else {
Err(Error::RouteManagerDown)
}
}
/// Retrieve a sender directly to the command channel.
- #[cfg(target_os = "linux")]
- pub fn channel(&self) -> Result<UnboundedSender<RouteManagerCommand>, Error> {
+ pub fn handle(&self) -> Result<RouteManagerHandle, Error> {
if let Some(tx) = &self.manage_tx {
- Ok(tx.clone())
+ Ok(RouteManagerHandle {
+ runtime: self.runtime.clone(),
+ tx: tx.clone(),
+ })
} else {
Err(Error::RouteManagerDown)
}
@@ -243,7 +265,7 @@ impl RouteManager {
/// Exposes runtime handle
#[cfg(target_os = "linux")]
pub fn runtime_handle(&self) -> tokio::runtime::Handle {
- self.runtime.handle().clone()
+ self.runtime.clone()
}
/// Route DNS requests through the tunnel interface.
@@ -266,13 +288,10 @@ impl RouteManager {
return Err(Error::RouteManagerDown);
}
- match self.runtime.block_on(result_rx) {
- Ok(result) => result.map_err(Error::PlatformError),
- Err(error) => {
- log::trace!("{}", error.display_chain_with_msg("channel is closed"));
- Ok(())
- }
- }
+ self.runtime
+ .block_on(result_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
+ .map_err(Error::PlatformError)
} else {
Err(Error::RouteManagerDown)
}
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index 412384574b..47a45672dc 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -1,10 +1,20 @@
use super::NetNode;
use crate::{routing::RequiredRoute, winnet};
+use futures::{
+ channel::{
+ mpsc::{self, UnboundedReceiver, UnboundedSender},
+ oneshot,
+ },
+ StreamExt,
+};
use std::collections::HashSet;
/// Windows routing errors.
#[derive(err_derive::Error, Debug)]
pub enum Error {
+ /// The sender was dropped unexpectedly -- possible panic
+ #[error(display = "The channel sender was dropped")]
+ ManagerChannelDown,
/// Failure to initialize route manager
#[error(display = "Failed to start route manager")]
FailedToStartManager,
@@ -14,6 +24,9 @@ pub enum Error {
/// Failure to clear routes
#[error(display = "Failed to clear applied routes")]
ClearRoutesFailed,
+ /// Attempt to use route manager that has been dropped
+ #[error(display = "Cannot send message to route manager since it is down")]
+ RouteManagerDown,
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -21,24 +34,103 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Manages routes by calling into WinNet
pub struct RouteManager {
callback_handles: Vec<winnet::WinNetCallbackHandle>,
- is_stopped: bool,
+ runtime: tokio::runtime::Handle,
+ manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
+}
+
+/// Handle to a route manager.
+#[derive(Clone)]
+pub struct RouteManagerHandle {
+ runtime: tokio::runtime::Handle,
+ tx: UnboundedSender<RouteManagerCommand>,
+}
+
+impl RouteManagerHandle {
+ /// Applies the given routes while the route manager is running.
+ pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ self.runtime
+ .block_on(response_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
+ }
+}
+
+#[derive(Debug)]
+pub enum RouteManagerCommand {
+ AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
+ Shutdown,
}
impl RouteManager {
/// Creates a new route manager that will apply the provided routes and ensure they exist until
/// it's stopped.
- pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> {
+ pub fn new(
+ runtime: tokio::runtime::Handle,
+ required_routes: HashSet<RequiredRoute>,
+ ) -> Result<Self> {
if !winnet::activate_routing_manager() {
return Err(Error::FailedToStartManager);
}
+ let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = Self {
callback_handles: vec![],
- is_stopped: false,
+ runtime: runtime.clone(),
+ manage_tx: Some(manage_tx),
};
+ runtime.spawn(RouteManager::listen(manage_rx));
manager.add_routes(required_routes)?;
+
Ok(manager)
}
+ /// Retrieve a sender directly to the command channel.
+ pub fn handle(&self) -> Result<RouteManagerHandle> {
+ if let Some(tx) = &self.manage_tx {
+ Ok(RouteManagerHandle {
+ runtime: self.runtime.clone(),
+ tx: tx.clone(),
+ })
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) {
+ while let Some(command) = manage_rx.next().await {
+ match command {
+ RouteManagerCommand::AddRoutes(routes, tx) => {
+ let routes: Vec<_> = routes
+ .iter()
+ .map(|route| {
+ let destination = winnet::WinNetIpNetwork::from(route.prefix);
+ match &route.node {
+ NetNode::DefaultNode => {
+ winnet::WinNetRoute::through_default_node(destination)
+ }
+ NetNode::RealNode(node) => winnet::WinNetRoute::new(
+ winnet::WinNetNode::from(node),
+ destination,
+ ),
+ }
+ })
+ .collect();
+
+ if winnet::routing_manager_add_routes(&routes) {
+ let _ = tx.send(Ok(()));
+ } else {
+ let _ = tx.send(Err(Error::AddRoutesFailed));
+ }
+ }
+ RouteManagerCommand::Shutdown => {
+ break;
+ }
+ }
+ }
+ }
+
/// Sets a callback that is called whenever the default route changes.
#[cfg(target_os = "windows")]
pub fn add_default_route_callback<T: 'static>(
@@ -46,7 +138,7 @@ impl RouteManager {
callback: Option<winnet::DefaultRouteChangedCallback>,
context: T,
) {
- if self.is_stopped {
+ if self.manage_tx.is_none() {
return;
}
@@ -71,32 +163,31 @@ impl RouteManager {
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
/// can be added
pub fn stop(&mut self) {
- if !self.is_stopped {
+ if let Some(tx) = self.manage_tx.take() {
+ if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() {
+ log::error!("RouteManager channel already down or thread panicked");
+ }
+
self.callback_handles.clear();
winnet::deactivate_routing_manager();
- self.is_stopped = true;
}
}
/// Applies the given routes until [`RouteManager::stop`] is called.
pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
- let routes: Vec<_> = routes
- .iter()
- .map(|route| {
- let destination = winnet::WinNetIpNetwork::from(route.prefix);
- match &route.node {
- NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
- NetNode::RealNode(node) => {
- winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
- }
- }
- })
- .collect();
-
- if winnet::routing_manager_add_routes(&routes) {
- Ok(())
+ if let Some(tx) = &self.manage_tx {
+ let (result_tx, result_rx) = oneshot::channel();
+ if tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
+ .is_err()
+ {
+ return Err(Error::RouteManagerDown);
+ }
+ self.runtime
+ .block_on(result_rx)
+ .map_err(|_| Error::ManagerChannelDown)?
} else {
- Err(Error::AddRoutesFailed)
+ Err(Error::RouteManagerDown)
}
}
diff --git a/talpid-core/src/tunnel/openvpn.rs b/talpid-core/src/tunnel/openvpn.rs
index 8730eadeeb..f83cabca3b 100644
--- a/talpid-core/src/tunnel/openvpn.rs
+++ b/talpid-core/src/tunnel/openvpn.rs
@@ -8,8 +8,6 @@ use crate::{
proxy::{self, ProxyMonitor, ProxyResourceData},
routing,
};
-#[cfg(target_os = "linux")]
-use futures::channel::oneshot;
use std::{
collections::HashMap,
fs,
@@ -172,18 +170,17 @@ impl OpenVpnMonitor<OpenVpnCommand> {
};
#[cfg(target_os = "linux")]
- let route_manager_tx = route_manager.channel().map_err(Error::SetupRoutingError)?;
+ let route_manager_handle = route_manager.handle().map_err(Error::SetupRoutingError)?;
let on_openvpn_event = move |event, env: HashMap<String, String>| {
#[cfg(target_os = "linux")]
if event == openvpn_plugin::EventType::Up {
- let (tx, rx) = oneshot::channel();
- let interface = env.get("dev").unwrap().to_owned();
- route_manager_tx
- .unbounded_send(routing::RouteManagerCommand::SetTunnelLink(interface, tx))
- .unwrap();
- tokio::task::block_in_place(move || {
- futures::executor::block_on(rx).unwrap();
+ let interface = env.get("dev").unwrap();
+ tokio::task::block_in_place(|| {
+ route_manager_handle
+ .clone()
+ .set_tunnel_link(interface)
+ .unwrap();
});
return;
}
diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs
index b98eb820d1..9b33e8ae48 100644
--- a/talpid-core/src/tunnel_state_machine/mod.rs
+++ b/talpid-core/src/tunnel_state_machine/mod.rs
@@ -109,6 +109,7 @@ pub async fn spawn(
let (startup_result_tx, startup_result_rx) = sync_mpsc::channel();
std::thread::spawn(move || {
let state_machine = TunnelStateMachine::new(
+ runtime.clone(),
allow_lan,
block_when_disconnected,
is_offline,
@@ -189,6 +190,7 @@ struct TunnelStateMachine {
impl TunnelStateMachine {
fn new(
+ runtime: tokio::runtime::Handle,
allow_lan: bool,
block_when_disconnected: bool,
is_offline: bool,
@@ -209,7 +211,7 @@ impl TunnelStateMachine {
let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?;
let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?;
let route_manager =
- RouteManager::new(HashSet::new()).map_err(Error::InitRouteManagerError)?;
+ RouteManager::new(runtime, HashSet::new()).map_err(Error::InitRouteManagerError)?;
let mut shared_values = SharedTunnelStateValues {
firewall,
dns_monitor,