diff options
| author | David Lönnhager <david.l@mullvad.net> | 2020-05-13 17:57:30 +0200 |
|---|---|---|
| committer | David Lönnhager <david.l@mullvad.net> | 2020-05-13 17:57:30 +0200 |
| commit | 410e79ecbd767c63a63ffd795cebe9e80fe74df5 (patch) | |
| tree | 630318cc2ed04e0e6fd682749cc80c00cd74779f | |
| parent | ba215529bbf0c6bbe708c538d0260f565833d77d (diff) | |
| parent | f60115e4992a298518942793edd0e05e802037e7 (diff) | |
| download | mullvadvpn-410e79ecbd767c63a63ffd795cebe9e80fe74df5.tar.xz mullvadvpn-410e79ecbd767c63a63ffd795cebe9e80fe74df5.zip | |
Merge branch 'routemgm-update'
| -rw-r--r-- | talpid-core/src/routing/android.rs | 30 | ||||
| -rw-r--r-- | talpid-core/src/routing/linux.rs | 170 | ||||
| -rw-r--r-- | talpid-core/src/routing/macos.rs | 146 | ||||
| -rw-r--r-- | talpid-core/src/routing/unix.rs | 87 | ||||
| -rw-r--r-- | talpid-core/src/routing/windows.rs | 69 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/mod.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/tunnel/wireguard/mod.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connected_state.rs | 15 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/connecting_state.rs | 20 | ||||
| -rw-r--r-- | talpid-core/src/tunnel_state_machine/mod.rs | 10 | ||||
| -rw-r--r-- | talpid-core/src/winnet.rs | 20 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.cpp | 49 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/routing/routemanager.h | 1 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.cpp | 29 | ||||
| -rw-r--r-- | windows/winnet/src/winnet/winnet.h | 7 |
15 files changed, 474 insertions, 209 deletions
diff --git a/talpid-core/src/routing/android.rs b/talpid-core/src/routing/android.rs index 1f47027343..d364f0769e 100644 --- a/talpid-core/src/routing/android.rs +++ b/talpid-core/src/routing/android.rs @@ -1,5 +1,5 @@ -use crate::routing::RequiredRoute; -use futures01::{sync::oneshot, Async, Future}; +use crate::routing::{imp::RouteManagerCommand, RequiredRoute}; +use futures01::{stream::Stream, sync::mpsc}; use std::collections::HashSet; /// Stub error type for routing errors on Android. @@ -9,30 +9,26 @@ pub struct Error; /// Stub route manager for Android pub struct RouteManagerImpl { - shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, + manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>, } impl RouteManagerImpl { pub fn new( _required_routes: HashSet<RequiredRoute>, - shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, + manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>, ) -> Result<Self, Error> { - Ok(RouteManagerImpl { shutdown_rx }) + Ok(RouteManagerImpl { manage_rx }) } -} - -impl Future for RouteManagerImpl { - type Item = (); - type Error = Error; - fn poll(&mut self) -> Result<Async<()>, Error> { - match self.shutdown_rx.poll() { - Ok(Async::Ready(result_tx)) => { - result_tx.send(()).map_err(|()| Error)?; - Ok(Async::Ready(())) + pub fn wait(self) -> Result<(), Error> { + for msg in self.manage_rx.wait() { + if let Ok(command) = msg { + if let RouteManagerCommand::Shutdown(tx) = command { + tx.send(()).map_err(|()| Error)?; + break; + } } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => Ok(Async::Ready(())), } + Ok(()) } } diff --git a/talpid-core/src/routing/linux.rs b/talpid-core/src/routing/linux.rs index 124c25e319..7a7bfa4865 100644 --- a/talpid-core/src/routing/linux.rs +++ b/talpid-core/src/routing/linux.rs @@ -1,17 +1,21 @@ -use crate::routing::{NetNode, Node, RequiredRoute, Route}; +use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route}; + +use talpid_types::ErrorExt; use ipnetwork::IpNetwork; use std::{ collections::{BTreeMap, HashSet}, io, net::IpAddr, + thread, }; -use futures01::sync::oneshot as old_oneshot; +use futures01::{stream::Stream as old_stream, sync::mpsc as old_mpsc}; use futures::{ - channel::mpsc::UnboundedReceiver, compat::Future01CompatExt, future::FutureExt, StreamExt, - TryStreamExt, + channel::mpsc::{self, UnboundedReceiver}, + future::FutureExt, + StreamExt, TryStreamExt, }; @@ -62,10 +66,13 @@ pub enum Error { #[error(display = "Unknown device index - {}", _0)] UnknownDeviceIndex(u32), + + #[error(display = "Shutting down route manager")] + Shutdown, } pub struct RouteManagerImpl { - shutdown_rx: old_oneshot::Receiver<old_oneshot::Sender<()>>, + manage_rx: old_mpsc::UnboundedReceiver<RouteManagerCommand>, manager: RouteManagerImplInner, runtime: tokio02::runtime::Runtime, } @@ -74,7 +81,7 @@ impl RouteManagerImpl { /// Creates a new RouteManagerImplInner. pub fn new( required_routes: HashSet<RequiredRoute>, - shutdown_rx: old_oneshot::Receiver<old_oneshot::Sender<()>>, + manage_rx: old_mpsc::UnboundedReceiver<RouteManagerCommand>, ) -> Result<Self> { let mut runtime = tokio02::runtime::Builder::new() .basic_scheduler() @@ -87,7 +94,7 @@ impl RouteManagerImpl { let manager = runtime.block_on(RouteManagerImplInner::new(required_routes))?; Ok(Self { - shutdown_rx, + manage_rx, runtime, manager, }) @@ -95,11 +102,28 @@ impl RouteManagerImpl { pub fn wait(self) -> Result<()> { let Self { - shutdown_rx, + manage_rx, mut runtime, manager, } = self; - runtime.block_on(manager.into_future(shutdown_rx)) + + let (new_manage_tx, new_manage_rx) = mpsc::unbounded(); + + thread::spawn(move || { + for msg in manage_rx.wait() { + match msg { + Ok(msg) => { + if new_manage_tx.unbounded_send(msg).is_err() { + log::error!("RouteManager receiver unexpectedly dropped"); + break; + } + } + Err(_) => break, + } + } + }); + + runtime.block_on(manager.into_future(new_manage_rx)) } } @@ -140,32 +164,12 @@ impl RouteManagerImplInner { let iface_map = Self::initialize_link_map(&handle).await?; - - let mut required_normal_routes = HashSet::new(); - let mut required_default_routes = HashSet::new(); - - for route in required_routes { - match route.node { - NetNode::RealNode(node) => { - required_normal_routes - .insert(Route::new(node, route.prefix).table(route.table_id)); - } - NetNode::DefaultNode => { - required_default_routes.insert(RequiredDefaultRoute { - table_id: route.table_id, - destination: route.prefix, - }); - } - } - } - - let mut monitor = Self { iface_map, handle, messages, - required_default_routes, + required_default_routes: HashSet::new(), added_routes: HashSet::new(), default_routes: HashSet::new(), @@ -179,24 +183,72 @@ impl RouteManagerImplInner { monitor.best_default_node_v6 = Self::pick_best_default_node(&monitor.default_routes, IpVersion::V6); + monitor.add_required_routes(required_routes).await?; - for normal_route in required_normal_routes.into_iter() { - monitor.add_route(normal_route).await?; - } + Ok(monitor) + } - for route in monitor.required_default_routes.clone().into_iter() { + async fn add_required_default_routes( + &mut self, + required_default_routes: HashSet<RequiredDefaultRoute>, + ) -> Result<()> { + for route in required_default_routes.into_iter() { if let (false, _, Some(default_node)) | (true, Some(default_node), _) = ( route.destination.is_ipv4(), - &monitor.best_default_node_v4, - &monitor.best_default_node_v6, + &self.best_default_node_v4, + &self.best_default_node_v6, ) { // best to pick a single node identifier rather than device + ip let new_route = Route::new(default_node.clone(), route.destination).table(route.table_id); - monitor.add_route(new_route).await?; + self.add_route(new_route).await?; } + self.required_default_routes.insert(route); } - Ok(monitor) + Ok(()) + } + + async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { + let mut required_normal_routes = HashSet::new(); + let mut required_default_routes = HashSet::new(); + + for route in required_routes { + match route.node { + NetNode::RealNode(node) => { + required_normal_routes + .insert(Route::new(node, route.prefix).table(route.table_id)); + } + NetNode::DefaultNode => { + required_default_routes.insert(RequiredDefaultRoute { + table_id: route.table_id, + destination: route.prefix, + }); + } + } + } + + for normal_route in required_normal_routes.into_iter() { + self.add_route(normal_route).await?; + } + + if self + .add_required_default_routes(required_default_routes.clone()) + .await + .is_err() + { + log::trace!("Refreshing default routes which may be stale"); + + self.default_routes = self.get_default_routes().await?; + self.best_default_node_v4 = + Self::pick_best_default_node(&self.default_routes, IpVersion::V4); + self.best_default_node_v6 = + Self::pick_best_default_node(&self.default_routes, IpVersion::V6); + + self.add_required_default_routes(required_default_routes) + .await?; + } + + Ok(()) } async fn get_default_routes(&self) -> Result<HashSet<Route>> { @@ -398,27 +450,47 @@ impl RouteManagerImplInner { pub async fn into_future( mut self, - shutdown_rx: futures01::sync::oneshot::Receiver<futures01::sync::oneshot::Sender<()>>, + mut manage_rx: UnboundedReceiver<RouteManagerCommand>, ) -> Result<()> { - let mut shutdown = shutdown_rx.compat().fuse(); loop { futures::select! { - shutdown_signal = shutdown => { - log::trace!("Shutting down route manager"); - self.cleanup_routes().await; - log::trace!("Route manager done"); - if let Ok(shutdown_signal) = shutdown_signal { - let _ = shutdown_signal.send(()); - } - return Ok(()); + command = manage_rx.select_next_some().fuse() => { + self.process_command(command).await?; }, (route_change, socket) = self.messages.select_next_some().fuse() => { - self.process_netlink_message(route_change).await?; + if let Err(error) = self.process_netlink_message(route_change).await { + log::error!("{}", error.display_chain_with_msg("Failed to process netlink message")); + } } }; } } + async fn process_command(&mut self, command: RouteManagerCommand) -> Result<()> { + match command { + RouteManagerCommand::Shutdown(shutdown_signal) => { + log::trace!("Shutting down route manager"); + self.cleanup_routes().await; + log::trace!("Route manager done"); + let _ = shutdown_signal.send(()); + return Err(Error::Shutdown); + } + RouteManagerCommand::AddRoutes(routes, result_rx) => { + log::debug!("Adding routes: {:?}", routes); + if let Err(error) = self.add_required_routes(routes.clone()).await { + let _ = result_rx.send(Err(error)); + } else { + let _ = result_rx.send(Ok(())); + } + } + RouteManagerCommand::ClearRoutes => { + log::debug!("Clearing routes"); + self.cleanup_routes().await; + } + } + Ok(()) + } + async fn process_netlink_message(&mut self, msg: NetlinkMessage<RtnlMessage>) -> Result<()> { match msg.payload { NetlinkPayload::InnerMessage(RtnlMessage::NewLink(new_link)) => { diff --git a/talpid-core/src/routing/macos.rs b/talpid-core/src/routing/macos.rs index 4f05927824..56f7257024 100644 --- a/talpid-core/src/routing/macos.rs +++ b/talpid-core/src/routing/macos.rs @@ -1,4 +1,4 @@ -use crate::routing::{NetNode, Node, RequiredRoute, Route}; +use crate::routing::{imp::RouteManagerCommand, NetNode, Node, RequiredRoute, Route}; use ipnetwork::IpNetwork; use std::{ @@ -8,7 +8,11 @@ use std::{ process::{Command, ExitStatus, Stdio}, }; -use futures01::{stream, sync::oneshot, Async, Future, IntoFuture, Stream}; +use futures01::{ + stream, + sync::{mpsc, oneshot}, + Async, Future, IntoFuture, Stream, +}; use tokio_process::{Child, CommandExt}; @@ -69,18 +73,16 @@ pub struct RouteManagerImpl { current_state: RouteManagerState, v4_gateway: Option<Node>, v6_gateway: Option<Node>, - shutdown_rx: Option<oneshot::Receiver<oneshot::Sender<()>>>, + manage_rx: Option<mpsc::UnboundedReceiver<RouteManagerCommand>>, } impl RouteManagerImpl { pub fn new( required_routes: HashSet<RequiredRoute>, - shutdown_rx: oneshot::Receiver<oneshot::Sender<()>>, + manage_rx: mpsc::UnboundedReceiver<RouteManagerCommand>, ) -> Result<Self> { - let mut applied_routes = HashSet::new(); - let mut routes_to_apply = vec![]; - let mut default_destinations = HashSet::new(); + let change_listener = ChangeListener::new().map_err(Error::FailedToMonitorRoutes)?; let v4_gateway = Self::get_default_node_cmd("-inet").wait()?; let v6_gateway = Self::get_default_node_cmd("-inet6").wait()?; @@ -89,6 +91,24 @@ impl RouteManagerImpl { return Err(Error::NoDefaultRoute); } + let mut manager = Self { + default_destinations: HashSet::new(), + applied_routes: HashSet::new(), + current_state: RouteManagerState::Listening(change_listener), + manage_rx: Some(manage_rx), + v4_gateway, + v6_gateway, + }; + + manager.add_required_routes(required_routes)?; + + Ok(manager) + } + + fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> { + let mut routes_to_apply = vec![]; + let mut default_destinations = HashSet::new(); + for route in required_routes { match route.node { NetNode::DefaultNode => { @@ -99,47 +119,22 @@ impl RouteManagerImpl { } } - let apply_routes_fn = || -> Result<()> { - for route in routes_to_apply { - Self::add_route(&route).wait()?; - applied_routes.insert(route); - } - for destination in default_destinations.iter() { - match (&v4_gateway, &v6_gateway, destination.is_ipv4()) { - (Some(gateway), _, true) | (_, Some(gateway), false) => { - let route = Route::new(gateway.clone(), *destination); - Self::add_route(&route).wait()?; - applied_routes.insert(route); - } - _ => (), - }; - } - - Ok(()) - }; - - if let Err(e) = apply_routes_fn() { - log::error!("Failed to apply routes - {}", e); - for applied_route in applied_routes.iter() { - if let Err(removal_err) = Self::delete_route(applied_route.prefix).wait() { - log::error!( - "Failed to clean up routes after failing to set them up - {}", - removal_err - ); + for route in routes_to_apply { + Self::add_route(&route).wait()?; + self.applied_routes.insert(route); + } + for destination in default_destinations.iter() { + match (&self.v4_gateway, &self.v6_gateway, destination.is_ipv4()) { + (Some(gateway), _, true) | (_, Some(gateway), false) => { + let route = Route::new(gateway.clone(), *destination); + Self::add_route(&route).wait()?; + self.applied_routes.insert(route); } - } - return Err(e); + _ => (), + }; } - let change_listener = ChangeListener::new().map_err(Error::FailedToMonitorRoutes)?; - Ok(Self { - default_destinations, - applied_routes, - current_state: RouteManagerState::Listening(change_listener), - shutdown_rx: Some(shutdown_rx), - v4_gateway, - v6_gateway, - }) + Ok(()) } // Retrieves the node that's currently used to reach 0.0.0.0/0 @@ -230,10 +225,7 @@ impl RouteManagerImpl { .map_err(Error::FailedToAddRoute) } - fn shutdown_future( - &self, - shutdown_done_tx: Option<oneshot::Sender<()>>, - ) -> impl Future<Item = (), Error = ()> + Send { + fn cleanup_routes(&self) -> impl Future<Item = (), Error = ()> + Send { let remove_route_future = |route: &Route| { Self::delete_route(route.prefix).then(|removal| { match removal { @@ -261,16 +253,21 @@ impl RouteManagerImpl { _ => None, } })); - stream::futures_ordered(routes_to_remove) - .for_each(|_| Ok(())) - .and_then(|_| { - if let Some(tx) = shutdown_done_tx { - if tx.send(()).is_err() { - log::debug!("RouteManager already dropped") - } + stream::futures_ordered(routes_to_remove).for_each(|_| Ok(())) + } + + fn shutdown_future( + &self, + shutdown_done_tx: Option<oneshot::Sender<()>>, + ) -> impl Future<Item = (), Error = ()> + Send { + self.cleanup_routes().and_then(|_| { + if let Some(tx) = shutdown_done_tx { + if tx.send(()).is_err() { + log::debug!("RouteManager already dropped") } - Ok(()) - }) + } + Ok(()) + }) } fn apply_new_default_routes( @@ -323,20 +320,35 @@ impl Future for RouteManagerImpl { type Item = (); type Error = Error; fn poll(&mut self) -> Result<Async<()>> { - if let Some(mut shutdown_rx) = self.shutdown_rx.take() { - match shutdown_rx.poll() { - Ok(Async::Ready(shutdown_tx)) => { - self.current_state = RouteManagerState::Shutdown(Box::new( - self.shutdown_future(Some(shutdown_tx)), - )); - } + if let Some(mut manage_rx) = self.manage_rx.take() { + match manage_rx.poll() { + Ok(Async::Ready(Some(command))) => match command { + RouteManagerCommand::Shutdown(tx) => { + self.current_state = + RouteManagerState::Shutdown(Box::new(self.shutdown_future(Some(tx)))); + } + RouteManagerCommand::AddRoutes(routes, result_tx) => { + self.manage_rx = Some(manage_rx); + log::debug!("Adding routes: {:?}", routes); + if let Err(error) = self.add_required_routes(routes) { + let _ = result_tx.send(Err(error)); + } else { + let _ = result_tx.send(Ok(())); + } + } + RouteManagerCommand::ClearRoutes => { + self.manage_rx = Some(manage_rx); + log::debug!("Clearing routes"); + let _ = self.cleanup_routes().wait(); + } + }, // handle is already dropped - Err(_) => { + Ok(Async::Ready(None)) | Err(_) => { self.current_state = RouteManagerState::Shutdown(Box::new(self.shutdown_future(None))); } Ok(Async::NotReady) => { - self.shutdown_rx = Some(shutdown_rx); + self.manage_rx = Some(manage_rx); } }; } diff --git a/talpid-core/src/routing/unix.rs b/talpid-core/src/routing/unix.rs index 8dbfdf4d56..3f50ef3db5 100644 --- a/talpid-core/src/routing/unix.rs +++ b/talpid-core/src/routing/unix.rs @@ -2,8 +2,15 @@ #![cfg_attr(target_os = "windows", allow(dead_code))] // TODO: remove the allow(dead_code) for android once it's up to scratch. use super::RequiredRoute; -use futures01::{sync::oneshot, Future}; +use futures01::{ + sync::{ + mpsc::{unbounded, UnboundedSender}, + oneshot, + }, + Future, +}; use std::{collections::HashSet, sync::mpsc::sync_channel}; +use talpid_types::ErrorExt; #[cfg(target_os = "macos")] #[path = "macos.rs"] @@ -25,19 +32,32 @@ pub enum Error { /// Routing manager thread panicked before starting routing manager #[error(display = "Routing manager thread panicked before starting routing manager")] RoutingManagerThreadPanic, - /// Platform sepcific error occured - #[error(display = "Failed to create route manager")] - FailedToInitializeManager(#[error(source)] imp::Error), + /// Platform specific error occured + #[error(display = "Internal route manager error")] + PlatformError(#[error(source)] imp::Error), /// Failed to spawn route manager future #[error(display = "Failed to spawn route manager on the provided executor")] FailedToSpawnManager, + /// Attempt to use route manager that has been dropped + #[error(display = "Cannot send message to route manager since it is down")] + RouteManagerDown, +} + +#[derive(Debug)] +pub enum RouteManagerCommand { + AddRoutes( + HashSet<RequiredRoute>, + oneshot::Sender<Result<(), PlatformError>>, + ), + ClearRoutes, + Shutdown(oneshot::Sender<()>), } /// RouteManager applies a set of routes to the route table. /// If a destination has to be routed through the default node, /// the route will be adjusted dynamically when the default route changes. pub struct RouteManager { - tx: Option<oneshot::Sender<oneshot::Sender<()>>>, + manage_tx: Option<UnboundedSender<RouteManagerCommand>>, } impl RouteManager { @@ -45,11 +65,11 @@ impl RouteManager { /// Takes a set of network destinations and network nodes as an argument, and applies said /// routes. pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self, Error> { - let (tx, rx) = oneshot::channel(); + let (manage_tx, manage_rx) = unbounded(); let (start_tx, start_rx) = sync_channel(1); std::thread::spawn( - move || match imp::RouteManagerImpl::new(required_routes, rx) { + move || match imp::RouteManagerImpl::new(required_routes, manage_rx) { Ok(route_manager) => { let _ = start_tx.send(Ok(())); if let Err(e) = route_manager.wait() { @@ -57,12 +77,14 @@ impl RouteManager { } } Err(e) => { - let _ = start_tx.send(Err(Error::FailedToInitializeManager(e))); + let _ = start_tx.send(Err(Error::PlatformError(e))); } }, ); match start_rx.recv() { - Ok(Ok(())) => Ok(Self { tx: Some(tx) }), + Ok(Ok(())) => Ok(Self { + manage_tx: Some(manage_tx), + }), Ok(Err(e)) => Err(e), Err(_) => Err(Error::RoutingManagerThreadPanic), } @@ -70,9 +92,13 @@ impl RouteManager { /// Stops RouteManager and removes all of the applied routes. pub fn stop(&mut self) { - if let Some(tx) = self.tx.take() { + if let Some(tx) = self.manage_tx.take() { let (wait_tx, wait_rx) = oneshot::channel(); - if tx.send(wait_tx).is_err() { + + if tx + .unbounded_send(RouteManagerCommand::Shutdown(wait_tx)) + .is_err() + { log::error!("RouteManager already down!"); return; } @@ -82,6 +108,45 @@ impl RouteManager { } } } + + /// Applies the given routes until [`RouteManager::stop`] is called. + pub fn add_routes(&mut self, routes: HashSet<RequiredRoute>) -> Result<(), Error> { + if let Some(tx) = &self.manage_tx { + let (result_tx, result_rx) = oneshot::channel(); + if tx + .unbounded_send(RouteManagerCommand::AddRoutes(routes, result_tx)) + .is_err() + { + return Err(Error::RouteManagerDown); + } + + match result_rx.wait() { + Ok(result) => result.map_err(Error::PlatformError), + Err(error) => { + log::trace!( + "{}", + error.display_chain_with_msg("oneshot channel is closed") + ); + Ok(()) + } + } + } else { + Err(Error::RouteManagerDown) + } + } + + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_routes(&mut self) -> Result<(), Error> { + if let Some(tx) = &self.manage_tx { + if tx.unbounded_send(RouteManagerCommand::ClearRoutes).is_err() { + return Err(Error::RouteManagerDown); + } + Ok(()) + } else { + Err(Error::RouteManagerDown) + } + } } impl Drop for RouteManager { diff --git a/talpid-core/src/routing/windows.rs b/talpid-core/src/routing/windows.rs index 0e75cab953..412384574b 100644 --- a/talpid-core/src/routing/windows.rs +++ b/talpid-core/src/routing/windows.rs @@ -5,9 +5,15 @@ use std::collections::HashSet; /// Windows routing errors. #[derive(err_derive::Error, Debug)] pub enum Error { - /// Failure to apply a route + /// Failure to initialize route manager #[error(display = "Failed to start route manager")] FailedToStartManager, + /// Failure to add routes + #[error(display = "Failed to add routes")] + AddRoutesFailed, + /// Failure to clear routes + #[error(display = "Failed to clear applied routes")] + ClearRoutesFailed, } pub type Result<T> = std::result::Result<T, Error>; @@ -22,27 +28,15 @@ impl RouteManager { /// Creates a new route manager that will apply the provided routes and ensure they exist until /// it's stopped. pub fn new(required_routes: HashSet<RequiredRoute>) -> Result<Self> { - let routes: Vec<_> = required_routes - .iter() - .map(|route| { - let destination = winnet::WinNetIpNetwork::from(route.prefix); - match &route.node { - NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), - NetNode::RealNode(node) => { - winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) - } - } - }) - .collect(); - - if !winnet::activate_routing_manager(&routes) { + if !winnet::activate_routing_manager() { return Err(Error::FailedToStartManager); } - - Ok(Self { + let manager = Self { callback_handles: vec![], is_stopped: false, - }) + }; + manager.add_routes(required_routes)?; + Ok(manager) } /// Sets a callback that is called whenever the default route changes. @@ -67,6 +61,13 @@ impl RouteManager { } } + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_default_route_callbacks(&mut self) { + // `WinNetCallbackHandle::drop` removes these callbacks. + self.callback_handles.clear(); + } + /// Stops the routing manager and invalidates the route manager - no new default route callbacks /// can be added pub fn stop(&mut self) { @@ -76,6 +77,38 @@ impl RouteManager { self.is_stopped = true; } } + + /// Applies the given routes until [`RouteManager::stop`] is called. + pub fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> { + let routes: Vec<_> = routes + .iter() + .map(|route| { + let destination = winnet::WinNetIpNetwork::from(route.prefix); + match &route.node { + NetNode::DefaultNode => winnet::WinNetRoute::through_default_node(destination), + NetNode::RealNode(node) => { + winnet::WinNetRoute::new(winnet::WinNetNode::from(node), destination) + } + } + }) + .collect(); + + if winnet::routing_manager_add_routes(&routes) { + Ok(()) + } else { + Err(Error::AddRoutesFailed) + } + } + + /// Removes all routes previously applied in [`RouteManager::new`] or + /// [`RouteManager::add_routes`]. + pub fn clear_routes(&self) -> Result<()> { + if winnet::routing_manager_delete_applied_routes() { + Ok(()) + } else { + Err(Error::ClearRoutesFailed) + } + } } impl Drop for RouteManager { diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 6881b02311..8e19fe3813 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -1,5 +1,5 @@ use self::tun_provider::TunProvider; -use crate::logging; +use crate::{logging, routing::RouteManager}; #[cfg(not(target_os = "android"))] use std::collections::HashMap; use std::{ @@ -149,6 +149,7 @@ impl TunnelMonitor { resource_dir: &Path, on_event: L, tun_provider: &mut TunProvider, + route_manager: &mut RouteManager, ) -> Result<Self> where L: Fn(TunnelEvent) + Send + Clone + Sync + 'static, @@ -164,9 +165,13 @@ impl TunnelMonitor { #[cfg(target_os = "android")] TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform), - TunnelParameters::Wireguard(config) => { - Self::start_wireguard_tunnel(&config, log_file, on_event, tun_provider) - } + TunnelParameters::Wireguard(config) => Self::start_wireguard_tunnel( + &config, + log_file, + on_event, + tun_provider, + route_manager, + ), } } @@ -175,6 +180,7 @@ impl TunnelMonitor { log: Option<PathBuf>, on_event: L, tun_provider: &mut TunProvider, + route_manager: &mut RouteManager, ) -> Result<Self> where L: Fn(TunnelEvent) + Send + Sync + Clone + 'static, @@ -185,6 +191,7 @@ impl TunnelMonitor { log.as_ref().map(|p| p.as_path()), on_event, tun_provider, + route_manager, )?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), diff --git a/talpid-core/src/tunnel/wireguard/mod.rs b/talpid-core/src/tunnel/wireguard/mod.rs index d19eb54550..6324d13e80 100644 --- a/talpid-core/src/tunnel/wireguard/mod.rs +++ b/talpid-core/src/tunnel/wireguard/mod.rs @@ -46,8 +46,6 @@ pub enum Error { pub struct WireguardMonitor { /// Tunnel implementation tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>, - /// Route manager - route_handle: routing::RouteManager, /// Callback to signal tunnel events event_callback: Box<dyn Fn(TunnelEvent) + Send + Sync + 'static>, close_msg_sender: mpsc::Sender<CloseMsg>, @@ -62,6 +60,7 @@ impl WireguardMonitor { log_path: Option<&Path>, on_event: F, tun_provider: &mut TunProvider, + route_manager: &mut routing::RouteManager, ) -> Result<WireguardMonitor> { let tunnel = Box::new(WgGoTunnel::start_tunnel( &config, @@ -70,12 +69,12 @@ impl WireguardMonitor { Self::get_tunnel_routes(config), )?); let iface_name = tunnel.get_interface_name().to_string(); - #[cfg_attr(not(windows), allow(unused_mut))] - let mut route_handle = routing::RouteManager::new(Self::get_routes(&iface_name, &config)) + route_manager + .add_routes(Self::get_routes(&iface_name, &config)) .map_err(Error::SetupRoutingError)?; #[cfg(target_os = "windows")] - route_handle + route_manager .add_default_route_callback(Some(WgGoTunnel::default_route_changed_callback), ()); let event_callback = Box::new(on_event.clone()); @@ -83,7 +82,6 @@ impl WireguardMonitor { let (pinger_tx, pinger_rx) = mpsc::channel(); let monitor = WireguardMonitor { tunnel: Arc::new(Mutex::new(Some(tunnel))), - route_handle, event_callback, close_msg_sender, close_msg_receiver, @@ -144,11 +142,6 @@ impl WireguardMonitor { let _ = self.pinger_stop_sender.send(()); - // Clear routes manually - otherwise there will be some log spam since the tunnel device - // can be removed before the routes are cleared, which automatically clears some of the - // routes that were set. - self.route_handle.stop(); - self.stop_tunnel(); (self.event_callback)(TunnelEvent::Down); diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 47bc8d0f36..03cbc41cf6 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -89,12 +89,25 @@ impl ConnectedState { } } + fn reset_routes(shared_values: &mut SharedTunnelStateValues) { + #[cfg(windows)] + shared_values.route_manager.clear_default_route_callbacks(); + if let Err(error) = shared_values.route_manager.clear_routes() { + log::error!( + "Failed to clear routes: {:?}", + error.display_chain_with_msg("Failed to clear routes") + ); + } + } + fn disconnect( self, shared_values: &mut SharedTunnelStateValues, after_disconnect: AfterDisconnect, ) -> EventConsequence<Self> { Self::reset_dns(shared_values); + Self::reset_routes(shared_values); + EventConsequence::NewState(DisconnectingState::enter( shared_values, (self.close_handle, self.tunnel_close_event, after_disconnect), @@ -185,6 +198,7 @@ impl ConnectedState { match poll_result { Ok(Async::Ready(block_reason)) => { if let Some(reason) = block_reason { + Self::reset_routes(shared_values); return NewState(ErrorState::enter(shared_values, reason)); } } @@ -194,6 +208,7 @@ impl ConnectedState { log::info!("Tunnel closed. Reconnecting."); Self::reset_dns(shared_values); + Self::reset_routes(shared_values); NewState(ConnectingState::enter(shared_values, 0)) } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 14cd2cc51d..5d70568239 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -5,6 +5,7 @@ use super::{ }; use crate::{ firewall::FirewallPolicy, + routing::RouteManager, tunnel::{ self, tun_provider::TunProvider, CloseHandle, TunnelEvent, TunnelMetadata, TunnelMonitor, }, @@ -68,18 +69,21 @@ impl ConnectingState { log_dir: &Option<PathBuf>, resource_dir: &Path, tun_provider: &mut TunProvider, + route_manager: &mut RouteManager, retry_attempt: u32, ) -> crate::tunnel::Result<Self> { let (event_tx, event_rx) = mpsc::unbounded(); let on_tunnel_event = move |event| { let _ = event_tx.unbounded_send(event); }; + let monitor = TunnelMonitor::start( ¶meters, log_dir, resource_dir, on_tunnel_event, tun_provider, + route_manager, )?; let close_handle = Some(monitor.close_handle()); let tunnel_close_event = Self::spawn_tunnel_monitor_wait_thread(monitor); @@ -165,11 +169,24 @@ impl ConnectingState { } } + fn reset_routes(shared_values: &mut SharedTunnelStateValues) { + #[cfg(windows)] + shared_values.route_manager.clear_default_route_callbacks(); + if let Err(error) = shared_values.route_manager.clear_routes() { + log::error!( + "Failed to clear routes: {:?}", + error.display_chain_with_msg("Failed to clear routes") + ); + } + } + fn disconnect( self, shared_values: &mut SharedTunnelStateValues, after_disconnect: AfterDisconnect, ) -> EventConsequence<Self> { + Self::reset_routes(shared_values); + EventConsequence::NewState(DisconnectingState::enter( shared_values, (self.close_handle, self.tunnel_close_event, after_disconnect), @@ -270,6 +287,7 @@ impl ConnectingState { match poll_result { Ok(Async::Ready(block_reason)) => { if let Some(reason) = block_reason { + Self::reset_routes(shared_values); return EventConsequence::NewState(ErrorState::enter(shared_values, reason)); } } @@ -281,6 +299,7 @@ impl ConnectingState { "Tunnel closed. Reconnecting, attempt {}.", self.retry_attempt + 1 ); + Self::reset_routes(shared_values); EventConsequence::NewState(ConnectingState::enter( shared_values, self.retry_attempt + 1, @@ -359,6 +378,7 @@ impl TunnelState for ConnectingState { &shared_values.log_dir, &shared_values.resource_dir, &mut shared_values.tun_provider, + &mut shared_values.route_manager, retry_attempt, ) { Ok(connecting_state) => { diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index baf52c4b2b..2ffce2bec9 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -19,6 +19,7 @@ use crate::{ firewall::{Firewall, FirewallArguments}, mpsc::Sender, offline, + routing::RouteManager, tunnel::tun_provider::TunProvider, }; @@ -27,6 +28,7 @@ use futures01::{ Async, Future, Poll, Stream, }; use std::{ + collections::HashSet, io, path::{Path, PathBuf}, sync::{mpsc as sync_mpsc, Arc}, @@ -56,6 +58,10 @@ pub enum Error { #[error(display = "Failed to initialize the system DNS manager and monitor")] InitDnsMonitorError(#[error(source)] crate::dns::Error), + /// Failed to initialize the route manager. + #[error(display = "Failed to initialize the route manager")] + InitRouteManagerError(#[error(source)] crate::routing::Error), + /// Failed to initialize tunnel state machine event loop executor #[error(display = "Failed to initialize tunnel state machine event loop executor")] ReactorError(#[error(source)] io::Error), @@ -231,9 +237,12 @@ impl TunnelStateMachine { }; let firewall = Firewall::new(args).map_err(Error::InitFirewallError)?; let dns_monitor = DnsMonitor::new(cache_dir).map_err(Error::InitDnsMonitorError)?; + let route_manager = + RouteManager::new(HashSet::new()).map_err(Error::InitRouteManagerError)?; let mut shared_values = SharedTunnelStateValues { firewall, dns_monitor, + route_manager, allow_lan, block_when_disconnected, is_offline, @@ -317,6 +326,7 @@ pub trait TunnelParametersGenerator: Send + 'static { struct SharedTunnelStateValues { firewall: Firewall, dns_monitor: DnsMonitor, + route_manager: RouteManager, /// Should LAN access be allowed outside the tunnel. allow_lan: bool, /// Should network access be allowed when in the disconnected state. diff --git a/talpid-core/src/winnet.rs b/talpid-core/src/winnet.rs index a0b183d779..ab9dff5d06 100644 --- a/talpid-core/src/winnet.rs +++ b/talpid-core/src/winnet.rs @@ -289,17 +289,8 @@ impl Drop for WinNetRoute { } } -pub fn activate_routing_manager(routes: &[WinNetRoute]) -> bool { - if unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } { - if routing_manager_add_routes(routes) { - true - } else { - deactivate_routing_manager(); - false - } - } else { - false - } +pub fn activate_routing_manager() -> bool { + unsafe { WinNet_ActivateRouteManager(Some(log_sink), logging_context()) } } pub struct WinNetCallbackHandle { @@ -360,6 +351,10 @@ pub fn routing_manager_add_routes(routes: &[WinNetRoute]) -> bool { unsafe { WinNet_AddRoutes(ptr, length) } } +pub fn routing_manager_delete_applied_routes() -> bool { + unsafe { WinNet_DeleteAppliedRoutes() } +} + pub fn deactivate_routing_manager() { unsafe { WinNet_DeactivateRouteManager() } } @@ -400,6 +395,9 @@ mod api { // #[link_name = "WinNet_DeleteRoute"] // pub fn WinNet_DeleteRoute(route: *const super::WinNetRoute) -> bool; + #[link_name = "WinNet_DeleteAppliedRoutes"] + pub fn WinNet_DeleteAppliedRoutes() -> bool; + #[link_name = "WinNet_DeactivateRouteManager"] pub fn WinNet_DeactivateRouteManager(); diff --git a/windows/winnet/src/winnet/routing/routemanager.cpp b/windows/winnet/src/winnet/routing/routemanager.cpp index 81ee4e3d96..9d5b36ff98 100644 --- a/windows/winnet/src/winnet/routing/routemanager.cpp +++ b/windows/winnet/src/winnet/routing/routemanager.cpp @@ -209,27 +209,7 @@ RouteManager::~RouteManager() m_routeMonitorV4.reset(); m_routeMonitorV6.reset(); - // - // Delete all routes owned by us. - // - - for (const auto &record : m_routes) - { - try - { - deleteFromRoutingTable(record.registeredRoute); - } - catch (const std::exception &ex) - { - std::wstringstream ss; - - ss << L"Failed to delete route as part of cleaning up, Route: " - << FormatRegisteredRoute(record.registeredRoute); - - m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); - m_logSink->error(ex.what()); - } - } + deleteAppliedRoutes(); } void RouteManager::addRoutes(const std::vector<Route> &routes) @@ -302,6 +282,33 @@ void RouteManager::deleteRoutes(const std::vector<Route> &routes) } } +void RouteManager::deleteAppliedRoutes() +{ + // + // Delete all routes owned by us. + // + + for (const auto &record : m_routes) + { + try + { + deleteFromRoutingTable(record.registeredRoute); + } + catch (const std::exception & ex) + { + std::wstringstream ss; + + ss << L"Failed to delete route while clearing applied routes, Route: " + << FormatRegisteredRoute(record.registeredRoute); + + m_logSink->error(common::string::ToAnsi(ss.str()).c_str()); + m_logSink->error(ex.what()); + } + } + + m_routes.clear(); +} + RouteManager::CallbackHandle RouteManager::registerDefaultRouteChangedCallback(DefaultRouteChangedCallback callback) { AutoRecursiveLockType lock(m_defaultRouteCallbacksLock); diff --git a/windows/winnet/src/winnet/routing/routemanager.h b/windows/winnet/src/winnet/routing/routemanager.h index 92c712e25d..07cc7dbf40 100644 --- a/windows/winnet/src/winnet/routing/routemanager.h +++ b/windows/winnet/src/winnet/routing/routemanager.h @@ -32,6 +32,7 @@ public: void addRoutes(const std::vector<Route> &routes); void deleteRoutes(const std::vector<Route> &routes); + void deleteAppliedRoutes(); using DefaultRouteChangedEventType = DefaultRouteMonitor::EventType; diff --git a/windows/winnet/src/winnet/winnet.cpp b/windows/winnet/src/winnet/winnet.cpp index ac56e94ff6..c52fc57d60 100644 --- a/windows/winnet/src/winnet/winnet.cpp +++ b/windows/winnet/src/winnet/winnet.cpp @@ -385,6 +385,35 @@ extern "C" WINNET_LINKAGE
bool
WINNET_API
+WinNet_DeleteAppliedRoutes()
+{
+ AutoLockType lock(g_RouteManagerLock);
+
+ if (nullptr == g_RouteManager)
+ {
+ return false;
+ }
+
+ try
+ {
+ g_RouteManager->deleteAppliedRoutes();
+ return true;
+ }
+ catch (const std::exception & err)
+ {
+ common::error::UnwindException(err, g_RouteManagerLogSink);
+ return false;
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+extern "C"
+WINNET_LINKAGE
+bool
+WINNET_API
WinNet_DeleteRoute(
const WINNET_ROUTE *route
)
diff --git a/windows/winnet/src/winnet/winnet.h b/windows/winnet/src/winnet/winnet.h index 98b0083f03..3597b01a26 100644 --- a/windows/winnet/src/winnet/winnet.h +++ b/windows/winnet/src/winnet/winnet.h @@ -173,6 +173,13 @@ WinNet_DeleteRoute( const WINNET_ROUTE *route ); +extern "C" +WINNET_LINKAGE +bool +WINNET_API +WinNet_DeleteAppliedRoutes( +); + enum WINNET_DEFAULT_ROUTE_CHANGED_EVENT_TYPE { // Best default route changed. |
