diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-10-30 13:42:02 +0100 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-11-06 16:17:01 +0100 |
| commit | 29259edada0e5a255cafb5ae5930c375cdb68764 (patch) | |
| tree | 4b419f2a76eeba0762d994bd70835de8e181dda4 | |
| parent | 30cafc9cfe8326f09ab8787a2ae3999370eb1ef6 (diff) | |
| download | mullvadvpn-29259edada0e5a255cafb5ae5930c375cdb68764.tar.xz mullvadvpn-29259edada0e5a255cafb5ae5930c375cdb68764.zip | |
Add command channel to the Windows route manager
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 123 |
1 files changed, 101 insertions, 22 deletions
diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index f3e2197602..80c21d5e56 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -1,5 +1,12 @@ use super::NetNode; use crate::{routing::RequiredRoute, winnet}; +use futures::{ + channel::{ + mpsc::{self, UnboundedReceiver, UnboundedSender}, + oneshot, + }, + StreamExt, +}; use std::collections::HashSet; /// Windows routing errors. @@ -14,6 +21,9 @@ pub enum Error { /// Failure to clear routes #[error(display = "Failed to clear applied routes")] ClearRoutesFailed, + /// Attempt to use route manager that has been dropped + #[error(display = "Cannot send message to route manager since it is down")] + RouteManagerDown, } pub type Result<T> = std::result::Result<T, Error>; @@ -21,8 +31,32 @@ pub type Result<T> = std::result::Result<T, Error>; /// Manages routes by calling into WinNet pub struct RouteManager { callback_handles: Vec<winnet::WinNetCallbackHandle>, - is_stopped: bool, runtime: tokio::runtime::Handle, + manage_tx: Option<UnboundedSender<RouteManagerCommand>>, +} + +/// Handle to a route manager. +#[derive(Clone)] +pub struct RouteManagerHandle { + runtime: tokio::runtime::Handle, + tx: UnboundedSender<RouteManagerCommand>, +} + +impl RouteManagerHandle { + /// Applies the given routes while the route manager is running. + pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + self.runtime.block_on(response_rx).unwrap() + } +} + +#[derive(Debug)] +pub enum RouteManagerCommand { + AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>), + Shutdown, } impl RouteManager { @@ -35,15 +69,63 @@ impl RouteManager { if !winnet::activate_routing_manager() { return Err(Error::FailedToStartManager); } + let (manage_tx, manage_rx) = mpsc::unbounded(); let manager = Self { callback_handles: vec![], - is_stopped: false, - runtime, + runtime: runtime.clone(), + manage_tx: Some(manage_tx), }; + runtime.spawn(RouteManager::listen(manage_rx)); manager.add_routes(required_routes)?; + Ok(manager) } + /// Retrieve a sender directly to the command channel. + pub fn handle(&self) -> Result<RouteManagerHandle> { + if let Some(tx) = &self.manage_tx { + Ok(RouteManagerHandle { + runtime: self.runtime.clone(), + tx: tx.clone(), + }) + } else { + Err(Error::RouteManagerDown) + } + } + + async fn listen(mut manage_rx: UnboundedReceiver<RouteManagerCommand>) { + while let Some(command) = manage_rx.next().await { + match command { + RouteManagerCommand::AddRoutes(routes, tx) => { + let routes: Vec<_> = routes + .iter() + .map(|route| { + let destination = winnet::WinNetIpNetwork::from(route.prefix); + match &route.node { + NetNode::DefaultNode => { + winnet::WinNetRoute::through_default_node(destination) + } + NetNode::RealNode(node) => winnet::WinNetRoute::new( + winnet::WinNetNode::from(node), + destination, + ), + } + }) + .collect(); + + if winnet::routing_manager_add_routes(&routes) { + let _ = tx.send(Ok(())); + } else { + let _ = tx.send(Err(Error::AddRoutesFailed)); + } + } + RouteManagerCommand::Shutdown => { + break; + } + } + } + } + /// Sets a callback that is called whenever the default route changes. #[cfg(target_os = "windows")] pub fn add_default_route_callback<T: 'static>( @@ -51,7 +133,7 @@ impl RouteManager { callback: Option<winnet::DefaultRouteChangedCallback>, context: T, ) { - if self.is_stopped { + if self.manage_tx.is_none() { return; } @@ -76,32 +158,29 @@ impl RouteManager { /// Stops the routing manager and invalidates the route manager - no new default route callbacks /// can be added pub fn stop(&mut self) { - if !self.is_stopped { + if let Some(tx) = self.manage_tx.take() { + if tx.unbounded_send(RouteManagerCommand::Shutdown).is_err() { + log::error!("RouteManager channel already down or thread panicked"); + } + self.callback_handles.clear(); winnet::deactivate_routing_manager(); - self.is_stopped = true; } } /// Applies the given routes until [`RouteManager::stop`] is called. pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { - let routes: Vec<_> = routes - .iter() - .map(|route| { - let destination = winnet::WinNetIpNetwork::from(route.prefix); - match &route.node { - NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), - NetNode::RealNode(node) => { - winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) - } - } - }) - .collect(); - - if winnet::routing_manager_add_routes(&routes) { - Ok(()) + if let Some(tx) = &self.manage_tx { + let (result_tx, result_rx) = oneshot::channel(); + if tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .is_err() + { + return Err(Error::RouteManagerDown); + } + self.runtime.block_on(result_rx).unwrap() } else { - Err(Error::AddRoutesFailed) + Err(Error::RouteManagerDown) } } |
