summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Lönnhager <david.l@mullvad.net>2020-10-30 13:42:02 +0100
committerDavid Lönnhager <david.l@mullvad.net>2020-11-06 16:17:01 +0100
commit29259edada0e5a255cafb5ae5930c375cdb68764 (patch)
tree4b419f2a76eeba0762d994bd70835de8e181dda4
parent30cafc9cfe8326f09ab8787a2ae3999370eb1ef6 (diff)
downloadmullvadvpn-29259edada0e5a255cafb5ae5930c375cdb68764.tar.xz
mullvadvpn-29259edada0e5a255cafb5ae5930c375cdb68764.zip
Add command channel to the Windows route manager
-rw-r--r--talpid-core/src/routing/windows.rs123
1 files changed, 101 insertions, 22 deletions
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs
index f3e2197602..80c21d5e56 100644
--- a/talpid-core/src/routing/windows.rs
+++ b/talpid-core/src/routing/windows.rs
@@ -1,5 +1,12 @@
use super::NetNode;
use crate::{routing::RequiredRoute, winnet};
+use futures::{
+ channel::{
+ mpsc::{self, UnboundedReceiver, UnboundedSender},
+ oneshot,
+ },
+ StreamExt,
+};
use std::collections::HashSet;
/// Windows routing errors.
@@ -14,6 +21,9 @@ pub enum Error {
/// Failure to clear routes
#[error(display = "Failed to clear applied routes")]
ClearRoutesFailed,
+ /// Attempt to use route manager that has been dropped
+ #[error(display = "Cannot send message to route manager since it is down")]
+ RouteManagerDown,
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -21,8 +31,32 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Manages routes by calling into WinNet
pub struct RouteManager {
callback_handles: Vec<winnet::WinNetCallbackHandle>,
- is_stopped: bool,
runtime: tokio::runtime::Handle,
+ manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
+}
+
+/// Handle to a route manager.
+#[derive(Clone)]
+pub struct RouteManagerHandle {
+ runtime: tokio::runtime::Handle,
+ tx: UnboundedSender<RouteManagerCommand>,
+}
+
+impl RouteManagerHandle {
+ /// Applies the given routes while the route manager is running.
+ pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
+ .map_err(|_| Error::RouteManagerDown)?;
+ self.runtime.block_on(response_rx).unwrap()
+ }
+}
+
+#[derive(Debug)]
+pub enum RouteManagerCommand {
+ AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
+ Shutdown,
}
impl RouteManager {
@@ -35,15 +69,63 @@ impl RouteManager {
if !winnet::activate_routing_manager() {
return Err(Error::FailedToStartManager);
}
+ let (manage_tx, manage_rx) = mpsc::unbounded();
let manager = Self {
callback_handles: vec![],
- is_stopped: false,
- runtime,
+ runtime: runtime.clone(),
+ manage_tx: Some(manage_tx),
};
+ runtime.spawn(RouteManager::listen(manage_rx));
manager.add_routes(required_routes)?;
+
Ok(manager)
}
+ /// Retrieve a sender directly to the command channel.
+ pub fn handle(&self) -> Result<RouteManagerHandle> {
+ if let Some(tx) = &self.manage_tx {
+ Ok(RouteManagerHandle {
+ runtime: self.runtime.clone(),
+ tx: tx.clone(),
+ })
+ } else {
+ Err(Error::RouteManagerDown)
+ }
+ }
+
+ async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) {
+ while let Some(command) = manage_rx.next().await {
+ match command {
+ RouteManagerCommand::AddRoutes(routes, tx) => {
+ let routes: Vec<_> = routes
+ .iter()
+ .map(|route| {
+ let destination = winnet::WinNetIpNetwork::from(route.prefix);
+ match &route.node {
+ NetNode::DefaultNode => {
+ winnet::WinNetRoute::through_default_node(destination)
+ }
+ NetNode::RealNode(node) => winnet::WinNetRoute::new(
+ winnet::WinNetNode::from(node),
+ destination,
+ ),
+ }
+ })
+ .collect();
+
+ if winnet::routing_manager_add_routes(&routes) {
+ let _ = tx.send(Ok(()));
+ } else {
+ let _ = tx.send(Err(Error::AddRoutesFailed));
+ }
+ }
+ RouteManagerCommand::Shutdown => {
+ break;
+ }
+ }
+ }
+ }
+
/// Sets a callback that is called whenever the default route changes.
#[cfg(target_os = "windows")]
pub fn add_default_route_callback<T: 'static>(
@@ -51,7 +133,7 @@ impl RouteManager {
callback: Option<winnet::DefaultRouteChangedCallback>,
context: T,
) {
- if self.is_stopped {
+ if self.manage_tx.is_none() {
return;
}
@@ -76,32 +158,29 @@ impl RouteManager {
/// Stops the routing manager and invalidates the route manager - no new default route callbacks
/// can be added
pub fn stop(&mut self) {
- if !self.is_stopped {
+ if let Some(tx) = self.manage_tx.take() {
+ if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() {
+ log::error!("RouteManager channel already down or thread panicked");
+ }
+
self.callback_handles.clear();
winnet::deactivate_routing_manager();
- self.is_stopped = true;
}
}
/// Applies the given routes until [`RouteManager::stop`] is called.
pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
- let routes: Vec<_> = routes
- .iter()
- .map(|route| {
- let destination = winnet::WinNetIpNetwork::from(route.prefix);
- match &route.node {
- NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination),
- NetNode::RealNode(node) => {
- winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination)
- }
- }
- })
- .collect();
-
- if winnet::routing_manager_add_routes(&routes) {
- Ok(())
+ if let Some(tx) = &self.manage_tx {
+ let (result_tx, result_rx) = oneshot::channel();
+ if tx
+ .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx))
+ .is_err()
+ {
+ return Err(Error::RouteManagerDown);
+ }
+ self.runtime.block_on(result_rx).unwrap()
} else {
- Err(Error::AddRoutesFailed)
+ Err(Error::RouteManagerDown)
}
}